/*
 * Decompiled with CFR 0.152.
 */
package org.renjin.primitives.matrix;

import com.github.fommil.netlib.BLAS;
import org.renjin.eval.EvalException;
import org.renjin.primitives.matrix.DeferredMatrixProduct;
import org.renjin.primitives.sequence.RepDoubleVector;
import org.renjin.sexp.AtomicVector;
import org.renjin.sexp.AttributeMap;
import org.renjin.sexp.DoubleArrayVector;
import org.renjin.sexp.ListVector;
import org.renjin.sexp.Null;
import org.renjin.sexp.SEXP;
import org.renjin.sexp.Vector;

class MatrixProduct {
    public static final int PROD = 0;
    public static final int CROSSPROD = 1;
    public static final int TCROSSPROD = 2;
    private static final int ROWS = 0;
    private static final int COLS = 1;
    private int operation;
    private boolean symmetrical;
    private AtomicVector x;
    private AtomicVector y;
    private int nrx = 0;
    private int ncx = 0;
    private int nry = 0;
    private int ncy = 0;
    private int ldx;
    private int ldy;
    private Vector[] operands;

    public MatrixProduct(int operation, AtomicVector x, AtomicVector y) {
        this.x = x;
        this.y = y;
        boolean bl = this.symmetrical = y == Null.INSTANCE;
        if (this.symmetrical && operation > 0) {
            this.y = x;
        }
        this.operation = operation;
        this.operands = this.symmetrical ? new Vector[]{x} : new Vector[]{x, y};
        this.computeMatrixDims();
    }

    private void computeMatrixDims() {
        Vector xdims = this.x.getAttributes().getDim();
        Vector ydims = this.y.getAttributes().getDim();
        this.ldx = xdims.length();
        this.ldy = ydims.length();
        if (this.ldx != 2 && this.ldy != 2) {
            if (this.operation == 0) {
                this.nrx = 1;
                this.ncx = this.x.length();
            } else {
                this.nrx = this.x.length();
                this.ncx = 1;
            }
            this.nry = this.y.length();
            this.ncy = 1;
        } else if (this.ldx != 2) {
            this.nry = ydims.getElementAsInt(0);
            this.ncy = ydims.getElementAsInt(1);
            this.nrx = 0;
            this.ncx = 0;
            switch (this.operation) {
                case 0: {
                    if (this.x.length() == this.nry) {
                        this.nrx = 1;
                        this.ncx = this.nry;
                        break;
                    }
                    if (this.nry != 1) break;
                    this.nrx = this.x.length();
                    this.ncx = 1;
                    break;
                }
                case 1: {
                    if (this.x.length() != this.nry) break;
                    this.nrx = this.nry;
                    this.ncx = 1;
                    break;
                }
                case 2: {
                    if (this.x.length() == this.ncy) {
                        this.nrx = 1;
                        this.ncx = this.ncy;
                        break;
                    }
                    if (this.ncy != 1) break;
                    this.nrx = this.x.length();
                    this.ncx = 1;
                }
            }
        } else if (this.ldy != 2) {
            this.nrx = xdims.getElementAsInt(0);
            this.ncx = xdims.getElementAsInt(1);
            this.nry = 0;
            this.ncy = 0;
            switch (this.operation) {
                case 0: {
                    if (this.y.length() == this.ncx) {
                        this.nry = this.ncx;
                        this.ncy = 1;
                        break;
                    }
                    if (this.ncx != 1) break;
                    this.nry = 1;
                    this.ncy = this.y.length();
                    break;
                }
                case 1: {
                    if (this.y.length() != this.nrx) break;
                    this.nry = this.nrx;
                    this.ncy = 1;
                    break;
                }
                case 2: {
                    this.nry = this.y.length();
                    this.ncy = 1;
                }
            }
        } else {
            this.nrx = xdims.getElementAsInt(0);
            this.ncx = xdims.getElementAsInt(1);
            this.nry = ydims.getElementAsInt(0);
            this.ncy = ydims.getElementAsInt(1);
        }
        if (this.operation == 0 && this.ncx != this.nry || this.operation == 1 && this.nrx != this.nry || this.operation == 2 && this.ncx != this.ncy) {
            throw new EvalException("non-conformable arguments", new Object[0]);
        }
    }

    public String getName() {
        switch (this.operation) {
            default: {
                return "%*%";
            }
            case 1: {
                return "crossprod";
            }
            case 2: 
        }
        return "tcrossprod";
    }

    public Vector[] getOperands() {
        return this.operands;
    }

    public boolean isNonZero() {
        return this.nrx > 0 && this.ncx > 0 && this.nry > 0 && this.ncy > 0;
    }

    public int computeLength() {
        switch (this.operation) {
            default: {
                return this.nrx * this.ncy;
            }
            case 1: {
                return this.ncx * this.ncy;
            }
            case 2: 
        }
        return this.nrx * this.nry;
    }

    public AttributeMap computeAttributes() {
        AttributeMap.Builder attributes2 = new AttributeMap.Builder();
        switch (this.operation) {
            case 0: {
                attributes2.setDim(this.nrx, this.ncy);
                attributes2.setDimNames(this.computeDimensionNames(0, 1));
                break;
            }
            case 1: {
                attributes2.setDim(this.ncx, this.ncy);
                attributes2.setDimNames(this.computeDimensionNames(1, 1));
                break;
            }
            case 2: {
                attributes2.setDim(this.nrx, this.nry);
                attributes2.setDimNames(this.computeDimensionNames(0, 0));
            }
        }
        return attributes2.build();
    }

    public Vector compute() {
        if (!this.isNonZero()) {
            return (Vector)RepDoubleVector.createConstantVector(0.0, this.computeLength()).setAttributes(this.computeAttributes());
        }
        if (this.x.isDeferred() || this.y.isDeferred() || this.computeLength() > 500) {
            return new DeferredMatrixProduct(this);
        }
        return DoubleArrayVector.unsafe(this.computeResult(), this.computeAttributes());
    }

    private Vector computeDimensionNames(int rowNamesDim, int colNamesDim) {
        Object rowNames;
        Vector xdims = this.x.getAttributes().getDimNames();
        Vector ydims = this.y.getAttributes().getDimNames();
        ListVector.NamedBuilder dimNames = new ListVector.NamedBuilder(2);
        dimNames.set(0, Null.INSTANCE);
        dimNames.set(1, Null.INSTANCE);
        boolean hasNames = false;
        if (xdims != Null.INSTANCE && this.ldx == 2) {
            rowNames = xdims.getElementAsSEXP(rowNamesDim);
            if (rowNames != Null.INSTANCE) {
                hasNames = true;
            }
            if (rowNames != Null.INSTANCE || xdims.hasNames()) {
                dimNames.set(0, (SEXP)rowNames);
                dimNames.setName(0, xdims.getName(rowNamesDim));
            }
        }
        if (ydims != Null.INSTANCE && this.ldy == 2) {
            rowNames = ydims.getElementAsSEXP(colNamesDim);
            if (rowNames != Null.INSTANCE) {
                hasNames = true;
            }
            if (rowNames != Null.INSTANCE || ydims.hasNames()) {
                dimNames.set(1, (SEXP)rowNames);
                dimNames.setName(1, ydims.getName(colNamesDim));
            }
        }
        if (hasNames) {
            return dimNames.build();
        }
        return Null.INSTANCE;
    }

    private double[] getXArray() {
        return this.x.toDoubleArray();
    }

    private double[] getYArray() {
        return this.y.toDoubleArray();
    }

    Vector computeResultVector(AttributeMap attributes2) {
        return DoubleArrayVector.unsafe(this.computeResult(), attributes2);
    }

    double[] computeResult() {
        switch (this.operation) {
            case 1: {
                if (this.symmetrical) {
                    return this.computeSymmetricalCrossProduct();
                }
                return this.computeCrossProduct();
            }
            case 2: {
                if (this.symmetrical) {
                    return this.computeTransposeSymmetricalCrossProduct();
                }
                return this.computeTransposeCrossProduct();
            }
        }
        return this.computeMatrixProduct();
    }

    private double[] computeMatrixProduct() {
        String transa = "N";
        String transb = "N";
        double one = 1.0;
        double zero = 0.0;
        boolean haveNA = false;
        double[] x = this.getXArray();
        double[] y = this.getYArray();
        double[] z = new double[this.nrx * this.ncy];
        if (this.nrx > 0 && this.ncx > 0 && this.nry > 0 && this.ncy > 0) {
            int i;
            for (i = 0; i < this.nrx * this.ncx; ++i) {
                if (!Double.isNaN(x[i])) continue;
                haveNA = true;
                break;
            }
            if (!haveNA) {
                for (i = 0; i < this.nry * this.ncy; ++i) {
                    if (!Double.isNaN(y[i])) continue;
                    haveNA = true;
                    break;
                }
            }
            if (haveNA) {
                for (i = 0; i < this.nrx; ++i) {
                    for (int k = 0; k < this.ncy; ++k) {
                        double sum2 = 0.0;
                        for (int j = 0; j < this.ncx; ++j) {
                            sum2 += x[i + j * this.nrx] * y[j + k * this.nry];
                        }
                        z[i + k * this.nrx] = sum2;
                    }
                }
            } else {
                BLAS.getInstance().dgemm(transa, transb, this.nrx, this.ncy, this.ncx, one, x, this.nrx, y, this.nry, zero, z, this.nrx);
            }
        }
        return z;
    }

    private double[] computeCrossProduct() {
        double[] x = this.getXArray();
        double[] y = this.getYArray();
        double[] z = new double[this.ncx * this.ncy];
        String transa = "T";
        String transb = "N";
        double one = 1.0;
        double zero = 0.0;
        if (this.nrx > 0 && this.ncx > 0 && this.nry > 0 && this.ncy > 0) {
            BLAS.getInstance().dgemm(transa, transb, this.ncx, this.ncy, this.nrx, one, x, this.nrx, y, this.nry, zero, z, this.ncx);
        }
        return z;
    }

    private double[] computeTransposeCrossProduct() {
        double[] x = this.getXArray();
        double[] y = this.getYArray();
        double[] z = new double[this.nrx * this.nry];
        String transa = "N";
        String transb = "T";
        double one = 1.0;
        double zero = 0.0;
        if (this.nrx > 0 && this.ncx > 0 && this.nry > 0 && this.ncy > 0) {
            BLAS.getInstance().dgemm(transa, transb, this.nrx, this.nry, this.ncx, one, x, this.nrx, y, this.nry, zero, z, this.nrx);
        }
        return z;
    }

    private double[] computeSymmetricalCrossProduct() {
        String trans = "T";
        String uplo = "U";
        double one = 1.0;
        double zero = 0.0;
        double[] x = this.getXArray();
        double[] z = new double[this.ncx * this.ncy];
        if (this.nrx > 0 && this.ncx > 0) {
            BLAS.getInstance().dsyrk(uplo, trans, this.ncx, this.nrx, one, x, this.nrx, zero, z, this.ncx);
            for (int i = 1; i < this.ncx; ++i) {
                for (int j = 0; j < i; ++j) {
                    z[i + this.ncx * j] = z[j + this.ncx * i];
                }
            }
        }
        return z;
    }

    private double[] computeTransposeSymmetricalCrossProduct() {
        double[] x = this.getXArray();
        double[] z = new double[this.nrx * this.nry];
        String trans = "N";
        String uplo = "U";
        double one = 1.0;
        double zero = 0.0;
        if (this.nrx > 0 && this.ncx > 0) {
            BLAS.getInstance().dsyrk(uplo, trans, this.nrx, this.ncx, one, x, this.nrx, zero, z, this.nrx);
            for (int i = 1; i < this.nrx; ++i) {
                for (int j = 0; j < i; ++j) {
                    z[i + this.nrx * j] = z[j + this.nrx * i];
                }
            }
        }
        return z;
    }
}

