package org.nd4j.imports.graphmapper.onnx;

import com.github.os72.protobuf351.Message;
import com.google.common.primitives.Floats;
import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs;
import java.io.BufferedInputStream;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import onnx.OnnxProto3;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.graphmapper.BaseGraphMapper;
import org.nd4j.imports.graphmapper.ImportState;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/nd4j/imports/graphmapper/onnx/OnnxGraphMapper.class */
public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, OnnxProto3.NodeProto, OnnxProto3.AttributeProto, OnnxProto3.TypeProto.Tensor> {
    private static OnnxGraphMapper INSTANCE = new OnnxGraphMapper();

    public static OnnxGraphMapper getInstance() {
        return INSTANCE;
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public void dumpBinaryProtoAsText(InputStream inputStream, File file) {
        try {
            OnnxProto3.ModelProto parseFrom = OnnxProto3.ModelProto.parseFrom(inputStream);
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(file, true));
            Iterator<OnnxProto3.NodeProto> it = parseFrom.getGraph().getNodeList().iterator();
            while (it.hasNext()) {
                bufferedWriter.write(it.next().toString() + "\n");
            }
            bufferedWriter.flush();
            bufferedWriter.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void initFunctionFromProperties(String str, DifferentialFunction differentialFunction, Map<String, OnnxProto3.AttributeProto> map, OnnxProto3.NodeProto nodeProto, OnnxProto3.GraphProto graphProto) {
        Map<String, PropertyMapping> map2 = differentialFunction.mappingsForFunction().get(str);
        Map<String, Field> fieldsForFunction = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(differentialFunction);
        Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction = differentialFunction.attributeAdaptersForFunction();
        for (Map.Entry<String, PropertyMapping> entry : map2.entrySet()) {
            String tfAttrName = entry.getValue().getTfAttrName();
            Field field = fieldsForFunction.get(entry.getKey());
            AttributeAdapter attributeAdapter = null;
            if (tfAttrName != null && field != null) {
                if (attributeAdaptersForFunction != null && !attributeAdaptersForFunction.isEmpty()) {
                    attributeAdapter = attributeAdaptersForFunction.get(differentialFunction.tensorflowName()).get(entry.getKey());
                }
                if (map.containsKey(tfAttrName)) {
                    OnnxProto3.AttributeProto attributeProto = map.get(tfAttrName);
                    switch (attributeProto.getType()) {
                        case STRING:
                            String stringUtf8 = attributeProto.getS().toStringUtf8();
                            if (attributeAdapter != null) {
                                attributeAdapter.mapAttributeFor(stringUtf8, field, differentialFunction);
                                break;
                            } else {
                                differentialFunction.setValueFor(field, stringUtf8);
                                break;
                            }
                        case INT:
                            int i = (int) attributeProto.getI();
                            if (attributeAdapter != null) {
                                attributeAdapter.mapAttributeFor(Integer.valueOf(i), field, differentialFunction);
                                break;
                            } else {
                                differentialFunction.setValueFor(field, Integer.valueOf(i));
                                break;
                            }
                        case INTS:
                            List<Long> intsList = attributeProto.getIntsList();
                            if (intsList.isEmpty()) {
                                break;
                            } else {
                                int[] array = Ints.toArray(intsList);
                                if (attributeAdapter != null) {
                                    attributeAdapter.mapAttributeFor(array, field, differentialFunction);
                                    break;
                                } else {
                                    differentialFunction.setValueFor(field, array);
                                    break;
                                }
                            }
                        case FLOATS:
                            List<Float> floatsList = attributeProto.getFloatsList();
                            if (floatsList.isEmpty()) {
                                break;
                            } else {
                                float[] array2 = Floats.toArray(floatsList);
                                if (attributeAdapter != null) {
                                    attributeAdapter.mapAttributeFor(array2, field, differentialFunction);
                                    break;
                                } else {
                                    differentialFunction.setValueFor(field, array2);
                                    break;
                                }
                            }
                        case TENSOR:
                            INDArray mapTensorProto = mapTensorProto(attributeProto.getT());
                            if (attributeAdapter != null) {
                                attributeAdapter.mapAttributeFor(mapTensorProto, field, differentialFunction);
                                break;
                            } else {
                                differentialFunction.setValueFor(field, mapTensorProto);
                                break;
                            }
                    }
                }
            }
        }
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public boolean isOpIgnoreException(OnnxProto3.NodeProto nodeProto) {
        return false;
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public String getTargetMappingForOp(DifferentialFunction differentialFunction, OnnxProto3.NodeProto nodeProto) {
        return differentialFunction.opName();
    }

    public void mapProperty(String str, DifferentialFunction differentialFunction, OnnxProto3.NodeProto nodeProto, OnnxProto3.GraphProto graphProto, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> map) {
        PropertyMapping propertyMapping = map.get(str).get(getTargetMappingForOp(differentialFunction, nodeProto));
        Map<String, Field> fieldsForFunction = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(differentialFunction);
        differentialFunction.propertiesForFunction();
        if (propertyMapping.getTfAttrName() != null) {
            OnnxProto3.AttributeProto attributeProto = getAttrMap(nodeProto).get(propertyMapping.getOnnxAttrName());
            OnnxProto3.AttributeProto.AttributeType type = attributeProto.getType();
            Field field = fieldsForFunction.get(propertyMapping.getPropertyNames()[0]);
            Object obj = null;
            switch (type) {
                case STRING:
                    obj = Float.valueOf(attributeProto.getF());
                    break;
                case INT:
                    obj = Long.valueOf(attributeProto.getI());
                    break;
                case FLOAT:
                    obj = Float.valueOf(attributeProto.getF());
                    break;
            }
            try {
                field.set(obj, differentialFunction);
                return;
            } catch (IllegalAccessException e) {
                e.printStackTrace();
                return;
            }
        }
        int intValue = propertyMapping.getTfInputPosition().intValue();
        if (intValue < 0) {
            intValue += nodeProto.getInputCount();
        }
        String input = nodeProto.getInput(intValue);
        getInstance().getNodeWithNameFromGraph(graphProto, input);
        INDArray arrForVarName = sameDiff.getArrForVarName(input);
        Field field2 = fieldsForFunction.get(propertyMapping.getPropertyNames()[0]);
        Class<?> type2 = field2.getType();
        if (type2.equals(int[].class)) {
            try {
                field2.set(arrForVarName.data().asInt(), differentialFunction);
                return;
            } catch (IllegalAccessException e2) {
                e2.printStackTrace();
                return;
            }
        }
        if (type2.equals(Integer.TYPE) || type2.equals(Long.TYPE) || type2.equals(Long.class) || type2.equals(Integer.class)) {
            try {
                field2.set(Integer.valueOf(arrForVarName.getInt(0)), differentialFunction);
                return;
            } catch (IllegalAccessException e3) {
                e3.printStackTrace();
                return;
            }
        }
        if (type2.equals(Float.TYPE) || type2.equals(Double.TYPE) || type2.equals(Float.class) || type2.equals(Double.class)) {
            try {
                field2.set(Double.valueOf(arrForVarName.getDouble(0L)), differentialFunction);
            } catch (IllegalAccessException e4) {
                e4.printStackTrace();
            }
        }
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public OnnxProto3.NodeProto getNodeWithNameFromGraph(OnnxProto3.GraphProto graphProto, String str) {
        for (int i = 0; i < graphProto.getNodeCount(); i++) {
            OnnxProto3.NodeProto node = graphProto.getNode(i);
            if (node.getName().equals(str)) {
                return node;
            }
        }
        return null;
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public boolean isPlaceHolderNode(OnnxProto3.TypeProto.Tensor tensor) {
        return false;
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public void dumpBinaryProtoAsText(File file, File file2) {
        try {
            OnnxProto3.ModelProto parseFrom = OnnxProto3.ModelProto.parseFrom(new BufferedInputStream(new FileInputStream(file)));
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(file2, true));
            Iterator<OnnxProto3.NodeProto> it = parseFrom.getGraph().getNodeList().iterator();
            while (it.hasNext()) {
                bufferedWriter.write(it.next().toString());
            }
            bufferedWriter.flush();
            bufferedWriter.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public DifferentialFunction getMappedOp(String str) {
        return DifferentialFunctionClassHolder.getInstance().getOpWithOnnxName(str);
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public Map<String, OnnxProto3.TypeProto.Tensor> variablesForGraph(OnnxProto3.GraphProto graphProto) {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < graphProto.getInputCount(); i++) {
            hashMap.put(graphProto.getInput(i).getName(), graphProto.getInput(i).getType().getTensorType());
        }
        for (int i2 = 0; i2 < graphProto.getOutputCount(); i2++) {
            hashMap.put(graphProto.getOutput(i2).getName(), graphProto.getOutput(i2).getType().getTensorType());
        }
        for (int i3 = 0; i3 < graphProto.getNodeCount(); i3++) {
            OnnxProto3.NodeProto node = graphProto.getNode(i3);
            String valueOf = node.getName().isEmpty() ? String.valueOf(i3) : node.getName();
            if (!hashMap.containsKey(valueOf)) {
                addDummyTensor(valueOf, hashMap);
            }
            for (int i4 = 0; i4 < node.getInputCount(); i4++) {
                if (!hashMap.containsKey(node.getInput(i4))) {
                    addDummyTensor(node.getInput(i4), hashMap);
                }
            }
            for (int i5 = 0; i5 < node.getOutputCount(); i5++) {
                if (!hashMap.containsKey(node.getOutput(i5))) {
                    addDummyTensor(node.getOutput(i5), hashMap);
                }
            }
        }
        return hashMap;
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public String translateToSameDiffName(String str, OnnxProto3.NodeProto nodeProto) {
        return null;
    }

    protected void addDummyTensor(String str, Map<String, OnnxProto3.TypeProto.Tensor> map) {
        OnnxProto3.TensorShapeProto.Dimension build = OnnxProto3.TensorShapeProto.Dimension.newBuilder().setDimValue(-1L).build();
        map.put(str, OnnxProto3.TypeProto.Tensor.newBuilder().setShape(OnnxProto3.TensorShapeProto.newBuilder().addDim(build).addDim(build).build()).build());
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public Message.Builder getNewGraphBuilder() {
        return OnnxProto3.GraphProto.newBuilder();
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public OnnxProto3.GraphProto parseGraphFrom(byte[] bArr) throws IOException {
        return OnnxProto3.ModelProto.parseFrom(bArr).getGraph();
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public OnnxProto3.GraphProto parseGraphFrom(InputStream inputStream) throws IOException {
        return OnnxProto3.ModelProto.parseFrom(inputStream).getGraph();
    }

    public void mapNodeType(OnnxProto3.NodeProto nodeProto, ImportState<OnnxProto3.GraphProto, OnnxProto3.TypeProto.Tensor> importState) {
        DifferentialFunction opWithOnnxName = DifferentialFunctionClassHolder.getInstance().getOpWithOnnxName(nodeProto.getOpType());
        if (opWithOnnxName == null) {
            throw new NoOpNameFoundException("No op name found " + nodeProto.getOpType());
        }
        SameDiff sameDiff = importState.getSameDiff();
        String name = !nodeProto.getName().isEmpty() ? nodeProto.getName() : String.valueOf(importState.getGraph().getNodeList().indexOf(nodeProto));
        try {
            DifferentialFunction differentialFunction = (DifferentialFunction) opWithOnnxName.getClass().newInstance();
            SDVariable[] sDVariableArr = new SDVariable[nodeProto.getInputCount()];
            differentialFunction.setSameDiff(importState.getSameDiff());
            differentialFunction.initFromOnnx(nodeProto, sameDiff, getAttrMap(nodeProto), importState.getGraph());
            importState.getSameDiff().putFunctionForId(differentialFunction.getOwnName(), differentialFunction);
            sameDiff.setBaseNameForFunctionInstanceId(nodeProto.getName(), differentialFunction);
            sameDiff.addVarNameForImport(nodeProto.getName());
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public DataBuffer.Type dataTypeForTensor(OnnxProto3.TypeProto.Tensor tensor) {
        return nd4jTypeFromOnnxType(tensor.getElemType());
    }

    public DataBuffer.Type nd4jTypeFromOnnxType(OnnxProto3.TensorProto.DataType dataType) {
        switch (dataType) {
            case DOUBLE:
                return DataBuffer.Type.DOUBLE;
            case FLOAT:
                return DataBuffer.Type.FLOAT;
            case FLOAT16:
                return DataBuffer.Type.HALF;
            case INT32:
            case INT64:
                return DataBuffer.Type.INT;
            default:
                return DataBuffer.Type.UNKNOWN;
        }
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public String getAttrValueFromNode(OnnxProto3.NodeProto nodeProto, String str) {
        for (OnnxProto3.AttributeProto attributeProto : nodeProto.getAttributeList()) {
            if (attributeProto.getName().equals(str)) {
                return attributeProto.getS().toString();
            }
        }
        throw new ND4JIllegalStateException("No key found for " + str);
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public long[] getShapeFromAttribute(OnnxProto3.AttributeProto attributeProto) {
        return Longs.toArray(attributeProto.getT().getDimsList());
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public boolean isPlaceHolder(OnnxProto3.TypeProto.Tensor tensor) {
        return false;
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public INDArray getNDArrayFromTensor(String str, OnnxProto3.TypeProto.Tensor tensor, OnnxProto3.GraphProto graphProto) {
        DataBuffer.Type dataTypeForTensor = dataTypeForTensor(tensor);
        if (!tensor.isInitialized()) {
            throw new ND4JIllegalStateException("Unable to retrieve ndarray. Tensor was not initialized");
        }
        OnnxProto3.TensorProto tensorProto = null;
        int i = 0;
        while (true) {
            if (i >= graphProto.getInitializerCount()) {
                break;
            }
            OnnxProto3.TensorProto initializer = graphProto.getInitializer(i);
            if (initializer.getName().equals(str)) {
                tensorProto = initializer;
                break;
            }
            i++;
        }
        if (tensorProto == null) {
            return null;
        }
        ByteBuffer order = tensorProto.getRawData().asReadOnlyByteBuffer().order(ByteOrder.nativeOrder());
        ByteBuffer order2 = ByteBuffer.allocateDirect(order.capacity()).order(ByteOrder.nativeOrder());
        order2.put(order);
        order2.rewind();
        long[] shapeFromTensor = getShapeFromTensor(tensor);
        return Nd4j.create(Nd4j.createBuffer(order2, dataTypeForTensor, ArrayUtil.prod(shapeFromTensor))).reshape(shapeFromTensor);
    }

    public INDArray mapTensorProto(OnnxProto3.TensorProto tensorProto) {
        if (tensorProto == null) {
            return null;
        }
        DataBuffer.Type nd4jTypeFromOnnxType = nd4jTypeFromOnnxType(tensorProto.getDataType());
        ByteBuffer order = tensorProto.getRawData().asReadOnlyByteBuffer().order(ByteOrder.nativeOrder());
        ByteBuffer order2 = ByteBuffer.allocateDirect(order.capacity()).order(ByteOrder.nativeOrder());
        order2.put(order);
        order2.rewind();
        long[] shapeFromTensor = getShapeFromTensor(tensorProto);
        return Nd4j.create(Nd4j.createBuffer(order2, nd4jTypeFromOnnxType, ArrayUtil.prod(shapeFromTensor))).reshape(shapeFromTensor);
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public long[] getShapeFromTensor(OnnxProto3.TypeProto.Tensor tensor) {
        long[] jArr = new long[Math.max(2, tensor.getShape().getDimCount())];
        if (tensor.getShape().getDimCount() >= 2) {
            for (int i = 0; i < jArr.length; i++) {
                jArr[i] = (int) tensor.getShape().getDim(i).getDimValue();
            }
        } else {
            jArr[0] = 1;
            for (int i2 = 1; i2 < jArr.length; i2++) {
                jArr[i2] = (int) tensor.getShape().getDim(i2 - 1).getDimValue();
            }
        }
        return jArr;
    }

    public long[] getShapeFromTensor(OnnxProto3.TensorProto tensorProto) {
        long[] jArr = new long[Math.max(2, tensorProto.getDimsCount())];
        if (tensorProto.getDimsCount() >= 2) {
            for (int i = 0; i < jArr.length; i++) {
                jArr[i] = (int) tensorProto.getDims(i);
            }
        } else {
            jArr[0] = 1;
            for (int i2 = 1; i2 < jArr.length; i2++) {
                jArr[i2] = (int) tensorProto.getDims(i2 - 1);
            }
        }
        return jArr;
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public Set<String> opsToIgnore() {
        return Collections.emptySet();
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public String getInputFromNode(OnnxProto3.NodeProto nodeProto, int i) {
        return nodeProto.getInput(i);
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public int numInputsFor(OnnxProto3.NodeProto nodeProto) {
        return nodeProto.getInputCount();
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public long[] getShapeFromAttr(OnnxProto3.AttributeProto attributeProto) {
        return Longs.toArray(attributeProto.getT().getDimsList());
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public Map<String, OnnxProto3.AttributeProto> getAttrMap(OnnxProto3.NodeProto nodeProto) {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < nodeProto.getAttributeCount(); i++) {
            OnnxProto3.AttributeProto attribute = nodeProto.getAttribute(i);
            hashMap.put(attribute.getName(), attribute);
        }
        return hashMap;
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public String getName(OnnxProto3.NodeProto nodeProto) {
        return nodeProto.getName();
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public boolean alreadySeen(OnnxProto3.NodeProto nodeProto) {
        return false;
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public boolean isVariableNode(OnnxProto3.NodeProto nodeProto) {
        return nodeProto.getOpType().contains("Var");
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public boolean shouldSkip(OnnxProto3.NodeProto nodeProto) {
        return false;
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public boolean hasShape(OnnxProto3.NodeProto nodeProto) {
        return false;
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public long[] getShape(OnnxProto3.NodeProto nodeProto) {
        return null;
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public INDArray getArrayFrom(OnnxProto3.NodeProto nodeProto, OnnxProto3.GraphProto graphProto) {
        return null;
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public String getOpType(OnnxProto3.NodeProto nodeProto) {
        return nodeProto.getOpType();
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public List<OnnxProto3.NodeProto> getNodeList(OnnxProto3.GraphProto graphProto) {
        return graphProto.getNodeList();
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public /* bridge */ /* synthetic */ void mapNodeType(Object obj, ImportState importState) {
        mapNodeType((OnnxProto3.NodeProto) obj, (ImportState<OnnxProto3.GraphProto, OnnxProto3.TypeProto.Tensor>) importState);
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public /* bridge */ /* synthetic */ void mapProperty(String str, DifferentialFunction differentialFunction, Object obj, Object obj2, SameDiff sameDiff, Map map) {
        mapProperty(str, differentialFunction, (OnnxProto3.NodeProto) obj, (OnnxProto3.GraphProto) obj2, sameDiff, (Map<String, Map<String, PropertyMapping>>) map);
    }
}
