/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.models.automl;

import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.SplittableRandom;
import java.util.stream.Collectors;
import org.neo4j.gds.ml.api.TrainingMethod;
import org.neo4j.gds.ml.models.TrainerConfig;
import org.neo4j.gds.ml.models.automl.HyperParameterOptimizer;
import org.neo4j.gds.ml.models.automl.TunableTrainerConfig;
import org.neo4j.gds.ml.models.automl.hyperparameter.DoubleRangeParameter;
import org.neo4j.gds.ml.models.automl.hyperparameter.IntegerRangeParameter;

public class RandomSearch
implements HyperParameterOptimizer {
    private final List<TunableTrainerConfig> concreteConfigs;
    private final List<TunableTrainerConfig> tunableConfigs;
    private final int totalNumberOfTrials;
    private final int numberOfConcreteTrials;
    private final SplittableRandom random;
    private int numberOfFinishedTrials;

    public RandomSearch(Map<TrainingMethod, List<TunableTrainerConfig>> parameterSpace, int maxTrials, long randomSeed) {
        this(parameterSpace, maxTrials, Optional.of(randomSeed));
    }

    public RandomSearch(Map<TrainingMethod, List<TunableTrainerConfig>> parameterSpace, int maxTrials, Optional<Long> randomSeed) {
        this.concreteConfigs = parameterSpace.values().stream().flatMap(Collection::stream).filter(TunableTrainerConfig::isConcrete).collect(Collectors.toList());
        this.tunableConfigs = parameterSpace.values().stream().flatMap(Collection::stream).filter(tunableTrainerConfig -> !tunableTrainerConfig.isConcrete()).collect(Collectors.toList());
        this.numberOfConcreteTrials = this.concreteConfigs.size();
        this.totalNumberOfTrials = maxTrials + this.numberOfConcreteTrials;
        this.random = randomSeed.map(SplittableRandom::new).orElseGet(SplittableRandom::new);
        this.numberOfFinishedTrials = 0;
    }

    @Override
    public boolean hasNext() {
        return this.numberOfFinishedTrials < this.numberOfConcreteTrials || this.numberOfFinishedTrials < this.totalNumberOfTrials && !this.tunableConfigs.isEmpty();
    }

    @Override
    public TrainerConfig next() {
        if (!this.hasNext()) {
            throw new IllegalStateException("RandomSearch has already exhausted the maximum trials or the parameter space.");
        }
        if (this.numberOfFinishedTrials < this.concreteConfigs.size()) {
            return this.concreteConfigs.get(this.numberOfFinishedTrials++).materialize(Map.of());
        }
        ++this.numberOfFinishedTrials;
        TunableTrainerConfig tunableConfig = this.tunableConfigs.get(this.random.nextInt(this.tunableConfigs.size()));
        return this.sample(tunableConfig);
    }

    private TrainerConfig sample(TunableTrainerConfig tunableConfig) {
        HashMap<String, Object> hyperParameterValues = new HashMap<String, Object>();
        tunableConfig.doubleRanges.forEach((name, range) -> hyperParameterValues.put((String)name, this.sampleDouble((DoubleRangeParameter)range)));
        tunableConfig.integerRanges.forEach((name, range) -> hyperParameterValues.put((String)name, this.sampleInteger((IntegerRangeParameter)range)));
        return tunableConfig.materialize(hyperParameterValues);
    }

    private int sampleInteger(IntegerRangeParameter range) {
        return this.random.nextInt((Integer)range.min(), (Integer)range.max());
    }

    private double sampleDouble(DoubleRangeParameter range) {
        if (range.logScale()) {
            double min = (Double)range.min() < 1.0E-20 ? Math.log(1.0E-20) : Math.log((Double)range.min());
            double max = Math.log((Double)range.max());
            return Math.exp(this.random.nextDouble(min, max));
        }
        return this.random.nextDouble((Double)range.min(), (Double)range.max());
    }
}

