/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.task;

import java.time.Instant;
import java.util.UUID;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.support.ThreadedActionListener;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.DeprecationHandler;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.common.xcontent.XContentParserUtils;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.commons.authuser.User;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.breaker.MLCircuitBreakerService;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.Predictable;
import org.opensearch.ml.indices.MLInputDatasetHandler;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.permission.AccessController;
import org.opensearch.ml.stats.ActionName;
import org.opensearch.ml.stats.MLActionLevelStat;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskDispatcher;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.task.MLTaskRunner;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportRequest;
import org.opensearch.transport.TransportResponseHandler;
import org.opensearch.transport.TransportService;

public class MLPredictTaskRunner
extends MLTaskRunner<MLPredictionTaskRequest, MLTaskResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(MLPredictTaskRunner.class);
    private final ThreadPool threadPool;
    private final ClusterService clusterService;
    private final Client client;
    private final MLInputDatasetHandler mlInputDatasetHandler;
    private final NamedXContentRegistry xContentRegistry;
    private final MLModelManager mlModelManager;
    private final DiscoveryNodeHelper nodeHelper;
    private final MLEngine mlEngine;

    public MLPredictTaskRunner(ThreadPool threadPool, ClusterService clusterService, Client client, MLTaskManager mlTaskManager, MLStats mlStats, MLInputDatasetHandler mlInputDatasetHandler, MLTaskDispatcher mlTaskDispatcher, MLCircuitBreakerService mlCircuitBreakerService, NamedXContentRegistry xContentRegistry, MLModelManager mlModelManager, DiscoveryNodeHelper nodeHelper, MLEngine mlEngine) {
        super(mlTaskManager, mlStats, nodeHelper, mlTaskDispatcher, mlCircuitBreakerService, clusterService);
        this.threadPool = threadPool;
        this.clusterService = clusterService;
        this.client = client;
        this.mlInputDatasetHandler = mlInputDatasetHandler;
        this.xContentRegistry = xContentRegistry;
        this.mlModelManager = mlModelManager;
        this.nodeHelper = nodeHelper;
        this.mlEngine = mlEngine;
    }

    @Override
    protected String getTransportActionName() {
        return "cluster:admin/opensearch/ml/predict";
    }

    @Override
    protected TransportResponseHandler<MLTaskResponse> getResponseHandler(ActionListener<MLTaskResponse> listener) {
        return new ActionListenerResponseHandler(listener, MLTaskResponse::new);
    }

    @Override
    public void dispatchTask(MLPredictionTaskRequest request, TransportService transportService, ActionListener<MLTaskResponse> listener) {
        String modelId = request.getModelId();
        MLInput input = request.getMlInput();
        FunctionName algorithm = input.getAlgorithm();
        try {
            ActionListener actionListener = ActionListener.wrap(node -> {
                if (this.clusterService.localNode().getId().equals(node.getId())) {
                    log.debug("Execute ML predict request {} locally on node {}", (Object)request.getRequestID(), (Object)node.getId());
                    this.executeTask(request, listener);
                } else {
                    log.debug("Execute ML predict request {} remotely on node {}", (Object)request.getRequestID(), (Object)node.getId());
                    request.setDispatchTask(false);
                    transportService.sendRequest(node, this.getTransportActionName(), (TransportRequest)request, this.getResponseHandler(listener));
                }
            }, e -> listener.onFailure(e));
            String[] workerNodes = this.mlModelManager.getWorkerNodes(modelId);
            if (workerNodes == null || workerNodes.length == 0) {
                if (algorithm == FunctionName.TEXT_EMBEDDING) {
                    listener.onFailure((Exception)new MLException("model not loaded"));
                    return;
                }
                workerNodes = this.nodeHelper.getEligibleNodeIds();
            }
            this.mlTaskDispatcher.dispatchPredictTask(workerNodes, (ActionListener<DiscoveryNode>)actionListener);
        }
        catch (Exception e2) {
            log.error("Failed to predict model " + modelId, (Throwable)e2);
            listener.onFailure(e2);
        }
    }

    @Override
    protected void executeTask(MLPredictionTaskRequest request, ActionListener<MLTaskResponse> listener) {
        MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
        Instant now = Instant.now();
        String modelId = request.getModelId();
        MLTask mlTask = MLTask.builder().taskId(UUID.randomUUID().toString()).modelId(modelId).taskType(MLTaskType.PREDICTION).inputType(inputDataType).functionName(request.getMlInput().getFunctionName()).state(MLTaskState.CREATED).workerNode(this.clusterService.localNode().getId()).createTime(now).lastUpdateTime(now).async(false).build();
        MLInput mlInput = request.getMlInput();
        switch (inputDataType) {
            case SEARCH_QUERY: {
                ActionListener dataFrameActionListener = ActionListener.wrap(dataSet -> {
                    MLInput newInput = mlInput.toBuilder().inputDataset(dataSet).build();
                    this.predict(modelId, mlTask, newInput, listener);
                }, e -> {
                    log.error("Failed to generate DataFrame from search query", (Throwable)e);
                    this.handleAsyncMLTaskFailure(mlTask, (Exception)e);
                    listener.onFailure(e);
                });
                this.mlInputDatasetHandler.parseSearchQueryInput(mlInput.getInputDataset(), (ActionListener<MLInputDataset>)this.threadedActionListener(dataFrameActionListener));
                break;
            }
            default: {
                this.threadPool.executor("opensearch_ml_predict").execute(() -> this.predict(modelId, mlTask, mlInput, listener));
            }
        }
    }

    private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListener<MLTaskResponse> listener) {
        ActionListener<MLTaskResponse> internalListener = this.wrappedCleanupListener(listener, mlTask.getTaskId());
        this.mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
        this.mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
        this.mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.PREDICT, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment();
        mlTask.setState(MLTaskState.RUNNING);
        this.mlTaskManager.add(mlTask);
        FunctionName algorithm = mlInput.getAlgorithm();
        if (modelId != null) {
            try {
                Predictable predictor = this.mlModelManager.getPredictor(modelId);
                if (predictor != null) {
                    MLOutput output = this.mlModelManager.trackPredictDuration(modelId, () -> predictor.predict(mlInput));
                    if (output instanceof MLPredictionOutput) {
                        ((MLPredictionOutput)output).setStatus(MLTaskState.COMPLETED.name());
                    }
                    this.handleAsyncMLTaskComplete(mlTask);
                    MLTaskResponse response = MLTaskResponse.builder().output(output).build();
                    internalListener.onResponse((Object)response);
                    return;
                }
                if (algorithm == FunctionName.TEXT_EMBEDDING) {
                    throw new MLException("model not loaded");
                }
            }
            catch (Exception e2) {
                this.handlePredictFailure(mlTask, internalListener, e2, false);
            }
            try (ThreadContext.StoredContext context = this.threadPool.getThreadContext().stashContext();){
                ActionListener getModelListener = ActionListener.wrap(r -> {
                    if (r == null || !r.isExists()) {
                        internalListener.onFailure((Exception)new ResourceNotFoundException("No model found, please check the modelId.", new Object[0]));
                        return;
                    }
                    try (XContentParser xContentParser = XContentType.JSON.xContent().createParser(this.xContentRegistry, (DeprecationHandler)LoggingDeprecationHandler.INSTANCE, r.getSourceAsString());){
                        XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)xContentParser.nextToken(), (XContentParser)xContentParser);
                        MLModel mlModel = MLModel.parse((XContentParser)xContentParser);
                        User resourceUser = mlModel.getUser();
                        User requestUser = AccessController.getUserContext(this.client);
                        if (!AccessController.checkUserPermissions(requestUser, resourceUser, modelId)) {
                            OpenSearchException e = new OpenSearchException("User: " + requestUser.getName() + " does not have permissions to run predict by model: " + modelId, new Object[0]);
                            this.handlePredictFailure(mlTask, internalListener, (Exception)e, false);
                            return;
                        }
                        this.mlTaskManager.updateTaskStateAsRunning(mlTask.getTaskId(), mlTask.isAsync());
                        MLOutput output = this.mlEngine.predict((Input)mlInput, mlModel);
                        if (output instanceof MLPredictionOutput) {
                            ((MLPredictionOutput)output).setStatus(MLTaskState.COMPLETED.name());
                        }
                        this.handleAsyncMLTaskComplete(mlTask);
                        MLTaskResponse response = MLTaskResponse.builder().output(output).build();
                        internalListener.onResponse((Object)response);
                    }
                    catch (Exception e) {
                        log.error("Failed to predict model " + modelId, (Throwable)e);
                        internalListener.onFailure(e);
                    }
                }, e -> {
                    log.error("Failed to predict " + mlInput.getAlgorithm() + ", modelId: " + mlTask.getModelId(), (Throwable)e);
                    this.handlePredictFailure(mlTask, internalListener, (Exception)e, true);
                });
                GetRequest getRequest = new GetRequest(".plugins-ml-model", mlTask.getModelId());
                this.client.get(getRequest, this.threadedActionListener(ActionListener.runBefore((ActionListener)getModelListener, () -> context.restore())));
            }
            catch (Exception e3) {
                log.error("Failed to get model " + mlTask.getModelId(), (Throwable)e3);
                this.handlePredictFailure(mlTask, internalListener, e3, true);
            }
        } else {
            IllegalArgumentException e4 = new IllegalArgumentException("ModelId is invalid");
            log.error("ModelId is invalid", (Throwable)e4);
            this.handlePredictFailure(mlTask, internalListener, e4, false);
        }
    }

    private <T> ThreadedActionListener<T> threadedActionListener(ActionListener<T> listener) {
        return new ThreadedActionListener(log, this.threadPool, "opensearch_ml_predict", listener, false);
    }

    private void handlePredictFailure(MLTask mlTask, ActionListener<MLTaskResponse> listener, Exception e, boolean trackFailure) {
        if (trackFailure) {
            this.mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.PREDICT, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment();
            this.mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT).increment();
        }
        this.handleAsyncMLTaskFailure(mlTask, e);
        listener.onFailure(e);
    }
}

