/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.sgd;

import com.google.common.base.Preconditions;
import java.util.Iterator;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.classifier.sgd.DefaultGradient;
import org.apache.mahout.classifier.sgd.Gradient;
import org.apache.mahout.classifier.sgd.PriorFunction;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.math.function.Functions;

public abstract class AbstractOnlineLogisticRegression
extends AbstractVectorClassifier
implements OnlineLearner {
    protected Matrix beta;
    protected int numCategories;
    protected int step;
    protected Vector updateSteps;
    protected Vector updateCounts;
    private double lambda = 1.0E-5;
    protected PriorFunction prior;
    private boolean sealed;
    private Gradient gradient = new DefaultGradient();

    public AbstractOnlineLogisticRegression lambda(double lambda) {
        this.lambda = lambda;
        return this;
    }

    public Vector link(Vector v) {
        double max = v.maxValue();
        if (max >= 40.0) {
            v.assign(Functions.minus((double)max)).assign(Functions.EXP);
            return v.divide(v.norm(1.0));
        }
        v.assign(Functions.EXP);
        return v.divide(1.0 + v.norm(1.0));
    }

    public double link(double r) {
        if (r < 0.0) {
            double s = Math.exp(r);
            return s / (1.0 + s);
        }
        double s = Math.exp(-r);
        return 1.0 / (1.0 + s);
    }

    @Override
    public Vector classifyNoLink(Vector instance) {
        this.regularize(instance);
        return this.beta.times(instance);
    }

    public double classifyScalarNoLink(Vector instance) {
        return this.beta.getRow(0).dot(instance);
    }

    @Override
    public Vector classify(Vector instance) {
        return this.link(this.classifyNoLink(instance));
    }

    @Override
    public double classifyScalar(Vector instance) {
        Preconditions.checkArgument((this.numCategories() == 2 ? 1 : 0) != 0, (Object)"Can only call classifyScalar with two categories");
        this.regularize(instance);
        return this.link(this.classifyScalarNoLink(instance));
    }

    @Override
    public void train(long trackingKey, String groupKey, int actual, Vector instance) {
        this.unseal();
        double learningRate = this.currentLearningRate();
        this.regularize(instance);
        Vector gradient = this.gradient.apply(groupKey, actual, instance, this);
        for (int i = 0; i < this.numCategories - 1; ++i) {
            double gradientBase = gradient.get(i);
            Iterator nonZeros = instance.iterateNonZero();
            while (nonZeros.hasNext()) {
                Vector.Element updateLocation = (Vector.Element)nonZeros.next();
                int j = updateLocation.index();
                double newValue = this.beta.getQuick(i, j) + gradientBase * learningRate * this.perTermLearningRate(j) * instance.get(j);
                this.beta.setQuick(i, j, newValue);
            }
        }
        Iterator i = instance.iterateNonZero();
        while (i.hasNext()) {
            Vector.Element element = (Vector.Element)i.next();
            int j = element.index();
            this.updateSteps.setQuick(j, (double)this.getStep());
            this.updateCounts.setQuick(j, this.updateCounts.getQuick(j) + 1.0);
        }
        this.nextStep();
    }

    @Override
    public void train(long trackingKey, int actual, Vector instance) {
        this.train(trackingKey, null, actual, instance);
    }

    @Override
    public void train(int actual, Vector instance) {
        this.train(0L, null, actual, instance);
    }

    public void regularize(Vector instance) {
        if (this.updateSteps == null || this.isSealed()) {
            return;
        }
        double learningRate = this.currentLearningRate();
        for (int i = 0; i < this.numCategories - 1; ++i) {
            Iterator nonZeros = instance.iterateNonZero();
            while (nonZeros.hasNext()) {
                Vector.Element updateLocation = (Vector.Element)nonZeros.next();
                int j = updateLocation.index();
                double missingUpdates = (double)this.getStep() - this.updateSteps.get(j);
                if (!(missingUpdates > 0.0)) continue;
                double rate = this.getLambda() * learningRate * this.perTermLearningRate(j);
                double newValue = this.prior.age(this.beta.get(i, j), missingUpdates, rate);
                this.beta.set(i, j, newValue);
                this.updateSteps.set(j, (double)this.getStep());
            }
        }
    }

    public abstract double perTermLearningRate(int var1);

    public abstract double currentLearningRate();

    public void setPrior(PriorFunction prior) {
        this.prior = prior;
    }

    public void setGradient(Gradient gradient) {
        this.gradient = gradient;
    }

    public PriorFunction getPrior() {
        return this.prior;
    }

    public Matrix getBeta() {
        this.close();
        return this.beta;
    }

    public void setBeta(int i, int j, double betaIJ) {
        this.beta.set(i, j, betaIJ);
    }

    @Override
    public int numCategories() {
        return this.numCategories;
    }

    public int numFeatures() {
        return this.beta.numCols();
    }

    public double getLambda() {
        return this.lambda;
    }

    public int getStep() {
        return this.step;
    }

    protected void nextStep() {
        ++this.step;
    }

    public boolean isSealed() {
        return this.sealed;
    }

    protected void unseal() {
        this.sealed = false;
    }

    private void regularizeAll() {
        DenseVector all = new DenseVector(this.beta.numCols());
        all.assign(1.0);
        this.regularize((Vector)all);
    }

    @Override
    public void close() {
        if (!this.sealed) {
            ++this.step;
            this.regularizeAll();
            this.sealed = true;
        }
    }

    public void copyFrom(AbstractOnlineLogisticRegression other) {
        Preconditions.checkArgument((this.numCategories == other.numCategories ? 1 : 0) != 0, (Object)"Can't copy unless number of target categories is the same");
        this.beta.assign(other.beta);
        this.step = other.step;
        this.updateSteps.assign(other.updateSteps);
        this.updateCounts.assign(other.updateCounts);
    }

    public boolean validModel() {
        double k = this.beta.aggregate(Functions.PLUS, new DoubleFunction(){

            public double apply(double v) {
                return Double.isNaN(v) || Double.isInfinite(v) ? 1.0 : 0.0;
            }
        });
        return k < 1.0;
    }
}

