package org.tribuo.interop.tensorflow;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.tensorflow.DataType;
import org.tensorflow.Graph;
import org.tensorflow.GraphOperationBuilder;
import org.tensorflow.Operation;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

/* loaded from: input_file:org/tribuo/interop/tensorflow/TensorflowUtil.class */
public class TensorflowUtil {
    private static final Logger logger = Logger.getLogger(TensorflowUtil.class.getName());
    public static final String VARIABLE_V2 = "VariableV2";
    public static final String ASSIGN_OP = "Assign";
    public static final String ASSIGN_PLACEHOLDER = "Assign_from_Placeholder";
    public static final String PLACEHOLDER = "Placeholder";
    public static final String DTYPE = "dtype";

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.tribuo.interop.tensorflow.TensorflowUtil$1, reason: invalid class name */
    /* loaded from: input_file:org/tribuo/interop/tensorflow/TensorflowUtil$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$tensorflow$DataType = new int[DataType.values().length];

        static {
            try {
                $SwitchMap$org$tensorflow$DataType[DataType.FLOAT.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$tensorflow$DataType[DataType.DOUBLE.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$tensorflow$DataType[DataType.INT32.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$tensorflow$DataType[DataType.UINT8.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$tensorflow$DataType[DataType.STRING.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$tensorflow$DataType[DataType.INT64.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$tensorflow$DataType[DataType.BOOL.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
        }
    }

    public static Object newBooleanArray(long[] jArr) {
        switch (jArr.length) {
            case 1:
                return new boolean[(int) jArr[0]];
            case 2:
                return new boolean[(int) jArr[0]][(int) jArr[1]];
            case 3:
                return new boolean[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]];
            case 4:
                return new boolean[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]];
            case 5:
                return new boolean[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]][(int) jArr[4]];
            case 6:
                return new boolean[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]][(int) jArr[4]][(int) jArr[5]];
            case 7:
                return new boolean[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]][(int) jArr[4]][(int) jArr[5]][(int) jArr[6]];
            case 8:
                return new boolean[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]][(int) jArr[4]][(int) jArr[5]][(int) jArr[6]][(int) jArr[7]];
            default:
                throw new IllegalArgumentException("Arrays with less than 1 and more than 8 dimensions are not supported.");
        }
    }

    public static Object newByteArray(long[] jArr) {
        switch (jArr.length) {
            case 1:
                return new byte[(int) jArr[0]];
            case 2:
                return new byte[(int) jArr[0]][(int) jArr[1]];
            case 3:
                return new byte[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]];
            case 4:
                return new byte[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]];
            case 5:
                return new byte[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]][(int) jArr[4]];
            case 6:
                return new byte[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]][(int) jArr[4]][(int) jArr[5]];
            case 7:
                return new byte[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]][(int) jArr[4]][(int) jArr[5]][(int) jArr[6]];
            case 8:
                return new byte[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]][(int) jArr[4]][(int) jArr[5]][(int) jArr[6]][(int) jArr[7]];
            default:
                throw new IllegalArgumentException("Arrays with less than 1 and more than 8 dimensions are not supported.");
        }
    }

    public static Object newIntArray(long[] jArr) {
        switch (jArr.length) {
            case 1:
                return new int[(int) jArr[0]];
            case 2:
                return new int[(int) jArr[0]][(int) jArr[1]];
            case 3:
                return new int[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]];
            case 4:
                return new int[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]];
            case 5:
                return new int[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]][(int) jArr[4]];
            case 6:
                return new int[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]][(int) jArr[4]][(int) jArr[5]];
            case 7:
                return new int[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]][(int) jArr[4]][(int) jArr[5]][(int) jArr[6]];
            case 8:
                return new int[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]][(int) jArr[4]][(int) jArr[5]][(int) jArr[6]][(int) jArr[7]];
            default:
                throw new IllegalArgumentException("Arrays with less than 1 and more than 8 dimensions are not supported.");
        }
    }

    public static Object newLongArray(long[] jArr) {
        switch (jArr.length) {
            case 1:
                return new long[(int) jArr[0]];
            case 2:
                return new long[(int) jArr[0]][(int) jArr[1]];
            case 3:
                return new long[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]];
            case 4:
                return new long[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]];
            case 5:
                return new long[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]][(int) jArr[4]];
            case 6:
                return new long[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]][(int) jArr[4]][(int) jArr[5]];
            case 7:
                return new long[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]][(int) jArr[4]][(int) jArr[5]][(int) jArr[6]];
            case 8:
                return new long[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]][(int) jArr[4]][(int) jArr[5]][(int) jArr[6]][(int) jArr[7]];
            default:
                throw new IllegalArgumentException("Arrays with less than 1 and more than 8 dimensions are not supported.");
        }
    }

    public static Object newFloatArray(long[] jArr) {
        switch (jArr.length) {
            case 1:
                return new float[(int) jArr[0]];
            case 2:
                return new float[(int) jArr[0]][(int) jArr[1]];
            case 3:
                return new float[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]];
            case 4:
                return new float[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]];
            case 5:
                return new float[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]][(int) jArr[4]];
            case 6:
                return new float[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]][(int) jArr[4]][(int) jArr[5]];
            case 7:
                return new float[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]][(int) jArr[4]][(int) jArr[5]][(int) jArr[6]];
            case 8:
                return new float[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]][(int) jArr[4]][(int) jArr[5]][(int) jArr[6]][(int) jArr[7]];
            default:
                throw new IllegalArgumentException("Arrays with less than 1 and more than 8 dimensions are not supported.");
        }
    }

    public static Object newDoubleArray(long[] jArr) {
        switch (jArr.length) {
            case 1:
                return new double[(int) jArr[0]];
            case 2:
                return new double[(int) jArr[0]][(int) jArr[1]];
            case 3:
                return new double[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]];
            case 4:
                return new double[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]];
            case 5:
                return new double[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]][(int) jArr[4]];
            case 6:
                return new double[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]][(int) jArr[4]][(int) jArr[5]];
            case 7:
                return new double[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]][(int) jArr[4]][(int) jArr[5]][(int) jArr[6]];
            case 8:
                return new double[(int) jArr[0]][(int) jArr[1]][(int) jArr[2]][(int) jArr[3]][(int) jArr[4]][(int) jArr[5]][(int) jArr[6]][(int) jArr[7]];
            default:
                throw new IllegalArgumentException("Arrays with less than 1 and more than 8 dimensions are not supported.");
        }
    }

    public static void closeTensorList(List<Tensor<?>> list) {
        Iterator<Tensor<?>> it = list.iterator();
        while (it.hasNext()) {
            it.next().close();
        }
    }

    public static Object convertTensorToArray(Tensor<?> tensor) {
        Object newBooleanArray;
        long[] shape = tensor.shape();
        switch (AnonymousClass1.$SwitchMap$org$tensorflow$DataType[tensor.dataType().ordinal()]) {
            case 1:
                newBooleanArray = newFloatArray(shape);
                break;
            case 2:
                newBooleanArray = newDoubleArray(shape);
                break;
            case 3:
                newBooleanArray = newIntArray(shape);
                break;
            case 4:
            case 5:
                newBooleanArray = newByteArray(shape);
                break;
            case 6:
                newBooleanArray = newLongArray(shape);
                break;
            case 7:
                newBooleanArray = newBooleanArray(shape);
                break;
            default:
                throw new IllegalArgumentException("Tribuo can't serialise Tensors with type " + tensor.dataType());
        }
        tensor.copyTo(newBooleanArray);
        return newBooleanArray;
    }

    public static Object convertTensorToScalar(Tensor<?> tensor) {
        Object valueOf;
        switch (AnonymousClass1.$SwitchMap$org$tensorflow$DataType[tensor.dataType().ordinal()]) {
            case 1:
                valueOf = Float.valueOf(tensor.floatValue());
                break;
            case 2:
                valueOf = Double.valueOf(tensor.doubleValue());
                break;
            case 3:
                valueOf = Integer.valueOf(tensor.intValue());
                break;
            case 4:
                valueOf = Byte.valueOf((byte) (tensor.intValue() & 255));
                break;
            case 5:
                valueOf = tensor.bytesValue();
                break;
            case 6:
                valueOf = Long.valueOf(tensor.longValue());
                break;
            case 7:
                valueOf = Boolean.valueOf(tensor.booleanValue());
                break;
            default:
                throw new IllegalArgumentException("Tribuo can't serialise Tensors with type " + tensor.dataType());
        }
        return valueOf;
    }

    public static void annotateGraph(Graph graph, Session session) {
        ArrayList arrayList = new ArrayList();
        HashMap hashMap = new HashMap();
        Iterator operations = graph.operations();
        while (operations.hasNext()) {
            Operation operation = (Operation) operations.next();
            if (operation.type().equals(VARIABLE_V2)) {
                arrayList.add(operation.name());
                hashMap.put(operation.name(), operation);
            }
        }
        Session.Runner runner = session.runner();
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            runner.fetch((String) it.next());
        }
        List run = runner.run();
        if (run.size() != arrayList.size()) {
            closeTensorList(run);
            throw new IllegalStateException("Failed to annotate all requested variables. Requested " + arrayList.size() + ", found " + run.size());
        }
        for (int i = 0; i < run.size(); i++) {
            GraphOperationBuilder opBuilder = graph.opBuilder(PLACEHOLDER, generatePlaceholderName((String) arrayList.get(i)));
            opBuilder.setAttr(DTYPE, ((Tensor) run.get(i)).dataType());
            Operation build = opBuilder.build();
            GraphOperationBuilder opBuilder2 = graph.opBuilder(ASSIGN_OP, ((String) arrayList.get(i)) + "/" + ASSIGN_PLACEHOLDER);
            opBuilder2.addInput(((Operation) hashMap.get(arrayList.get(i))).output(0));
            opBuilder2.addInput(build.output(0));
            opBuilder2.build();
        }
        closeTensorList(run);
    }

    public static String generatePlaceholderName(String str) {
        return str + "-" + PLACEHOLDER;
    }

    public static Map<String, Object> serialise(Graph graph, Session session) {
        ArrayList arrayList = new ArrayList();
        Iterator operations = graph.operations();
        while (operations.hasNext()) {
            Operation operation = (Operation) operations.next();
            if (operation.type().equals(VARIABLE_V2)) {
                arrayList.add(operation.name());
            }
        }
        Session.Runner runner = session.runner();
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            runner.fetch((String) it.next());
        }
        List run = runner.run();
        if (run.size() != arrayList.size()) {
            closeTensorList(run);
            throw new IllegalStateException("Failed to serialise all requested variables. Requested " + arrayList.size() + ", found " + run.size());
        }
        HashMap hashMap = new HashMap();
        for (int i = 0; i < arrayList.size(); i++) {
            String str = (String) arrayList.get(i);
            Tensor tensor = (Tensor) run.get(i);
            hashMap.put(str, tensor.numDimensions() == 0 ? convertTensorToScalar(tensor) : convertTensorToArray(tensor));
        }
        closeTensorList(run);
        return hashMap;
    }

    public static void deserialise(Session session, Map<String, Object> map) {
        Session.Runner runner = session.runner();
        ArrayList arrayList = new ArrayList();
        for (Map.Entry<String, Object> entry : map.entrySet()) {
            logger.log(Level.FINEST, "Loading " + entry.getKey() + " of type " + entry.getValue().getClass().getName());
            Tensor create = Tensor.create(entry.getValue());
            runner.feed(generatePlaceholderName(entry.getKey()), create);
            runner.addTarget(entry.getKey() + "/" + ASSIGN_PLACEHOLDER);
            arrayList.add(create);
        }
        runner.run();
        closeTensorList(arrayList);
    }
}
