/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.model;

import com.google.common.math.Quantiles;
import java.util.DoubleSummaryStatistics;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.stream.DoubleStream;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.engine.Predictable;
import org.opensearch.ml.profile.MLPredictRequestStats;

public class MLModelCache {
    @Generated
    private static final Logger log = LogManager.getLogger(MLModelCache.class);
    private MLModelState modelState;
    private FunctionName functionName;
    private Predictable predictor;
    private final Set<String> workerNodes = ConcurrentHashMap.newKeySet();
    private final Queue<Double> modelInferenceDurationQueue = new ConcurrentLinkedQueue<Double>();
    private final Queue<Double> predictRequestDurationQueue = new ConcurrentLinkedQueue<Double>();

    public void removeWorkerNode(String nodeId) {
        this.workerNodes.remove(nodeId);
    }

    public void removeWorkerNodes(Set<String> removedNodes) {
        this.workerNodes.removeAll(removedNodes);
    }

    public void addWorkerNode(String nodeId) {
        this.workerNodes.add(nodeId);
    }

    public String[] getWorkerNodes() {
        return this.workerNodes.toArray(new String[0]);
    }

    public void syncWorkerNode(Set<String> workerNodes) {
        this.workerNodes.clear();
        this.workerNodes.addAll(workerNodes);
    }

    public void clearWorkerNodes() {
        this.workerNodes.clear();
    }

    public void clear() {
        this.modelState = null;
        this.functionName = null;
        this.workerNodes.clear();
        this.modelInferenceDurationQueue.clear();
        this.predictRequestDurationQueue.clear();
        if (this.predictor != null) {
            this.predictor.close();
        }
    }

    public void addModelInferenceDuration(double duration, long maxRequestCount) {
        this.addInferenceDuration(duration, maxRequestCount, this.modelInferenceDurationQueue);
    }

    public void addPredictRequestDuration(double duration, long maxRequestCount) {
        this.addInferenceDuration(duration, maxRequestCount, this.predictRequestDurationQueue);
    }

    private void addInferenceDuration(double duration, long maxRequestCount, Queue<Double> queue) {
        this.resizeInferenceQueue(maxRequestCount, queue);
        if (maxRequestCount > 0L) {
            queue.add(duration);
        }
    }

    public void resizeMonitoringQueue(long maxRequestCount) {
        log.debug("resize inference duration monitoring queue with size {}", (Object)maxRequestCount);
        this.resizeInferenceQueue(maxRequestCount, this.predictRequestDurationQueue);
        this.resizeInferenceQueue(maxRequestCount, this.modelInferenceDurationQueue);
    }

    private void resizeInferenceQueue(long maxRequestCount, Queue<Double> queue) {
        if (maxRequestCount <= 0L) {
            queue.clear();
        } else {
            while ((long)queue.size() >= maxRequestCount) {
                queue.poll();
            }
        }
    }

    public MLPredictRequestStats getInferenceStats(boolean modelInference) {
        Queue<Double> queue;
        Queue<Double> queue2 = queue = modelInference ? this.modelInferenceDurationQueue : this.predictRequestDurationQueue;
        if (queue.size() > 0) {
            MLPredictRequestStats.MLPredictRequestStatsBuilder statsBuilder = MLPredictRequestStats.builder();
            DoubleStream doubleStream = queue.stream().mapToDouble(v -> v);
            DoubleSummaryStatistics doubleSummaryStatistics = doubleStream.summaryStatistics();
            statsBuilder.count(doubleSummaryStatistics.getCount());
            statsBuilder.max(doubleSummaryStatistics.getMax());
            statsBuilder.min(doubleSummaryStatistics.getMin());
            statsBuilder.average(doubleSummaryStatistics.getAverage());
            Quantiles.Scale percentiles = Quantiles.percentiles();
            statsBuilder.p50(percentiles.index(50).compute(queue));
            statsBuilder.p90(percentiles.index(90).compute(queue));
            statsBuilder.p99(percentiles.index(99).compute(queue));
            return statsBuilder.build();
        }
        return null;
    }

    public boolean isValidCache() {
        return this.modelState != null || this.workerNodes.size() > 0;
    }

    @Generated
    protected void setModelState(MLModelState modelState) {
        this.modelState = modelState;
    }

    @Generated
    protected MLModelState getModelState() {
        return this.modelState;
    }

    @Generated
    protected void setFunctionName(FunctionName functionName) {
        this.functionName = functionName;
    }

    @Generated
    protected FunctionName getFunctionName() {
        return this.functionName;
    }

    @Generated
    protected void setPredictor(Predictable predictor) {
        this.predictor = predictor;
    }

    @Generated
    protected Predictable getPredictor() {
        return this.predictor;
    }
}

