/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.load;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.load.LoadModelInput;
import org.opensearch.ml.common.transport.load.LoadModelNodesRequest;
import org.opensearch.ml.common.transport.load.LoadModelResponse;
import org.opensearch.ml.common.transport.load.MLLoadModelOnNodeAction;
import org.opensearch.ml.common.transport.load.MLLoadModelRequest;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskDispatcher;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

public class TransportLoadModelAction
extends HandledTransportAction<ActionRequest, LoadModelResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(TransportLoadModelAction.class);
    TransportService transportService;
    ModelHelper modelHelper;
    MLTaskManager mlTaskManager;
    ClusterService clusterService;
    ThreadPool threadPool;
    Client client;
    NamedXContentRegistry xContentRegistry;
    DiscoveryNodeHelper nodeFilter;
    MLTaskDispatcher mlTaskDispatcher;
    MLModelManager mlModelManager;
    MLStats mlStats;

    @Inject
    public TransportLoadModelAction(TransportService transportService, ActionFilters actionFilters, ModelHelper modelHelper, MLTaskManager mlTaskManager, ClusterService clusterService, ThreadPool threadPool, Client client, NamedXContentRegistry xContentRegistry, DiscoveryNodeHelper nodeFilter, MLTaskDispatcher mlTaskDispatcher, MLModelManager mlModelManager, MLStats mlStats) {
        super("cluster:admin/opensearch/ml/load_model", transportService, actionFilters, MLLoadModelRequest::new);
        this.transportService = transportService;
        this.modelHelper = modelHelper;
        this.mlTaskManager = mlTaskManager;
        this.clusterService = clusterService;
        this.threadPool = threadPool;
        this.client = client;
        this.xContentRegistry = xContentRegistry;
        this.nodeFilter = nodeFilter;
        this.mlTaskDispatcher = mlTaskDispatcher;
        this.mlModelManager = mlModelManager;
        this.mlStats = mlStats;
    }

    protected void doExecute(Task task, ActionRequest request, ActionListener<LoadModelResponse> listener) {
        MLLoadModelRequest deployModelRequest = MLLoadModelRequest.fromActionRequest((ActionRequest)request);
        String modelId = deployModelRequest.getModelId();
        String[] targetNodeIds = deployModelRequest.getModelNodeIds();
        this.mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
        DiscoveryNode[] allEligibleNodes = this.nodeFilter.getEligibleNodes();
        HashMap<String, DiscoveryNode> nodeMapping = new HashMap<String, DiscoveryNode>();
        for (DiscoveryNode discoveryNode : allEligibleNodes) {
            nodeMapping.put(discoveryNode.getId(), discoveryNode);
        }
        Set allEligibleNodeIds = Arrays.stream(allEligibleNodes).map(DiscoveryNode::getId).collect(Collectors.toSet());
        ArrayList<DiscoveryNode> eligibleNodes = new ArrayList<DiscoveryNode>();
        ArrayList<String> nodeIds = new ArrayList<String>();
        if (targetNodeIds != null && targetNodeIds.length > 0) {
            for (String nodeId : targetNodeIds) {
                if (!allEligibleNodeIds.contains(nodeId)) continue;
                eligibleNodes.add((DiscoveryNode)nodeMapping.get(nodeId));
                nodeIds.add(nodeId);
            }
        } else {
            nodeIds.addAll(allEligibleNodeIds);
            eligibleNodes.addAll(Arrays.asList(allEligibleNodes));
        }
        if (nodeIds.size() == 0) {
            listener.onFailure((Exception)new MLResourceNotFoundException("no eligible node found"));
            return;
        }
        String string = String.join((CharSequence)",", nodeIds);
        log.warn("Will load model on these nodes: {}", (Object)string);
        String localNodeId = this.clusterService.localNode().getId();
        String[] excludes = new String[]{"model_content", "content"};
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            this.mlModelManager.getModel(modelId, null, excludes, (ActionListener<MLModel>)ActionListener.wrap(mlModel -> {
                FunctionName algorithm = mlModel.getAlgorithm();
                MLTask mlTask = MLTask.builder().async(true).modelId(modelId).taskType(MLTaskType.LOAD_MODEL).functionName(algorithm).createTime(Instant.now()).lastUpdateTime(Instant.now()).state(MLTaskState.CREATED).workerNode(workerNodes).build();
                this.mlTaskManager.createMLTask(mlTask, (ActionListener<IndexResponse>)ActionListener.wrap(response -> {
                    String taskId = response.getId();
                    mlTask.setTaskId(taskId);
                    try {
                        this.mlTaskManager.add(mlTask, nodeIds);
                        listener.onResponse((Object)new LoadModelResponse(taskId, MLTaskState.CREATED.name()));
                        this.threadPool.executor("opensearch_ml_load").execute(() -> this.updateModelLoadStatusAndTriggerOnNodesAction(modelId, taskId, (MLModel)mlModel, localNodeId, mlTask, (List<DiscoveryNode>)eligibleNodes, algorithm));
                    }
                    catch (Exception ex) {
                        log.error("Failed to load model", (Throwable)ex);
                        this.mlTaskManager.updateMLTask(taskId, (Map<String, Object>)ImmutableMap.of((Object)"state", (Object)MLTaskState.FAILED, (Object)"error", (Object)ExceptionUtils.getStackTrace((Throwable)ex)), MLTaskManager.TASK_SEMAPHORE_TIMEOUT, true);
                        listener.onFailure(ex);
                    }
                }, exception -> {
                    log.error("Failed to create upload model task for " + modelId, (Throwable)exception);
                    listener.onFailure(exception);
                }));
            }, e -> {
                log.error("Failed to get model " + modelId, (Throwable)e);
                listener.onFailure(e);
            }));
        }
        catch (Exception e2) {
            log.error("Failed to load model " + modelId, (Throwable)e2);
            listener.onFailure(e2);
        }
    }

    @VisibleForTesting
    void updateModelLoadStatusAndTriggerOnNodesAction(String modelId, String taskId, MLModel mlModel, String localNodeId, MLTask mlTask, List<DiscoveryNode> eligibleNodes, FunctionName algorithm) {
        LoadModelInput loadModelInput = new LoadModelInput(modelId, taskId, mlModel.getModelContentHash(), Integer.valueOf(eligibleNodes.size()), localNodeId, mlTask);
        LoadModelNodesRequest loadModelRequest = new LoadModelNodesRequest(eligibleNodes.toArray(new DiscoveryNode[0]), loadModelInput);
        ActionListener actionListener = ActionListener.wrap(r -> {
            if (this.mlTaskManager.contains(taskId)) {
                this.mlTaskManager.updateMLTask(taskId, (Map<String, Object>)ImmutableMap.of((Object)"state", (Object)MLTaskState.RUNNING), MLTaskManager.TASK_SEMAPHORE_TIMEOUT, false);
            }
        }, e -> {
            log.error("Failed to load model " + modelId, (Throwable)e);
            this.mlTaskManager.updateMLTask(taskId, (Map<String, Object>)ImmutableMap.of((Object)"error", (Object)ExceptionUtils.getStackTrace((Throwable)e), (Object)"state", (Object)MLTaskState.FAILED), MLTaskManager.TASK_SEMAPHORE_TIMEOUT, true);
            MLModelState state = algorithm == FunctionName.TEXT_EMBEDDING ? MLModelState.UPLOADED : MLModelState.TRAINED;
            this.mlModelManager.updateModel(modelId, (ImmutableMap<String, Object>)ImmutableMap.of((Object)"model_state", (Object)state));
        });
        this.mlModelManager.updateModel(modelId, (Map<String, Object>)ImmutableMap.of((Object)"model_state", (Object)MLModelState.LOADING), (ActionListener<UpdateResponse>)ActionListener.wrap(r -> this.client.execute((ActionType)MLLoadModelOnNodeAction.INSTANCE, (ActionRequest)loadModelRequest, actionListener), arg_0 -> ((ActionListener)actionListener).onFailure(arg_0)));
    }
}

