package org.tensorflow;

import java.util.concurrent.atomic.AtomicReferenceArray;
import org.tensorflow.EagerSession;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:org/tensorflow/EagerOperation.class */
public class EagerOperation extends AbstractOperation {
    private final EagerSession session;
    private final NativeReference nativeRef;
    private final String type;
    private final String name;
    private final AtomicReferenceArray<Tensor<?>> outputTensors;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/tensorflow/EagerOperation$NativeReference.class */
    public static class NativeReference extends EagerSession.NativeReference {
        private long opHandle;
        private final long[] outputHandles;

        NativeReference(EagerSession eagerSession, EagerOperation eagerOperation, long j, long[] jArr) {
            super(eagerSession, eagerOperation);
            this.opHandle = j;
            this.outputHandles = jArr;
        }

        @Override // org.tensorflow.EagerSession.NativeReference
        void delete() {
            if (this.opHandle != 0) {
                for (int i = 0; i < this.outputHandles.length; i++) {
                    if (this.outputHandles[i] != 0) {
                        EagerOperation.deleteTensorHandle(this.outputHandles[i]);
                        this.outputHandles[i] = 0;
                    }
                }
                EagerOperation.delete(this.opHandle);
                this.opHandle = 0L;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public EagerOperation(EagerSession eagerSession, long j, long[] jArr, String str, String str2) {
        this.session = eagerSession;
        this.type = str;
        this.name = str2;
        this.nativeRef = new NativeReference(eagerSession, this, j, jArr);
        this.outputTensors = new AtomicReferenceArray<>(jArr.length);
    }

    @Override // org.tensorflow.Operation
    public String name() {
        return this.name;
    }

    @Override // org.tensorflow.Operation
    public String type() {
        return this.type;
    }

    @Override // org.tensorflow.Operation
    public int numOutputs() {
        return this.nativeRef.outputHandles.length;
    }

    @Override // org.tensorflow.Operation
    public int outputListLength(String str) {
        return outputListLength(this.nativeRef.opHandle, str);
    }

    @Override // org.tensorflow.Operation
    public int inputListLength(String str) {
        return inputListLength(this.nativeRef.opHandle, str);
    }

    @Override // org.tensorflow.AbstractOperation
    public long getUnsafeNativeHandle(int i) {
        return this.nativeRef.outputHandles[i];
    }

    @Override // org.tensorflow.AbstractOperation
    public long[] shape(int i) {
        Tensor<?> tensor = this.outputTensors.get(i);
        if (tensor != null) {
            return tensor.shape();
        }
        long unsafeNativeHandle = getUnsafeNativeHandle(i);
        long[] jArr = new long[numDims(unsafeNativeHandle)];
        for (int i2 = 0; i2 < jArr.length; i2++) {
            jArr[i2] = dim(unsafeNativeHandle, i2);
        }
        return jArr;
    }

    @Override // org.tensorflow.AbstractOperation
    public DataType dtype(int i) {
        Tensor<?> tensor = this.outputTensors.get(i);
        return tensor != null ? tensor.dataType() : DataType.fromC(dataType(getUnsafeNativeHandle(i)));
    }

    @Override // org.tensorflow.AbstractOperation
    public Tensor<?> tensor(int i) {
        Tensor<?> tensor = this.outputTensors.get(i);
        if (tensor == null) {
            tensor = resolveTensor(i);
        }
        return tensor;
    }

    private Tensor<?> resolveTensor(int i) {
        Tensor<?> fromHandle = Tensor.fromHandle(resolveTensorHandle(getUnsafeNativeHandle(i)), this.session);
        if (!this.outputTensors.compareAndSet(i, null, fromHandle)) {
            fromHandle.close();
            fromHandle = this.outputTensors.get(i);
        }
        return fromHandle;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static native void delete(long j);

    /* JADX INFO: Access modifiers changed from: private */
    public static native void deleteTensorHandle(long j);

    private static native long resolveTensorHandle(long j);

    private static native int outputListLength(long j, String str);

    private static native int inputListLength(long j, String str);

    private static native int dataType(long j);

    private static native int numDims(long j);

    private static native long dim(long j, int i);
}
