/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.sphinx.decoder.adaptation;

import edu.cmu.sphinx.decoder.adaptation.Stats;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.Sphinx3Loader;
import java.io.File;
import java.io.PrintWriter;
import java.util.Scanner;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.DecompositionSolver;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;

public class Transform {
    private float[][][][] As;
    private float[][][] Bs;
    private Sphinx3Loader loader;
    private int nrOfClusters;

    public Transform(Sphinx3Loader loader, int nrOfClusters) {
        this.loader = loader;
        this.nrOfClusters = nrOfClusters;
    }

    public float[][][][] getAs() {
        return this.As;
    }

    public float[][][] getBs() {
        return this.Bs;
    }

    public void store(String filePath, int index) throws Exception {
        PrintWriter writer = new PrintWriter(filePath, "UTF-8");
        writer.println("1");
        writer.println(this.loader.getNumStreams());
        int i = 0;
        while (i < this.loader.getNumStreams()) {
            writer.println(this.loader.getVectorLength()[i]);
            int j = 0;
            while (j < this.loader.getVectorLength()[i]) {
                int k = 0;
                while (k < this.loader.getVectorLength()[i]) {
                    writer.print(this.As[index][i][j][k]);
                    writer.print(" ");
                    ++k;
                }
                writer.println();
                ++j;
            }
            j = 0;
            while (j < this.loader.getVectorLength()[i]) {
                writer.print(this.Bs[index][i][j]);
                writer.print(" ");
                ++j;
            }
            writer.println();
            j = 0;
            while (j < this.loader.getVectorLength()[i]) {
                writer.print("1.0 ");
                ++j;
            }
            writer.println();
            ++i;
        }
        writer.close();
    }

    private void computeMllrTransforms(double[][][][][] regLs, double[][][][] regRs) {
        int c = 0;
        while (c < this.nrOfClusters) {
            this.As[c] = new float[this.loader.getNumStreams()][][];
            this.Bs[c] = new float[this.loader.getNumStreams()][];
            int i = 0;
            while (i < this.loader.getNumStreams()) {
                int len = this.loader.getVectorLength()[i];
                this.As[c][i] = new float[len][len];
                this.Bs[c][i] = new float[len];
                int j = 0;
                while (j < len) {
                    Array2DRowRealMatrix coef = new Array2DRowRealMatrix(regLs[c][i][j], false);
                    DecompositionSolver solver = new LUDecomposition((RealMatrix)coef).getSolver();
                    ArrayRealVector vect = new ArrayRealVector(regRs[c][i][j], false);
                    RealVector ABloc = solver.solve((RealVector)vect);
                    int k = 0;
                    while (k < len) {
                        this.As[c][i][j][k] = (float)ABloc.getEntry(k);
                        ++k;
                    }
                    this.Bs[c][i][j] = (float)ABloc.getEntry(len);
                    ++j;
                }
                ++i;
            }
            ++c;
        }
    }

    public void load(String filePath) throws Exception {
        Scanner input = new Scanner(new File(filePath));
        int nMllrClass = input.nextInt();
        assert (nMllrClass == 1);
        int numStreams = input.nextInt();
        this.As = new float[nMllrClass][numStreams][][];
        this.Bs = new float[nMllrClass][numStreams][];
        int i = 0;
        while (i < numStreams) {
            int length = input.nextInt();
            this.As[0][i] = new float[length][length];
            this.Bs[0][i] = new float[length];
            int j = 0;
            while (j < length) {
                int k = 0;
                while (k < length) {
                    this.As[0][i][j][k] = input.nextFloat();
                    ++k;
                }
                ++j;
            }
            j = 0;
            while (j < length) {
                this.Bs[0][i][j] = input.nextFloat();
                ++j;
            }
            j = 0;
            while (j < length) {
                input.nextFloat();
                ++j;
            }
            ++i;
        }
        input.close();
    }

    public void update(Stats stats) {
        stats.fillRegLowerPart();
        this.As = new float[this.nrOfClusters][][][];
        this.Bs = new float[this.nrOfClusters][][];
        this.computeMllrTransforms(stats.getRegLs(), stats.getRegRs());
    }
}

