package org.nd4j.linalg.cpu.nativecpu.blas;

import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.openblas.global.openblas;
import org.nd4j.linalg.api.blas.BlasException;
import org.nd4j.linalg.api.blas.impl.BaseLapack;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/nd4j/linalg/cpu/nativecpu/blas/CpuLapack.class */
public class CpuLapack extends BaseLapack {
    protected static int getColumnOrder(INDArray iNDArray) {
        return iNDArray.ordering() == 'f' ? 102 : 101;
    }

    protected static int getLda(INDArray iNDArray) {
        if (iNDArray.rows() > Integer.MAX_VALUE || iNDArray.columns() > Integer.MAX_VALUE) {
            throw new ND4JArraySizeException();
        }
        return iNDArray.ordering() == 'f' ? iNDArray.rows() : iNDArray.columns();
    }

    @Override // org.nd4j.linalg.api.blas.impl.BaseLapack
    public void sgetrf(int i, int i2, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        int LAPACKE_sgetrf = openblas.LAPACKE_sgetrf(getColumnOrder(iNDArray), i, i2, (FloatPointer) iNDArray.data().addressPointer(), getLda(iNDArray), (IntPointer) iNDArray2.data().addressPointer());
        if (LAPACKE_sgetrf < 0) {
            throw new BlasException("Failed to execute sgetrf", LAPACKE_sgetrf);
        }
    }

    @Override // org.nd4j.linalg.api.blas.impl.BaseLapack
    public void dgetrf(int i, int i2, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        int LAPACKE_dgetrf = openblas.LAPACKE_dgetrf(getColumnOrder(iNDArray), i, i2, (DoublePointer) iNDArray.data().addressPointer(), getLda(iNDArray), (IntPointer) iNDArray2.data().addressPointer());
        if (LAPACKE_dgetrf < 0) {
            throw new BlasException("Failed to execute dgetrf", LAPACKE_dgetrf);
        }
    }

    @Override // org.nd4j.linalg.api.blas.impl.BaseLapack
    public void sgeqrf(int i, int i2, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        INDArray create = Nd4j.create(DataType.FLOAT, i2);
        int LAPACKE_sgeqrf = openblas.LAPACKE_sgeqrf(getColumnOrder(iNDArray), i, i2, (FloatPointer) iNDArray.data().addressPointer(), getLda(iNDArray), (FloatPointer) create.data().addressPointer());
        if (LAPACKE_sgeqrf != 0) {
            throw new BlasException("Failed to execute sgeqrf", LAPACKE_sgeqrf);
        }
        if (iNDArray2 != null) {
            iNDArray2.assign(iNDArray.get(NDArrayIndex.interval(0, iNDArray.columns()), NDArrayIndex.all()));
            INDArrayIndex[] iNDArrayIndexArr = new INDArrayIndex[2];
            for (int i3 = 1; i3 < Math.min(iNDArray.rows(), iNDArray.columns()); i3++) {
                iNDArrayIndexArr[0] = NDArrayIndex.point(i3);
                iNDArrayIndexArr[1] = NDArrayIndex.interval(0, i3);
                iNDArray2.put(iNDArrayIndexArr, (Number) 0);
            }
        }
        int LAPACKE_sorgqr = openblas.LAPACKE_sorgqr(getColumnOrder(iNDArray), i, i2, i2, (FloatPointer) iNDArray.data().addressPointer(), getLda(iNDArray), (FloatPointer) create.data().addressPointer());
        if (LAPACKE_sorgqr != 0) {
            throw new BlasException("Failed to execute sorgqr", LAPACKE_sorgqr);
        }
    }

    @Override // org.nd4j.linalg.api.blas.impl.BaseLapack
    public void dgeqrf(int i, int i2, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        INDArray create = Nd4j.create(DataType.DOUBLE, i2);
        int LAPACKE_dgeqrf = openblas.LAPACKE_dgeqrf(getColumnOrder(iNDArray), i, i2, (DoublePointer) iNDArray.data().addressPointer(), getLda(iNDArray), (DoublePointer) create.data().addressPointer());
        if (LAPACKE_dgeqrf != 0) {
            throw new BlasException("Failed to execute dgeqrf", LAPACKE_dgeqrf);
        }
        if (iNDArray2 != null) {
            iNDArray2.assign(iNDArray.get(NDArrayIndex.interval(0, iNDArray.columns()), NDArrayIndex.all()));
            INDArrayIndex[] iNDArrayIndexArr = new INDArrayIndex[2];
            for (int i3 = 1; i3 < Math.min(iNDArray.rows(), iNDArray.columns()); i3++) {
                iNDArrayIndexArr[0] = NDArrayIndex.point(i3);
                iNDArrayIndexArr[1] = NDArrayIndex.interval(0, i3);
                iNDArray2.put(iNDArrayIndexArr, (Number) 0);
            }
        }
        int LAPACKE_dorgqr = openblas.LAPACKE_dorgqr(getColumnOrder(iNDArray), i, i2, i2, (DoublePointer) iNDArray.data().addressPointer(), getLda(iNDArray), (DoublePointer) create.data().addressPointer());
        if (LAPACKE_dorgqr != 0) {
            throw new BlasException("Failed to execute dorgqr", LAPACKE_dorgqr);
        }
    }

    @Override // org.nd4j.linalg.api.blas.impl.BaseLapack
    public void spotrf(byte b, int i, INDArray iNDArray, INDArray iNDArray2) {
        int LAPACKE_spotrf = openblas.LAPACKE_spotrf(getColumnOrder(iNDArray), b, i, (FloatPointer) iNDArray.data().addressPointer(), getLda(iNDArray));
        if (LAPACKE_spotrf != 0) {
            throw new BlasException("Failed to execute spotrf", LAPACKE_spotrf);
        }
        if (b == 85) {
            INDArrayIndex[] iNDArrayIndexArr = new INDArrayIndex[2];
            for (int i2 = 1; i2 < Math.min(iNDArray.rows(), iNDArray.columns()); i2++) {
                iNDArrayIndexArr[0] = NDArrayIndex.point(i2);
                iNDArrayIndexArr[1] = NDArrayIndex.interval(0, i2);
                iNDArray.put(iNDArrayIndexArr, (Number) 0);
            }
            return;
        }
        INDArrayIndex[] iNDArrayIndexArr2 = new INDArrayIndex[2];
        for (int i3 = 0; i3 < Math.min(iNDArray.rows(), iNDArray.columns() - 1); i3++) {
            iNDArrayIndexArr2[0] = NDArrayIndex.point(i3);
            iNDArrayIndexArr2[1] = NDArrayIndex.interval(i3 + 1, iNDArray.columns());
            iNDArray.put(iNDArrayIndexArr2, (Number) 0);
        }
    }

    @Override // org.nd4j.linalg.api.blas.impl.BaseLapack
    public void dpotrf(byte b, int i, INDArray iNDArray, INDArray iNDArray2) {
        int LAPACKE_dpotrf = openblas.LAPACKE_dpotrf(getColumnOrder(iNDArray), b, i, (DoublePointer) iNDArray.data().addressPointer(), getLda(iNDArray));
        if (LAPACKE_dpotrf != 0) {
            throw new BlasException("Failed to execute dpotrf", LAPACKE_dpotrf);
        }
        if (b == 85) {
            INDArrayIndex[] iNDArrayIndexArr = new INDArrayIndex[2];
            for (int i2 = 1; i2 < Math.min(iNDArray.rows(), iNDArray.columns()); i2++) {
                iNDArrayIndexArr[0] = NDArrayIndex.point(i2);
                iNDArrayIndexArr[1] = NDArrayIndex.interval(0, i2);
                iNDArray.put(iNDArrayIndexArr, (Number) 0);
            }
            return;
        }
        INDArrayIndex[] iNDArrayIndexArr2 = new INDArrayIndex[2];
        for (int i3 = 0; i3 < Math.min(iNDArray.rows(), iNDArray.columns() - 1); i3++) {
            iNDArrayIndexArr2[0] = NDArrayIndex.point(i3);
            iNDArrayIndexArr2[1] = NDArrayIndex.interval(i3 + 1, iNDArray.columns());
            iNDArray.put(iNDArrayIndexArr2, (Number) 0);
        }
    }

    @Override // org.nd4j.linalg.api.blas.impl.BaseLapack
    public void sgesvd(byte b, byte b2, int i, int i2, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, INDArray iNDArray5) {
        DataType dataType = DataType.FLOAT;
        long[] jArr = new long[1];
        jArr[0] = i < i2 ? i : i2;
        int LAPACKE_sgesvd = openblas.LAPACKE_sgesvd(getColumnOrder(iNDArray), b, b2, i, i2, (FloatPointer) iNDArray.data().addressPointer(), getLda(iNDArray), (FloatPointer) iNDArray2.data().addressPointer(), iNDArray3 == null ? null : (FloatPointer) iNDArray3.data().addressPointer(), iNDArray3 == null ? 1 : getLda(iNDArray3), iNDArray4 == null ? null : (FloatPointer) iNDArray4.data().addressPointer(), iNDArray4 == null ? 1 : getLda(iNDArray4), (FloatPointer) Nd4j.create(dataType, jArr).data().addressPointer());
        if (LAPACKE_sgesvd != 0) {
            throw new BlasException("Failed to execute sgesvd", LAPACKE_sgesvd);
        }
    }

    @Override // org.nd4j.linalg.api.blas.impl.BaseLapack
    public void dgesvd(byte b, byte b2, int i, int i2, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, INDArray iNDArray5) {
        DataType dataType = DataType.DOUBLE;
        long[] jArr = new long[1];
        jArr[0] = i < i2 ? i : i2;
        int LAPACKE_dgesvd = openblas.LAPACKE_dgesvd(getColumnOrder(iNDArray), b, b2, i, i2, (DoublePointer) iNDArray.data().addressPointer(), getLda(iNDArray), (DoublePointer) iNDArray2.data().addressPointer(), iNDArray3 == null ? null : (DoublePointer) iNDArray3.data().addressPointer(), iNDArray3 == null ? 1 : getLda(iNDArray3), iNDArray4 == null ? null : (DoublePointer) iNDArray4.data().addressPointer(), iNDArray4 == null ? 1 : getLda(iNDArray4), (DoublePointer) Nd4j.create(dataType, jArr).data().addressPointer());
        if (LAPACKE_dgesvd != 0) {
            throw new BlasException("Failed to execute dgesvd", LAPACKE_dgesvd);
        }
    }

    @Override // org.nd4j.linalg.api.blas.impl.BaseLapack
    public int ssyev(char c, char c2, int i, INDArray iNDArray, INDArray iNDArray2) {
        FloatPointer floatPointer = new FloatPointer(1L);
        int LAPACKE_ssyev_work = openblas.LAPACKE_ssyev_work(getColumnOrder(iNDArray), (byte) c, (byte) c2, i, (FloatPointer) iNDArray.data().addressPointer(), getLda(iNDArray), (FloatPointer) iNDArray2.data().addressPointer(), floatPointer, -1);
        if (LAPACKE_ssyev_work == 0) {
            int i2 = (int) floatPointer.get();
            INDArray createArrayFromShapeBuffer = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createFloat(i2), Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{i2}, iNDArray.dataType()).getFirst());
            LAPACKE_ssyev_work = openblas.LAPACKE_ssyev(getColumnOrder(iNDArray), (byte) c, (byte) c2, i, (FloatPointer) iNDArray.data().addressPointer(), getLda(iNDArray), (FloatPointer) createArrayFromShapeBuffer.data().addressPointer());
            if (LAPACKE_ssyev_work == 0) {
                iNDArray2.assign(createArrayFromShapeBuffer.get(NDArrayIndex.interval(0, i)));
            }
        }
        return LAPACKE_ssyev_work;
    }

    @Override // org.nd4j.linalg.api.blas.impl.BaseLapack
    public int dsyev(char c, char c2, int i, INDArray iNDArray, INDArray iNDArray2) {
        DoublePointer doublePointer = new DoublePointer(1L);
        int LAPACKE_dsyev_work = openblas.LAPACKE_dsyev_work(getColumnOrder(iNDArray), (byte) c, (byte) c2, i, (DoublePointer) iNDArray.data().addressPointer(), getLda(iNDArray), (DoublePointer) iNDArray2.data().addressPointer(), doublePointer, -1);
        if (LAPACKE_dsyev_work == 0) {
            int i2 = (int) doublePointer.get();
            INDArray createArrayFromShapeBuffer = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createDouble(i2), Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{i2}, iNDArray.dataType()).getFirst());
            LAPACKE_dsyev_work = openblas.LAPACKE_dsyev(getColumnOrder(iNDArray), (byte) c, (byte) c2, i, (DoublePointer) iNDArray.data().addressPointer(), getLda(iNDArray), (DoublePointer) createArrayFromShapeBuffer.data().addressPointer());
            if (LAPACKE_dsyev_work == 0) {
                iNDArray2.assign(createArrayFromShapeBuffer.get(NDArrayIndex.interval(0, i)));
            }
        }
        return LAPACKE_dsyev_work;
    }

    @Override // org.nd4j.linalg.api.blas.Lapack
    public void getri(int i, INDArray iNDArray, int i2, int[] iArr, INDArray iNDArray2, int i3, int i4) {
    }
}
