package org.campagnelab.dl.framework.training;

import it.unimi.dsi.logging.ProgressLogger;
import org.campagnelab.dl.framework.iterators.MDSHelper;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/campagnelab/dl/framework/training/SequentialTrainer.class */
public class SequentialTrainer implements Trainer {
    private boolean logSpeed;
    private static Logger LOG = LoggerFactory.getLogger(SequentialTrainer.class);
    double score;
    int n;

    @Override // org.campagnelab.dl.framework.training.Trainer
    public int train(ComputationGraph computationGraph, MultiDataSetIterator multiDataSetIterator, ProgressLogger progressLogger) {
        int i = 0;
        int i2 = 0;
        this.score = 0.0d;
        this.n = 0;
        while (multiDataSetIterator.hasNext()) {
            MultiDataSet multiDataSet = (MultiDataSet) multiDataSetIterator.next();
            MDSHelper.attach(multiDataSet);
            computationGraph.fit(multiDataSet);
            double score = computationGraph.score();
            if (score != score) {
                i2++;
            } else {
                i2 = 0;
                this.score += score;
                this.n++;
            }
            int size = multiDataSet.getFeatures(0).size(0);
            i += size;
            multiDataSet.detach();
            if (this.logSpeed) {
                progressLogger.update();
            }
            if (i2 > 100) {
                LOG.error("Nan score encountered too many consecutive times");
                return size;
            }
        }
        return i;
    }

    @Override // org.campagnelab.dl.framework.training.Trainer
    public void setLogSpeed(boolean z) {
        this.logSpeed = z;
    }

    @Override // org.campagnelab.dl.framework.training.Trainer
    public double getScore() {
        return this.score / this.n;
    }
}
