/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.discriminative;

import org.apache.mahout.classifier.discriminative.LinearModel;
import org.apache.mahout.classifier.discriminative.TrainingException;
import org.apache.mahout.math.CardinalityException;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class LinearTrainer {
    private static final Logger log = LoggerFactory.getLogger(LinearTrainer.class);
    private final LinearModel model;

    protected LinearTrainer(int dimension, double threshold, double init, double initBias) {
        DenseVector initialWeights = new DenseVector(dimension);
        initialWeights.assign(init);
        this.model = new LinearModel((Vector)initialWeights, initBias, threshold);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void train(Vector labelset, Matrix dataset) throws TrainingException {
        if (labelset.size() != dataset.size()[1]) {
            throw new CardinalityException(labelset.size(), dataset.size()[1]);
        }
        boolean converged = false;
        int iteration = 0;
        while (!converged) {
            if (iteration > 1000) {
                throw new TrainingException("Too many iterations needed to find hyperplane.");
            }
            converged = true;
            int columnCount = dataset.size()[1];
            for (int i = 0; i < columnCount; ++i) {
                Vector dataPoint = dataset.getColumn(i);
                log.debug("Training point: " + dataPoint);
                LinearModel linearModel = this.model;
                synchronized (linearModel) {
                    boolean prediction = this.model.classify(dataPoint);
                    double label = labelset.get(i);
                    if (label <= 0.0 && prediction || label > 0.0 && !prediction) {
                        log.debug("updating");
                        converged = false;
                        this.update(label, dataPoint, this.model);
                    }
                    continue;
                }
            }
            ++iteration;
        }
    }

    public LinearModel getModel() {
        return this.model;
    }

    protected abstract void update(double var1, Vector var3, LinearModel var4);
}

