package org.nd4j.linalg.api.ops.impl.accum;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import onnx.OnnxProto3;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.util.ArrayUtil;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/accum/Mmul.class */
public class Mmul extends DynamicCustomOp {
    protected MMulTranspose mt;

    public Mmul(SameDiff sameDiff, SDVariable sDVariable, SDVariable sDVariable2, MMulTranspose mMulTranspose) {
        super((String) null, sameDiff, new SDVariable[]{sDVariable, sDVariable2});
        this.mt = mMulTranspose;
        addIArgument(ArrayUtil.fromBoolean(mMulTranspose.isTransposeA()), ArrayUtil.fromBoolean(mMulTranspose.isTransposeB()), ArrayUtil.fromBoolean(mMulTranspose.isTransposeResult()));
    }

    public Mmul(SameDiff sameDiff, SDVariable sDVariable, SDVariable sDVariable2) {
        this(sameDiff, sDVariable, sDVariable2, MMulTranspose.allFalse());
    }

    public Mmul(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, MMulTranspose mMulTranspose) {
        super((String) null, new INDArray[]{iNDArray, iNDArray2}, iNDArray3 == null ? null : new INDArray[]{iNDArray3});
        if (mMulTranspose != null) {
            this.mt = mMulTranspose;
            addIArgument(ArrayUtil.fromBoolean(mMulTranspose.isTransposeA()), ArrayUtil.fromBoolean(mMulTranspose.isTransposeB()), ArrayUtil.fromBoolean(mMulTranspose.isTransposeResult()));
        }
    }

    public Mmul() {
    }

    public long[] transposeShapeArray(long[] jArr) {
        if (jArr.length == 2) {
            return ArrayUtil.reverseCopy(jArr);
        }
        if (jArr.length == 3) {
            return new long[]{jArr[0], jArr[2], jArr[1]};
        }
        throw new IllegalArgumentException("Matrix input has to be of length 2 or 3, got: " + jArr.length);
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public List<long[]> calculateOutputShape() {
        if (this.mt == null) {
            this.mt = MMulTranspose.allFalse();
        }
        long[] shape = larg().getShape();
        long[] shape2 = rarg().getShape();
        if (Shape.isPlaceholderShape(shape) || Shape.isPlaceholderShape(shape2)) {
            return Collections.emptyList();
        }
        long[] matrixMultiplyShape = Shape.getMatrixMultiplyShape(this.mt.isTransposeA() ? transposeShapeArray(shape) : shape, this.mt.isTransposeB() ? transposeShapeArray(shape2) : shape2);
        if (this.mt.isTransposeResult()) {
            matrixMultiplyShape = transposeShapeArray(matrixMultiplyShape);
        }
        for (int i = 0; i < matrixMultiplyShape.length; i++) {
            if (matrixMultiplyShape[i] < 1) {
                throw new ND4JIllegalStateException("Invalid shape computed at index " + i + ": shape " + Arrays.toString(matrixMultiplyShape));
            }
        }
        return Collections.singletonList(matrixMultiplyShape);
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String onnxName() {
        return "MatMul";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String tensorflowName() {
        return "MatMul";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String opName() {
        return "mmul";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff sameDiff, Map<String, AttrValue> map, GraphDef graphDef) {
        super.initFromTensorFlow(nodeDef, sameDiff, map, graphDef);
        this.mt = MMulTranspose.builder().transposeA(map.get("transpose_a").getB()).transposeB(map.get("transpose_b").getB()).build();
        for (SDVariable sDVariable : args()) {
            if (this.sameDiff.isPlaceHolder(sDVariable.getVarName()) || sDVariable.getShape() == null) {
                this.sameDiff.addPropertyToResolve(this, sDVariable.getVarName());
            }
        }
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromOnnx(OnnxProto3.NodeProto nodeProto, SameDiff sameDiff, Map<String, OnnxProto3.AttributeProto> map, OnnxProto3.GraphProto graphProto) {
        this.mt = MMulTranspose.builder().transposeA(!map.containsKey("transA") ? false : map.get("transA").getI() > 0).transposeB(!map.containsKey("transB") ? false : map.get("transB").getI() > 0).build();
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        ArrayList arrayList = new ArrayList();
        SDVariable sDVariable = list.get(0);
        SDVariable mmul = this.sameDiff.mmul(sDVariable, rarg(), MMulTranspose.builder().transposeA(this.mt.isTransposeResult()).transposeB(!this.mt.isTransposeB()).transposeResult(this.mt.isTransposeA()).build());
        SDVariable mmul2 = this.sameDiff.mmul(larg(), sDVariable, MMulTranspose.builder().transposeA(!this.mt.isTransposeA()).transposeB(this.mt.isTransposeResult()).transposeResult(this.mt.isTransposeB()).build());
        arrayList.add(mmul);
        arrayList.add(mmul2);
        return arrayList;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        PropertyMapping build = PropertyMapping.builder().onnxAttrName("transA").tfAttrName("transpose_a").propertyNames(new String[]{"transposeA"}).build();
        PropertyMapping build2 = PropertyMapping.builder().onnxAttrName("transB").tfAttrName("transpose_b").propertyNames(new String[]{"transposeB"}).build();
        hashMap2.put("transposeA", build);
        hashMap2.put("transposeB", build2);
        hashMap.put(tensorflowName(), hashMap2);
        hashMap.put(onnxName(), hashMap2);
        return hashMap;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof Mmul)) {
            return false;
        }
        Mmul mmul = (Mmul) obj;
        if (!mmul.canEqual(this)) {
            return false;
        }
        MMulTranspose mMulTranspose = this.mt;
        MMulTranspose mMulTranspose2 = mmul.mt;
        return mMulTranspose == null ? mMulTranspose2 == null : mMulTranspose.equals(mMulTranspose2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof Mmul;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public int hashCode() {
        MMulTranspose mMulTranspose = this.mt;
        return (1 * 59) + (mMulTranspose == null ? 43 : mMulTranspose.hashCode());
    }
}
