package org.renjin.primitives.matrix;

import com.github.fommil.netlib.BLAS;
import org.renjin.eval.EvalException;
import org.renjin.sexp.AtomicVector;
import org.renjin.sexp.AttributeMap;
import org.renjin.sexp.DoubleArrayVector;
import org.renjin.sexp.DoubleVector;
import org.renjin.sexp.ListVector;
import org.renjin.sexp.Null;
import org.renjin.sexp.Symbols;
import org.renjin.sexp.Vector;

/* loaded from: input_file:WEB-INF/lib/renjin-core-0.8.2415.jar:org/renjin/primitives/matrix/MatrixProduct.class */
class MatrixProduct {
    public static final int PROD = 0;
    public static final int CROSSPROD = 1;
    public static final int TCROSSPROD = 2;
    private AtomicVector x;
    private AtomicVector y;
    private int primop;
    private AtomicVector xdims;
    private AtomicVector ydims;
    private int ldx;
    private int ldy;
    private boolean sym;
    private int nrx = 0;
    private int ncx = 0;
    private int nry = 0;
    private int ncy = 0;
    private ListVector.Builder dimnames = new ListVector.Builder(2);

    public MatrixProduct(int i, AtomicVector atomicVector, AtomicVector atomicVector2) {
        this.x = atomicVector;
        this.y = atomicVector2;
        this.sym = atomicVector2 == Null.INSTANCE;
        if (this.sym && i > 0) {
            this.y = atomicVector;
        }
        this.primop = i;
        computeMatrixDims();
    }

    private void computeMatrixDims() {
        this.xdims = (AtomicVector) this.x.getAttribute(Symbols.DIM);
        this.ydims = (AtomicVector) this.y.getAttribute(Symbols.DIM);
        this.ldx = this.xdims.length();
        this.ldy = this.ydims.length();
        if (this.ldx != 2 && this.ldy != 2) {
            if (this.primop == 0) {
                this.nrx = 1;
                this.ncx = this.x.length();
            } else {
                this.nrx = this.x.length();
                this.ncx = 1;
            }
            this.nry = this.y.length();
            this.ncy = 1;
        } else if (this.ldx != 2) {
            this.nry = this.ydims.getElementAsInt(0);
            this.ncy = this.ydims.getElementAsInt(1);
            this.nrx = 0;
            this.ncx = 0;
            if (this.primop == 0) {
                if (this.x.length() == this.nry) {
                    this.nrx = 1;
                    this.ncx = this.nry;
                } else if (this.nry == 1) {
                    this.nrx = this.x.length();
                    this.ncx = 1;
                }
            } else if (this.primop == 1) {
                if (this.x.length() == this.nry) {
                    this.nrx = this.nry;
                    this.ncx = 1;
                }
            } else if (this.x.length() == this.ncy) {
                this.nrx = 1;
                this.ncx = this.ncy;
            } else if (this.ncy == 1) {
                this.nrx = this.x.length();
                this.ncx = 1;
            }
        } else if (this.ldy != 2) {
            this.nrx = this.xdims.getElementAsInt(0);
            this.ncx = this.xdims.getElementAsInt(1);
            this.nry = 0;
            this.ncy = 0;
            if (this.primop == 0) {
                if (this.y.length() == this.ncx) {
                    this.nry = this.ncx;
                    this.ncy = 1;
                } else if (this.ncx == 1) {
                    this.nry = 1;
                    this.ncy = this.y.length();
                }
            } else if (this.primop != 1) {
                this.nry = this.y.length();
                this.ncy = 1;
            } else if (this.y.length() == this.nrx) {
                this.nry = this.nrx;
                this.ncy = 1;
            }
        } else {
            this.nrx = this.xdims.getElementAsInt(0);
            this.ncx = this.xdims.getElementAsInt(1);
            this.nry = this.ydims.getElementAsInt(0);
            this.ncy = this.ydims.getElementAsInt(1);
        }
        if ((this.primop == 0 && this.ncx != this.nry) || ((this.primop == 1 && this.nrx != this.nry) || (this.primop == 2 && this.ncx != this.ncy))) {
            throw new EvalException("non-conformable arguments", new Object[0]);
        }
    }

    public Vector matprod() {
        double[] dArr = new double[this.nrx * this.ncy];
        matprod(getXArray(), this.nrx, this.ncx, getYArray(), this.nry, this.ncy, dArr);
        Vector vector = (Vector) this.x.getAttribute(Symbols.DIMNAMES);
        if (vector != Null.INSTANCE && (this.ldx == 2 || this.ncx == 1)) {
            this.dimnames.mo9046set(0, vector.getElementAsSEXP(0));
        }
        ydimsEtcetera();
        return makeMatrix(dArr, this.nrx, this.ncy);
    }

    private DoubleVector makeMatrix(double[] dArr, int i, int i2) {
        AttributeMap.Builder builder = AttributeMap.builder();
        builder.setDim(i, i2);
        builder.set(Symbols.DIMNAMES, buildDimnames());
        return new DoubleArrayVector(dArr, builder.build());
    }

    private Vector buildDimnames() {
        ListVector build = this.dimnames.build();
        return (build.getElementAsSEXP(0) == Null.INSTANCE && build.getElementAsSEXP(1) == Null.INSTANCE) ? Null.INSTANCE : build;
    }

    public DoubleVector crossprod() {
        double[] dArr = new double[this.ncx * this.ncy];
        if (this.sym) {
            symcrossprod(getXArray(), this.nrx, this.ncx, dArr);
        } else {
            crossprod(getXArray(), this.nrx, this.ncx, getYArray(), this.nry, this.ncy, dArr);
        }
        return makeMatrix(dArr, this.ncx, this.ncy);
    }

    private void ydimsEtcetera() {
        Vector vector = (Vector) this.y.getAttribute(Symbols.DIMNAMES);
        if (vector != Null.INSTANCE) {
            if (this.ldy == 2) {
                this.dimnames.mo9046set(1, vector.getElementAsSEXP(1));
            } else if (this.nry == 1) {
                this.dimnames.mo9046set(1, vector.getElementAsSEXP(0));
            }
        }
    }

    public DoubleVector tcrossprod() {
        double[] dArr = new double[this.nrx * this.nry];
        if (this.sym) {
            symtcrossprod(getXArray(), this.nrx, this.ncx, dArr);
        } else {
            tcrossprod(getXArray(), this.nrx, this.ncx, getYArray(), this.nry, this.ncy, dArr);
        }
        return makeMatrix(dArr, this.nrx, this.nry);
    }

    private void symcrossprod(double[] dArr, int i, int i2, double[] dArr2) {
        if (i <= 0 || i2 <= 0) {
            for (int i3 = 0; i3 < i2 * i2; i3++) {
                dArr2[i3] = 0.0d;
            }
            return;
        }
        BLAS.getInstance().dsyrk("U", "T", i2, i, 1.0d, dArr, i, 0.0d, dArr2, i2);
        for (int i4 = 1; i4 < i2; i4++) {
            for (int i5 = 0; i5 < i4; i5++) {
                dArr2[i4 + (i2 * i5)] = dArr2[i5 + (i2 * i4)];
            }
        }
    }

    private double[] getXArray() {
        return this.x.toDoubleArray();
    }

    private double[] getYArray() {
        return this.y.toDoubleArray();
    }

    private void matprod(double[] dArr, int i, int i2, double[] dArr2, int i3, int i4, double[] dArr3) {
        boolean z = false;
        if (i <= 0 || i2 <= 0 || i3 <= 0 || i4 <= 0) {
            for (int i5 = 0; i5 < i * i4; i5++) {
                dArr3[i5] = 0.0d;
            }
            return;
        }
        int i6 = 0;
        while (true) {
            if (i6 >= i * i2) {
                break;
            }
            if (Double.isNaN(dArr[i6])) {
                z = true;
                break;
            }
            i6++;
        }
        if (!z) {
            int i7 = 0;
            while (true) {
                if (i7 >= i3 * i4) {
                    break;
                }
                if (Double.isNaN(dArr2[i7])) {
                    z = true;
                    break;
                }
                i7++;
            }
        }
        if (!z) {
            BLAS.getInstance().dgemm("N", "N", i, i4, i2, 1.0d, dArr, i, dArr2, i3, 0.0d, dArr3, i);
            return;
        }
        for (int i8 = 0; i8 < i; i8++) {
            for (int i9 = 0; i9 < i4; i9++) {
                double d = 0.0d;
                for (int i10 = 0; i10 < i2; i10++) {
                    d += dArr[i8 + (i10 * i)] * dArr2[i10 + (i9 * i3)];
                }
                dArr3[i8 + (i9 * i)] = d;
            }
        }
    }

    private void symtcrossprod(double[] dArr, int i, int i2, double[] dArr2) {
        if (i <= 0 || i2 <= 0) {
            for (int i3 = 0; i3 < i * i; i3++) {
                dArr2[i3] = 0.0d;
            }
            return;
        }
        BLAS.getInstance().dsyrk("U", "N", i, i2, 1.0d, dArr, i, 0.0d, dArr2, i);
        for (int i4 = 1; i4 < i; i4++) {
            for (int i5 = 0; i5 < i4; i5++) {
                dArr2[i4 + (i * i5)] = dArr2[i5 + (i * i4)];
            }
        }
    }

    private void tcrossprod(double[] dArr, int i, int i2, double[] dArr2, int i3, int i4, double[] dArr3) {
        if (i > 0 && i2 > 0 && i3 > 0 && i4 > 0) {
            BLAS.getInstance().dgemm("N", "T", i, i3, i2, 1.0d, dArr, i, dArr2, i3, 0.0d, dArr3, i);
            return;
        }
        for (int i5 = 0; i5 < i * i3; i5++) {
            dArr3[i5] = 0.0d;
        }
    }

    private void crossprod(double[] dArr, int i, int i2, double[] dArr2, int i3, int i4, double[] dArr3) {
        if (i > 0 && i2 > 0 && i3 > 0 && i4 > 0) {
            BLAS.getInstance().dgemm("T", "N", i2, i4, i, 1.0d, dArr, i, dArr2, i3, 0.0d, dArr3, i2);
            return;
        }
        for (int i5 = 0; i5 < i2 * i4; i5++) {
            dArr3[i5] = 0.0d;
        }
    }
}
