/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.sgd;

import com.google.common.base.Charsets;
import com.google.common.collect.ConcurrentHashMultiset;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multiset;
import com.google.common.collect.Ordering;
import com.google.common.io.Files;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.Reader;
import java.io.StringReader;
import java.nio.charset.Charset;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Random;
import java.util.TreeMap;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.util.Version;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression;
import org.apache.mahout.classifier.sgd.CrossFoldLearner;
import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.classifier.sgd.ModelDissector;
import org.apache.mahout.classifier.sgd.ModelSerializer;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.classifier.sgd.PriorFunction;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.ep.State;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
import org.apache.mahout.vectorizer.encoders.Dictionary;
import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;

public final class TrainNewsGroups {
    private static final int FEATURES = 10000;
    private static final long DATE_REFERENCE = 853286460L;
    private static final long MONTH = 2592000L;
    private static final long WEEK = 604800L;
    private static final Random rand = RandomUtils.getRandom();
    private static final String[] LEAK_LABELS = new String[]{"none", "month-year", "day-month-year"};
    private static final SimpleDateFormat[] DATE_FORMATS = new SimpleDateFormat[]{new SimpleDateFormat("", Locale.ENGLISH), new SimpleDateFormat("MMM-yyyy", Locale.ENGLISH), new SimpleDateFormat("dd-MMM-yyyy HH:mm:ss", Locale.ENGLISH)};
    private static final Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_30);
    private static final FeatureVectorEncoder encoder = new StaticWordValueEncoder("body");
    private static final FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept");
    private static Multiset<String> overallCounts;

    private TrainNewsGroups() {
    }

    public static void main(String[] args) throws IOException {
        File base = new File(args[0]);
        overallCounts = HashMultiset.create();
        int leakType = 0;
        if (args.length > 1) {
            leakType = Integer.parseInt(args[1]);
        }
        Dictionary newsGroups = new Dictionary();
        encoder.setProbes(2);
        AdaptiveLogisticRegression learningAlgorithm = new AdaptiveLogisticRegression(20, 10000, (PriorFunction)new L1());
        learningAlgorithm.setInterval(800);
        learningAlgorithm.setAveragingWindow(500);
        ArrayList files = Lists.newArrayList();
        for (File newsgroup : base.listFiles()) {
            if (!newsgroup.isDirectory()) continue;
            newsGroups.intern(newsgroup.getName());
            files.addAll(Arrays.asList(newsgroup.listFiles()));
        }
        Collections.shuffle(files);
        System.out.printf("%d training files\n", files.size());
        double averageLL = 0.0;
        double averageCorrect = 0.0;
        int k = 0;
        double step = 0.0;
        int[] bumps = new int[]{1, 2, 5};
        for (File file : files.subList(0, 3000)) {
            double norm;
            double positive;
            double nonZeros;
            double maxBeta;
            String ng = file.getParentFile().getName();
            int actual = newsGroups.intern(ng);
            Vector v = TrainNewsGroups.encodeFeatureVector(file, actual, leakType);
            learningAlgorithm.train(actual, v);
            ++k;
            int bump = bumps[(int)Math.floor(step) % bumps.length];
            int scale = (int)Math.pow(10.0, Math.floor(step / (double)bumps.length));
            State best = learningAlgorithm.getBest();
            double lambda = 0.0;
            double mu = 0.0;
            if (best != null) {
                CrossFoldLearner state = ((AdaptiveLogisticRegression.Wrapper)best.getPayload()).getLearner();
                averageCorrect = state.percentCorrect();
                averageLL = state.logLikelihood();
                OnlineLogisticRegression model = (OnlineLogisticRegression)state.getModels().get(0);
                model.close();
                Matrix beta = model.getBeta();
                maxBeta = beta.aggregate(Functions.MAX, Functions.ABS);
                nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction(){

                    public double apply(double v) {
                        return Math.abs(v) > 1.0E-6 ? 1.0 : 0.0;
                    }
                });
                positive = beta.aggregate(Functions.PLUS, new DoubleFunction(){

                    public double apply(double v) {
                        return v > 0.0 ? 1.0 : 0.0;
                    }
                });
                norm = beta.aggregate(Functions.PLUS, Functions.ABS);
                lambda = learningAlgorithm.getBest().getMappedParams()[0];
                mu = learningAlgorithm.getBest().getMappedParams()[1];
            } else {
                maxBeta = 0.0;
                nonZeros = 0.0;
                positive = 0.0;
                norm = 0.0;
            }
            if (k % (bump * scale) != 0) continue;
            if (learningAlgorithm.getBest() != null) {
                ModelSerializer.writeBinary((String)("/tmp/news-group-" + k + ".model"), (OnlineLogisticRegression)((OnlineLogisticRegression)((AdaptiveLogisticRegression.Wrapper)learningAlgorithm.getBest().getPayload()).getLearner().getModels().get(0)));
            }
            step += 0.25;
            System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda, mu);
            System.out.printf("%d\t%.3f\t%.2f\t%s\n", k, averageLL, averageCorrect * 100.0, LEAK_LABELS[leakType % 3]);
        }
        learningAlgorithm.close();
        TrainNewsGroups.dissect(leakType, newsGroups, learningAlgorithm, files);
        System.out.println("exiting main");
        ModelSerializer.writeBinary((String)"/tmp/news-group.model", (OnlineLogisticRegression)((OnlineLogisticRegression)((AdaptiveLogisticRegression.Wrapper)learningAlgorithm.getBest().getPayload()).getLearner().getModels().get(0)));
        ArrayList counts = Lists.newArrayList();
        System.out.printf("Word counts\n", new Object[0]);
        for (Object count : overallCounts.elementSet()) {
            counts.add(overallCounts.count(count));
        }
        Collections.sort(counts, Ordering.natural().reverse());
        k = 0;
        for (Object count : counts) {
            System.out.printf("%d\t%d\n", k, count);
            if (++k <= 1000) continue;
            break;
        }
    }

    private static void dissect(int leakType, Dictionary newsGroups, AdaptiveLogisticRegression learningAlgorithm, Iterable<File> files) throws IOException {
        CrossFoldLearner model = ((AdaptiveLogisticRegression.Wrapper)learningAlgorithm.getBest().getPayload()).getLearner();
        model.close();
        TreeMap traceDictionary = Maps.newTreeMap();
        ModelDissector md = new ModelDissector();
        encoder.setTraceDictionary((Map)traceDictionary);
        bias.setTraceDictionary((Map)traceDictionary);
        for (File file : TrainNewsGroups.permute(files, rand).subList(0, 500)) {
            String ng = file.getParentFile().getName();
            int actual = newsGroups.intern(ng);
            traceDictionary.clear();
            Vector v = TrainNewsGroups.encodeFeatureVector(file, actual, leakType);
            md.update(v, (Map)traceDictionary, (AbstractVectorClassifier)model);
        }
        ArrayList ngNames = Lists.newArrayList((Iterable)newsGroups.values());
        List weights = md.summary(100);
        for (ModelDissector.Weight w : weights) {
            System.out.printf("%s\t%.1f\t%s\t%.1f\t%s\t%.1f\t%s\n", w.getFeature(), w.getWeight(), ngNames.get(w.getMaxImpact() + 1), w.getCategory(1), w.getWeight(1), w.getCategory(2), w.getWeight(2));
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static Vector encodeFeatureVector(File file, int actual, int leakType) throws IOException {
        long date = (long)(1000.0 * ((double)(853286460L + (long)actual * 2592000L) + 604800.0 * rand.nextDouble()));
        ConcurrentHashMultiset words = ConcurrentHashMultiset.create();
        BufferedReader reader = Files.newReader((File)file, (Charset)Charsets.UTF_8);
        try {
            String line = reader.readLine();
            StringReader dateString = new StringReader(DATE_FORMATS[leakType % 3].format(new Date(date)));
            TrainNewsGroups.countWords(analyzer, (Collection<String>)words, dateString);
            while (line != null && line.length() > 0) {
                boolean countHeader = (line.startsWith("From:") || line.startsWith("Subject:") || line.startsWith("Keywords:") || line.startsWith("Summary:")) && leakType < 6;
                do {
                    StringReader in = new StringReader(line);
                    if (!countHeader) continue;
                    TrainNewsGroups.countWords(analyzer, (Collection<String>)words, in);
                } while ((line = reader.readLine()) != null && line.startsWith(" "));
            }
            if (leakType < 3) {
                TrainNewsGroups.countWords(analyzer, (Collection<String>)words, reader);
            }
        }
        finally {
            reader.close();
        }
        RandomAccessSparseVector v = new RandomAccessSparseVector(10000);
        bias.addToVector("", 1.0, (Vector)v);
        for (String word : words.elementSet()) {
            encoder.addToVector(word, Math.log(1 + words.count((Object)word)), (Vector)v);
        }
        return v;
    }

    private static void countWords(Analyzer analyzer, Collection<String> words, Reader in) throws IOException {
        TokenStream ts = analyzer.tokenStream("text", in);
        ts.addAttribute(CharTermAttribute.class);
        while (ts.incrementToken()) {
            String s = ((CharTermAttribute)ts.getAttribute(CharTermAttribute.class)).toString();
            words.add(s);
        }
        overallCounts.addAll(words);
    }

    private static List<File> permute(Iterable<File> files, Random rand) {
        ArrayList r = Lists.newArrayList();
        for (File file : files) {
            int i = rand.nextInt(r.size() + 1);
            if (i == r.size()) {
                r.add(file);
                continue;
            }
            r.add(r.get(i));
            r.set(i, file);
        }
        return r;
    }
}

