/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.models.automl;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.neo4j.gds.ml.api.TrainingMethod;
import org.neo4j.gds.ml.models.TrainerConfig;
import org.neo4j.gds.ml.models.automl.ParameterParser;
import org.neo4j.gds.ml.models.automl.hyperparameter.ConcreteParameter;
import org.neo4j.gds.ml.models.automl.hyperparameter.DoubleRangeParameter;
import org.neo4j.gds.ml.models.automl.hyperparameter.IntegerRangeParameter;
import org.neo4j.gds.ml.models.automl.hyperparameter.NumericalRangeParameter;
import org.neo4j.gds.ml.models.linearregression.LinearRegressionTrainConfig;
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionTrainConfig;
import org.neo4j.gds.ml.models.mlp.MLPClassifierTrainConfig;
import org.neo4j.gds.ml.models.randomforest.RandomForestClassifierTrainerConfig;
import org.neo4j.gds.ml.models.randomforest.RandomForestRegressorTrainerConfig;
import org.neo4j.gds.utils.StringFormatting;

public final class TunableTrainerConfig {
    static final double EPSILON = 1.0E-8;
    static final List<String> LOG_SCALE_PARAMETERS = List.of("penalty", "learningRate", "tolerance");
    static final Map<String, Class> NON_NUMERIC_PARAMETERS = Map.of("criterion", String.class, "hiddenLayerSizes", List.class, "classWeights", List.class);
    private final Map<String, ConcreteParameter<?>> concreteParameters;
    public final Map<String, DoubleRangeParameter> doubleRanges;
    public final Map<String, IntegerRangeParameter> integerRanges;
    private final TrainingMethod method;

    private TunableTrainerConfig(Map<String, ConcreteParameter<?>> concreteParameters, Map<String, DoubleRangeParameter> doubleRanges, Map<String, IntegerRangeParameter> integerRanges, TrainingMethod method) {
        this.concreteParameters = concreteParameters;
        this.doubleRanges = doubleRanges;
        this.integerRanges = integerRanges;
        this.method = method;
    }

    public static TunableTrainerConfig of(Map<String, Object> userInput, TrainingMethod method) {
        ParameterParser.RangeParameters rangeParameters = ParameterParser.parseRangeParameters(userInput);
        Map defaults = TunableTrainerConfig.createTrainerConfigFromMap(Map.of(), method).toMap();
        Map<String, Object> inputWithDefaults = TunableTrainerConfig.fillDefaults(userInput, defaults);
        Map<String, ConcreteParameter<?>> concreteParameters = ParameterParser.parseConcreteParameters(inputWithDefaults);
        TunableTrainerConfig tunableTrainerConfig = new TunableTrainerConfig(concreteParameters, rangeParameters.doubleRanges(), rangeParameters.integerRanges(), method);
        tunableTrainerConfig.streamCornerCaseConfigs().forEach(unused -> {});
        return tunableTrainerConfig;
    }

    private static Map<String, Object> fillDefaults(Map<String, Object> userInput, Map<String, Object> defaults) {
        return Stream.concat(defaults.keySet().stream(), userInput.keySet().stream()).distinct().filter(key -> !key.equals("methodName")).collect(Collectors.toMap(key -> key, key -> userInput.getOrDefault(key, defaults.get(key))));
    }

    public TrainerConfig materialize(Map<String, Object> hyperParameterValues) {
        HashMap<String, Object> materializedMap = new HashMap<String, Object>();
        this.concreteParameters.forEach((key, value) -> materializedMap.put((String)key, value.value()));
        materializedMap.putAll(hyperParameterValues);
        return TunableTrainerConfig.createTrainerConfigFromMap(materializedMap, this.method);
    }

    public Stream<TrainerConfig> streamCornerCaseConfigs() {
        HashMap<String, NumericalRangeParameter<Double>> rangeParameters = new HashMap<String, NumericalRangeParameter<Double>>();
        rangeParameters.putAll(this.doubleRanges);
        rangeParameters.putAll(this.integerRanges);
        int numberOfHyperParameters = rangeParameters.size();
        if (numberOfHyperParameters > 20) {
            throw new IllegalArgumentException("Currently at most 20 hyperparameters are supported");
        }
        return IntStream.range(0, (int)Math.pow(2.0, numberOfHyperParameters)).mapToObj(bitset -> {
            HashMap<String, Object> hyperParameterValues = new HashMap<String, Object>();
            int parameterIdx = 0;
            for (Map.Entry entry : rangeParameters.entrySet()) {
                boolean useMin = (bitset >> parameterIdx & 1) == 0;
                NumericalRangeParameter range = (NumericalRangeParameter)entry.getValue();
                String key = (String)entry.getKey();
                Number materializedValue = TunableTrainerConfig.endPoint(useMin, this.doubleRanges.containsKey(key), range);
                hyperParameterValues.put(key, materializedValue);
                ++parameterIdx;
            }
            return this.materialize(hyperParameterValues);
        });
    }

    private static Number endPoint(boolean useMin, boolean isDouble, NumericalRangeParameter<?> range) {
        if (isDouble) {
            return useMin ? (Double)range.min() + 1.0E-8 : (Double)range.max() - 1.0E-8;
        }
        return useMin ? range.min() : range.max();
    }

    public Map<String, Object> toMap() {
        HashMap<String, Object> result = new HashMap<String, Object>();
        this.concreteParameters.forEach((key, value) -> result.put((String)key, value.value()));
        this.doubleRanges.forEach((key, value) -> result.put((String)key, value.toMap()));
        this.integerRanges.forEach((key, value) -> result.put((String)key, value.toMap()));
        result.put("methodName", this.trainingMethod().toString());
        return result;
    }

    public TrainingMethod trainingMethod() {
        return this.method;
    }

    public boolean isConcrete() {
        return this.doubleRanges.isEmpty() && this.integerRanges.isEmpty();
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        TunableTrainerConfig that = (TunableTrainerConfig)o;
        return Objects.equals(this.concreteParameters, that.concreteParameters) && this.method == that.method;
    }

    public int hashCode() {
        return Objects.hash(this.concreteParameters, this.method);
    }

    private static TrainerConfig createTrainerConfigFromMap(Map<String, Object> configMap, TrainingMethod method) {
        switch (method) {
            case LogisticRegression: {
                return LogisticRegressionTrainConfig.of(configMap);
            }
            case RandomForestClassification: {
                return RandomForestClassifierTrainerConfig.of(configMap);
            }
            case MLPClassification: {
                return MLPClassifierTrainConfig.of(configMap);
            }
            case LinearRegression: {
                return LinearRegressionTrainConfig.of(configMap);
            }
            case RandomForestRegression: {
                return RandomForestRegressorTrainerConfig.of(configMap);
            }
        }
        throw new IllegalStateException(StringFormatting.formatWithLocale((String)"Method %s does not have a trainerConfig Implemented", (Object[])new Object[]{method.name()}));
    }
}

