/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.pipeline;

import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.Collection;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.neo4j.gds.core.utils.TimeUtil;
import org.neo4j.gds.ml.models.TrainerConfig;
import org.neo4j.gds.ml.models.TrainingMethod;
import org.neo4j.gds.ml.models.automl.TunableTrainerConfig;
import org.neo4j.gds.ml.pipeline.AutoTuningConfig;
import org.neo4j.gds.ml.pipeline.ExecutableNodePropertyStep;
import org.neo4j.gds.ml.pipeline.FeatureStep;
import org.neo4j.gds.ml.pipeline.Pipeline;
import org.neo4j.gds.utils.StringFormatting;

public abstract class TrainingPipeline<FEATURE_STEP extends FeatureStep>
implements Pipeline<FEATURE_STEP> {
    protected final List<ExecutableNodePropertyStep> nodePropertySteps = new ArrayList<ExecutableNodePropertyStep>();
    protected final List<FEATURE_STEP> featureSteps = new ArrayList<FEATURE_STEP>();
    private final ZonedDateTime creationTime = TimeUtil.now();
    protected Map<TrainingMethod, List<TunableTrainerConfig>> trainingParameterSpace = new EnumMap<TrainingMethod, List<TunableTrainerConfig>>(TrainingMethod.class);
    protected AutoTuningConfig autoTuningConfig = AutoTuningConfig.DEFAULT_CONFIG;

    public static Map<String, List<Map<String, Object>>> toMapParameterSpace(Map<TrainingMethod, List<TunableTrainerConfig>> parameterSpace) {
        return parameterSpace.entrySet().stream().collect(Collectors.toMap(entry -> ((TrainingMethod)entry.getKey()).toString(), entry -> ((List)entry.getValue()).stream().map(TunableTrainerConfig::toMap).collect(Collectors.toList())));
    }

    protected TrainingPipeline(TrainingType trainingType) {
        trainingType.supportedMethods().forEach(method -> this.trainingParameterSpace.put((TrainingMethod)method, new ArrayList()));
    }

    public Map<String, Object> toMap() {
        HashMap<String, Object> map = new HashMap<String, Object>();
        map.put("featurePipeline", this.featurePipelineDescription());
        map.put("trainingParameterSpace", TrainingPipeline.toMapParameterSpace(this.trainingParameterSpace));
        map.put("autoTuningConfig", this.autoTuningConfig().toMap());
        map.putAll(this.additionalEntries());
        return map;
    }

    public abstract String type();

    protected abstract Map<String, List<Map<String, Object>>> featurePipelineDescription();

    protected abstract Map<String, Object> additionalEntries();

    private int numberOfTrainerConfigs() {
        return this.trainingParameterSpace().values().stream().mapToInt(List::size).sum();
    }

    public void addNodePropertyStep(ExecutableNodePropertyStep step) {
        this.validateUniqueMutateProperty(step);
        this.nodePropertySteps.add(step);
    }

    public void addFeatureStep(FEATURE_STEP featureStep) {
        this.featureSteps.add(featureStep);
    }

    @Override
    public List<ExecutableNodePropertyStep> nodePropertySteps() {
        return this.nodePropertySteps;
    }

    @Override
    public List<FEATURE_STEP> featureSteps() {
        return this.featureSteps;
    }

    public Map<TrainingMethod, List<TunableTrainerConfig>> trainingParameterSpace() {
        return this.trainingParameterSpace;
    }

    private int concreteTrainerConfigsCount() {
        return (int)this.trainingParameterSpace().values().stream().flatMap(Collection::stream).filter(TunableTrainerConfig::isConcrete).count();
    }

    public int numberOfModelSelectionTrials() {
        int concreteTrainerConfigsCount = this.concreteTrainerConfigsCount();
        return concreteTrainerConfigsCount == this.numberOfTrainerConfigs() ? this.numberOfTrainerConfigs() : this.autoTuningConfig().maxTrials() + concreteTrainerConfigsCount;
    }

    public void addTrainerConfig(TunableTrainerConfig trainingConfig) {
        this.trainingParameterSpace.get(trainingConfig.trainingMethod()).add(trainingConfig);
    }

    public void addTrainerConfig(TrainerConfig trainingConfig) {
        this.trainingParameterSpace.get(trainingConfig.method()).add(trainingConfig.toTunableConfig());
    }

    public AutoTuningConfig autoTuningConfig() {
        return this.autoTuningConfig;
    }

    public void setAutoTuningConfig(AutoTuningConfig autoTuningConfig) {
        this.autoTuningConfig = autoTuningConfig;
    }

    public void validateTrainingParameterSpace() {
        if (this.numberOfModelSelectionTrials() == 0) {
            throw new IllegalArgumentException("Need at least one model candidate for training.");
        }
    }

    private void validateUniqueMutateProperty(ExecutableNodePropertyStep step) {
        this.nodePropertySteps.forEach(nodePropertyStep -> {
            String existingMutatePropertyName;
            String newMutatePropertyName = step.mutateNodeProperty();
            if (newMutatePropertyName.equals(existingMutatePropertyName = nodePropertyStep.mutateNodeProperty())) {
                throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"The value of `%s` is expected to be unique, but %s was already specified in the %s procedure.", (Object[])new Object[]{"mutateProperty", newMutatePropertyName, nodePropertyStep.procName()}));
            }
        });
    }

    public ZonedDateTime creationTime() {
        return this.creationTime;
    }

    protected static enum TrainingType {
        CLASSIFICATION{

            @Override
            List<TrainingMethod> supportedMethods() {
                return List.of(TrainingMethod.LogisticRegression, TrainingMethod.RandomForestClassification, TrainingMethod.MLPClassification);
            }
        }
        ,
        REGRESSION{

            @Override
            List<TrainingMethod> supportedMethods() {
                return List.of(TrainingMethod.LinearRegression, TrainingMethod.RandomForestRegression);
            }
        };


        abstract List<TrainingMethod> supportedMethods();
    }
}

