/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.sphinx.linguist.lextree;

import edu.cmu.sphinx.linguist.WordSequence;
import edu.cmu.sphinx.linguist.acoustic.HMM;
import edu.cmu.sphinx.linguist.acoustic.HMMPool;
import edu.cmu.sphinx.linguist.acoustic.HMMPosition;
import edu.cmu.sphinx.linguist.acoustic.Unit;
import edu.cmu.sphinx.linguist.dictionary.Dictionary;
import edu.cmu.sphinx.linguist.dictionary.Pronunciation;
import edu.cmu.sphinx.linguist.dictionary.Word;
import edu.cmu.sphinx.linguist.language.ngram.LanguageModel;
import edu.cmu.sphinx.linguist.lextree.EndNode;
import edu.cmu.sphinx.linguist.lextree.HMMNode;
import edu.cmu.sphinx.linguist.lextree.InitialWordNode;
import edu.cmu.sphinx.linguist.lextree.Node;
import edu.cmu.sphinx.linguist.lextree.UnitNode;
import edu.cmu.sphinx.linguist.lextree.WordNode;
import edu.cmu.sphinx.util.Utilities;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;

class HMMTree {
    private final HMMPool hmmPool;
    private InitialWordNode initialNode;
    private Dictionary dictionary;
    private LanguageModel lm;
    private final boolean addFillerWords;
    private final boolean addSilenceWord = true;
    private final Set<Unit> entryPoints = new HashSet<Unit>();
    private Set<Unit> exitPoints = new HashSet<Unit>();
    private Set<Word> allWords;
    private EntryPointTable entryPointTable;
    private boolean debug;
    private final float languageWeight;
    private final Map<Object, HMMNode[]> endNodeMap;
    private final Map<Pronunciation, WordNode> wordNodeMap;
    private WordNode sentenceEndWordNode;
    private Logger logger;

    HMMTree(HMMPool pool, Dictionary dictionary, LanguageModel lm, boolean addFillerWords, float languageWeight) {
        this.hmmPool = pool;
        this.dictionary = dictionary;
        this.lm = lm;
        this.endNodeMap = new HashMap<Object, HMMNode[]>();
        this.wordNodeMap = new HashMap<Pronunciation, WordNode>();
        this.addFillerWords = addFillerWords;
        this.languageWeight = languageWeight;
        this.logger = Logger.getLogger(HMMTree.class.getSimpleName());
        this.compile();
    }

    public Node[] getEntryPoint(Unit lc, Unit base) {
        EntryPoint ep = this.entryPointTable.getEntryPoint(base);
        return ep.getEntryPointsFromLeftContext(lc).getSuccessors();
    }

    public HMMNode[] getHMMNodes(EndNode endNode) {
        HMMNode[] results = this.endNodeMap.get(endNode.getKey());
        if (results == null) {
            HashMap<HMM, HMMNode> resultMap = new HashMap<HMM, HMMNode>();
            Unit baseUnit = endNode.getBaseUnit();
            Unit lc = endNode.getLeftContext();
            for (Unit rc : this.entryPoints) {
                HMM hmm = this.hmmPool.getHMM(baseUnit, lc, rc, HMMPosition.END);
                HMMNode hmmNode = (HMMNode)resultMap.get(hmm);
                if (hmmNode == null) {
                    hmmNode = new HMMNode(hmm, 0.0f);
                    resultMap.put(hmm, hmmNode);
                }
                hmmNode.addRC(rc);
                Node[] nodeArray = endNode.getSuccessors();
                int n = nodeArray.length;
                int n2 = 0;
                while (n2 < n) {
                    Node node = nodeArray[n2];
                    WordNode wordNode = (WordNode)node;
                    hmmNode.addSuccessor(wordNode);
                    ++n2;
                }
            }
            results = resultMap.values().toArray(new HMMNode[resultMap.size()]);
            this.endNodeMap.put(endNode.getKey(), results);
        }
        return results;
    }

    public WordNode getSentenceEndWordNode() {
        assert (this.sentenceEndWordNode != null);
        return this.sentenceEndWordNode;
    }

    private void compile() {
        this.collectEntryAndExitUnits();
        this.entryPointTable = new EntryPointTable(this.entryPoints);
        this.addWords();
        this.entryPointTable.createEntryPointMaps();
        this.freeze();
    }

    void dumpTree() {
        System.out.println("Dumping Tree ...");
        HashMap<Node, Node> dupNode = new HashMap<Node, Node>();
        this.dumpTree(0, this.getInitialNode(), dupNode);
        System.out.println("... done Dumping Tree");
    }

    private void dumpTree(int level, Node node, Map<Node, Node> dupNode) {
        if (dupNode.get(node) == null) {
            dupNode.put(node, node);
            System.out.println(String.valueOf(Utilities.pad(level)) + node);
            if (!(node instanceof WordNode)) {
                Node[] nodeArray = node.getSuccessors();
                int n = nodeArray.length;
                int n2 = 0;
                while (n2 < n) {
                    Node nextNode = nodeArray[n2];
                    this.dumpTree(level + 1, nextNode, dupNode);
                    ++n2;
                }
            }
        }
    }

    private void collectEntryAndExitUnits() {
        Set<Word> words = this.getAllWords();
        for (Word word : words) {
            int j = 0;
            while (j < word.getPronunciations().length) {
                Pronunciation p = word.getPronunciations()[j];
                Unit first = p.getUnits()[0];
                Unit last = p.getUnits()[p.getUnits().length - 1];
                this.entryPoints.add(first);
                this.exitPoints.add(last);
                ++j;
            }
        }
        if (this.debug) {
            System.out.println("Entry Points: " + this.entryPoints.size());
            System.out.println("Exit Points: " + this.exitPoints.size());
        }
    }

    private void freeze() {
        this.entryPointTable.freeze();
        this.dictionary = null;
        this.lm = null;
        this.exitPoints = null;
        this.allWords = null;
        this.wordNodeMap.clear();
        this.endNodeMap.clear();
    }

    private void addWords() {
        Set<Word> words = this.getAllWords();
        for (Word word : words) {
            this.addWord(word);
        }
    }

    private void addWord(Word word) {
        Pronunciation[] pronunciations;
        float prob = this.getWordUnigramProbability(word);
        Pronunciation[] pronunciationArray = pronunciations = word.getPronunciations();
        int n = pronunciations.length;
        int n2 = 0;
        while (n2 < n) {
            Pronunciation pronunciation = pronunciationArray[n2];
            this.addPronunciation(pronunciation, prob);
            ++n2;
        }
    }

    private void addPronunciation(Pronunciation pronunciation, float probability) {
        Unit[] units = pronunciation.getUnits();
        Unit baseUnit = units[0];
        EntryPoint ep = this.entryPointTable.getEntryPoint(baseUnit);
        ep.addProbability(probability);
        if (units.length > 1) {
            Node curNode = ep.getNode();
            Unit lc = baseUnit;
            int i = 1;
            while (i < units.length - 1) {
                baseUnit = units[i];
                Unit rc = units[i + 1];
                HMM hmm = this.hmmPool.getHMM(baseUnit, lc, rc, HMMPosition.INTERNAL);
                if (hmm == null) {
                    this.logger.severe("Missing HMM for unit " + baseUnit.getName() + " with lc=" + lc.getName() + " rc=" + rc.getName());
                } else {
                    curNode = curNode.addSuccessor(hmm, probability);
                }
                lc = baseUnit;
                ++i;
            }
            baseUnit = units[units.length - 1];
            EndNode endNode = new EndNode(baseUnit, lc, probability);
            WordNode wordNode = (curNode = curNode.addSuccessor(endNode, probability)).addSuccessor(pronunciation, probability, this.wordNodeMap);
            if (wordNode.getWord().isSentenceEndWord()) {
                this.sentenceEndWordNode = wordNode;
            }
        } else {
            ep.addSingleUnitWord(pronunciation);
        }
    }

    private float getWordUnigramProbability(Word word) {
        float prob = 0.0f;
        if (!word.isFiller()) {
            Word[] wordArray = new Word[]{word};
            prob = this.lm.getProbability(new WordSequence(wordArray));
            prob *= this.languageWeight;
        }
        return prob;
    }

    private Set<Word> getAllWords() {
        if (this.allWords == null) {
            this.allWords = new HashSet<Word>();
            for (String spelling : this.lm.getVocabulary()) {
                Word word = this.dictionary.getWord(spelling);
                if (word == null) continue;
                this.allWords.add(word);
            }
            if (this.addFillerWords) {
                this.allWords.addAll(Arrays.asList(this.dictionary.getFillerWords()));
            } else {
                this.allWords.add(this.dictionary.getSilenceWord());
            }
        }
        return this.allWords;
    }

    InitialWordNode getInitialNode() {
        return this.initialNode;
    }

    class EntryPoint {
        final Unit baseUnit;
        final Node baseNode;
        final Map<Unit, Node> unitToEntryPointMap;
        List<Pronunciation> singleUnitWords;
        int nodeCount;
        Set<Unit> rcSet;
        float totalProbability;

        EntryPoint(Unit baseUnit) {
            this.baseUnit = baseUnit;
            this.baseNode = new Node(-3.4028235E38f);
            this.unitToEntryPointMap = new HashMap<Unit, Node>();
            this.singleUnitWords = new ArrayList<Pronunciation>();
            this.totalProbability = -3.4028235E38f;
        }

        Node getEntryPointsFromLeftContext(Unit leftContext) {
            return this.unitToEntryPointMap.get(leftContext);
        }

        void addProbability(float probability) {
            if (probability > this.totalProbability) {
                this.totalProbability = probability;
            }
        }

        float getProbability() {
            return this.totalProbability;
        }

        void freeze() {
            for (Node node : this.unitToEntryPointMap.values()) {
                node.freeze();
            }
            this.singleUnitWords = null;
            this.rcSet = null;
        }

        Node getNode() {
            return this.baseNode;
        }

        void addSingleUnitWord(Pronunciation p) {
            this.singleUnitWords.add(p);
        }

        private Collection<Unit> getEntryPointRC() {
            if (this.rcSet == null) {
                this.rcSet = new HashSet<Unit>();
                for (Node node : this.baseNode.getSuccessorMap().values()) {
                    UnitNode unitNode = (UnitNode)node;
                    this.rcSet.add(unitNode.getBaseUnit());
                }
            }
            return this.rcSet;
        }

        void createEntryPointMap() {
            HashMap<HMM, Node> map = new HashMap<HMM, Node>();
            HashMap<HMM, HMMNode> singleUnitMap = new HashMap<HMM, HMMNode>();
            for (Unit lc : HMMTree.this.exitPoints) {
                Node epNode = new Node(-3.4028235E38f);
                for (Unit rc : this.getEntryPointRC()) {
                    HMM hmm = HMMTree.this.hmmPool.getHMM(this.baseUnit, lc, rc, HMMPosition.BEGIN);
                    Node addedNode = (Node)map.get(hmm);
                    if (addedNode == null) {
                        addedNode = epNode.addSuccessor(hmm, this.getProbability());
                        map.put(hmm, addedNode);
                    } else {
                        epNode.putSuccessor(hmm, addedNode);
                    }
                    ++this.nodeCount;
                    this.connectEntryPointNode(addedNode, rc);
                }
                this.connectSingleUnitWords(lc, epNode, singleUnitMap);
                this.unitToEntryPointMap.put(lc, epNode);
            }
        }

        private void connectSingleUnitWords(Unit lc, Node epNode, HashMap<HMM, HMMNode> map) {
            if (!this.singleUnitWords.isEmpty()) {
                for (Unit rc : HMMTree.this.entryPoints) {
                    HMM hmm = HMMTree.this.hmmPool.getHMM(this.baseUnit, lc, rc, HMMPosition.SINGLE);
                    HMMNode tailNode = map.get(hmm);
                    if (tailNode == null) {
                        tailNode = (HMMNode)epNode.addSuccessor(hmm, this.getProbability());
                        map.put(hmm, tailNode);
                    } else {
                        epNode.putSuccessor(hmm, tailNode);
                    }
                    tailNode.addRC(rc);
                    ++this.nodeCount;
                    for (Pronunciation p : this.singleUnitWords) {
                        if (p.getWord() == HMMTree.this.dictionary.getSentenceStartWord()) {
                            HMMTree.this.initialNode = new InitialWordNode(p, tailNode);
                        } else {
                            float prob = HMMTree.this.getWordUnigramProbability(p.getWord());
                            WordNode wordNode = tailNode.addSuccessor(p, prob, HMMTree.this.wordNodeMap);
                            if (p.getWord() == HMMTree.this.dictionary.getSentenceEndWord()) {
                                HMMTree.this.sentenceEndWordNode = wordNode;
                            }
                        }
                        ++this.nodeCount;
                    }
                }
            }
        }

        private void connectEntryPointNode(Node epNode, Unit rc) {
            Node[] nodeArray = this.baseNode.getSuccessors();
            int n = nodeArray.length;
            int n2 = 0;
            while (n2 < n) {
                Node node = nodeArray[n2];
                UnitNode successor = (UnitNode)node;
                if (successor.getBaseUnit() == rc) {
                    epNode.addSuccessor(successor);
                }
                ++n2;
            }
        }

        void dump() {
            System.out.println("EntryPoint " + this.baseUnit + " RC Followers: " + this.getEntryPointRC().size());
            int count = 0;
            Collection<Unit> rcs = this.getEntryPointRC();
            System.out.print("    ");
            for (Unit rc : rcs) {
                System.out.print(Utilities.pad(rc.getName(), 4));
                if (count++ < 12) continue;
                count = 0;
                System.out.println();
                System.out.print("    ");
            }
            System.out.println();
        }
    }

    class EntryPointTable {
        private final Map<Unit, EntryPoint> entryPoints = new HashMap<Unit, EntryPoint>();

        EntryPointTable(Collection<Unit> entryPointCollection) {
            for (Unit unit : entryPointCollection) {
                this.entryPoints.put(unit, new EntryPoint(unit));
            }
        }

        EntryPoint getEntryPoint(Unit baseUnit) {
            return this.entryPoints.get(baseUnit);
        }

        void createEntryPointMaps() {
            for (EntryPoint ep : this.entryPoints.values()) {
                ep.createEntryPointMap();
            }
        }

        void freeze() {
            for (EntryPoint ep : this.entryPoints.values()) {
                ep.freeze();
            }
        }

        void dump() {
            for (EntryPoint ep : this.entryPoints.values()) {
                ep.dump();
            }
        }
    }
}

