/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine;

import java.nio.file.Path;
import java.util.Map;
import lombok.Generated;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.Output;
import org.opensearch.ml.engine.Executable;
import org.opensearch.ml.engine.MLEngineClassLoader;
import org.opensearch.ml.engine.Predictable;
import org.opensearch.ml.engine.TrainAndPredictable;
import org.opensearch.ml.engine.Trainable;

public class MLEngine {
    private final Path djlCachePath;
    private final Path djlModelsCachePath;

    public MLEngine(Path opensearchDataFolder) {
        this.djlCachePath = opensearchDataFolder.resolve("djl");
        this.djlModelsCachePath = this.djlCachePath.resolve("models_cache");
    }

    public Path getUploadModelPath(String modelId, String modelName, String version) {
        return this.getUploadModelPath(modelId).resolve(version).resolve(modelName);
    }

    public Path getUploadModelPath(String modelId) {
        return this.getUploadModelRootPath().resolve(modelId);
    }

    public Path getUploadModelRootPath() {
        return this.djlModelsCachePath.resolve("upload");
    }

    public Path getLoadModelPath(String modelId) {
        return this.getLoadModelRootPath().resolve(modelId);
    }

    public String getLoadModelZipPath(String modelId, String modelName) {
        return this.djlModelsCachePath.resolve("load").resolve(modelId).resolve(modelName) + ".zip";
    }

    public Path getLoadModelRootPath() {
        return this.djlModelsCachePath.resolve("load");
    }

    public Path getLoadModelChunkPath(String modelId, Integer chunkNumber) {
        return this.djlModelsCachePath.resolve("load").resolve(modelId).resolve("chunks").resolve("" + chunkNumber);
    }

    public Path getModelCachePath(String modelId, String modelName, String version) {
        return this.getModelCachePath(modelId).resolve(version).resolve(modelName);
    }

    public Path getModelCachePath(String modelId) {
        return this.getModelCacheRootPath().resolve(modelId);
    }

    public Path getModelCacheRootPath() {
        return this.djlModelsCachePath.resolve("models");
    }

    public MLModel train(Input input) {
        this.validateMLInput(input);
        MLInput mlInput = (MLInput)input;
        Trainable trainable = (Trainable)MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class);
        if (trainable == null) {
            throw new IllegalArgumentException("Unsupported algorithm: " + mlInput.getAlgorithm());
        }
        return trainable.train(mlInput);
    }

    public Predictable load(MLModel mlModel, Map<String, Object> params) {
        Predictable predictable = (Predictable)MLEngineClassLoader.initInstance(mlModel.getAlgorithm(), null, MLAlgoParams.class);
        predictable.initModel(mlModel, params);
        return predictable;
    }

    public MLOutput predict(Input input, MLModel model) {
        this.validateMLInput(input);
        MLInput mlInput = (MLInput)input;
        Predictable predictable = (Predictable)MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class);
        if (predictable == null) {
            throw new IllegalArgumentException("Unsupported algorithm: " + mlInput.getAlgorithm());
        }
        return predictable.predict(mlInput, model);
    }

    public MLOutput trainAndPredict(Input input) {
        this.validateMLInput(input);
        MLInput mlInput = (MLInput)input;
        TrainAndPredictable trainAndPredictable = (TrainAndPredictable)MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class);
        if (trainAndPredictable == null) {
            throw new IllegalArgumentException("Unsupported algorithm: " + mlInput.getAlgorithm());
        }
        return trainAndPredictable.trainAndPredict(mlInput);
    }

    public Output execute(Input input) {
        this.validateInput(input);
        Executable executable = (Executable)MLEngineClassLoader.initInstance(input.getFunctionName(), input, Input.class);
        if (executable == null) {
            throw new IllegalArgumentException("Unsupported executable function: " + input.getFunctionName());
        }
        return executable.execute(input);
    }

    private void validateMLInput(Input input) {
        DataFrame dataFrame;
        this.validateInput(input);
        if (!(input instanceof MLInput)) {
            throw new IllegalArgumentException("Input should be MLInput");
        }
        MLInput mlInput = (MLInput)input;
        MLInputDataset inputDataset = mlInput.getInputDataset();
        if (inputDataset == null) {
            throw new IllegalArgumentException("Input data set should not be null");
        }
        if (inputDataset instanceof DataFrameInputDataset && ((dataFrame = ((DataFrameInputDataset)inputDataset).getDataFrame()) == null || dataFrame.size() == 0)) {
            throw new IllegalArgumentException("Input data frame should not be null or empty");
        }
    }

    private void validateInput(Input input) {
        if (input == null) {
            throw new IllegalArgumentException("Input should not be null");
        }
        if (input.getFunctionName() == null) {
            throw new IllegalArgumentException("Function name should not be null");
        }
    }

    @Generated
    public Path getDjlCachePath() {
        return this.djlCachePath;
    }
}

