/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.sphinx.linguist.acoustic.tiedstate.HTK;

import edu.cmu.sphinx.util.LogMath;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.StringTokenizer;

public class GMMDiag {
    public int nT;
    public String nom;
    public LogMath logMath;
    private int ncoefs;
    private int ngauss;
    protected float[] weights;
    protected float[][] means;
    protected float[][] covar;
    private float[] logPreComputedGaussianFactor;
    protected float[] loglikes;
    private static final float distFloor = -3.4028235E38f;

    public GMMDiag() {
    }

    public GMMDiag(int ng, int nc) {
        this.ngauss = ng;
        this.ncoefs = nc;
        this.allocate();
    }

    public int getNgauss() {
        return this.ngauss;
    }

    public float getWeight(int i) {
        return (float)this.logMath.logToLinear(this.weights[i]);
    }

    public float getVar(int i, int j) {
        return -1.0f / (2.0f * this.covar[i][j]);
    }

    public void setWeight(int i, float w) {
        if (this.weights == null) {
            this.weights = new float[this.ngauss];
        }
        this.weights[i] = this.logMath.linearToLog(w);
    }

    public void setVar(int i, int j, float v) {
        if (v <= 0.0f) {
            System.err.println("WARNING: setVar " + v);
        }
        this.covar[i][j] = -1.0f / (2.0f * v);
    }

    public void setMean(int i, int j, float v) {
        this.means[i][j] = v;
    }

    public float getMean(int i, int j) {
        return this.means[i][j];
    }

    public void save(String name) {
        try {
            PrintWriter fout = new PrintWriter(new FileWriter(name));
            fout.println(String.valueOf(this.ngauss) + " " + this.ncoefs);
            int i = 0;
            while (i < this.ngauss) {
                fout.println("gauss " + i + ' ' + this.getWeight(i));
                int j = 0;
                while (j < this.ncoefs) {
                    fout.print(String.valueOf(this.means[i][j]) + " ");
                    ++j;
                }
                fout.println();
                j = 0;
                while (j < this.ncoefs) {
                    fout.print(String.valueOf(this.getVar(i, j)) + " ");
                    ++j;
                }
                fout.println();
                ++i;
            }
            fout.println(this.nT);
            fout.close();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void load(String name) {
        try {
            BufferedReader fin = new BufferedReader(new FileReader(name));
            String s = fin.readLine();
            String[] ss = s.split(" ");
            this.ngauss = Integer.parseInt(ss[0]);
            this.ncoefs = Integer.parseInt(ss[1]);
            this.allocate();
            int i = 0;
            while (i < this.ngauss) {
                s = fin.readLine();
                ss = s.split(" ");
                if (!ss[0].equals("gauss") || Integer.parseInt(ss[1]) != i) {
                    System.err.println("Error loading GMM " + s + ' ' + i);
                    System.exit(1);
                }
                this.setWeight(i, Float.parseFloat(ss[2]));
                s = fin.readLine();
                ss = s.split(" ");
                int j = 0;
                while (j < this.ncoefs) {
                    this.setMean(i, j, Float.parseFloat(ss[j]));
                    ++j;
                }
                s = fin.readLine();
                ss = s.split(" ");
                j = 0;
                while (j < this.ncoefs) {
                    this.setVar(i, j, Float.parseFloat(ss[j]));
                    ++j;
                }
                ++i;
            }
            s = fin.readLine();
            if (s != null) {
                this.nT = Integer.parseInt(s);
            }
            fin.close();
            this.precomputeDistance();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void saveHTK(String nomFich, String nomHMM) {
        this.saveHTK(nomFich, nomHMM, "<USER>");
    }

    public PrintWriter saveHTKheader(String nomFich, String parmKind) {
        try {
            PrintWriter fout = new PrintWriter(new FileWriter(nomFich));
            fout.println("~o");
            fout.println("<HMMSETID> tree");
            fout.println("<STREAMINFO> 1 " + this.getNcoefs());
            fout.println("<VECSIZE> " + this.getNcoefs() + "<NULLD>" + parmKind + "<DIAGC>");
            fout.println("~r \"rtree_1\"");
            fout.println("<REGTREE> 1");
            fout.println("<TNODE> 1 " + this.getNgauss());
            return fout;
        }
        catch (IOException e) {
            e.printStackTrace();
            return null;
        }
    }

    public void saveHTKState(PrintWriter fout) {
        fout.println("<NUMMIXES> " + this.getNgauss());
        int i = 1;
        while (i <= this.getNgauss()) {
            fout.println("<MIXTURE> " + i + ' ' + this.getWeight(i - 1));
            fout.println("<RCLASS> 1");
            fout.println("<MEAN> " + this.getNcoefs());
            int j = 0;
            while (j < this.getNcoefs()) {
                fout.print(String.valueOf(this.getMean(i - 1, j)) + " ");
                ++j;
            }
            fout.println();
            fout.println("<VARIANCE> " + this.getNcoefs());
            j = 0;
            while (j < this.getNcoefs()) {
                fout.print(String.valueOf(this.getVar(i - 1, j)) + " ");
                ++j;
            }
            fout.println();
            ++i;
        }
    }

    public void saveHTKtailer(int nstates, PrintWriter fout) {
        fout.println("<TRANSP> " + nstates);
        int j = 0;
        while (j < nstates) {
            fout.print("0 ");
            ++j;
        }
        fout.println();
        int i = 1;
        while (i < nstates - 1) {
            int j2 = 0;
            while (j2 < i) {
                fout.print("0 ");
                ++j2;
            }
            fout.print("0.5 0.5");
            j2 = i + 3;
            while (j2 < nstates) {
                fout.print("0 ");
                ++j2;
            }
            ++i;
        }
        fout.println();
        fout.println("0 0 0");
        fout.println("<ENDHMM>");
    }

    public void saveHTK(String nomFich, String nomHMM, String parmKind) {
        try {
            PrintWriter fout = new PrintWriter(new FileWriter(nomFich));
            fout.println("~o");
            fout.println("<HMMSETID> tree");
            fout.println("<STREAMINFO> 1 " + this.getNcoefs());
            fout.println("<VECSIZE> " + this.getNcoefs() + "<NULLD>" + parmKind + "<DIAGC>");
            fout.println("~r \"rtree_1\"");
            fout.println("<REGTREE> 1");
            fout.println("<TNODE> 1 " + this.getNgauss());
            fout.println("~h \"" + nomHMM + '\"');
            fout.println("<BEGINHMM>");
            fout.println("<NUMSTATES> 3");
            fout.println("<STATE> 2");
            fout.println("<NUMMIXES> " + this.getNgauss());
            int i = 1;
            while (i <= this.getNgauss()) {
                fout.println("<MIXTURE> " + i + ' ' + this.getWeight(i - 1));
                fout.println("<RCLASS> 1");
                fout.println("<MEAN> " + this.getNcoefs());
                int j = 0;
                while (j < this.getNcoefs()) {
                    fout.print(String.valueOf(this.getMean(i - 1, j)) + " ");
                    ++j;
                }
                fout.println();
                fout.println("<VARIANCE> " + this.getNcoefs());
                j = 0;
                while (j < this.getNcoefs()) {
                    fout.print(String.valueOf(this.getVar(i - 1, j)) + " ");
                    ++j;
                }
                fout.println();
                ++i;
            }
            fout.println("<TRANSP> 3");
            fout.println("0 1 0");
            fout.println("0 0.7 0.3");
            fout.println("0 0 0");
            fout.println("<ENDHMM>");
            fout.close();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void loadHTK(String nom) {
        try {
            StringTokenizer st;
            String s;
            BufferedReader fin = new BufferedReader(new FileReader(nom));
            this.ngauss = 0;
            this.ncoefs = 0;
            while ((s = fin.readLine()) != null) {
                if (!s.contains("<MEAN>")) continue;
                ++this.ngauss;
                if (this.ncoefs != 0) continue;
                st = new StringTokenizer(s);
                st.nextToken();
                this.ncoefs = Integer.parseInt(st.nextToken());
            }
            fin.close();
            this.allocate();
            fin = new BufferedReader(new FileReader(nom));
            int g = 0;
            while ((s = fin.readLine()) != null) {
                String s2;
                if (!s.contains("<MEAN>")) continue;
                s = fin.readLine();
                st = new StringTokenizer(s);
                int c = 0;
                while (st.hasMoreTokens()) {
                    s2 = st.nextToken();
                    this.setMean(g, c, Float.parseFloat(s2));
                    ++c;
                }
                s = fin.readLine();
                if (!s.contains("<VARIANCE>")) {
                    fin.close();
                    throw new IOException();
                }
                s = fin.readLine();
                st = new StringTokenizer(s);
                c = 0;
                while (st.hasMoreTokens()) {
                    s2 = st.nextToken();
                    this.setVar(g, c, Float.parseFloat(s2));
                    ++c;
                }
                ++g;
            }
            fin.close();
            this.precomputeDistance();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void loadScaleKMeans(String nom) {
        int ng = 0;
        try {
            String s;
            BufferedReader fin = new BufferedReader(new FileReader(nom));
            while ((s = fin.readLine()) != null) {
                ++ng;
            }
            this.ngauss = ng / 2;
            fin.close();
            fin = new BufferedReader(new FileReader(nom));
            s = fin.readLine();
            String[] ss = s.split(" ");
            this.ncoefs = ss.length - 1;
            fin.close();
            fin = new BufferedReader(new FileReader(nom));
            this.allocate();
            this.nT = 0;
            int i = 0;
            while (i < this.ngauss) {
                s = fin.readLine();
                ss = s.split(" ");
                this.weights[i] = Float.parseFloat(ss[0]);
                this.nT = (int)((float)this.nT + this.weights[i]);
                int j = 0;
                while (j < this.ncoefs) {
                    this.setMean(i, j, Float.parseFloat(ss[j + 1]));
                    ++j;
                }
                s = fin.readLine();
                ss = s.split(" ");
                j = 0;
                while (j < this.ncoefs) {
                    this.setVar(i, j, Float.parseFloat(ss[j]));
                    ++j;
                }
                ++i;
            }
            i = 0;
            while (i < this.ngauss) {
                this.setWeight(i, this.weights[i] / (float)this.nT);
                ++i;
            }
            fin.close();
            this.precomputeDistance();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    private void allocateWeights() {
        this.logMath = LogMath.getLogMath();
        this.weights = new float[this.ngauss];
        int i = 0;
        while (i < this.ngauss) {
            this.setWeight(i, 1.0f / (float)this.ngauss);
            ++i;
        }
    }

    public void precomputeDistance() {
        int gidx = 0;
        while (gidx < this.ngauss) {
            float fact = 0.0f;
            int i = 0;
            while (i < this.ncoefs) {
                fact += this.logMath.linearToLog(this.getVar(gidx, i));
                ++i;
            }
            this.logPreComputedGaussianFactor[gidx] = (fact += this.logMath.linearToLog(Math.PI * 2) * (float)this.ncoefs) * 0.5f;
            ++gidx;
        }
    }

    private void allocate() {
        if (this.weights == null) {
            this.allocateWeights();
        }
        if (this.means == null) {
            this.loglikes = new float[this.ngauss];
            this.means = new float[this.ngauss][this.ncoefs];
            this.covar = new float[this.ngauss][this.ncoefs];
            this.logPreComputedGaussianFactor = new float[this.ngauss];
        }
    }

    public void computeLogLikes(float[] data) {
        float logDval1gauss = 0.0f;
        int gidx = 0;
        while (gidx < this.ngauss) {
            logDval1gauss = 0.0f;
            int i = 0;
            while (i < data.length) {
                float logDiff = data[i] - this.means[gidx][i];
                logDval1gauss += logDiff * logDiff * this.covar[gidx][i];
                ++i;
            }
            if (Float.isNaN(logDval1gauss -= this.logPreComputedGaussianFactor[gidx])) {
                System.err.println("gs2 is Nan, converting to 0 debug " + gidx + ' ' + this.logPreComputedGaussianFactor[gidx] + ' ' + this.means[gidx][0] + ' ' + this.covar[gidx][0]);
                logDval1gauss = -3.4028235E38f;
            }
            if (logDval1gauss < -3.4028235E38f) {
                logDval1gauss = -3.4028235E38f;
            }
            this.loglikes[gidx] = this.weights[gidx] + logDval1gauss;
            ++gidx;
        }
    }

    public float getLogLike() {
        float sc = this.loglikes[0];
        int i = 1;
        while (i < this.ngauss) {
            sc = this.logMath.addAsLinear(sc, this.loglikes[i]);
            ++i;
        }
        return sc;
    }

    public int getWinningGauss() {
        int imax = 0;
        int i = 1;
        while (i < this.ngauss) {
            if (this.loglikes[i] > this.loglikes[imax]) {
                imax = i;
            }
            ++i;
        }
        return imax;
    }

    public int getNcoefs() {
        return this.ncoefs;
    }

    public GMMDiag getMarginal(boolean[] mask) {
        int nc = 0;
        boolean[] blArray = mask;
        int n = mask.length;
        int n2 = 0;
        while (n2 < n) {
            boolean flag = blArray[n2];
            if (flag) {
                ++nc;
            }
            ++n2;
        }
        GMMDiag g = new GMMDiag(this.getNgauss(), nc);
        int curc = 0;
        int j = 0;
        while (j < this.ncoefs) {
            if (mask[j]) {
                int i = 0;
                while (i < this.ngauss) {
                    g.setMean(i, curc, this.getMean(i, j));
                    g.setVar(i, curc, this.getVar(i, j));
                    ++i;
                }
                ++curc;
            }
            ++j;
        }
        int i = 0;
        while (i < this.ngauss) {
            g.setWeight(i, this.getWeight(i));
            ++i;
        }
        g.precomputeDistance();
        return g;
    }

    public GMMDiag merge(GMMDiag g, float w1) {
        GMMDiag res = new GMMDiag(this.getNgauss() + g.getNgauss(), this.getNcoefs());
        int i = 0;
        while (i < this.getNgauss()) {
            System.arraycopy(this.means[i], 0, res.means[i], 0, this.getNcoefs());
            System.arraycopy(this.covar[i], 0, res.covar[i], 0, this.getNcoefs());
            res.setWeight(i, this.getWeight(i) * w1);
            ++i;
        }
        i = 0;
        while (i < g.getNgauss()) {
            System.arraycopy(g.means[i], 0, res.means[this.ngauss + i], 0, this.getNcoefs());
            System.arraycopy(g.covar[i], 0, res.covar[this.ngauss + i], 0, this.getNcoefs());
            res.setWeight(this.ngauss + i, g.getWeight(i) * (1.0f - w1));
            ++i;
        }
        res.precomputeDistance();
        return res;
    }

    public GMMDiag getGauss(int i) {
        GMMDiag res = new GMMDiag(1, this.getNcoefs());
        System.arraycopy(this.means[i], 0, res.means[0], 0, this.getNcoefs());
        System.arraycopy(this.covar[i], 0, res.covar[0], 0, this.getNcoefs());
        res.setWeight(0, 1.0f);
        res.precomputeDistance();
        return res;
    }

    public void setNom(String s) {
        this.nom = s;
    }

    public boolean isEqual(GMMDiag g) {
        if (this.getNgauss() != g.getNgauss()) {
            return false;
        }
        if (this.getNgauss() != g.getNcoefs()) {
            return false;
        }
        int i = 0;
        while (i < this.getNgauss()) {
            if (this.isDiff(this.getWeight(i), g.getWeight(i))) {
                return false;
            }
            int j = 0;
            while (j < this.getNcoefs()) {
                if (this.isDiff(this.getMean(i, j), g.getMean(i, j))) {
                    return false;
                }
                if (this.isDiff(this.getVar(i, j), g.getVar(i, j))) {
                    return false;
                }
                ++j;
            }
            ++i;
        }
        return true;
    }

    private boolean isDiff(float a, float b) {
        return (double)Math.abs(1.0f - b / a) > 0.01;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        int i = 0;
        while (i < this.getNgauss()) {
            sb.append(this.getMean(i, 0)).append(' ').append(this.getVar(i, 0)).append('\n');
            ++i;
        }
        return sb.toString();
    }
}

