package org.nd4j.remote.clients;

import com.mashape.unirest.http.HttpResponse;
import com.mashape.unirest.http.Unirest;
import com.mashape.unirest.http.exceptions.UnirestException;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import lombok.NonNull;
import org.json.JSONObject;
import org.nd4j.remote.clients.serde.BinaryDeserializer;
import org.nd4j.remote.clients.serde.BinarySerializer;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/remote/clients/JsonRemoteInference.class */
public class JsonRemoteInference<I, O> {
    private static final Logger log = LoggerFactory.getLogger(JsonRemoteInference.class);
    private String endpointAddress;
    private JsonSerializer<I> serializer;
    private JsonDeserializer<O> deserializer;
    private BinarySerializer<I> binarySerializer;
    private BinaryDeserializer<O> binaryDeserializer;
    private static final String APPLICATION_JSON = "application/json";
    private static final String APPLICATION_OCTET_STREAM = "application/octet-stream";

    /* loaded from: input_file:org/nd4j/remote/clients/JsonRemoteInference$InferenceFuture.class */
    private class InferenceFuture implements Future<O> {
        private Future<HttpResponse<String>> unirestFuture;

        private InferenceFuture(@NonNull Future<HttpResponse<String>> future) {
            if (future == null) {
                throw new NullPointerException("future is marked non-null but is null");
            }
            this.unirestFuture = future;
        }

        @Override // java.util.concurrent.Future
        public boolean cancel(boolean z) {
            return this.unirestFuture.cancel(z);
        }

        @Override // java.util.concurrent.Future
        public boolean isCancelled() {
            return this.unirestFuture.isCancelled();
        }

        @Override // java.util.concurrent.Future
        public boolean isDone() {
            return this.unirestFuture.isDone();
        }

        @Override // java.util.concurrent.Future
        public O get() throws InterruptedException, ExecutionException {
            try {
                return (O) JsonRemoteInference.this.processResponse(this.unirestFuture.get());
            } catch (IOException e) {
                throw new ExecutionException(e);
            }
        }

        @Override // java.util.concurrent.Future
        public O get(long j, TimeUnit timeUnit) throws InterruptedException, ExecutionException, TimeoutException {
            try {
                return (O) JsonRemoteInference.this.processResponse(this.unirestFuture.get(j, timeUnit));
            } catch (IOException e) {
                throw new ExecutionException(e);
            }
        }
    }

    /* loaded from: input_file:org/nd4j/remote/clients/JsonRemoteInference$JsonRemoteInferenceBuilder.class */
    public static class JsonRemoteInferenceBuilder<I, O> {
        private String endpointAddress;
        private JsonSerializer<I> inputSerializer;
        private JsonDeserializer<O> outputDeserializer;
        private BinarySerializer<I> inputBinarySerializer;
        private BinaryDeserializer<O> outputBinaryDeserializer;

        JsonRemoteInferenceBuilder() {
        }

        public JsonRemoteInferenceBuilder<I, O> endpointAddress(@NonNull String str) {
            if (str == null) {
                throw new NullPointerException("endpointAddress is marked non-null but is null");
            }
            this.endpointAddress = str;
            return this;
        }

        public JsonRemoteInferenceBuilder<I, O> inputSerializer(JsonSerializer<I> jsonSerializer) {
            this.inputSerializer = jsonSerializer;
            return this;
        }

        public JsonRemoteInferenceBuilder<I, O> outputDeserializer(JsonDeserializer<O> jsonDeserializer) {
            this.outputDeserializer = jsonDeserializer;
            return this;
        }

        public JsonRemoteInferenceBuilder<I, O> inputBinarySerializer(BinarySerializer<I> binarySerializer) {
            this.inputBinarySerializer = binarySerializer;
            return this;
        }

        public JsonRemoteInferenceBuilder<I, O> outputBinaryDeserializer(BinaryDeserializer<O> binaryDeserializer) {
            this.outputBinaryDeserializer = binaryDeserializer;
            return this;
        }

        public JsonRemoteInference<I, O> build() {
            return new JsonRemoteInference<>(this.endpointAddress, this.inputSerializer, this.outputDeserializer, this.inputBinarySerializer, this.outputBinaryDeserializer);
        }

        public String toString() {
            return "JsonRemoteInference.JsonRemoteInferenceBuilder(endpointAddress=" + this.endpointAddress + ", inputSerializer=" + this.inputSerializer + ", outputDeserializer=" + this.outputDeserializer + ", inputBinarySerializer=" + this.inputBinarySerializer + ", outputBinaryDeserializer=" + this.outputBinaryDeserializer + ")";
        }
    }

    public JsonRemoteInference(@NonNull String str, JsonSerializer<I> jsonSerializer, JsonDeserializer<O> jsonDeserializer, BinarySerializer<I> binarySerializer, BinaryDeserializer<O> binaryDeserializer) {
        if (str == null) {
            throw new NullPointerException("endpointAddress is marked non-null but is null");
        }
        this.endpointAddress = str;
        this.serializer = jsonSerializer;
        this.deserializer = jsonDeserializer;
        this.binarySerializer = binarySerializer;
        this.binaryDeserializer = binaryDeserializer;
        if ((this.serializer != null && this.binarySerializer != null) || (this.serializer == null && this.binarySerializer == null)) {
            throw new IllegalStateException("Binary and JSON serializers/deserializers are mutually exclusive and mandatory.");
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public O processResponse(HttpResponse<String> httpResponse) throws IOException {
        if (httpResponse.getStatus() != 200) {
            throw new IOException("Inference request returned bad error code: " + httpResponse.getStatus());
        }
        O deserialize = this.deserializer.deserialize((String) httpResponse.getBody());
        if (deserialize == null) {
            throw new IOException("Deserialization failed!");
        }
        return deserialize;
    }

    private O processResponseBinary(HttpResponse<InputStream> httpResponse) throws IOException {
        if (httpResponse.getStatus() != 200) {
            throw new IOException("Inference request returned bad error code: " + httpResponse.getStatus());
        }
        List list = (List) httpResponse.getHeaders().get("Content-Length");
        if (list == null || list.size() < 1) {
            throw new IOException("Content-Length is required for binary data");
        }
        byte[] bArr = new byte[Integer.parseInt((String) list.get(0))];
        ((InputStream) httpResponse.getBody()).read(bArr);
        O deserialize = this.binaryDeserializer.deserialize(bArr);
        if (deserialize == null) {
            throw new IOException("Deserialization failed!");
        }
        return deserialize;
    }

    public O predict(I i) throws IOException {
        try {
            return (this.binarySerializer == null || this.binaryDeserializer == null) ? (this.binarySerializer == null || this.binaryDeserializer != null) ? processResponse(Unirest.post(this.endpointAddress).header("Content-Type", APPLICATION_JSON).header("Accept", APPLICATION_JSON).body(new JSONObject(this.serializer.serialize(i))).asString()) : processResponse(Unirest.post(this.endpointAddress).header("Content-Type", APPLICATION_OCTET_STREAM).header("Accept", APPLICATION_OCTET_STREAM).body(this.binarySerializer.serialize(i)).asString()) : processResponseBinary(Unirest.post(this.endpointAddress).header("Content-Type", APPLICATION_OCTET_STREAM).header("Accept", APPLICATION_OCTET_STREAM).body(this.binarySerializer.serialize(i)).asBinary());
        } catch (UnirestException e) {
            throw new IOException((Throwable) e);
        }
    }

    public Future<O> predictAsync(I i) {
        return new InferenceFuture(this.binarySerializer != null ? Unirest.post(this.endpointAddress).header("Content-Type", APPLICATION_OCTET_STREAM).header("Accept", APPLICATION_OCTET_STREAM).body(this.binarySerializer.serialize(i)).asStringAsync() : Unirest.post(this.endpointAddress).header("Content-Type", APPLICATION_JSON).header("Accept", APPLICATION_JSON).body(new JSONObject(this.serializer.serialize(i))).asStringAsync());
    }

    public static <I, O> JsonRemoteInferenceBuilder<I, O> builder() {
        return new JsonRemoteInferenceBuilder<>();
    }
}
