/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.load;

import com.google.common.collect.ImmutableMap;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.nodes.TransportNodesAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.breaker.MLCircuitBreakerService;
import org.opensearch.ml.common.exception.MLLimitExceededException;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.forward.MLForwardInput;
import org.opensearch.ml.common.transport.forward.MLForwardRequest;
import org.opensearch.ml.common.transport.forward.MLForwardRequestType;
import org.opensearch.ml.common.transport.forward.MLForwardResponse;
import org.opensearch.ml.common.transport.load.LoadModelInput;
import org.opensearch.ml.common.transport.load.LoadModelNodeRequest;
import org.opensearch.ml.common.transport.load.LoadModelNodeResponse;
import org.opensearch.ml.common.transport.load.LoadModelNodesRequest;
import org.opensearch.ml.common.transport.load.LoadModelNodesResponse;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLCommonsSettings;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportRequest;
import org.opensearch.transport.TransportResponseHandler;
import org.opensearch.transport.TransportService;

public class TransportLoadModelOnNodeAction
extends TransportNodesAction<LoadModelNodesRequest, LoadModelNodesResponse, LoadModelNodeRequest, LoadModelNodeResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(TransportLoadModelOnNodeAction.class);
    TransportService transportService;
    ModelHelper modelHelper;
    MLTaskManager mlTaskManager;
    MLModelManager mlModelManager;
    ClusterService clusterService;
    ThreadPool threadPool;
    Client client;
    NamedXContentRegistry xContentRegistry;
    MLCircuitBreakerService mlCircuitBreakerService;
    MLStats mlStats;
    volatile Integer maxLoadTasksPerNode;

    @Inject
    public TransportLoadModelOnNodeAction(TransportService transportService, ActionFilters actionFilters, ModelHelper modelHelper, MLTaskManager mlTaskManager, MLModelManager mlModelManager, ClusterService clusterService, ThreadPool threadPool, Client client, NamedXContentRegistry xContentRegistry, MLCircuitBreakerService mlCircuitBreakerService, MLStats mlStats, Settings settings) {
        super("cluster:admin/opensearch/ml/load_model_on_nodes", threadPool, clusterService, transportService, actionFilters, LoadModelNodesRequest::new, LoadModelNodeRequest::new, "management", LoadModelNodeResponse.class);
        this.transportService = transportService;
        this.modelHelper = modelHelper;
        this.mlTaskManager = mlTaskManager;
        this.mlModelManager = mlModelManager;
        this.clusterService = clusterService;
        this.threadPool = threadPool;
        this.client = client;
        this.xContentRegistry = xContentRegistry;
        this.mlCircuitBreakerService = mlCircuitBreakerService;
        this.mlStats = mlStats;
        this.maxLoadTasksPerNode = (Integer)MLCommonsSettings.ML_COMMONS_MAX_LOAD_MODEL_TASKS_PER_NODE.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_MAX_LOAD_MODEL_TASKS_PER_NODE, it -> {
            this.maxLoadTasksPerNode = it;
        });
    }

    protected LoadModelNodesResponse newResponse(LoadModelNodesRequest nodesRequest, List<LoadModelNodeResponse> responses, List<FailedNodeException> failures) {
        return new LoadModelNodesResponse(this.clusterService.getClusterName(), responses, failures);
    }

    protected LoadModelNodeRequest newNodeRequest(LoadModelNodesRequest request) {
        return new LoadModelNodeRequest(request);
    }

    protected LoadModelNodeResponse newNodeResponse(StreamInput in) throws IOException {
        return new LoadModelNodeResponse(in);
    }

    protected LoadModelNodeResponse nodeOperation(LoadModelNodeRequest request) {
        return this.createLoadModelNodeResponse(request.getLoadModelNodesRequest());
    }

    private LoadModelNodeResponse createLoadModelNodeResponse(LoadModelNodesRequest loadModelNodesRequest) {
        LoadModelInput loadModelInput = loadModelNodesRequest.getLoadModelInput();
        String modelId = loadModelInput.getModelId();
        String taskId = loadModelInput.getTaskId();
        Integer nodeCount = loadModelInput.getNodeCount();
        String coordinatingNodeId = loadModelInput.getCoordinatingNodeId();
        MLTask mlTask = loadModelInput.getMlTask();
        String modelContentHash = loadModelInput.getModelContentHash();
        HashMap<String, String> modelLoadStatus = new HashMap<String, String>();
        modelLoadStatus.put(modelId, "received");
        String localNodeId = this.clusterService.localNode().getId();
        ActionListener taskDoneListener = ActionListener.wrap(res -> log.info("load model done " + res), ex -> log.error(ex));
        this.loadModel(modelId, modelContentHash, mlTask.getFunctionName(), localNodeId, coordinatingNodeId, mlTask, (ActionListener<String>)ActionListener.wrap(r -> {
            if (!coordinatingNodeId.equals(localNodeId)) {
                this.mlTaskManager.remove(taskId);
            }
            MLForwardInput mlForwardInput = MLForwardInput.builder().requestType(MLForwardRequestType.LOAD_MODEL_DONE).taskId(taskId).modelId(modelId).workerNodeId(this.clusterService.localNode().getId()).build();
            MLForwardRequest loadModelDoneMessage = new MLForwardRequest(mlForwardInput);
            this.transportService.sendRequest(this.getNodeById(coordinatingNodeId), "cluster:admin/opensearch/mlinternal/forward", (TransportRequest)loadModelDoneMessage, (TransportResponseHandler)new ActionListenerResponseHandler(taskDoneListener, MLForwardResponse::new));
        }, e -> {
            if (e instanceof MLLimitExceededException) {
                this.mlTaskManager.updateMLTaskDirectly(mlTask.getTaskId(), (Map<String, Object>)ImmutableMap.of((Object)"state", (Object)MLTaskState.FAILED, (Object)"error", (Object)e.getMessage()));
            } else {
                boolean removeTaskCache = !coordinatingNodeId.equals(localNodeId);
                this.mlTaskManager.updateMLTask(taskId, (Map<String, Object>)ImmutableMap.of((Object)"error", (Object)ExceptionUtils.getStackTrace((Throwable)e), (Object)"state", (Object)MLTaskState.FAILED), MLTaskManager.TASK_SEMAPHORE_TIMEOUT, removeTaskCache);
            }
            MLModelState state = mlTask.getFunctionName() == FunctionName.TEXT_EMBEDDING ? MLModelState.UPLOADED : MLModelState.TRAINED;
            this.mlModelManager.updateModel(modelId, (ImmutableMap<String, Object>)ImmutableMap.of((Object)"model_state", (Object)state));
            MLForwardInput mlForwardInput = MLForwardInput.builder().requestType(MLForwardRequestType.LOAD_MODEL_DONE).taskId(taskId).modelId(modelId).workerNodeId(this.clusterService.localNode().getId()).error(ExceptionUtils.getStackTrace((Throwable)e)).build();
            MLForwardRequest loadModelDoneMessage = new MLForwardRequest(mlForwardInput);
            this.transportService.sendRequest(this.getNodeById(coordinatingNodeId), "cluster:admin/opensearch/mlinternal/forward", (TransportRequest)loadModelDoneMessage, (TransportResponseHandler)new ActionListenerResponseHandler(taskDoneListener, MLForwardResponse::new));
        }));
        return new LoadModelNodeResponse(this.clusterService.localNode(), modelLoadStatus);
    }

    private DiscoveryNode getNodeById(String nodeId) {
        DiscoveryNodes nodes = this.clusterService.state().getNodes();
        for (DiscoveryNode node : nodes) {
            if (!node.getId().equals(nodeId)) continue;
            return node;
        }
        return null;
    }

    private void loadModel(String modelId, String modelContentHash, FunctionName functionName, String localNodeId, String coordinatingNodeId, MLTask mlTask, ActionListener<String> listener) {
        try {
            String errorMsg = this.mlModelManager.checkAndAddRunningTask(mlTask, this.maxLoadTasksPerNode);
            if (errorMsg != null) {
                listener.onFailure((Exception)new MLLimitExceededException(errorMsg));
                return;
            }
            log.debug("start loading model {}", (Object)modelId);
            this.mlModelManager.loadModel(modelId, modelContentHash, functionName, listener);
        }
        catch (Exception e) {
            log.error("Failed to load model " + modelId, (Throwable)e);
            listener.onFailure(e);
        }
    }
}

