/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.gradientdescent;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.progress.tasks.LogLevel;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.batch.Batch;
import org.neo4j.gds.ml.core.batch.BatchQueue;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.optimizer.AdamOptimizer;
import org.neo4j.gds.ml.core.tensor.Scalar;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.ml.core.tensor.TensorFunctions;
import org.neo4j.gds.ml.gradientdescent.GradientDescentConfig;
import org.neo4j.gds.ml.gradientdescent.Objective;
import org.neo4j.gds.ml.gradientdescent.TrainingStopper;
import org.neo4j.gds.termination.TerminationFlag;
import org.neo4j.gds.utils.StringFormatting;

public class Training {
    private final GradientDescentConfig config;
    private final ProgressTracker progressTracker;
    private final LogLevel messageLogLevel;
    private final long trainSize;
    private final TerminationFlag terminationFlag;

    public Training(GradientDescentConfig config, ProgressTracker progressTracker, LogLevel messageLogLevel, long trainSize, TerminationFlag terminationFlag) {
        this.config = config;
        this.progressTracker = progressTracker;
        this.messageLogLevel = messageLogLevel;
        this.trainSize = trainSize;
        this.terminationFlag = terminationFlag;
    }

    public static MemoryEstimation memoryEstimation(int numberOfFeatures, int numberOfClasses) {
        return Training.memoryEstimation(MemoryRange.of((long)numberOfFeatures), numberOfClasses);
    }

    public static MemoryEstimation memoryEstimation(MemoryRange numberOfFeaturesRange, int numberOfClasses) {
        return MemoryEstimations.builder((String)Training.class.getSimpleName()).add(MemoryEstimations.of((String)"updater", (MemoryRange)numberOfFeaturesRange.apply(features -> AdamOptimizer.sizeInBytes((int)numberOfClasses, (int)Math.toIntExact(features))))).perThread("weight gradients", numberOfFeaturesRange.apply(features -> Weights.sizeInBytes((int)numberOfClasses, (int)Math.toIntExact(features)))).build();
    }

    public void train(Objective<?> objective, Supplier<BatchQueue> queueSupplier, int concurrency) {
        AdamOptimizer updater = new AdamOptimizer(objective.weights(), this.config.learningRate());
        TrainingStopper stopper = TrainingStopper.defaultStopper(this.config);
        ArrayList<Double> losses = new ArrayList<Double>();
        List<ObjectiveUpdateConsumer> consumers = this.executeBatches(concurrency, objective, queueSupplier.get());
        List<? extends Tensor<? extends Tensor<?>>> prevWeightGradients = this.avgWeightGradients(consumers);
        double initialLoss = this.avgLoss(consumers);
        this.progressTracker.logMessage(this.messageLogLevel, StringFormatting.formatWithLocale((String)"Initial loss %s", (Object[])new Object[]{initialLoss}));
        while (!stopper.terminated()) {
            this.terminationFlag.assertRunning();
            updater.update(prevWeightGradients);
            consumers = this.executeBatches(concurrency, objective, queueSupplier.get());
            prevWeightGradients = this.avgWeightGradients(consumers);
            double loss = this.avgLoss(consumers);
            losses.add(loss);
            stopper.registerLoss(loss);
            this.progressTracker.logMessage(this.messageLogLevel, StringFormatting.formatWithLocale((String)"Epoch %d with loss %s", (Object[])new Object[]{losses.size(), loss}));
        }
        this.progressTracker.logMessage(this.messageLogLevel, StringFormatting.formatWithLocale((String)"%s after %d out of %d epochs. Initial loss: %s, Last loss: %s.%s", (Object[])new Object[]{stopper.converged() ? "converged" : "terminated", losses.size(), this.config.maxEpochs(), initialLoss, losses.get(losses.size() - 1), stopper.converged() ? "" : " Did not converge"}));
    }

    private List<ObjectiveUpdateConsumer> executeBatches(int concurrency, Objective<?> objective, BatchQueue batches) {
        ArrayList<ObjectiveUpdateConsumer> consumers = new ArrayList<ObjectiveUpdateConsumer>(concurrency);
        for (int i = 0; i < concurrency; ++i) {
            consumers.add(new ObjectiveUpdateConsumer(objective, this.trainSize));
        }
        batches.parallelConsume(concurrency, consumers, this.terminationFlag);
        return consumers;
    }

    private List<? extends Tensor<? extends Tensor<?>>> avgWeightGradients(List<ObjectiveUpdateConsumer> consumers) {
        List localGradientSums = consumers.stream().map(ObjectiveUpdateConsumer::summedWeightGradients).collect(Collectors.toList());
        int numberOfBatches = consumers.stream().mapToInt(ObjectiveUpdateConsumer::consumedBatches).sum();
        return TensorFunctions.averageTensors(localGradientSums, (int)numberOfBatches);
    }

    private double avgLoss(List<ObjectiveUpdateConsumer> consumers) {
        return consumers.stream().mapToDouble(ObjectiveUpdateConsumer::lossSum).sum() / (double)consumers.stream().mapToInt(ObjectiveUpdateConsumer::consumedBatches).sum();
    }

    static class ObjectiveUpdateConsumer
    implements Consumer<Batch> {
        private final Objective<?> objective;
        private final long trainSize;
        private final List<? extends Tensor<?>> summedWeightGradients;
        private double lossSum;
        private int consumedBatches;

        ObjectiveUpdateConsumer(Objective<?> objective, long trainSize) {
            this.objective = objective;
            this.trainSize = trainSize;
            this.summedWeightGradients = objective.weights().stream().map(weight -> weight.data().createWithSameDimensions()).collect(Collectors.toList());
            this.consumedBatches = 0;
            this.lossSum = 0.0;
        }

        @Override
        public void accept(Batch batch) {
            Variable<Scalar> loss = this.objective.loss(batch, this.trainSize);
            ComputationContext ctx = new ComputationContext();
            this.lossSum += ((Scalar)ctx.forward(loss)).value();
            ctx.backward(loss);
            List localWeightGradient = this.objective.weights().stream().map(arg_0 -> ((ComputationContext)ctx).gradient(arg_0)).collect(Collectors.toList());
            for (int i = 0; i < this.summedWeightGradients.size(); ++i) {
                this.summedWeightGradients.get(i).addInPlace((Tensor)localWeightGradient.get(i));
            }
            ++this.consumedBatches;
        }

        List<? extends Tensor<?>> summedWeightGradients() {
            return this.summedWeightGradients;
        }

        int consumedBatches() {
            return this.consumedBatches;
        }

        double lossSum() {
            return this.lossSum;
        }
    }
}

