/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.sphinx.linguist.language.ngram.trie;

import edu.cmu.sphinx.linguist.WordSequence;
import edu.cmu.sphinx.linguist.dictionary.Dictionary;
import edu.cmu.sphinx.linguist.dictionary.Word;
import edu.cmu.sphinx.linguist.language.ngram.LanguageModel;
import edu.cmu.sphinx.linguist.language.ngram.trie.BinaryLoader;
import edu.cmu.sphinx.linguist.language.ngram.trie.NgramTrie;
import edu.cmu.sphinx.linguist.language.ngram.trie.NgramTrieQuant;
import edu.cmu.sphinx.linguist.util.LRUCache;
import edu.cmu.sphinx.util.LogMath;
import edu.cmu.sphinx.util.TimerPool;
import edu.cmu.sphinx.util.props.ConfigurationManagerUtils;
import edu.cmu.sphinx.util.props.PropertyException;
import edu.cmu.sphinx.util.props.PropertySheet;
import edu.cmu.sphinx.util.props.S4Boolean;
import edu.cmu.sphinx.util.props.S4Double;
import edu.cmu.sphinx.util.props.S4Integer;
import edu.cmu.sphinx.util.props.S4String;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintWriter;
import java.net.URL;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;

public class NgramTrieModel
implements LanguageModel {
    @S4String(mandatory=false)
    public static final String PROP_QUERY_LOG_FILE = "queryLogFile";
    @S4Integer(defaultValue=100000)
    public static final String PROP_NGRAM_CACHE_SIZE = "ngramCacheSize";
    @S4Boolean(defaultValue=false)
    public static final String PROP_CLEAR_CACHES_AFTER_UTTERANCE = "clearCachesAfterUtterance";
    @S4Double(defaultValue=1.0)
    public static final String PROP_LANGUAGE_WEIGHT = "languageWeight";
    @S4Boolean(defaultValue=false)
    public static final String PROP_APPLY_LANGUAGE_WEIGHT_AND_WIP = "applyLanguageWeightAndWip";
    @S4Double(defaultValue=1.0)
    public static final String PROP_WORD_INSERTION_PROBABILITY = "wordInsertionProbability";
    URL location;
    protected Logger logger;
    protected LogMath logMath;
    protected int maxDepth;
    protected int curDepth;
    protected int[] counts;
    protected int ngramCacheSize;
    protected boolean clearCacheAfterUtterance;
    protected Dictionary dictionary;
    protected String format;
    protected boolean applyLanguageWeightAndWip;
    protected float languageWeight;
    protected float unigramWeight;
    protected float logWip;
    protected String ngramLogFile;
    private int ngramMisses;
    private int ngramHits;
    private PrintWriter logFile;
    protected TrieUnigram[] unigrams;
    protected String[] words;
    protected NgramTrieQuant quant;
    protected NgramTrie trie;
    protected Map<Word, Integer> unigramIDMap;
    private LRUCache<WordSequence, Float> ngramProbCache;

    public NgramTrieModel(String format, URL location, String ngramLogFile, int maxNGramCacheSize, boolean clearCacheAfterUtterance, int maxDepth, Dictionary dictionary, boolean applyLanguageWeightAndWip, float languageWeight, double wip, float unigramWeight) {
        this.logger = Logger.getLogger(this.getClass().getName());
        this.format = format;
        this.location = location;
        this.ngramLogFile = ngramLogFile;
        this.ngramCacheSize = maxNGramCacheSize;
        this.clearCacheAfterUtterance = clearCacheAfterUtterance;
        this.maxDepth = maxDepth;
        this.logMath = LogMath.getLogMath();
        this.dictionary = dictionary;
        this.applyLanguageWeightAndWip = applyLanguageWeightAndWip;
        this.languageWeight = languageWeight;
        this.logWip = this.logMath.linearToLog(wip);
        this.unigramWeight = unigramWeight;
    }

    public NgramTrieModel() {
    }

    @Override
    public void newProperties(PropertySheet ps) throws PropertyException {
        this.logger = ps.getLogger();
        this.logMath = LogMath.getLogMath();
        this.location = ConfigurationManagerUtils.getResource("location", ps);
        this.ngramLogFile = ps.getString(PROP_QUERY_LOG_FILE);
        this.maxDepth = ps.getInt("maxDepth");
        this.ngramCacheSize = ps.getInt(PROP_NGRAM_CACHE_SIZE);
        this.clearCacheAfterUtterance = ps.getBoolean(PROP_CLEAR_CACHES_AFTER_UTTERANCE);
        this.dictionary = (Dictionary)ps.getComponent("dictionary");
        this.applyLanguageWeightAndWip = ps.getBoolean(PROP_APPLY_LANGUAGE_WEIGHT_AND_WIP);
        this.languageWeight = ps.getFloat(PROP_LANGUAGE_WEIGHT);
        this.logWip = this.logMath.linearToLog(ps.getDouble(PROP_WORD_INSERTION_PROBABILITY));
        this.unigramWeight = ps.getFloat("unigramWeight");
    }

    private void buildUnigramIDMap() {
        int missingWords = 0;
        if (this.unigramIDMap == null) {
            this.unigramIDMap = new HashMap<Word, Integer>();
        }
        int i = 0;
        while (i < this.words.length) {
            Word word = this.dictionary.getWord(this.words[i]);
            if (word == null) {
                this.logger.warning("The dictionary is missing a phonetic transcription for the word '" + this.words[i] + "'");
                ++missingWords;
            }
            this.unigramIDMap.put(word, i);
            if (this.logger.isLoggable(Level.FINE)) {
                this.logger.fine("Word: " + word);
            }
            ++i;
        }
        if (missingWords > 0) {
            this.logger.warning("Dictionary is missing " + missingWords + " words that are contained in the language model.");
        }
    }

    @Override
    public void allocate() throws IOException {
        BinaryLoader loader;
        TimerPool.getTimer(this, "Load LM").start();
        this.logger.info("Loading n-gram language model from: " + this.location);
        if (this.ngramLogFile != null) {
            this.logFile = new PrintWriter(new FileOutputStream(this.ngramLogFile));
        }
        if (this.location.getProtocol() == null || this.location.getProtocol().equals("file")) {
            try {
                loader = new BinaryLoader(new File(this.location.toURI()));
            }
            catch (Exception ex) {
                loader = new BinaryLoader(new File(this.location.getPath()));
            }
        } else {
            loader = new BinaryLoader(this.location);
        }
        loader.verifyHeader();
        this.counts = loader.readCounts();
        if (this.maxDepth <= 0 || this.maxDepth > this.counts.length) {
            this.maxDepth = this.counts.length;
        }
        if (this.maxDepth > 1) {
            this.quant = loader.readQuant(this.maxDepth);
        }
        this.unigrams = loader.readUnigrams(this.counts[0]);
        if (this.maxDepth > 1) {
            this.trie = new NgramTrie(this.counts, this.quant.getProbBoSize(), this.quant.getProbSize());
            loader.readTrieByteArr(this.trie.getMem());
        }
        this.words = loader.readWords(this.counts[0]);
        this.buildUnigramIDMap();
        this.ngramProbCache = new LRUCache(this.ngramCacheSize);
        loader.close();
        TimerPool.getTimer(this, "Load LM").stop();
    }

    @Override
    public void deallocate() throws IOException {
        if (this.logFile != null) {
            this.logFile.flush();
        }
    }

    private float getAvailableProb(WordSequence wordSequence, TrieRange range, float prob) {
        if (!range.isSearchable()) {
            return prob;
        }
        int reverseOrderMinusTwo = wordSequence.size() - 2;
        while (reverseOrderMinusTwo >= 0) {
            int orderMinusTwo = wordSequence.size() - 2 - reverseOrderMinusTwo;
            if (orderMinusTwo + 1 == this.maxDepth) break;
            int wordId = this.unigramIDMap.get(wordSequence.getWord(reverseOrderMinusTwo));
            float updatedProb = this.trie.readNgramProb(wordId, orderMinusTwo, range, this.quant);
            if (!range.getFound()) break;
            prob = updatedProb;
            ++this.curDepth;
            if (!range.isSearchable()) break;
            --reverseOrderMinusTwo;
        }
        return prob;
    }

    private float getAvailableBackoff(WordSequence wordSequence) {
        float backoff = 0.0f;
        int wordsNum = wordSequence.size();
        int wordId = this.unigramIDMap.get(wordSequence.getWord(wordsNum - 2));
        TrieRange range = new TrieRange(this.unigrams[wordId].next, this.unigrams[wordId + 1].next);
        if (this.curDepth == 1) {
            backoff += this.unigrams[wordId].backoff;
        }
        int sequenceIdx = wordsNum - 3;
        int orderMinusTwo = 0;
        while (sequenceIdx >= 0) {
            int tmpWordId = this.unigramIDMap.get(wordSequence.getWord(sequenceIdx));
            float tmpBackoff = this.trie.readNgramBackoff(tmpWordId, orderMinusTwo, range, this.quant);
            if (!range.getFound()) break;
            backoff += tmpBackoff;
            if (!range.isSearchable()) break;
            --sequenceIdx;
            ++orderMinusTwo;
        }
        return backoff;
    }

    private float getProbabilityRaw(WordSequence wordSequence) {
        int wordsNum = wordSequence.size();
        int wordId = this.unigramIDMap.get(wordSequence.getWord(wordsNum - 1));
        TrieRange range = new TrieRange(this.unigrams[wordId].next, this.unigrams[wordId + 1].next);
        float prob = this.unigrams[wordId].prob;
        this.curDepth = 1;
        if (wordsNum == 1) {
            return prob;
        }
        prob = this.getAvailableProb(wordSequence, range, prob);
        if (this.curDepth < wordsNum) {
            prob += this.getAvailableBackoff(wordSequence);
        }
        return prob;
    }

    private float applyWeights(float score) {
        if (this.applyLanguageWeightAndWip) {
            return score * this.languageWeight + this.logWip;
        }
        return score;
    }

    @Override
    public float getProbability(WordSequence wordSequence) {
        int numberWords = wordSequence.size();
        if (numberWords > this.maxDepth) {
            throw new Error("Unsupported NGram: " + wordSequence.size());
        }
        if (numberWords == this.maxDepth) {
            Float probability = (Float)this.ngramProbCache.get(wordSequence);
            if (probability != null) {
                ++this.ngramHits;
                return probability.floatValue();
            }
            ++this.ngramMisses;
        }
        float probability = this.applyWeights(this.getProbabilityRaw(wordSequence));
        if (numberWords == this.maxDepth) {
            this.ngramProbCache.put(wordSequence, Float.valueOf(probability));
        }
        if (this.logFile != null) {
            this.logFile.println(String.valueOf(wordSequence.toString().replace("][", " ")) + " : " + Float.toString(probability));
        }
        return probability;
    }

    @Override
    public float getSmear(WordSequence wordSequence) {
        return 0.0f;
    }

    @Override
    public Set<String> getVocabulary() {
        HashSet<String> vocabulary = new HashSet<String>(Arrays.asList(this.words));
        return Collections.unmodifiableSet(vocabulary);
    }

    public int getNGramMisses() {
        return this.ngramMisses;
    }

    public int getNGramHits() {
        return this.ngramHits;
    }

    @Override
    public int getMaxDepth() {
        return this.maxDepth;
    }

    private void clearCache() {
        this.logger.info("LM Cache Size: " + this.ngramProbCache.size() + " Hits: " + this.ngramHits + " Misses: " + this.ngramMisses);
        if (this.clearCacheAfterUtterance) {
            this.ngramProbCache = new LRUCache(this.ngramCacheSize);
        }
    }

    @Override
    public void onUtteranceEnd() {
        this.clearCache();
        if (this.logFile != null) {
            this.logFile.println("<END_UTT>");
            this.logFile.flush();
        }
    }

    public static class TrieRange {
        int begin;
        int end;
        boolean found;

        TrieRange(int begin, int end) {
            this.begin = begin;
            this.end = end;
            this.found = true;
        }

        int getWidth() {
            return this.end - this.begin;
        }

        void setFound(boolean found) {
            this.found = found;
        }

        boolean getFound() {
            return this.found;
        }

        boolean isSearchable() {
            return this.getWidth() > 0;
        }
    }

    public static class TrieUnigram {
        public float prob;
        public float backoff;
        public int next;
    }
}

