/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.client;

import java.util.function.Function;
import lombok.Generated;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionResponse;
import org.opensearch.action.ActionType;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.node.NodeClient;
import org.opensearch.ml.client.MachineLearningClient;
import org.opensearch.ml.common.parameter.MLInput;
import org.opensearch.ml.common.parameter.MLModel;
import org.opensearch.ml.common.parameter.MLOutput;
import org.opensearch.ml.common.parameter.MLTask;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.ml.common.transport.model.MLModelGetAction;
import org.opensearch.ml.common.transport.model.MLModelGetRequest;
import org.opensearch.ml.common.transport.model.MLModelGetResponse;
import org.opensearch.ml.common.transport.model.MLModelSearchAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.transport.task.MLTaskDeleteAction;
import org.opensearch.ml.common.transport.task.MLTaskDeleteRequest;
import org.opensearch.ml.common.transport.task.MLTaskGetAction;
import org.opensearch.ml.common.transport.task.MLTaskGetRequest;
import org.opensearch.ml.common.transport.task.MLTaskGetResponse;
import org.opensearch.ml.common.transport.task.MLTaskSearchAction;
import org.opensearch.ml.common.transport.training.MLTrainingTaskAction;
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction;

public class MachineLearningNodeClient
implements MachineLearningClient {
    private final NodeClient client;

    @Override
    public void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
        this.validateMLInput(mlInput, true);
        MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest.builder().mlInput(mlInput).modelId(modelId).build();
        this.client.execute((ActionType)MLPredictionTaskAction.INSTANCE, (ActionRequest)predictionRequest, this.getMlPredictionTaskResponseActionListener(listener));
    }

    @Override
    public void trainAndPredict(MLInput mlInput, ActionListener<MLOutput> listener) {
        this.validateMLInput(mlInput, true);
        MLTrainingTaskRequest request = MLTrainingTaskRequest.builder().mlInput(mlInput).build();
        this.client.execute((ActionType)MLTrainAndPredictionTaskAction.INSTANCE, (ActionRequest)request, this.getMlPredictionTaskResponseActionListener(listener));
    }

    @Override
    public void train(MLInput mlInput, boolean asyncTask, ActionListener<MLOutput> listener) {
        this.validateMLInput(mlInput, true);
        MLTrainingTaskRequest trainingTaskRequest = MLTrainingTaskRequest.builder().mlInput(mlInput).async(asyncTask).build();
        this.client.execute((ActionType)MLTrainingTaskAction.INSTANCE, (ActionRequest)trainingTaskRequest, this.getMlPredictionTaskResponseActionListener(listener));
    }

    @Override
    public void getModel(String modelId, ActionListener<MLModel> listener) {
        MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder().modelId(modelId).build();
        this.client.execute((ActionType)MLModelGetAction.INSTANCE, (ActionRequest)mlModelGetRequest, ActionListener.wrap(response -> listener.onResponse((Object)MLModelGetResponse.fromActionResponse(response).getMlModel()), arg_0 -> listener.onFailure(arg_0)));
    }

    @Override
    public void deleteModel(String modelId, ActionListener<DeleteResponse> listener) {
        MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId(modelId).build();
        this.client.execute((ActionType)MLModelDeleteAction.INSTANCE, (ActionRequest)mlModelDeleteRequest, ActionListener.wrap(deleteResponse -> listener.onResponse(deleteResponse), arg_0 -> listener.onFailure(arg_0)));
    }

    @Override
    public void searchModel(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
        this.client.execute((ActionType)MLModelSearchAction.INSTANCE, (ActionRequest)searchRequest, ActionListener.wrap(searchResponse -> listener.onResponse(searchResponse), arg_0 -> listener.onFailure(arg_0)));
    }

    @Override
    public void getTask(String taskId, ActionListener<MLTask> listener) {
        MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder().taskId(taskId).build();
        this.client.execute((ActionType)MLTaskGetAction.INSTANCE, (ActionRequest)mlTaskGetRequest, ActionListener.wrap(response -> listener.onResponse((Object)MLTaskGetResponse.fromActionResponse(response).getMlTask()), arg_0 -> listener.onFailure(arg_0)));
    }

    @Override
    public void deleteTask(String taskId, ActionListener<DeleteResponse> listener) {
        MLTaskDeleteRequest mlTaskDeleteRequest = MLTaskDeleteRequest.builder().taskId(taskId).build();
        this.client.execute((ActionType)MLTaskDeleteAction.INSTANCE, (ActionRequest)mlTaskDeleteRequest, ActionListener.wrap(deleteResponse -> listener.onResponse(deleteResponse), arg_0 -> listener.onFailure(arg_0)));
    }

    @Override
    public void searchTask(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
        this.client.execute((ActionType)MLTaskSearchAction.INSTANCE, (ActionRequest)searchRequest, ActionListener.wrap(searchResponse -> listener.onResponse(searchResponse), arg_0 -> listener.onFailure(arg_0)));
    }

    private ActionListener<MLTaskResponse> getMlPredictionTaskResponseActionListener(ActionListener<MLOutput> listener) {
        ActionListener internalListener = ActionListener.wrap(predictionResponse -> listener.onResponse((Object)predictionResponse.getOutput()), arg_0 -> listener.onFailure(arg_0));
        ActionListener<MLTaskResponse> actionListener = this.wrapActionListener(internalListener, res -> {
            MLTaskResponse predictionResponse = MLTaskResponse.fromActionResponse(res);
            return predictionResponse;
        });
        return actionListener;
    }

    private <T extends ActionResponse> ActionListener<T> wrapActionListener(ActionListener<T> listener, Function<ActionResponse, T> recreate) {
        ActionListener actionListener = ActionListener.wrap(r -> listener.onResponse((Object)((ActionResponse)recreate.apply((ActionResponse)r))), e -> listener.onFailure(e));
        return actionListener;
    }

    private void validateMLInput(MLInput mlInput, boolean requireInput) {
        if (mlInput == null) {
            throw new IllegalArgumentException("ML Input can't be null");
        }
        if (requireInput && mlInput.getInputDataset() == null) {
            throw new IllegalArgumentException("input data set can't be null");
        }
    }

    @Generated
    public MachineLearningNodeClient(NodeClient client) {
        this.client = client;
    }
}

