/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine;

import ai.djl.training.util.DownloadUtils;
import ai.djl.training.util.ProgressBar;
import ai.djl.util.Progress;
import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.security.AccessController;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.utils.FileUtils;

public class ModelHelper {
    @Generated
    private static final Logger log = LogManager.getLogger(ModelHelper.class);
    public static final String CHUNK_FILES = "chunk_files";
    public static final String MODEL_SIZE_IN_BYTES = "model_size_in_bytes";
    public static final String MODEL_FILE_HASH = "model_file_hash";
    public static final int CHUNK_SIZE = 10000000;
    public static final String PYTORCH_FILE_EXTENSION = ".pt";
    public static final String ONNX_FILE_EXTENSION = ".onnx";
    public static final String TOKENIZER_FILE_NAME = "tokenizer.json";
    public static final String PYTORCH_ENGINE = "PyTorch";
    public static final String ONNX_ENGINE = "OnnxRuntime";
    private final MLEngine mlEngine;

    public ModelHelper(MLEngine mlEngine) {
        this.mlEngine = mlEngine;
    }

    public void downloadAndSplit(String modelId, String modelName, String version, String url, ActionListener<Map<String, Object>> listener) {
        try {
            AccessController.doPrivileged(() -> {
                Path modelUploadPath = this.mlEngine.getUploadModelPath(modelId, modelName, version);
                String modelPath = modelUploadPath + ".zip";
                Path modelPartsPath = modelUploadPath.resolve("chunks");
                File modelZipFile = new File(modelPath);
                log.debug("download model to file {}", (Object)modelZipFile.getAbsolutePath());
                DownloadUtils.download((String)url, (String)modelPath, (Progress)new ProgressBar());
                this.verifyModelZipFile(modelPath);
                List<String> chunkFiles = FileUtils.splitFileIntoChunks(modelZipFile, modelPartsPath, 10000000);
                HashMap<String, Object> result = new HashMap<String, Object>();
                result.put(CHUNK_FILES, chunkFiles);
                result.put(MODEL_SIZE_IN_BYTES, modelZipFile.length());
                result.put(MODEL_FILE_HASH, FileUtils.calculateFileHash(modelZipFile));
                FileUtils.deleteFileQuietly(modelZipFile);
                listener.onResponse(result);
                return null;
            });
        }
        catch (Exception e) {
            listener.onFailure(e);
        }
    }

    private void verifyModelZipFile(String modelZipFilePath) throws IOException {
        boolean hasModelFile = false;
        boolean hasTokenizerFile = false;
        try (ZipFile zipFile = new ZipFile(modelZipFilePath);){
            Enumeration<? extends ZipEntry> zipEntries = zipFile.entries();
            while (zipEntries.hasMoreElements()) {
                String fileName = zipEntries.nextElement().getName();
                if (fileName.endsWith(PYTORCH_FILE_EXTENSION) || fileName.endsWith(ONNX_FILE_EXTENSION)) {
                    if (hasModelFile) {
                        throw new IllegalArgumentException("Find multiple model files, but expected only one");
                    }
                    hasModelFile = true;
                }
                if (!fileName.equals(TOKENIZER_FILE_NAME)) continue;
                hasTokenizerFile = true;
            }
        }
        if (!hasModelFile) {
            throw new IllegalArgumentException("Can't find model file");
        }
        if (!hasTokenizerFile) {
            throw new IllegalArgumentException("Can't find tokenizer file");
        }
    }

    public void deleteFileCache(String modelId) {
        FileUtils.deleteFileQuietly(this.mlEngine.getModelCachePath(modelId));
        FileUtils.deleteFileQuietly(this.mlEngine.getLoadModelPath(modelId));
        FileUtils.deleteFileQuietly(this.mlEngine.getUploadModelPath(modelId));
    }
}

