/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.gradientdescent;

import org.neo4j.gds.ml.gradientdescent.TrainingStopper;

class StreakStopper
implements TrainingStopper {
    private final int minEpochs;
    private final int patience;
    private final int maxEpochs;
    private final double tolerance;
    private int ranEpochs;
    private double bestLoss;
    private int unproductiveStreak;

    StreakStopper(int minEpochs, int patience, int maxEpochs, double tolerance) {
        this.minEpochs = minEpochs;
        this.patience = patience;
        this.maxEpochs = maxEpochs;
        this.tolerance = tolerance;
        this.bestLoss = Double.MAX_VALUE;
    }

    @Override
    public void registerLoss(double loss) {
        if (this.terminated()) {
            throw new IllegalStateException("Does not accept losses after convergence");
        }
        if (this.ranEpochs >= this.minEpochs) {
            this.unproductiveStreak = loss - this.bestLoss >= -this.tolerance * Math.abs(this.bestLoss) ? ++this.unproductiveStreak : 0;
        }
        ++this.ranEpochs;
        this.bestLoss = Math.min(this.bestLoss, loss);
    }

    @Override
    public boolean terminated() {
        return this.ranEpochs >= this.maxEpochs || this.unproductiveStreak >= this.patience;
    }

    @Override
    public boolean converged() {
        return this.unproductiveStreak >= this.patience;
    }
}

