/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.sphinx.decoder.adaptation;

import edu.cmu.sphinx.api.SpeechResult;
import edu.cmu.sphinx.decoder.adaptation.ClusteredDensityFileData;
import edu.cmu.sphinx.decoder.adaptation.Transform;
import edu.cmu.sphinx.decoder.search.Token;
import edu.cmu.sphinx.frontend.FloatData;
import edu.cmu.sphinx.linguist.HMMSearchState;
import edu.cmu.sphinx.linguist.SearchState;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.Loader;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.Sphinx3Loader;
import edu.cmu.sphinx.util.LogMath;

public class Stats {
    private static final int MIN_FRAMES = 300;
    private ClusteredDensityFileData means;
    private double[][][][][] regLs;
    private double[][][][] regRs;
    private int nClusters;
    private Sphinx3Loader loader;
    private float varFlor;
    private LogMath logMath = LogMath.getLogMath();
    private int nFrames;

    public Stats(Loader loader, ClusteredDensityFileData means) {
        this.loader = (Sphinx3Loader)loader;
        this.nClusters = means.getNumberOfClusters();
        this.means = means;
        this.varFlor = 1.0E-5f;
        this.invertVariances();
        this.init();
        this.nFrames = 0;
    }

    private void init() {
        int len = this.loader.getVectorLength()[0];
        this.regLs = new double[this.nClusters][][][][];
        this.regRs = new double[this.nClusters][][][];
        int i = 0;
        while (i < this.nClusters) {
            this.regLs[i] = new double[this.loader.getNumStreams()][][][];
            this.regRs[i] = new double[this.loader.getNumStreams()][][];
            int j = 0;
            while (j < this.loader.getNumStreams()) {
                len = this.loader.getVectorLength()[j];
                this.regLs[i][j] = new double[len][len + 1][len + 1];
                this.regRs[i][j] = new double[len][len + 1];
                ++j;
            }
            ++i;
        }
    }

    public ClusteredDensityFileData getClusteredData() {
        return this.means;
    }

    public double[][][][][] getRegLs() {
        return this.regLs;
    }

    public double[][][][] getRegRs() {
        return this.regRs;
    }

    private void invertVariances() {
        int i = 0;
        while (i < this.loader.getNumStates()) {
            int k = 0;
            while (k < this.loader.getNumGaussiansPerState()) {
                int l = 0;
                while (l < this.loader.getVectorLength()[0]) {
                    this.loader.getVariancePool().get((int)(i * this.loader.getNumGaussiansPerState() + k))[l] = (double)this.loader.getVariancePool().get(i * this.loader.getNumGaussiansPerState() + k)[l] <= 0.0 ? 0.5f : (this.loader.getVariancePool().get(i * this.loader.getNumGaussiansPerState() + k)[l] < this.varFlor ? (float)(1.0 / (double)this.varFlor) : (float)(1.0 / (double)this.loader.getVariancePool().get(i * this.loader.getNumGaussiansPerState() + k)[l]));
                    ++l;
                }
                ++k;
            }
            ++i;
        }
    }

    private float[] computePosterios(float[] componentScores, int numStreams) {
        float[] posteriors = componentScores;
        int step = componentScores.length / numStreams;
        int startIdx = 0;
        int i = 0;
        while (i < numStreams) {
            float max = posteriors[startIdx];
            int j = startIdx + 1;
            while (j < startIdx + step) {
                if (posteriors[j] > max) {
                    max = posteriors[j];
                }
                ++j;
            }
            j = startIdx;
            while (j < startIdx + step) {
                posteriors[j] = (float)this.logMath.logToLinear(posteriors[j] - max);
                ++j;
            }
            startIdx += step;
            ++i;
        }
        return posteriors;
    }

    public void collect(SpeechResult result) throws Exception {
        Token token = result.getResult().getBestToken();
        if (token == null) {
            throw new Exception("Best token not found!");
        }
        do {
            FloatData feature = (FloatData)token.getData();
            SearchState ss = token.getSearchState();
            if (!(ss instanceof HMMSearchState) || !ss.isEmitting()) {
                token = token.getPredecessor();
                continue;
            }
            ++this.nFrames;
            float[] componentScore = token.calculateComponentScore(feature);
            float[] featureVector = FloatData.toFloatData(feature).getValues();
            int mId = (int)((HMMSearchState)token.getSearchState()).getHMMState().getMixtureId();
            if (this.loader instanceof Sphinx3Loader && this.loader.hasTiedMixtures()) {
                mId = this.loader.getSenone2Ci()[mId];
            }
            int[] len = this.loader.getVectorLength();
            int numStreams = this.loader.getNumStreams();
            int gauPerState = this.loader.getNumGaussiansPerState();
            float[] posteriors = this.computePosterios(componentScore, numStreams);
            int featVectorStartIdx = 0;
            int i = 0;
            while (i < numStreams) {
                int j = 0;
                while (j < gauPerState) {
                    int cluster = this.means.getClassIndex(mId * numStreams * gauPerState + i * gauPerState + j);
                    float dnom = posteriors[i * gauPerState + j];
                    if ((double)dnom > 0.0) {
                        float[] tmean = this.loader.getMeansPool().get(mId * numStreams * gauPerState + i * gauPerState + j);
                        int k = 0;
                        while (k < len[i]) {
                            float mean = posteriors[i * gauPerState + j] * featureVector[k + featVectorStartIdx];
                            float wtMeanVar = mean * this.loader.getVariancePool().get(mId * numStreams * gauPerState + i * gauPerState + j)[k];
                            float wtDcountVar = dnom * this.loader.getVariancePool().get(mId * numStreams * gauPerState + i * gauPerState + j)[k];
                            int p = 0;
                            while (p < len[i]) {
                                float wtDcountVarMean = wtDcountVar * tmean[p];
                                int q = p;
                                while (q < len[i]) {
                                    double[] dArray = this.regLs[cluster][i][k][p];
                                    int n = q;
                                    dArray[n] = dArray[n] + (double)(wtDcountVarMean * tmean[q]);
                                    ++q;
                                }
                                double[] dArray = this.regLs[cluster][i][k][p];
                                int n = len[i];
                                dArray[n] = dArray[n] + (double)wtDcountVarMean;
                                double[] dArray2 = this.regRs[cluster][i][k];
                                int n2 = p;
                                dArray2[n2] = dArray2[n2] + (double)(wtMeanVar * tmean[p]);
                                ++p;
                            }
                            double[] dArray = this.regLs[cluster][i][k][len[i]];
                            int n = len[i];
                            dArray[n] = dArray[n] + (double)wtDcountVar;
                            double[] dArray3 = this.regRs[cluster][i][k];
                            int n3 = len[i];
                            dArray3[n3] = dArray3[n3] + (double)wtMeanVar;
                            ++k;
                        }
                    }
                    ++j;
                }
                featVectorStartIdx += len[i];
                ++i;
            }
            token = token.getPredecessor();
        } while (token != null);
    }

    public void fillRegLowerPart() {
        int i = 0;
        while (i < this.nClusters) {
            int j = 0;
            while (j < this.loader.getNumStreams()) {
                int l = 0;
                while (l < this.loader.getVectorLength()[j]) {
                    int p = 0;
                    while (p <= this.loader.getVectorLength()[j]) {
                        int q = p + 1;
                        while (q <= this.loader.getVectorLength()[j]) {
                            this.regLs[i][j][l][q][p] = this.regLs[i][j][l][p][q];
                            ++q;
                        }
                        ++p;
                    }
                    ++l;
                }
                ++j;
            }
            ++i;
        }
    }

    public Transform createTransform() {
        if (this.nFrames < 300 * this.nClusters) {
            return null;
        }
        Transform transform = new Transform(this.loader, this.nClusters);
        transform.update(this);
        return transform;
    }

    public int getFrames() {
        return this.nFrames;
    }
}

