/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.common;

import java.lang.reflect.Constructor;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ml.common.annotation.InputDataSet;
import org.opensearch.ml.common.annotation.MLAlgoOutput;
import org.opensearch.ml.common.annotation.MLAlgoParameter;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.parameter.FunctionName;
import org.opensearch.ml.common.parameter.MLOutputType;
import org.reflections.Reflections;
import org.reflections.scanners.Scanner;

public class MLCommonsClassLoader {
    private static final Logger logger = LogManager.getLogger(MLCommonsClassLoader.class);
    private static Map<Enum<?>, Class<?>> parameterClassMap = new HashMap();

    public static void loadClassMapping() {
        MLCommonsClassLoader.loadMLAlgoParameterClassMapping();
        MLCommonsClassLoader.loadMLInputDataSetClassMapping();
    }

    private static void loadMLAlgoParameterClassMapping() {
        Reflections reflections = new Reflections("org.opensearch.ml.common.parameter", new Scanner[0]);
        Set<Class<?>> classes = reflections.getTypesAnnotatedWith(MLAlgoParameter.class);
        for (Class<?> clazz : classes) {
            MLAlgoParameter mlAlgoParameter = clazz.getAnnotation(MLAlgoParameter.class);
            FunctionName[] algorithms = mlAlgoParameter.algorithms();
            if (algorithms == null || algorithms.length <= 0) continue;
            for (FunctionName name : algorithms) {
                parameterClassMap.put(name, clazz);
            }
        }
        classes = reflections.getTypesAnnotatedWith(MLAlgoOutput.class);
        for (Class<?> clazz : classes) {
            MLAlgoOutput mlAlgoOutput = clazz.getAnnotation(MLAlgoOutput.class);
            MLOutputType mlOutputType = mlAlgoOutput.value();
            if (mlOutputType == null) continue;
            parameterClassMap.put(mlOutputType, clazz);
        }
    }

    private static void loadMLInputDataSetClassMapping() {
        Reflections reflections = new Reflections("org.opensearch.ml.common.dataset", new Scanner[0]);
        Set<Class<?>> classes = reflections.getTypesAnnotatedWith(InputDataSet.class);
        for (Class<?> clazz : classes) {
            InputDataSet inputDataSet = clazz.getAnnotation(InputDataSet.class);
            MLInputDataType value = inputDataSet.value();
            if (value == null) continue;
            parameterClassMap.put(value, clazz);
        }
    }

    public static <T extends Enum<T>, S, I> S initInstance(T type, I in, Class<?> constructorParamClass) {
        Class<?> clazz = parameterClassMap.get(type);
        if (clazz == null) {
            throw new IllegalArgumentException("Can't find class for type " + type);
        }
        try {
            Constructor<?> constructor = clazz.getConstructor(constructorParamClass);
            return (S)constructor.newInstance(in);
        }
        catch (Exception e) {
            Throwable cause = e.getCause();
            if (cause instanceof MLException) {
                throw (MLException)cause;
            }
            logger.error("Failed to init instance for type " + type, (Throwable)e);
            return null;
        }
    }

    static {
        try {
            AccessController.doPrivileged(() -> {
                MLCommonsClassLoader.loadClassMapping();
                return null;
            });
        }
        catch (PrivilegedActionException e) {
            throw new RuntimeException("Can't load class mapping in ML commons", e);
        }
    }
}

