/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.forward;

import com.google.common.collect.ImmutableMap;
import java.time.Instant;
import java.util.Arrays;
import java.util.Map;
import java.util.Set;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.inject.Inject;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.forward.MLForwardInput;
import org.opensearch.ml.common.transport.forward.MLForwardRequest;
import org.opensearch.ml.common.transport.forward.MLForwardRequestType;
import org.opensearch.ml.common.transport.forward.MLForwardResponse;
import org.opensearch.ml.common.transport.sync.MLSyncUpAction;
import org.opensearch.ml.common.transport.sync.MLSyncUpInput;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest;
import org.opensearch.ml.common.transport.upload.MLUploadInput;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.task.MLTaskCache;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

public class TransportForwardAction
extends HandledTransportAction<ActionRequest, MLForwardResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(TransportForwardAction.class);
    private MLTaskManager mlTaskManager;
    private Client client;
    private MLModelManager mlModelManager;
    private DiscoveryNodeHelper nodeHelper;

    @Inject
    public TransportForwardAction(TransportService transportService, ActionFilters actionFilters, MLTaskManager mlTaskManager, Client client, MLModelManager mlModelManager, DiscoveryNodeHelper nodeHelper) {
        super("cluster:admin/opensearch/mlinternal/forward", transportService, actionFilters, MLForwardRequest::new);
        this.mlTaskManager = mlTaskManager;
        this.client = client;
        this.mlModelManager = mlModelManager;
        this.nodeHelper = nodeHelper;
    }

    protected void doExecute(Task task, ActionRequest request, ActionListener<MLForwardResponse> listener) {
        MLForwardRequest mlForwardRequest = MLForwardRequest.fromActionRequest((ActionRequest)request);
        MLForwardInput forwardInput = mlForwardRequest.getForwardInput();
        String modelId = forwardInput.getModelId();
        String taskId = forwardInput.getTaskId();
        MLUploadInput uploadInput = forwardInput.getUploadInput();
        MLTask mlTask = forwardInput.getMlTask();
        String workerNodeId = forwardInput.getWorkerNodeId();
        MLForwardRequestType requestType = forwardInput.getRequestType();
        String error = forwardInput.getError();
        log.debug("receive forward request: {}", (Object)forwardInput.getRequestType());
        try {
            switch (requestType) {
                case LOAD_MODEL_DONE: {
                    Set<String> workNodes = this.mlTaskManager.getWorkNodes(taskId);
                    if (workNodes != null) {
                        workNodes.remove(workerNodeId);
                    }
                    if (error != null) {
                        this.mlTaskManager.addNodeError(taskId, workerNodeId, error);
                    } else {
                        this.mlModelManager.addModelWorkerNode(modelId, workerNodeId);
                        this.syncModelWorkerNodes(modelId);
                    }
                    if (workNodes == null || workNodes.size() == 0) {
                        MLTaskState taskState;
                        MLTaskCache mlTaskCache = this.mlTaskManager.getMLTaskCache(taskId);
                        MLTaskState mLTaskState = taskState = mlTaskCache.hasError() ? MLTaskState.COMPLETED_WITH_ERROR : MLTaskState.COMPLETED;
                        if (mlTaskCache.allNodeFailed()) {
                            taskState = MLTaskState.FAILED;
                        } else {
                            this.syncModelWorkerNodes(modelId);
                        }
                        ImmutableMap.Builder builder = ImmutableMap.builder();
                        builder.put((Object)"state", (Object)taskState);
                        if (mlTaskCache.hasError()) {
                            builder.put((Object)"error", (Object)mlTaskCache.getErrors().toString());
                        }
                        this.mlTaskManager.updateMLTask(taskId, (Map<String, Object>)builder.build(), MLTaskManager.TASK_SEMAPHORE_TIMEOUT, true);
                        if (!mlTaskCache.allNodeFailed()) {
                            MLModelState modelState = mlTaskCache.hasError() ? MLModelState.PARTIALLY_LOADED : MLModelState.LOADED;
                            log.info("load model done with state: {}, model id: {}", (Object)modelState, (Object)modelId);
                            this.mlModelManager.updateModel(modelId, (ImmutableMap<String, Object>)ImmutableMap.of((Object)"model_state", (Object)modelState, (Object)"last_loaded_time", (Object)Instant.now().toEpochMilli()));
                        } else {
                            log.error("load model failed on all nodes, model id: {}", (Object)modelId);
                        }
                    }
                    listener.onResponse((Object)new MLForwardResponse("ok", null));
                    break;
                }
                case UPLOAD_MODEL: {
                    this.mlModelManager.uploadMLModel(uploadInput, mlTask);
                    listener.onResponse((Object)new MLForwardResponse("ok", null));
                    break;
                }
                default: {
                    throw new IllegalArgumentException("unsupported request type");
                }
            }
        }
        catch (Exception e) {
            log.error("Failed to execute forward action", (Throwable)e);
            listener.onFailure(e);
        }
    }

    private void syncModelWorkerNodes(String modelId) {
        DiscoveryNode[] allNodes = this.nodeHelper.getAllNodes();
        Object[] workerNodes = this.mlModelManager.getWorkerNodes(modelId);
        if (allNodes.length > 1 && workerNodes.length > 0) {
            log.debug("Sync to other nodes about worker nodes of model {}: {}", (Object)modelId, (Object)Arrays.toString(workerNodes));
            MLSyncUpInput syncUpInput = MLSyncUpInput.builder().addedWorkerNodes((Map)ImmutableMap.of((Object)modelId, (Object)workerNodes)).build();
            MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(allNodes, syncUpInput);
            this.client.execute((ActionType)MLSyncUpAction.INSTANCE, (ActionRequest)syncUpRequest, ActionListener.wrap(r -> log.debug("Sync up successfully"), e -> log.error("Failed to sync up", (Throwable)e)));
        }
    }
}

