/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.task;

import com.google.common.collect.ImmutableMap;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.common.Strings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.indices.MLIndicesHandler;
import org.opensearch.ml.task.MLTaskCache;
import org.opensearch.rest.RestStatus;
import org.opensearch.threadpool.ThreadPool;

public class MLTaskManager {
    @Generated
    private static final Logger log = LogManager.getLogger(MLTaskManager.class);
    public static int TASK_SEMAPHORE_TIMEOUT = 5000;
    private final Map<String, MLTaskCache> taskCaches;
    private final Client client;
    private final ThreadPool threadPool;
    private final MLIndicesHandler mlIndicesHandler;
    private final Map<MLTaskType, AtomicInteger> runningTasksCount;

    public MLTaskManager(Client client, ThreadPool threadPool, MLIndicesHandler mlIndicesHandler) {
        this.client = client;
        this.threadPool = threadPool;
        this.mlIndicesHandler = mlIndicesHandler;
        this.taskCaches = new ConcurrentHashMap<String, MLTaskCache>();
        this.runningTasksCount = new ConcurrentHashMap<MLTaskType, AtomicInteger>();
    }

    public synchronized String checkLimitAndAddRunningTask(MLTask mlTask, Integer limit) {
        AtomicInteger runningTaskCount = this.runningTasksCount.computeIfAbsent(mlTask.getTaskType(), it -> new AtomicInteger(0));
        if (runningTaskCount.get() < 0) {
            runningTaskCount.set(0);
        }
        log.debug("Task id: {}, current running task {}: {}", (Object)mlTask.getTaskId(), (Object)mlTask.getTaskType(), (Object)runningTaskCount.get());
        if (runningTaskCount.get() >= limit) {
            String error = "exceed max running task limit";
            log.info(error + " for task " + mlTask.getTaskId());
            return error;
        }
        if (this.contains(mlTask.getTaskId())) {
            this.getMLTask(mlTask.getTaskId()).setState(MLTaskState.RUNNING);
        } else {
            mlTask.setState(MLTaskState.RUNNING);
            this.add(mlTask);
        }
        runningTaskCount.incrementAndGet();
        return null;
    }

    public synchronized void add(MLTask mlTask) {
        this.add(mlTask, null);
    }

    public synchronized void add(MLTask mlTask, List<String> workerNodes) {
        String taskId = mlTask.getTaskId();
        if (this.contains(taskId)) {
            throw new IllegalArgumentException("Duplicate taskId");
        }
        this.taskCaches.put(taskId, new MLTaskCache(mlTask, workerNodes));
        log.debug("add ML task to cache " + taskId);
    }

    public boolean contains(String taskId) {
        return this.taskCaches.containsKey(taskId);
    }

    public void remove(String taskId) {
        if (this.contains(taskId)) {
            AtomicInteger runningTaskCount;
            MLTaskCache taskCache = this.taskCaches.remove(taskId);
            MLTask mlTask = taskCache.getMlTask();
            if (mlTask.getState() != MLTaskState.CREATED && (runningTaskCount = this.runningTasksCount.get(mlTask.getTaskType())) != null) {
                runningTaskCount.decrementAndGet();
            }
            log.debug("remove ML task from cache " + taskId);
        }
    }

    public MLTask getMLTask(String taskId) {
        if (this.contains(taskId)) {
            return this.taskCaches.get(taskId).getMlTask();
        }
        return null;
    }

    public MLTaskCache getMLTaskCache(String taskId) {
        if (this.contains(taskId)) {
            return this.taskCaches.get(taskId);
        }
        return null;
    }

    public Set<String> getWorkNodes(String taskId) {
        if (this.taskCaches.containsKey(taskId)) {
            return this.taskCaches.get(taskId).getWorkerNodes();
        }
        return null;
    }

    public void addNodeError(String taskId, String workerNodeId, String error) {
        log.debug("add task error: taskId: {}, workerNodeId: {}, error: {}", (Object)taskId, (Object)workerNodeId, (Object)error);
        if (this.taskCaches.containsKey(taskId)) {
            this.taskCaches.get(taskId).addError(workerNodeId, error);
        }
    }

    public String[] getAllTaskIds() {
        return Strings.toStringArray(this.taskCaches.keySet());
    }

    public int getRunningTaskCount() {
        int res = 0;
        for (Map.Entry<String, MLTaskCache> entry : this.taskCaches.entrySet()) {
            MLTask mlTask = entry.getValue().getMlTask();
            if (mlTask.getState() == null || mlTask.getState() != MLTaskState.RUNNING) continue;
            ++res;
        }
        return res;
    }

    public void clear() {
        this.taskCaches.clear();
    }

    public void createMLTask(MLTask mlTask, ActionListener<IndexResponse> listener) {
        this.mlIndicesHandler.initMLTaskIndex((ActionListener<Boolean>)ActionListener.wrap(indexCreated -> {
            if (!indexCreated.booleanValue()) {
                listener.onFailure((Exception)new RuntimeException("No response to create ML task index"));
                return;
            }
            IndexRequest request = new IndexRequest(".plugins-ml-task");
            try (XContentBuilder builder = XContentFactory.jsonBuilder();
                 ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                request.source(mlTask.toXContent(builder, ToXContent.EMPTY_PARAMS)).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
                this.client.index(request, ActionListener.runBefore((ActionListener)listener, () -> context.restore()));
            }
            catch (Exception e) {
                log.error("Failed to create AD task for " + mlTask.getFunctionName() + ", " + mlTask.getTaskType(), (Throwable)e);
                listener.onFailure(e);
            }
        }, e -> {
            log.error("Failed to create ML index", (Throwable)e);
            listener.onFailure(e);
        }));
    }

    public void updateTaskStateAsRunning(String taskId, boolean isAsyncTask) {
        if (!this.contains(taskId)) {
            throw new IllegalArgumentException("Task not found");
        }
        MLTask task = this.getMLTask(taskId);
        task.setState(MLTaskState.RUNNING);
        if (isAsyncTask) {
            this.updateMLTask(taskId, (Map<String, Object>)ImmutableMap.of((Object)"state", (Object)MLTaskState.RUNNING), TASK_SEMAPHORE_TIMEOUT, false);
        }
    }

    public void updateMLTask(String taskId, Map<String, Object> updatedFields, long timeoutInMillis, boolean removeFromCache) {
        ActionListener internalListener = ActionListener.wrap(response -> {
            if (response.status() == RestStatus.OK) {
                log.debug("Updated ML task successfully: {}, task id: {}", (Object)response.status(), (Object)taskId);
            } else {
                log.error("Failed to update ML task {}, status: {}", (Object)taskId, (Object)response.status());
            }
        }, e -> log.error("Failed to update ML task: " + taskId, (Throwable)e));
        this.updateMLTask(taskId, updatedFields, (ActionListener<UpdateResponse>)ActionListener.runAfter((ActionListener)internalListener, () -> {
            if (removeFromCache) {
                this.remove(taskId);
            }
        }), timeoutInMillis);
    }

    public void updateMLTask(String taskId, Map<String, Object> updatedFields, ActionListener<UpdateResponse> listener, long timeoutInMillis) {
        MLTaskCache taskCache = this.taskCaches.get(taskId);
        if (taskCache == null) {
            listener.onFailure((Exception)new MLResourceNotFoundException("Can't find task"));
            return;
        }
        this.threadPool.executor("opensearch_ml_general").execute(() -> {
            Semaphore semaphore = taskCache.getUpdateTaskIndexSemaphore();
            try {
                if (semaphore != null && !semaphore.tryAcquire(timeoutInMillis, TimeUnit.MILLISECONDS)) {
                    listener.onFailure((Exception)new MLException("Other updating request not finished yet"));
                    return;
                }
            }
            catch (InterruptedException e) {
                log.error("Failed to acquire semaphore for ML task " + taskId, (Throwable)e);
                listener.onFailure((Exception)e);
                return;
            }
            try {
                if (updatedFields == null || updatedFields.size() == 0) {
                    listener.onFailure((Exception)new IllegalArgumentException("Updated fields is null or empty"));
                    return;
                }
                UpdateRequest updateRequest = new UpdateRequest(".plugins-ml-task", taskId);
                HashMap<String, Long> updatedContent = new HashMap<String, Long>();
                updatedContent.putAll(updatedFields);
                updatedContent.put("last_update_time", Instant.now().toEpochMilli());
                updateRequest.doc(updatedContent);
                updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
                ActionListener actionListener = semaphore == null ? listener : ActionListener.runAfter((ActionListener)listener, () -> semaphore.release());
                try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                    this.client.update(updateRequest, ActionListener.runBefore((ActionListener)actionListener, () -> context.restore()));
                }
                catch (Exception e) {
                    actionListener.onFailure(e);
                }
            }
            catch (Exception e) {
                semaphore.release();
                log.error("Failed to update ML task " + taskId, (Throwable)e);
                listener.onFailure(e);
            }
        });
    }

    public void updateMLTaskDirectly(String taskId, Map<String, Object> updatedFields) {
        this.updateMLTaskDirectly(taskId, updatedFields, (ActionListener<UpdateResponse>)ActionListener.wrap(r -> log.debug("updated ML task directly: {}", (Object)taskId), e -> log.error("Failed to update ML task " + taskId, (Throwable)e)));
    }

    public void updateMLTaskDirectly(String taskId, Map<String, Object> updatedFields, ActionListener<UpdateResponse> listener) {
        try {
            if (updatedFields == null || updatedFields.size() == 0) {
                listener.onFailure((Exception)new IllegalArgumentException("Updated fields is null or empty"));
                return;
            }
            UpdateRequest updateRequest = new UpdateRequest(".plugins-ml-task", taskId);
            HashMap<String, Object> updatedContent = new HashMap<String, Object>();
            updatedContent.putAll(updatedFields);
            updatedContent.put("last_update_time", Instant.now().toEpochMilli());
            updateRequest.doc(updatedContent);
            updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
            try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                this.client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore()));
            }
            catch (Exception e) {
                listener.onFailure(e);
            }
        }
        catch (Exception e) {
            log.error("Failed to update ML task " + taskId, (Throwable)e);
            listener.onFailure(e);
        }
    }

    public boolean containsModel(String modelId) {
        for (Map.Entry<String, MLTaskCache> entry : this.taskCaches.entrySet()) {
            if (!modelId.equals(entry.getValue().mlTask.getModelId())) continue;
            return true;
        }
        return false;
    }

    public String[] getLocalRunningLoadModelTasks() {
        ArrayList<String> runningLoadModelTaskIds = new ArrayList<String>();
        for (Map.Entry<String, MLTaskCache> entry : this.taskCaches.entrySet()) {
            MLTask mlTask = entry.getValue().getMlTask();
            if (mlTask.getTaskType() != MLTaskType.LOAD_MODEL || mlTask.getState() == MLTaskState.CREATED) continue;
            runningLoadModelTaskIds.add(entry.getKey());
        }
        return runningLoadModelTaskIds.toArray(new String[0]);
    }

    public void syncRunningLoadModelTasks(Map<String, Set<String>> runningLoadModelTasks) {
        Instant ttlEndTime = Instant.now().minus(10L, ChronoUnit.MINUTES);
        HashSet<String> staleTasks = new HashSet<String>();
        boolean noRunningTask = runningLoadModelTasks == null || runningLoadModelTasks.size() == 0;
        for (Map.Entry<String, MLTaskCache> entry : this.taskCaches.entrySet()) {
            String taskId = entry.getKey();
            MLTask mlTask = entry.getValue().getMlTask();
            boolean exceedTTL = mlTask.getLastUpdateTime().isBefore(ttlEndTime);
            if (!exceedTTL || mlTask.getTaskType() != MLTaskType.LOAD_MODEL || mlTask.getState() != MLTaskState.CREATED || !noRunningTask && runningLoadModelTasks.containsKey(taskId)) continue;
            staleTasks.add(entry.getKey());
        }
        if (staleTasks.size() > 0) {
            log.debug("remove stale load tasks : {}", (Object)Arrays.toString(staleTasks.toArray(new String[0])));
            for (String taskId : staleTasks) {
                this.taskCaches.remove(taskId);
            }
        }
    }
}

