/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.df;

import java.io.IOException;
import java.util.Random;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.OptionException;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.cli2.option.DefaultOption;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.df.DFUtils;
import org.apache.mahout.df.DecisionForest;
import org.apache.mahout.df.ErrorEstimate;
import org.apache.mahout.df.builder.DefaultTreeBuilder;
import org.apache.mahout.df.builder.TreeBuilder;
import org.apache.mahout.df.callback.ForestPredictions;
import org.apache.mahout.df.callback.MeanTreeCollector;
import org.apache.mahout.df.callback.MultiCallback;
import org.apache.mahout.df.callback.PredictionCallback;
import org.apache.mahout.df.data.Data;
import org.apache.mahout.df.data.DataLoader;
import org.apache.mahout.df.data.Dataset;
import org.apache.mahout.df.ref.SequentialBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.uncommons.maths.Maths;

public class BreimanExample
extends Configured
implements Tool {
    private static final Logger log = LoggerFactory.getLogger(BreimanExample.class);
    private double sumTestErr;
    private double sumTreeErr;
    private double sumOneErr;
    private long sumTimeM;
    private long sumTimeOne;
    private long numNodesM;
    private long numNodesOne;

    private void runIteration(Random rng, Data data, int m, int nbtrees) {
        int nblabels = data.getDataset().nblabels();
        log.info("Splitting the data");
        Data train = data.clone();
        Data test = train.rsplit(rng, (int)((double)data.size() * 0.1));
        int[] trainLabels = train.extractLabels();
        int[] testLabels = test.extractLabels();
        DefaultTreeBuilder treeBuilder = new DefaultTreeBuilder();
        SequentialBuilder forestBuilder = new SequentialBuilder(rng, (TreeBuilder)treeBuilder, train);
        ForestPredictions errorM = new ForestPredictions(train.size(), nblabels);
        treeBuilder.setM(m);
        long time = System.currentTimeMillis();
        log.info("Growing a forest with m={}", (Object)m);
        DecisionForest forestM = forestBuilder.build(nbtrees, (PredictionCallback)errorM);
        this.sumTimeM += System.currentTimeMillis() - time;
        this.numNodesM += forestM.nbNodes();
        double oobM = ErrorEstimate.errorRate((int[])trainLabels, (int[])errorM.computePredictions(rng));
        ForestPredictions errorOne = new ForestPredictions(train.size(), nblabels);
        treeBuilder.setM(1);
        time = System.currentTimeMillis();
        log.info("Growing a forest with m=1");
        DecisionForest forestOne = forestBuilder.build(nbtrees, (PredictionCallback)errorOne);
        this.sumTimeOne += System.currentTimeMillis() - time;
        this.numNodesOne += forestOne.nbNodes();
        double oobOne = ErrorEstimate.errorRate((int[])trainLabels, (int[])errorOne.computePredictions(rng));
        ForestPredictions testError = new ForestPredictions(test.size(), nblabels);
        MeanTreeCollector treeError = new MeanTreeCollector(test, nbtrees);
        errorOne = new ForestPredictions(test.size(), nblabels);
        if (oobM < oobOne) {
            forestM.classify(test, (PredictionCallback)new MultiCallback(new PredictionCallback[]{testError, treeError}));
            forestOne.classify(test, (PredictionCallback)errorOne);
        } else {
            forestOne.classify(test, (PredictionCallback)new MultiCallback(new PredictionCallback[]{testError, treeError, errorOne}));
        }
        this.sumTestErr += ErrorEstimate.errorRate((int[])testLabels, (int[])testError.computePredictions(rng));
        this.sumOneErr += ErrorEstimate.errorRate((int[])testLabels, (int[])errorOne.computePredictions(rng));
        this.sumTreeErr += treeError.meanTreeError();
    }

    public static void main(String[] args) throws Exception {
        ToolRunner.run((Configuration)new Configuration(), (Tool)new BreimanExample(), (String[])args);
    }

    public int run(String[] args) throws IOException {
        Path datasetPath;
        Path dataPath;
        int nbIterations;
        int nbTrees;
        DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
        ArgumentBuilder abuilder = new ArgumentBuilder();
        GroupBuilder gbuilder = new GroupBuilder();
        DefaultOption dataOpt = obuilder.withLongName("data").withShortName("d").withRequired(true).withArgument(abuilder.withName("path").withMinimum(1).withMaximum(1).create()).withDescription("Data path").create();
        DefaultOption datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true).withArgument(abuilder.withName("dataset").withMinimum(1).withMaximum(1).create()).withDescription("Dataset path").create();
        DefaultOption nbtreesOpt = obuilder.withLongName("nbtrees").withShortName("t").withRequired(true).withArgument(abuilder.withName("nbtrees").withMinimum(1).withMaximum(1).create()).withDescription("Number of trees to grow, each iteration").create();
        DefaultOption nbItersOpt = obuilder.withLongName("iterations").withShortName("i").withRequired(true).withArgument(abuilder.withName("numIterations").withMinimum(1).withMaximum(1).create()).withDescription("Number of times to repeat the test").create();
        DefaultOption helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h").create();
        Group group = gbuilder.withName("Options").withOption((Option)dataOpt).withOption((Option)datasetOpt).withOption((Option)nbItersOpt).withOption((Option)nbtreesOpt).withOption((Option)helpOpt).create();
        try {
            Parser parser = new Parser();
            parser.setGroup(group);
            CommandLine cmdLine = parser.parse(args);
            if (cmdLine.hasOption("help")) {
                CommandLineUtil.printHelp((Group)group);
                return -1;
            }
            String dataName = cmdLine.getValue((Option)dataOpt).toString();
            String datasetName = cmdLine.getValue((Option)datasetOpt).toString();
            nbTrees = Integer.parseInt(cmdLine.getValue((Option)nbtreesOpt).toString());
            nbIterations = Integer.parseInt(cmdLine.getValue((Option)nbItersOpt).toString());
            dataPath = new Path(dataName);
            datasetPath = new Path(datasetName);
        }
        catch (OptionException e) {
            log.error("Error while parsing options", (Throwable)e);
            CommandLineUtil.printHelp((Group)group);
            return -1;
        }
        FileSystem fs = dataPath.getFileSystem(new Configuration());
        Dataset dataset = Dataset.load((Configuration)this.getConf(), (Path)datasetPath);
        Data data = DataLoader.loadData((Dataset)dataset, (FileSystem)fs, (Path)dataPath);
        int m = (int)Math.floor(Maths.log((double)2.0, (double)data.getDataset().nbAttributes()) + 1.0);
        Random rng = RandomUtils.getRandom();
        for (int iteration = 0; iteration < nbIterations; ++iteration) {
            log.info("Iteration {}", (Object)iteration);
            this.runIteration(rng, data, m, nbTrees);
        }
        log.info("********************************************");
        log.info("Selection error : {}", (Object)(this.sumTestErr / (double)nbIterations));
        log.info("Single Input error : {}", (Object)(this.sumOneErr / (double)nbIterations));
        log.info("One Tree error : {}", (Object)(this.sumTreeErr / (double)nbIterations));
        log.info("Mean Random Input Time : {}", (Object)DFUtils.elapsedTime((long)(this.sumTimeM / (long)nbIterations)));
        log.info("Mean Single Input Time : {}", (Object)DFUtils.elapsedTime((long)(this.sumTimeOne / (long)nbIterations)));
        log.info("Mean Random Input Num Nodes : {}", (Object)(this.numNodesM / (long)nbIterations));
        log.info("Mean Single Input Num Nodes : {}", (Object)(this.numNodesOne / (long)nbIterations));
        return 0;
    }
}

