package org.renjin.primitives.matrix;

import com.github.fommil.netlib.BLAS;
import org.jetbrains.kotlin.com.intellij.psi.PsiReferenceRegistrar;
import org.renjin.eval.EvalException;
import org.renjin.primitives.sequence.RepDoubleVector;
import org.renjin.sexp.AtomicVector;
import org.renjin.sexp.AttributeMap;
import org.renjin.sexp.DoubleArrayVector;
import org.renjin.sexp.ListVector;
import org.renjin.sexp.Null;
import org.renjin.sexp.SEXP;
import org.renjin.sexp.Vector;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:WEB-INF/lib/renjin-core-0.9.2726.jar:org/renjin/primitives/matrix/MatrixProduct.class */
public class MatrixProduct {
    public static final int PROD = 0;
    public static final int CROSSPROD = 1;
    public static final int TCROSSPROD = 2;
    private static final int ROWS = 0;
    private static final int COLS = 1;
    private int operation;
    private boolean symmetrical;
    private AtomicVector x;
    private AtomicVector y;
    private int nrx = 0;
    private int ncx = 0;
    private int nry = 0;
    private int ncy = 0;
    private int ldx;
    private int ldy;
    private Vector[] operands;

    public MatrixProduct(int i, AtomicVector atomicVector, AtomicVector atomicVector2) {
        this.x = atomicVector;
        this.y = atomicVector2;
        this.symmetrical = atomicVector2 == Null.INSTANCE;
        if (this.symmetrical && i > 0) {
            this.y = atomicVector;
        }
        this.operation = i;
        if (this.symmetrical) {
            this.operands = new Vector[]{atomicVector};
        } else {
            this.operands = new Vector[]{atomicVector, atomicVector2};
        }
        computeMatrixDims();
    }

    private void computeMatrixDims() {
        Vector dim = this.x.getAttributes().getDim();
        Vector dim2 = this.y.getAttributes().getDim();
        this.ldx = dim.length();
        this.ldy = dim2.length();
        if (this.ldx != 2 && this.ldy != 2) {
            if (this.operation == 0) {
                this.nrx = 1;
                this.ncx = this.x.length();
            } else {
                this.nrx = this.x.length();
                this.ncx = 1;
            }
            this.nry = this.y.length();
            this.ncy = 1;
        } else if (this.ldx != 2) {
            this.nry = dim2.getElementAsInt(0);
            this.ncy = dim2.getElementAsInt(1);
            this.nrx = 0;
            this.ncx = 0;
            switch (this.operation) {
                case 0:
                    if (this.x.length() != this.nry) {
                        if (this.nry == 1) {
                            this.nrx = this.x.length();
                            this.ncx = 1;
                            break;
                        }
                    } else {
                        this.nrx = 1;
                        this.ncx = this.nry;
                        break;
                    }
                    break;
                case 1:
                    if (this.x.length() == this.nry) {
                        this.nrx = this.nry;
                        this.ncx = 1;
                        break;
                    }
                    break;
                case 2:
                    if (this.x.length() != this.ncy) {
                        if (this.ncy == 1) {
                            this.nrx = this.x.length();
                            this.ncx = 1;
                            break;
                        }
                    } else {
                        this.nrx = 1;
                        this.ncx = this.ncy;
                        break;
                    }
                    break;
            }
        } else if (this.ldy != 2) {
            this.nrx = dim.getElementAsInt(0);
            this.ncx = dim.getElementAsInt(1);
            this.nry = 0;
            this.ncy = 0;
            switch (this.operation) {
                case 0:
                    if (this.y.length() != this.ncx) {
                        if (this.ncx == 1) {
                            this.nry = 1;
                            this.ncy = this.y.length();
                            break;
                        }
                    } else {
                        this.nry = this.ncx;
                        this.ncy = 1;
                        break;
                    }
                    break;
                case 1:
                    if (this.y.length() == this.nrx) {
                        this.nry = this.nrx;
                        this.ncy = 1;
                        break;
                    }
                    break;
                case 2:
                    this.nry = this.y.length();
                    this.ncy = 1;
                    break;
            }
        } else {
            this.nrx = dim.getElementAsInt(0);
            this.ncx = dim.getElementAsInt(1);
            this.nry = dim2.getElementAsInt(0);
            this.ncy = dim2.getElementAsInt(1);
        }
        if ((this.operation == 0 && this.ncx != this.nry) || ((this.operation == 1 && this.nrx != this.nry) || (this.operation == 2 && this.ncx != this.ncy))) {
            throw new EvalException("non-conformable arguments", new Object[0]);
        }
    }

    public String getName() {
        switch (this.operation) {
            case 0:
            default:
                return "%*%";
            case 1:
                return "crossprod";
            case 2:
                return "tcrossprod";
        }
    }

    public Vector[] getOperands() {
        return this.operands;
    }

    public boolean isNonZero() {
        return this.nrx > 0 && this.ncx > 0 && this.nry > 0 && this.ncy > 0;
    }

    public int computeLength() {
        switch (this.operation) {
            case 0:
            default:
                return this.nrx * this.ncy;
            case 1:
                return this.ncx * this.ncy;
            case 2:
                return this.nrx * this.nry;
        }
    }

    public AttributeMap computeAttributes() {
        AttributeMap.Builder builder = new AttributeMap.Builder();
        switch (this.operation) {
            case 0:
                builder.setDim(this.nrx, this.ncy);
                builder.setDimNames(computeDimensionNames(0, 1));
                break;
            case 1:
                builder.setDim(this.ncx, this.ncy);
                builder.setDimNames(computeDimensionNames(1, 1));
                break;
            case 2:
                builder.setDim(this.nrx, this.nry);
                builder.setDimNames(computeDimensionNames(0, 0));
                break;
        }
        return builder.build();
    }

    public Vector compute() {
        return !isNonZero() ? (Vector) RepDoubleVector.createConstantVector(PsiReferenceRegistrar.DEFAULT_PRIORITY, computeLength()).setAttributes(computeAttributes()) : (this.x.isDeferred() || this.y.isDeferred() || computeLength() > 500) ? new DeferredMatrixProduct(this) : DoubleArrayVector.unsafe(computeResult(), computeAttributes());
    }

    private Vector computeDimensionNames(int i, int i2) {
        Vector dimNames = this.x.getAttributes().getDimNames();
        Vector dimNames2 = this.y.getAttributes().getDimNames();
        ListVector.NamedBuilder namedBuilder = new ListVector.NamedBuilder(2);
        namedBuilder.mo22871set(0, (SEXP) Null.INSTANCE);
        namedBuilder.mo22871set(1, (SEXP) Null.INSTANCE);
        boolean z = false;
        if (dimNames != Null.INSTANCE && this.ldx == 2) {
            SEXP elementAsSEXP = dimNames.getElementAsSEXP(i);
            if (elementAsSEXP != Null.INSTANCE) {
                z = true;
            }
            if (elementAsSEXP != Null.INSTANCE || dimNames.hasNames()) {
                namedBuilder.mo22871set(0, elementAsSEXP);
                namedBuilder.setName(0, dimNames.getName(i));
            }
        }
        if (dimNames2 != Null.INSTANCE && this.ldy == 2) {
            SEXP elementAsSEXP2 = dimNames2.getElementAsSEXP(i2);
            if (elementAsSEXP2 != Null.INSTANCE) {
                z = true;
            }
            if (elementAsSEXP2 != Null.INSTANCE || dimNames2.hasNames()) {
                namedBuilder.mo22871set(1, elementAsSEXP2);
                namedBuilder.setName(1, dimNames2.getName(i2));
            }
        }
        return z ? namedBuilder.build() : Null.INSTANCE;
    }

    private double[] getXArray() {
        return this.x.toDoubleArray();
    }

    private double[] getYArray() {
        return this.y.toDoubleArray();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Vector computeResultVector(AttributeMap attributeMap) {
        return DoubleArrayVector.unsafe(computeResult(), attributeMap);
    }

    double[] computeResult() {
        switch (this.operation) {
            case 0:
            default:
                return computeMatrixProduct();
            case 1:
                return this.symmetrical ? computeSymmetricalCrossProduct() : computeCrossProduct();
            case 2:
                return this.symmetrical ? computeTransposeSymmetricalCrossProduct() : computeTransposeCrossProduct();
        }
    }

    private double[] computeMatrixProduct() {
        boolean z = false;
        double[] xArray = getXArray();
        double[] yArray = getYArray();
        double[] dArr = new double[this.nrx * this.ncy];
        if (this.nrx > 0 && this.ncx > 0 && this.nry > 0 && this.ncy > 0) {
            int i = 0;
            while (true) {
                if (i >= this.nrx * this.ncx) {
                    break;
                }
                if (Double.isNaN(xArray[i])) {
                    z = true;
                    break;
                }
                i++;
            }
            if (!z) {
                int i2 = 0;
                while (true) {
                    if (i2 >= this.nry * this.ncy) {
                        break;
                    }
                    if (Double.isNaN(yArray[i2])) {
                        z = true;
                        break;
                    }
                    i2++;
                }
            }
            if (z) {
                for (int i3 = 0; i3 < this.nrx; i3++) {
                    for (int i4 = 0; i4 < this.ncy; i4++) {
                        double d = 0.0d;
                        for (int i5 = 0; i5 < this.ncx; i5++) {
                            d += xArray[i3 + (i5 * this.nrx)] * yArray[i5 + (i4 * this.nry)];
                        }
                        dArr[i3 + (i4 * this.nrx)] = d;
                    }
                }
            } else {
                BLAS.getInstance().dgemm("N", "N", this.nrx, this.ncy, this.ncx, 1.0d, xArray, this.nrx, yArray, this.nry, PsiReferenceRegistrar.DEFAULT_PRIORITY, dArr, this.nrx);
            }
        }
        return dArr;
    }

    private double[] computeCrossProduct() {
        double[] xArray = getXArray();
        double[] yArray = getYArray();
        double[] dArr = new double[this.ncx * this.ncy];
        if (this.nrx > 0 && this.ncx > 0 && this.nry > 0 && this.ncy > 0) {
            BLAS.getInstance().dgemm("T", "N", this.ncx, this.ncy, this.nrx, 1.0d, xArray, this.nrx, yArray, this.nry, PsiReferenceRegistrar.DEFAULT_PRIORITY, dArr, this.ncx);
        }
        return dArr;
    }

    private double[] computeTransposeCrossProduct() {
        double[] xArray = getXArray();
        double[] yArray = getYArray();
        double[] dArr = new double[this.nrx * this.nry];
        if (this.nrx > 0 && this.ncx > 0 && this.nry > 0 && this.ncy > 0) {
            BLAS.getInstance().dgemm("N", "T", this.nrx, this.nry, this.ncx, 1.0d, xArray, this.nrx, yArray, this.nry, PsiReferenceRegistrar.DEFAULT_PRIORITY, dArr, this.nrx);
        }
        return dArr;
    }

    private double[] computeSymmetricalCrossProduct() {
        double[] xArray = getXArray();
        double[] dArr = new double[this.ncx * this.ncy];
        if (this.nrx > 0 && this.ncx > 0) {
            BLAS.getInstance().dsyrk("U", "T", this.ncx, this.nrx, 1.0d, xArray, this.nrx, PsiReferenceRegistrar.DEFAULT_PRIORITY, dArr, this.ncx);
            for (int i = 1; i < this.ncx; i++) {
                for (int i2 = 0; i2 < i; i2++) {
                    dArr[i + (this.ncx * i2)] = dArr[i2 + (this.ncx * i)];
                }
            }
        }
        return dArr;
    }

    private double[] computeTransposeSymmetricalCrossProduct() {
        double[] xArray = getXArray();
        double[] dArr = new double[this.nrx * this.nry];
        if (this.nrx > 0 && this.ncx > 0) {
            BLAS.getInstance().dsyrk("U", "N", this.nrx, this.ncx, 1.0d, xArray, this.nrx, PsiReferenceRegistrar.DEFAULT_PRIORITY, dArr, this.nrx);
            for (int i = 1; i < this.nrx; i++) {
                for (int i2 = 0; i2 < i; i2++) {
                    dArr[i + (this.nrx * i2)] = dArr[i2 + (this.nrx * i)];
                }
            }
        }
        return dArr;
    }
}
