package edu.stanford.nlp.classify;

import edu.stanford.nlp.classify.CrossValidator;
import edu.stanford.nlp.classify.LogPrior;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.CGMinimizer;
import edu.stanford.nlp.optimization.DiffFunction;
import edu.stanford.nlp.optimization.Evaluator;
import edu.stanford.nlp.optimization.GoldenSectionLineSearch;
import edu.stanford.nlp.optimization.HasEvaluators;
import edu.stanford.nlp.optimization.HybridMinimizer;
import edu.stanford.nlp.optimization.InefficientSGDMinimizer;
import edu.stanford.nlp.optimization.LineSearcher;
import edu.stanford.nlp.optimization.Minimizer;
import edu.stanford.nlp.optimization.QNMinimizer;
import edu.stanford.nlp.optimization.SGDMinimizer;
import edu.stanford.nlp.optimization.SGDToQNMinimizer;
import edu.stanford.nlp.optimization.SMDMinimizer;
import edu.stanford.nlp.optimization.SQNMinimizer;
import edu.stanford.nlp.optimization.StochasticCalculateMethods;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.stats.MultiClassAccuracyStats;
import edu.stanford.nlp.stats.Scorer;
import edu.stanford.nlp.util.ArrayUtils;
import edu.stanford.nlp.util.Factory;
import edu.stanford.nlp.util.Function;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.Timing;
import edu.stanford.nlp.util.Triple;
import java.io.BufferedReader;
import java.util.Arrays;
import java.util.List;

/* loaded from: input_file:edu/stanford/nlp/classify/LinearClassifierFactory.class */
public class LinearClassifierFactory<L, F> extends AbstractLinearClassifierFactory<L, F> {
    private static final long serialVersionUID = 7893768984379107397L;
    private double TOL;
    private int mem;
    private boolean verbose;
    private LogPrior logPrior;
    private boolean tuneSigmaHeldOut;
    private boolean tuneSigmaCV;
    private int folds;
    private double min;
    private double max;
    private boolean retrainFromScratchAfterSigmaTuning;
    private Factory<Minimizer<DiffFunction>> minimizerCreator;
    private int evalIters;
    private Evaluator[] evaluators;
    protected static final double[] sigmasToTry = {0.5d, 1.0d, 2.0d, 4.0d, 10.0d, 20.0d, 100.0d};
    private LineSearcher heldOutSearcher;

    /* loaded from: input_file:edu/stanford/nlp/classify/LinearClassifierFactory$Factory15.class */
    private static class Factory15 implements Factory<Minimizer<DiffFunction>> {
        private static final long serialVersionUID = 6215752553371189173L;

        private Factory15() {
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // edu.stanford.nlp.util.Factory
        public Minimizer<DiffFunction> create() {
            return new QNMinimizer(15);
        }
    }

    /* loaded from: input_file:edu/stanford/nlp/classify/LinearClassifierFactory$LinearClassifierCreator.class */
    public static class LinearClassifierCreator<L, F> implements ClassifierCreator, ProbabilisticClassifierCreator {
        LogConditionalObjectiveFunction objective;
        Index<F> featureIndex;
        Index<L> labelIndex;

        public LinearClassifierCreator(LogConditionalObjectiveFunction logConditionalObjectiveFunction, Index<F> index, Index<L> index2) {
            this.objective = logConditionalObjectiveFunction;
            this.featureIndex = index;
            this.labelIndex = index2;
        }

        public LinearClassifierCreator(Index<F> index, Index<L> index2) {
            this.featureIndex = index;
            this.labelIndex = index2;
        }

        public LinearClassifier createLinearClassifier(double[] dArr) {
            return new LinearClassifier(this.objective != null ? this.objective.to2D(dArr) : ArrayUtils.to2D(dArr, this.featureIndex.size(), this.labelIndex.size()), this.featureIndex, this.labelIndex);
        }

        @Override // edu.stanford.nlp.classify.ClassifierCreator
        public Classifier createClassifier(double[] dArr) {
            return createLinearClassifier(dArr);
        }

        @Override // edu.stanford.nlp.classify.ProbabilisticClassifierCreator
        public ProbabilisticClassifier createProbabilisticClassifier(double[] dArr) {
            return createLinearClassifier(dArr);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:edu/stanford/nlp/classify/LinearClassifierFactory$NegativeScorer.class */
    public class NegativeScorer implements Function<Double, Double> {
        public double[] weights = null;
        GeneralDataset<L, F> trainSet;
        GeneralDataset<L, F> devSet;
        Scorer<L> scorer;
        Timing timer;

        public NegativeScorer(GeneralDataset<L, F> generalDataset, GeneralDataset<L, F> generalDataset2, Scorer<L> scorer, Timing timing) {
            this.trainSet = generalDataset;
            this.devSet = generalDataset2;
            this.scorer = scorer;
            this.timer = timing;
        }

        @Override // edu.stanford.nlp.util.Function
        public Double apply(Double d) {
            LinearClassifierFactory.this.setSigma(d.doubleValue());
            double[][] trainWeights = LinearClassifierFactory.this.trainWeights(this.trainSet, this.weights, true);
            this.weights = ArrayUtils.flatten(trainWeights);
            double score = this.scorer.score(new LinearClassifier(trainWeights, this.trainSet.featureIndex, this.trainSet.labelIndex), this.devSet);
            System.err.print("##sigma = " + LinearClassifierFactory.this.getSigma() + " ");
            System.err.println("-> average Score: " + score);
            System.err.println("##time elapsed: " + this.timer.stop() + " milliseconds.");
            this.timer.restart();
            return Double.valueOf(-score);
        }
    }

    public LinearClassifierFactory() {
        this(new Factory15());
        this.mem = 15;
        useQuasiNewton();
    }

    public LinearClassifierFactory(Minimizer<DiffFunction> minimizer) {
        this(minimizer, false);
    }

    public LinearClassifierFactory(Factory<Minimizer<DiffFunction>> factory) {
        this(factory, false);
    }

    public LinearClassifierFactory(boolean z) {
        this(new Factory15(), z);
        this.mem = 15;
        useQuasiNewton();
    }

    public LinearClassifierFactory(double d) {
        this((Factory<Minimizer<DiffFunction>>) new Factory15(), d, false);
        this.mem = 15;
        useQuasiNewton();
    }

    public LinearClassifierFactory(Minimizer<DiffFunction> minimizer, boolean z) {
        this(minimizer, 1.0E-4d, z);
    }

    public LinearClassifierFactory(Factory<Minimizer<DiffFunction>> factory, boolean z) {
        this(factory, 1.0E-4d, z);
    }

    public LinearClassifierFactory(Minimizer<DiffFunction> minimizer, double d, boolean z) {
        this(minimizer, d, z, 1.0d);
    }

    public LinearClassifierFactory(Factory<Minimizer<DiffFunction>> factory, double d, boolean z) {
        this(factory, d, z, 1.0d);
    }

    public LinearClassifierFactory(double d, boolean z, double d2) {
        this(new Factory15(), d, z, d2);
        this.mem = 15;
        useQuasiNewton();
    }

    public LinearClassifierFactory(Minimizer<DiffFunction> minimizer, double d, boolean z, double d2) {
        this(minimizer, d, z, LogPrior.LogPriorType.QUADRATIC.ordinal(), d2);
    }

    public LinearClassifierFactory(Factory<Minimizer<DiffFunction>> factory, double d, boolean z, double d2) {
        this(factory, d, z, LogPrior.LogPriorType.QUADRATIC.ordinal(), d2);
    }

    public LinearClassifierFactory(Minimizer<DiffFunction> minimizer, double d, boolean z, int i, double d2) {
        this(minimizer, d, z, i, d2, 0.0d);
    }

    public LinearClassifierFactory(Factory<Minimizer<DiffFunction>> factory, double d, boolean z, int i, double d2) {
        this(factory, d, z, i, d2, 0.0d);
    }

    public LinearClassifierFactory(double d, boolean z, int i, double d2, double d3) {
        this(new Factory15(), d, z, new LogPrior(i, d2, d3));
        this.mem = 15;
        useQuasiNewton();
    }

    public LinearClassifierFactory(double d, boolean z, int i, double d2, double d3, int i2) {
        this(new Factory15(), d, z, new LogPrior(i, d2, d3));
        useQuasiNewton();
    }

    public LinearClassifierFactory(Minimizer<DiffFunction> minimizer, double d, boolean z, int i, double d2, double d3) {
        this(minimizer, d, z, new LogPrior(i, d2, d3));
    }

    public LinearClassifierFactory(Factory<Minimizer<DiffFunction>> factory, double d, boolean z, int i, double d2, double d3) {
        this(factory, d, z, new LogPrior(i, d2, d3));
    }

    public LinearClassifierFactory(final Minimizer<DiffFunction> minimizer, double d, boolean z, LogPrior logPrior) {
        this.mem = 15;
        this.verbose = false;
        this.tuneSigmaHeldOut = false;
        this.tuneSigmaCV = false;
        this.min = 0.1d;
        this.max = 10.0d;
        this.retrainFromScratchAfterSigmaTuning = false;
        this.minimizerCreator = null;
        this.evalIters = -1;
        this.evaluators = null;
        this.heldOutSearcher = null;
        this.minimizerCreator = new Factory<Minimizer<DiffFunction>>() { // from class: edu.stanford.nlp.classify.LinearClassifierFactory.1
            private static final long serialVersionUID = -6439748445540743949L;

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // edu.stanford.nlp.util.Factory
            public Minimizer<DiffFunction> create() {
                return minimizer;
            }
        };
        this.TOL = d;
        this.logPrior = logPrior;
    }

    public LinearClassifierFactory(Factory<Minimizer<DiffFunction>> factory, double d, boolean z, LogPrior logPrior) {
        this.mem = 15;
        this.verbose = false;
        this.tuneSigmaHeldOut = false;
        this.tuneSigmaCV = false;
        this.min = 0.1d;
        this.max = 10.0d;
        this.retrainFromScratchAfterSigmaTuning = false;
        this.minimizerCreator = null;
        this.evalIters = -1;
        this.evaluators = null;
        this.heldOutSearcher = null;
        this.minimizerCreator = factory;
        this.TOL = d;
        this.logPrior = logPrior;
    }

    public void setTol(double d) {
        this.TOL = d;
    }

    public void setPrior(LogPrior logPrior) {
        this.logPrior = logPrior;
    }

    public void setVerbose(boolean z) {
        this.verbose = z;
    }

    public void setMinimizerCreator(Factory<Minimizer<DiffFunction>> factory) {
        this.minimizerCreator = factory;
    }

    public void setEpsilon(double d) {
        this.logPrior.setEpsilon(d);
    }

    public void setSigma(double d) {
        this.logPrior.setSigma(d);
    }

    public double getSigma() {
        return this.logPrior.getSigma();
    }

    public void useQuasiNewton() {
        this.minimizerCreator = new Factory<Minimizer<DiffFunction>>() { // from class: edu.stanford.nlp.classify.LinearClassifierFactory.2
            private static final long serialVersionUID = 9028306475652690036L;

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // edu.stanford.nlp.util.Factory
            public Minimizer<DiffFunction> create() {
                return new QNMinimizer(LinearClassifierFactory.this.mem);
            }
        };
    }

    public void useQuasiNewton(final boolean z) {
        this.minimizerCreator = new Factory<Minimizer<DiffFunction>>() { // from class: edu.stanford.nlp.classify.LinearClassifierFactory.3
            private static final long serialVersionUID = -9108222058357693242L;

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // edu.stanford.nlp.util.Factory
            public Minimizer<DiffFunction> create() {
                return new QNMinimizer(LinearClassifierFactory.this.mem, z);
            }
        };
    }

    public void useStochasticQN(final double d, final int i) {
        this.minimizerCreator = new Factory<Minimizer<DiffFunction>>() { // from class: edu.stanford.nlp.classify.LinearClassifierFactory.4
            private static final long serialVersionUID = -7760753348350678588L;

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // edu.stanford.nlp.util.Factory
            public Minimizer<DiffFunction> create() {
                return new SQNMinimizer(LinearClassifierFactory.this.mem, d, i, false);
            }
        };
    }

    public void useStochasticMetaDescent() {
        useStochasticMetaDescent(0.1d, 15, StochasticCalculateMethods.ExternalFiniteDifference, 20);
    }

    public void useStochasticMetaDescent(final double d, final int i, final StochasticCalculateMethods stochasticCalculateMethods, final int i2) {
        this.minimizerCreator = new Factory<Minimizer<DiffFunction>>() { // from class: edu.stanford.nlp.classify.LinearClassifierFactory.5
            private static final long serialVersionUID = 6860437108371914482L;

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // edu.stanford.nlp.util.Factory
            public Minimizer<DiffFunction> create() {
                return new SMDMinimizer(d, i, stochasticCalculateMethods, i2);
            }
        };
    }

    public void useStochasticGradientDescent() {
        useStochasticGradientDescent(0.1d, 15);
    }

    public void useStochasticGradientDescent(final double d, final int i) {
        this.minimizerCreator = new Factory<Minimizer<DiffFunction>>() { // from class: edu.stanford.nlp.classify.LinearClassifierFactory.6
            private static final long serialVersionUID = 2564615420955196299L;

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // edu.stanford.nlp.util.Factory
            public Minimizer<DiffFunction> create() {
                return new InefficientSGDMinimizer(d, i);
            }
        };
    }

    public void useInPlaceStochasticGradientDescent() {
        useInPlaceStochasticGradientDescent(-1, -1, 1.0d);
    }

    public void useInPlaceStochasticGradientDescent(final int i, final int i2, final double d) {
        this.minimizerCreator = new Factory<Minimizer<DiffFunction>>() { // from class: edu.stanford.nlp.classify.LinearClassifierFactory.7
            private static final long serialVersionUID = -5319225231759162616L;

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // edu.stanford.nlp.util.Factory
            public Minimizer<DiffFunction> create() {
                return new SGDMinimizer(d, i, i2);
            }
        };
    }

    public void useHybridMinimizerWithInPlaceSGD(final int i, final int i2, final double d) {
        this.minimizerCreator = new Factory<Minimizer<DiffFunction>>() { // from class: edu.stanford.nlp.classify.LinearClassifierFactory.8
            private static final long serialVersionUID = -3042400543337763144L;

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // edu.stanford.nlp.util.Factory
            public Minimizer<DiffFunction> create() {
                return new HybridMinimizer(new SGDMinimizer(d, i, i2), new QNMinimizer(LinearClassifierFactory.this.mem), i);
            }
        };
    }

    public void useStochasticGradientDescentToQuasiNewton(final double d, final int i, final int i2, final int i3, final int i4, final int i5, final boolean z) {
        this.minimizerCreator = new Factory<Minimizer<DiffFunction>>() { // from class: edu.stanford.nlp.classify.LinearClassifierFactory.9
            private static final long serialVersionUID = 5823852936137599566L;

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // edu.stanford.nlp.util.Factory
            public Minimizer<DiffFunction> create() {
                return new SGDToQNMinimizer(d, i, i2, i3, i4, i5, z);
            }
        };
    }

    public void useHybridMinimizer() {
        useHybridMinimizer(0.1d, 15, StochasticCalculateMethods.ExternalFiniteDifference, 0);
    }

    public void useHybridMinimizer(final double d, final int i, final StochasticCalculateMethods stochasticCalculateMethods, final int i2) {
        this.minimizerCreator = new Factory<Minimizer<DiffFunction>>() { // from class: edu.stanford.nlp.classify.LinearClassifierFactory.10
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // edu.stanford.nlp.util.Factory
            public Minimizer<DiffFunction> create() {
                return new HybridMinimizer(new SMDMinimizer(d, i, stochasticCalculateMethods, i2), new QNMinimizer(LinearClassifierFactory.this.mem), i2);
            }
        };
    }

    public void setMem(int i) {
        this.mem = i;
    }

    public void useConjugateGradientAscent(boolean z) {
        this.verbose = z;
        useConjugateGradientAscent();
    }

    public void useConjugateGradientAscent() {
        this.minimizerCreator = new Factory<Minimizer<DiffFunction>>() { // from class: edu.stanford.nlp.classify.LinearClassifierFactory.11
            private static final long serialVersionUID = -561168861131879990L;

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // edu.stanford.nlp.util.Factory
            public Minimizer<DiffFunction> create() {
                return new CGMinimizer(!LinearClassifierFactory.this.verbose);
            }
        };
    }

    public void setUseSum(boolean z) {
    }

    private Minimizer<DiffFunction> getMinimizer() {
        Minimizer<DiffFunction> create = this.minimizerCreator.create();
        if (create instanceof HasEvaluators) {
            ((HasEvaluators) create).setEvaluators(this.evalIters, this.evaluators);
        }
        return create;
    }

    public double[][] adaptWeights(double[][] dArr, GeneralDataset<L, F> generalDataset) {
        Minimizer<DiffFunction> minimizer = getMinimizer();
        System.err.println("adaptWeights in LinearClassifierFactory. increase weight dim only");
        double[][] dArr2 = new double[generalDataset.featureIndex.size()][generalDataset.labelIndex.size()];
        System.arraycopy(dArr, 0, dArr2, 0, dArr.length);
        AdaptedGaussianPriorObjectiveFunction adaptedGaussianPriorObjectiveFunction = new AdaptedGaussianPriorObjectiveFunction(generalDataset, this.logPrior, dArr2);
        return adaptedGaussianPriorObjectiveFunction.to2D(minimizer.minimize(adaptedGaussianPriorObjectiveFunction, this.TOL, adaptedGaussianPriorObjectiveFunction.initial()));
    }

    @Override // edu.stanford.nlp.classify.AbstractLinearClassifierFactory
    public double[][] trainWeights(GeneralDataset<L, F> generalDataset) {
        return trainWeights(generalDataset, null);
    }

    public double[][] trainWeights(GeneralDataset<L, F> generalDataset, double[] dArr) {
        return trainWeights(generalDataset, dArr, false);
    }

    public double[][] trainWeights(GeneralDataset<L, F> generalDataset, double[] dArr, boolean z) {
        return trainWeights(generalDataset, dArr, z, null);
    }

    public double[][] trainWeights(GeneralDataset<L, F> generalDataset, double[] dArr, boolean z, Minimizer<DiffFunction> minimizer) {
        if (minimizer == null) {
            minimizer = getMinimizer();
        }
        if (generalDataset instanceof RVFDataset) {
            ((RVFDataset) generalDataset).ensureRealValues();
        }
        double[] dArr2 = null;
        if (!z) {
            if (this.tuneSigmaHeldOut) {
                dArr2 = heldOutSetSigma(generalDataset);
            } else if (this.tuneSigmaCV) {
                crossValidateSetSigma(generalDataset, this.folds);
            }
        }
        LogConditionalObjectiveFunction logConditionalObjectiveFunction = new LogConditionalObjectiveFunction(generalDataset, this.logPrior);
        if (dArr == null && dArr2 != null && !this.retrainFromScratchAfterSigmaTuning) {
            dArr = dArr2;
        }
        if (dArr == null) {
            dArr = logConditionalObjectiveFunction.initial();
        }
        return logConditionalObjectiveFunction.to2D(minimizer.minimize(logConditionalObjectiveFunction, this.TOL, dArr));
    }

    public Classifier<L, F> trainClassifierSemiSup(GeneralDataset<L, F> generalDataset, GeneralDataset<L, F> generalDataset2, double[][] dArr, double[] dArr2) {
        return new LinearClassifier(trainWeightsSemiSup(generalDataset, generalDataset2, dArr, dArr2), generalDataset.featureIndex(), generalDataset.labelIndex());
    }

    public double[][] trainWeightsSemiSup(GeneralDataset<L, F> generalDataset, GeneralDataset<L, F> generalDataset2, double[][] dArr, double[] dArr2) {
        Minimizer<DiffFunction> minimizer = getMinimizer();
        LogConditionalObjectiveFunction logConditionalObjectiveFunction = new LogConditionalObjectiveFunction(generalDataset, new LogPrior(LogPrior.LogPriorType.NULL));
        SemiSupervisedLogConditionalObjectiveFunction semiSupervisedLogConditionalObjectiveFunction = new SemiSupervisedLogConditionalObjectiveFunction(logConditionalObjectiveFunction, new BiasedLogConditionalObjectiveFunction(generalDataset2, dArr, new LogPrior(LogPrior.LogPriorType.NULL)), this.logPrior);
        if (dArr2 == null) {
            dArr2 = logConditionalObjectiveFunction.initial();
        }
        return logConditionalObjectiveFunction.to2D(minimizer.minimize(semiSupervisedLogConditionalObjectiveFunction, this.TOL, dArr2));
    }

    public LinearClassifier<L, F> trainSemiSupGE(GeneralDataset<L, F> generalDataset, List<? extends Datum<L, F>> list, List<F> list2, double d) {
        Minimizer<DiffFunction> minimizer = getMinimizer();
        LogConditionalObjectiveFunction logConditionalObjectiveFunction = new LogConditionalObjectiveFunction(generalDataset, new LogPrior(LogPrior.LogPriorType.NULL));
        return new LinearClassifier<>(logConditionalObjectiveFunction.to2D(minimizer.minimize(new SemiSupervisedLogConditionalObjectiveFunction(logConditionalObjectiveFunction, new GeneralizedExpectationObjectiveFunction(generalDataset, list, list2), null, d), this.TOL, logConditionalObjectiveFunction.initial())), generalDataset.featureIndex(), generalDataset.labelIndex());
    }

    public LinearClassifier<L, F> trainSemiSupGE(GeneralDataset<L, F> generalDataset, List<? extends Datum<L, F>> list) {
        return trainSemiSupGE(generalDataset, list, getHighPrecisionFeatures(generalDataset, 0.9d, 10), 0.5d);
    }

    public LinearClassifier<L, F> trainSemiSupGE(GeneralDataset<L, F> generalDataset, List<? extends Datum<L, F>> list, double d) {
        return trainSemiSupGE(generalDataset, list, getHighPrecisionFeatures(generalDataset, 0.9d, 10), d);
    }

    private List<F> getHighPrecisionFeatures(GeneralDataset<L, F> generalDataset, double d, int i) {
        int[][] iArr = new int[generalDataset.numFeatures()][generalDataset.numClasses()];
        for (int i2 = 0; i2 < generalDataset.numFeatures(); i2++) {
            Arrays.fill(iArr[i2], 0);
        }
        int[][] iArr2 = generalDataset.data;
        int[] iArr3 = generalDataset.labels;
        for (int i3 = 0; i3 < iArr2.length; i3++) {
            int i4 = iArr3[i3];
            if (iArr2[i3] != null) {
                for (int i5 = 0; i5 < iArr2[i3].length; i5++) {
                    int[] iArr4 = iArr[iArr2[i3][i5]];
                    iArr4[i4] = iArr4[i4] + 1;
                }
            }
        }
        ClassicCounter classicCounter = new ClassicCounter();
        for (int i6 = 0; i6 < generalDataset.numFeatures(); i6++) {
            int max = ArrayMath.max(iArr[i6]);
            int sum = ArrayMath.sum(iArr[i6]);
            double d2 = max / sum;
            F f = generalDataset.featureIndex.get(i6);
            if (d2 >= d) {
                classicCounter.incrementCount(f, sum);
            }
        }
        if (classicCounter.size() > i) {
            Counters.retainTop(classicCounter, i);
        }
        return Counters.toSortedList(classicCounter);
    }

    public LinearClassifier<L, F> trainClassifierV(GeneralDataset<L, F> generalDataset, GeneralDataset<L, F> generalDataset2, double d, double d2, boolean z) {
        this.labelIndex = generalDataset.labelIndex();
        this.featureIndex = generalDataset.featureIndex();
        this.min = d;
        this.max = d2;
        heldOutSetSigma(generalDataset, generalDataset2);
        return new LinearClassifier<>(trainWeights(generalDataset), generalDataset.featureIndex(), generalDataset.labelIndex());
    }

    public LinearClassifier<L, F> trainClassifierV(GeneralDataset<L, F> generalDataset, double d, double d2, boolean z) {
        this.labelIndex = generalDataset.labelIndex();
        this.featureIndex = generalDataset.featureIndex();
        this.tuneSigmaHeldOut = true;
        this.min = d;
        this.max = d2;
        heldOutSetSigma(generalDataset);
        return new LinearClassifier<>(trainWeights(generalDataset), generalDataset.featureIndex(), generalDataset.labelIndex());
    }

    public void setTuneSigmaHeldOut() {
        this.tuneSigmaHeldOut = true;
        this.tuneSigmaCV = false;
    }

    public void setTuneSigmaCV(int i) {
        this.tuneSigmaCV = true;
        this.tuneSigmaHeldOut = false;
        this.folds = i;
    }

    public void resetWeight() {
    }

    public void crossValidateSetSigma(GeneralDataset<L, F> generalDataset) {
        crossValidateSetSigma(generalDataset, 5);
    }

    public void crossValidateSetSigma(GeneralDataset<L, F> generalDataset, int i) {
        System.err.println("##you are here.");
        crossValidateSetSigma(generalDataset, i, new MultiClassAccuracyStats(2), new GoldenSectionLineSearch(true, 0.01d, this.min, this.max));
    }

    public void crossValidateSetSigma(GeneralDataset<L, F> generalDataset, int i, Scorer<L> scorer) {
        crossValidateSetSigma(generalDataset, i, scorer, new GoldenSectionLineSearch(true, 0.01d, this.min, this.max));
    }

    public void crossValidateSetSigma(GeneralDataset<L, F> generalDataset, int i, LineSearcher lineSearcher) {
        crossValidateSetSigma(generalDataset, i, new MultiClassAccuracyStats(2), lineSearcher);
    }

    public void crossValidateSetSigma(GeneralDataset<L, F> generalDataset, int i, final Scorer<L> scorer, LineSearcher lineSearcher) {
        System.err.println("##in Cross Validate, folds = " + i);
        System.err.println("##Scorer is " + scorer);
        this.featureIndex = generalDataset.featureIndex;
        this.labelIndex = generalDataset.labelIndex;
        final CrossValidator crossValidator = new CrossValidator(generalDataset, i);
        final Function<Triple<GeneralDataset<L, F>, GeneralDataset<L, F>, CrossValidator.SavedState>, Double> function = new Function<Triple<GeneralDataset<L, F>, GeneralDataset<L, F>, CrossValidator.SavedState>, Double>() { // from class: edu.stanford.nlp.classify.LinearClassifierFactory.12
            @Override // edu.stanford.nlp.util.Function
            public Double apply(Triple<GeneralDataset<L, F>, GeneralDataset<L, F>, CrossValidator.SavedState> triple) {
                GeneralDataset<L, F> first = triple.first();
                GeneralDataset<L, F> second = triple.second();
                double[][] trainWeights = LinearClassifierFactory.this.trainWeights(first, (double[]) triple.third().state, true);
                triple.third().state = ArrayUtils.flatten(trainWeights);
                double score = scorer.score(new LinearClassifier(trainWeights, first.featureIndex, first.labelIndex), second);
                System.out.print(".");
                return Double.valueOf(score);
            }
        };
        double minimize = lineSearcher.minimize(new Function<Double, Double>() { // from class: edu.stanford.nlp.classify.LinearClassifierFactory.13
            @Override // edu.stanford.nlp.util.Function
            public Double apply(Double d) {
                LinearClassifierFactory.this.setSigma(d.doubleValue());
                Double valueOf = Double.valueOf(crossValidator.computeAverage(function));
                System.err.print("##sigma = " + LinearClassifierFactory.this.getSigma() + " ");
                System.err.println("-> average Score: " + valueOf);
                return Double.valueOf(-valueOf.doubleValue());
            }
        });
        System.err.println("##best sigma: " + minimize);
        setSigma(minimize);
    }

    public void setHeldOutSearcher(LineSearcher lineSearcher) {
        this.heldOutSearcher = lineSearcher;
    }

    public double[] heldOutSetSigma(GeneralDataset<L, F> generalDataset) {
        Pair<GeneralDataset<L, F>, GeneralDataset<L, F>> split = generalDataset.split(0.3d);
        return heldOutSetSigma(split.first(), split.second());
    }

    public double[] heldOutSetSigma(GeneralDataset<L, F> generalDataset, Scorer<L> scorer) {
        Pair<GeneralDataset<L, F>, GeneralDataset<L, F>> split = generalDataset.split(0.3d);
        return heldOutSetSigma(split.first(), split.second(), scorer);
    }

    public double[] heldOutSetSigma(GeneralDataset<L, F> generalDataset, GeneralDataset<L, F> generalDataset2) {
        return heldOutSetSigma(generalDataset, generalDataset2, new MultiClassAccuracyStats(2), this.heldOutSearcher == null ? new GoldenSectionLineSearch(true, 0.01d, this.min, this.max) : this.heldOutSearcher);
    }

    public double[] heldOutSetSigma(GeneralDataset<L, F> generalDataset, GeneralDataset<L, F> generalDataset2, Scorer<L> scorer) {
        return heldOutSetSigma(generalDataset, generalDataset2, scorer, new GoldenSectionLineSearch(true, 0.01d, this.min, this.max));
    }

    public double[] heldOutSetSigma(GeneralDataset<L, F> generalDataset, GeneralDataset<L, F> generalDataset2, LineSearcher lineSearcher) {
        return heldOutSetSigma(generalDataset, generalDataset2, new MultiClassAccuracyStats(2), lineSearcher);
    }

    public double[] heldOutSetSigma(GeneralDataset<L, F> generalDataset, GeneralDataset<L, F> generalDataset2, Scorer<L> scorer, LineSearcher lineSearcher) {
        this.featureIndex = generalDataset.featureIndex;
        this.labelIndex = generalDataset.labelIndex;
        Timing timing = new Timing();
        NegativeScorer negativeScorer = new NegativeScorer(generalDataset, generalDataset2, scorer, timing);
        timing.start();
        double minimize = lineSearcher.minimize(negativeScorer);
        System.err.println("##best sigma: " + minimize);
        setSigma(minimize);
        return ArrayUtils.flatten(trainWeights(generalDataset, negativeScorer.weights, true));
    }

    public void setRetrainFromScratchAfterSigmaTuning(boolean z) {
        this.retrainFromScratchAfterSigmaTuning = z;
    }

    public Classifier<L, F> trainClassifier(Iterable<Datum<L, F>> iterable) {
        Minimizer<DiffFunction> minimizer = getMinimizer();
        Index newIndex = Generics.newIndex();
        Index newIndex2 = Generics.newIndex();
        for (Datum<L, F> datum : iterable) {
            newIndex2.add(datum.label());
            newIndex.addAll(datum.asFeatures());
        }
        System.err.println(String.format("Training linear classifier with %d features and %d labels", Integer.valueOf(newIndex.size()), Integer.valueOf(newIndex2.size())));
        LogConditionalObjectiveFunction logConditionalObjectiveFunction = new LogConditionalObjectiveFunction(iterable, this.logPrior, newIndex, newIndex2);
        return new LinearClassifier(logConditionalObjectiveFunction.to2D(minimizer.minimize(logConditionalObjectiveFunction, this.TOL, logConditionalObjectiveFunction.initial())), newIndex, newIndex2);
    }

    public Classifier<L, F> trainClassifier(GeneralDataset<L, F> generalDataset, float[] fArr, LogPrior logPrior) {
        Minimizer<DiffFunction> minimizer = getMinimizer();
        if (generalDataset instanceof RVFDataset) {
            ((RVFDataset) generalDataset).ensureRealValues();
        }
        LogConditionalObjectiveFunction logConditionalObjectiveFunction = new LogConditionalObjectiveFunction(generalDataset, fArr, logPrior);
        return new LinearClassifier(logConditionalObjectiveFunction.to2D(minimizer.minimize(logConditionalObjectiveFunction, this.TOL, logConditionalObjectiveFunction.initial())), generalDataset.featureIndex(), generalDataset.labelIndex());
    }

    @Override // edu.stanford.nlp.classify.AbstractLinearClassifierFactory, edu.stanford.nlp.classify.ClassifierFactory
    public LinearClassifier<L, F> trainClassifier(GeneralDataset<L, F> generalDataset) {
        return trainClassifier(generalDataset, null);
    }

    public LinearClassifier<L, F> trainClassifier(GeneralDataset<L, F> generalDataset, double[] dArr) {
        if (generalDataset instanceof RVFDataset) {
            ((RVFDataset) generalDataset).ensureRealValues();
        }
        if (dArr != null) {
            for (double d : dArr) {
                if (Double.isNaN(d) || Double.isInfinite(d)) {
                    throw new IllegalArgumentException("Initial weights are invalid!");
                }
            }
        }
        return new LinearClassifier<>(trainWeights(generalDataset, dArr, false), generalDataset.featureIndex(), generalDataset.labelIndex());
    }

    public LinearClassifier<L, F> trainClassifierWithInitialWeights(GeneralDataset<L, F> generalDataset, double[][] dArr) {
        return trainClassifier(generalDataset, dArr != null ? ArrayUtils.flatten(dArr) : null);
    }

    public LinearClassifier<L, F> trainClassifierWithInitialWeights(GeneralDataset<L, F> generalDataset, LinearClassifier<L, F> linearClassifier) {
        return trainClassifierWithInitialWeights(generalDataset, linearClassifier != null ? linearClassifier.weights() : (double[][]) null);
    }

    public static LinearClassifier<String, String> loadFromFilename(String str) {
        try {
            BufferedReader readerFromString = IOUtils.readerFromString(str);
            Index<String> loadFromReader = HashIndex.loadFromReader(readerFromString);
            Index<String> loadFromReader2 = HashIndex.loadFromReader(readerFromString);
            double[][] dArr = new double[loadFromReader2.size()][loadFromReader.size()];
            int i = 1;
            for (String readLine = readerFromString.readLine(); readLine != null && readLine.length() > 0; readLine = readerFromString.readLine()) {
                String[] split = readLine.split(LinearClassifier.TEXT_SERIALIZATION_DELIMITER);
                if (split.length != 3) {
                    throw new Exception("Error: incorrect number of tokens in weight specifier, line=" + i + " in file " + str);
                }
                i++;
                dArr[Integer.parseInt(split[0])][Integer.parseInt(split[1])] = Double.parseDouble(split[2]);
            }
            double[] dArr2 = new double[Integer.parseInt(readerFromString.readLine())];
            int i2 = 0;
            while (true) {
                String readLine2 = readerFromString.readLine();
                if (readLine2 == null) {
                    readerFromString.close();
                    return new LinearClassifier<>(dArr, loadFromReader2, loadFromReader);
                }
                int i3 = i2;
                i2++;
                dArr2[i3] = Double.parseDouble(readLine2.trim());
            }
        } catch (Exception e) {
            System.err.println("Error in LinearClassifierFactory, loading from file=" + str);
            e.printStackTrace();
            return null;
        }
    }

    @Override // edu.stanford.nlp.classify.AbstractLinearClassifierFactory, edu.stanford.nlp.classify.ClassifierFactory
    @Deprecated
    public LinearClassifier<L, F> trainClassifier(List<RVFDatum<L, F>> list) {
        throw new UnsupportedOperationException("Unsupported deprecated method");
    }

    public void setEvaluators(int i, Evaluator[] evaluatorArr) {
        this.evalIters = i;
        this.evaluators = evaluatorArr;
    }

    public LinearClassifierCreator<L, F> getClassifierCreator(GeneralDataset<L, F> generalDataset) {
        return new LinearClassifierCreator<>(generalDataset.featureIndex, generalDataset.labelIndex);
    }
}
