/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.clustering.lda;

import com.google.common.base.Preconditions;
import java.io.IOException;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Random;
import org.apache.commons.cli2.Option;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.hadoop.util.ReflectionUtils;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.clustering.lda.LDADocumentTopicMapper;
import org.apache.mahout.clustering.lda.LDAInference;
import org.apache.mahout.clustering.lda.LDAReducer;
import org.apache.mahout.clustering.lda.LDAState;
import org.apache.mahout.clustering.lda.LDAWordTopicMapper;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.IntPairWritable;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterator;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class LDADriver
extends AbstractJob {
    private static final String TOPIC_SMOOTHING_OPTION = "topicSmoothing";
    private static final String NUM_WORDS_OPTION = "numWords";
    private static final String NUM_TOPICS_OPTION = "numTopics";
    static final String STATE_IN_KEY = "org.apache.mahout.clustering.lda.stateIn";
    static final String NUM_TOPICS_KEY = "org.apache.mahout.clustering.lda.numTopics";
    static final String NUM_WORDS_KEY = "org.apache.mahout.clustering.lda.numWords";
    static final String TOPIC_SMOOTHING_KEY = "org.apache.mahout.clustering.lda.topicSmoothing";
    static final int LOG_LIKELIHOOD_KEY = -2;
    static final int TOPIC_SUM_KEY = -1;
    static final double OVERALL_CONVERGENCE = 1.0E-5;
    private static final Logger log = LoggerFactory.getLogger(LDADriver.class);
    private LDAState state = null;
    private LDAState newState = null;
    private LDAInference inference = null;
    private Iterable<Pair<Writable, VectorWritable>> trainingCorpus = null;

    private LDADriver() {
    }

    public static void main(String[] args) throws Exception {
        ToolRunner.run((Configuration)new Configuration(), (Tool)new LDADriver(), (String[])args);
    }

    public static LDAState createState(Configuration job) {
        return LDADriver.createState(job, false);
    }

    public static LDAState createState(Configuration job, boolean empty) {
        String statePath = job.get(STATE_IN_KEY);
        int numTopics = Integer.parseInt(job.get(NUM_TOPICS_KEY));
        int numWords = Integer.parseInt(job.get(NUM_WORDS_KEY));
        double topicSmoothing = Double.parseDouble(job.get(TOPIC_SMOOTHING_KEY));
        Path dir = new Path(statePath);
        DenseMatrix pWgT = new DenseMatrix(numTopics, numWords);
        double[] logTotals = new double[numTopics];
        Arrays.fill(logTotals, Double.NEGATIVE_INFINITY);
        double ll = 0.0;
        if (empty) {
            return new LDAState(numTopics, numWords, topicSmoothing, (Matrix)pWgT, logTotals, ll);
        }
        for (Pair record : new SequenceFileDirIterable(new Path(dir, "part-*"), PathType.GLOB, null, null, true, job)) {
            IntPairWritable key = (IntPairWritable)record.getFirst();
            DoubleWritable value = (DoubleWritable)record.getSecond();
            int topic = key.getFirst();
            int word = key.getSecond();
            if (word == -1) {
                logTotals[topic] = value.get();
                Preconditions.checkArgument((!Double.isInfinite(value.get()) ? 1 : 0) != 0);
                continue;
            }
            if (topic == -2) {
                ll = value.get();
                continue;
            }
            Preconditions.checkArgument((topic >= 0 ? 1 : 0) != 0, (String)"topic should be non-negative, not %d", (Object[])new Object[]{topic});
            Preconditions.checkArgument((word >= 0 ? 1 : 0) != 0, (String)"word should be non-negative not %d", (Object[])new Object[]{word});
            Preconditions.checkArgument((pWgT.getQuick(topic, word) == 0.0 ? 1 : 0) != 0);
            pWgT.setQuick(topic, word, value.get());
            Preconditions.checkArgument((!Double.isInfinite(pWgT.getQuick(topic, word)) ? 1 : 0) != 0);
        }
        return new LDAState(numTopics, numWords, topicSmoothing, (Matrix)pWgT, logTotals, ll);
    }

    public int run(String[] args) throws IOException, ClassNotFoundException, InterruptedException {
        this.addInputOption();
        this.addOutputOption();
        this.addOption((Option)DefaultOptionCreator.overwriteOption().create());
        this.addOption(NUM_TOPICS_OPTION, "k", "The total number of topics in the corpus", true);
        this.addOption(NUM_WORDS_OPTION, "v", "The total number of words in the corpus (can be approximate, needs to exceed the actual value)");
        this.addOption(TOPIC_SMOOTHING_OPTION, "a", "Topic smoothing parameter. Default is 50/numTopics.", "-1.0");
        this.addOption((Option)DefaultOptionCreator.maxIterationsOption().withRequired(false).create());
        if (this.parseArguments(args) == null) {
            return -1;
        }
        Path input = this.getInputPath();
        Path output = this.getOutputPath();
        if (this.hasOption("overwrite")) {
            HadoopUtil.delete(this.getConf(), output);
        }
        int maxIterations = Integer.parseInt(this.getOption("maxIter"));
        int numTopics = Integer.parseInt(this.getOption(NUM_TOPICS_OPTION));
        int numWords = Integer.parseInt(this.getOption(NUM_WORDS_OPTION));
        double topicSmoothing = Double.parseDouble(this.getOption(TOPIC_SMOOTHING_OPTION));
        if (topicSmoothing < 1.0) {
            topicSmoothing = 50.0 / (double)numTopics;
        }
        boolean runSequential = false;
        this.run(this.getConf(), input, output, numTopics, numWords, topicSmoothing, maxIterations, runSequential);
        return 0;
    }

    private Path getLastKnownStatePath(Configuration conf, Path stateDir) throws IOException {
        FileSystem fs = FileSystem.get((Configuration)conf);
        Path lastPath = null;
        int maxIteration = Integer.MIN_VALUE;
        for (FileStatus fstatus : fs.globStatus(new Path(stateDir, "state-*"))) {
            try {
                int iteration = Integer.parseInt(fstatus.getPath().getName().split("-")[1]);
                if (iteration <= maxIteration) continue;
                maxIteration = iteration;
                lastPath = fstatus.getPath();
            }
            catch (NumberFormatException nfe) {
                throw new IOException(nfe);
            }
        }
        return lastPath;
    }

    private void run(Configuration conf, Path input, Path output, int numTopics, int numWords, double topicSmoothing, int maxIterations, boolean runSequential) throws IOException, InterruptedException, ClassNotFoundException {
        Path stateIn;
        Path lastKnownState = this.getLastKnownStatePath(conf, output);
        if (lastKnownState == null) {
            stateIn = new Path(output, "state-0");
            LDADriver.writeInitialState(stateIn, numTopics, numWords);
        } else {
            stateIn = lastKnownState;
        }
        conf.set(STATE_IN_KEY, stateIn.toString());
        conf.set(NUM_TOPICS_KEY, Integer.toString(numTopics));
        conf.set(NUM_WORDS_KEY, Integer.toString(numWords));
        conf.set(TOPIC_SMOOTHING_KEY, Double.toString(topicSmoothing));
        double oldLL = Double.NEGATIVE_INFINITY;
        boolean converged = false;
        for (int iteration = Integer.parseInt(stateIn.getName().split("-")[1]) + 1; !(maxIterations >= 1 && iteration > maxIterations || converged); ++iteration) {
            log.info("LDA Iteration {}", (Object)iteration);
            conf.set(STATE_IN_KEY, stateIn.toString());
            Path stateOut = new Path(output, "state-" + iteration);
            double ll = runSequential ? this.runIterationSequential(conf, input, stateOut) : LDADriver.runIteration(conf, input, stateIn, stateOut);
            double relChange = (oldLL - ll) / oldLL;
            log.info("Iteration {} finished. Log Likelihood: {}", (Object)iteration, (Object)ll);
            log.info("(Old LL: {})", (Object)oldLL);
            log.info("(Rel Change: {})", (Object)relChange);
            converged = iteration > 3 && relChange < 1.0E-5;
            stateIn = stateOut;
            oldLL = ll;
        }
        if (runSequential) {
            this.computeDocumentTopicProbabilitiesSequential(conf, input, new Path(output, "docTopics"));
        } else {
            LDADriver.computeDocumentTopicProbabilities(conf, input, stateIn, new Path(output, "docTopics"), numTopics, numWords, topicSmoothing);
        }
    }

    private static void writeInitialState(Path statePath, int numTopics, int numWords) throws IOException {
        Configuration job = new Configuration();
        FileSystem fs = statePath.getFileSystem(job);
        DoubleWritable v = new DoubleWritable();
        Random random = RandomUtils.getRandom();
        for (int k = 0; k < numTopics; ++k) {
            Path path = new Path(statePath, "part-" + k);
            SequenceFile.Writer writer = new SequenceFile.Writer(fs, job, path, IntPairWritable.class, DoubleWritable.class);
            double total = 0.0;
            for (int w = 0; w < numWords; ++w) {
                IntPairWritable kw = new IntPairWritable(k, w);
                double pseudocount = random.nextDouble() + 1.0E-8;
                total += pseudocount;
                v.set(Math.log(pseudocount));
                writer.append((Writable)kw, (Writable)v);
            }
            IntPairWritable kTsk = new IntPairWritable(k, -1);
            v.set(Math.log(total));
            writer.append((Writable)kTsk, (Writable)v);
            writer.close();
        }
    }

    private static void writeState(Configuration job, LDAState state, Path statePath) throws IOException {
        FileSystem fs = statePath.getFileSystem(job);
        DoubleWritable v = new DoubleWritable();
        for (int k = 0; k < state.getNumTopics(); ++k) {
            Path path = new Path(statePath, "part-" + k);
            SequenceFile.Writer writer = new SequenceFile.Writer(fs, job, path, IntPairWritable.class, DoubleWritable.class);
            for (int w = 0; w < state.getNumWords(); ++w) {
                IntPairWritable kw = new IntPairWritable(k, w);
                v.set(state.logProbWordGivenTopic(w, k) + state.getLogTotal(k));
                writer.append((Writable)kw, (Writable)v);
            }
            IntPairWritable kTsk = new IntPairWritable(k, -1);
            v.set(state.getLogTotal(k));
            writer.append((Writable)kTsk, (Writable)v);
            writer.close();
        }
        Path path = new Path(statePath, "part--2");
        SequenceFile.Writer writer = new SequenceFile.Writer(fs, job, path, IntPairWritable.class, DoubleWritable.class);
        IntPairWritable kTsk = new IntPairWritable(-2, -2);
        v.set(state.getLogLikelihood());
        writer.append((Writable)kTsk, (Writable)v);
        writer.close();
    }

    private static double findLL(Path statePath, Configuration job) throws IOException {
        FileSystem fs = statePath.getFileSystem(job);
        double ll = 0.0;
        for (FileStatus status : fs.globStatus(new Path(statePath, "part-*"))) {
            Path path = status.getPath();
            SequenceFileIterator iterator = new SequenceFileIterator(path, true, job);
            while (iterator.hasNext()) {
                Pair record = (Pair)iterator.next();
                if (((IntPairWritable)record.getFirst()).getFirst() != -2) continue;
                ll = ((DoubleWritable)record.getSecond()).get();
                break;
            }
            iterator.close();
        }
        return ll;
    }

    private double runIterationSequential(Configuration conf, Path input, Path stateOut) throws IOException {
        if (this.state == null) {
            this.state = LDADriver.createState(conf);
        }
        if (this.trainingCorpus == null) {
            Class<? extends Writable> keyClass = LDADriver.peekAtSequenceFileForKeyType(conf, input);
            LinkedList<Pair<Writable, VectorWritable>> corpus = new LinkedList<Pair<Writable, VectorWritable>>();
            for (FileStatus fileStatus : FileSystem.get((Configuration)conf).globStatus(new Path(input, "part-*"))) {
                Path inputPart = fileStatus.getPath();
                SequenceFile.Reader reader = new SequenceFile.Reader(FileSystem.get((Configuration)conf), inputPart, conf);
                Writable key = (Writable)ReflectionUtils.newInstance(keyClass, (Configuration)conf);
                VectorWritable value = new VectorWritable();
                while (reader.next(key, (Writable)value)) {
                    Writable nextKey = (Writable)ReflectionUtils.newInstance(keyClass, (Configuration)conf);
                    VectorWritable nextValue = new VectorWritable();
                    corpus.add(new Pair<Writable, VectorWritable>(key, value));
                    key = nextKey;
                    value = nextValue;
                }
            }
            this.trainingCorpus = corpus;
        }
        if (this.inference == null) {
            this.inference = new LDAInference(this.state);
        }
        this.newState = LDADriver.createState(conf, true);
        double ll = 0.0;
        for (Pair<Writable, VectorWritable> slice : this.trainingCorpus) {
            LDAInference.InferredDocument doc;
            Vector wordCounts = slice.getSecond().get();
            try {
                doc = this.inference.infer(wordCounts);
            }
            catch (ArrayIndexOutOfBoundsException e1) {
                throw new IllegalStateException("This is probably because the --numWords argument is set too small.  \n\tIt needs to be >= than the number of words (terms actually) in the corpus and can be \n\tlarger if some storage inefficiency can be tolerated.", e1);
            }
            Iterator iter = wordCounts.iterateNonZero();
            while (iter.hasNext()) {
                Vector.Element e = (Vector.Element)iter.next();
                int w = e.index();
                for (int k = 0; k < this.state.getNumTopics(); ++k) {
                    double vwUpdate = doc.phi(k, w) + Math.log(e.get());
                    this.newState.updateLogProbGivenTopic(w, k, vwUpdate);
                    this.newState.updateLogTotals(k, vwUpdate);
                }
                ll += doc.getLogLikelihood();
            }
        }
        this.newState.setLogLikelihood(ll);
        LDADriver.writeState(conf, this.newState, stateOut);
        this.state = this.newState;
        this.newState = null;
        return ll;
    }

    private static double runIteration(Configuration conf, Path input, Path stateIn, Path stateOut) throws IOException, InterruptedException, ClassNotFoundException {
        conf.set(STATE_IN_KEY, stateIn.toString());
        Job job = new Job(conf, "LDA Driver running runIteration over stateIn: " + stateIn);
        job.setOutputKeyClass(IntPairWritable.class);
        job.setOutputValueClass(DoubleWritable.class);
        FileInputFormat.addInputPaths((Job)job, (String)input.toString());
        FileOutputFormat.setOutputPath((Job)job, (Path)stateOut);
        job.setMapperClass(LDAWordTopicMapper.class);
        job.setReducerClass(LDAReducer.class);
        job.setCombinerClass(LDAReducer.class);
        job.setOutputFormatClass(SequenceFileOutputFormat.class);
        job.setInputFormatClass(SequenceFileInputFormat.class);
        job.setJarByClass(LDADriver.class);
        if (!job.waitForCompletion(true)) {
            throw new InterruptedException("LDA Iteration failed processing " + stateIn);
        }
        return LDADriver.findLL(stateOut, conf);
    }

    private static void computeDocumentTopicProbabilities(Configuration conf, Path input, Path stateIn, Path outputPath, int numTopics, int numWords, double topicSmoothing) throws IOException, InterruptedException, ClassNotFoundException {
        conf.set(STATE_IN_KEY, stateIn.toString());
        conf.set(NUM_TOPICS_KEY, Integer.toString(numTopics));
        conf.set(NUM_WORDS_KEY, Integer.toString(numWords));
        conf.set(TOPIC_SMOOTHING_KEY, Double.toString(topicSmoothing));
        Job job = new Job(conf, "LDA Driver computing p(topic|doc) for all docs/topics with stateIn: " + stateIn);
        job.setOutputKeyClass(LDADriver.peekAtSequenceFileForKeyType(conf, input));
        job.setOutputValueClass(VectorWritable.class);
        FileInputFormat.addInputPaths((Job)job, (String)input.toString());
        FileOutputFormat.setOutputPath((Job)job, (Path)outputPath);
        job.setMapperClass(LDADocumentTopicMapper.class);
        job.setNumReduceTasks(0);
        job.setOutputFormatClass(SequenceFileOutputFormat.class);
        job.setInputFormatClass(SequenceFileInputFormat.class);
        job.setJarByClass(LDADriver.class);
        if (!job.waitForCompletion(true)) {
            throw new InterruptedException("LDA failed to compute and output document topic probabilities with: " + stateIn);
        }
    }

    private void computeDocumentTopicProbabilitiesSequential(Configuration conf, Path input, Path outputPath) throws IOException {
        FileSystem fs = input.getFileSystem(conf);
        Class<? extends Writable> keyClass = LDADriver.peekAtSequenceFileForKeyType(conf, input);
        SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, outputPath, keyClass, VectorWritable.class);
        Writable key = (Writable)ReflectionUtils.newInstance(keyClass, (Configuration)conf);
        VectorWritable vw = new VectorWritable();
        for (Pair<Writable, VectorWritable> slice : this.trainingCorpus) {
            Vector wordCounts = slice.getSecond().get();
            try {
                this.inference.infer(wordCounts);
            }
            catch (ArrayIndexOutOfBoundsException e1) {
                throw new IllegalStateException("This is probably because the --numWords argument is set too small.  \n\tIt needs to be >= than the number of words (terms actually) in the corpus and can be \n\tlarger if some storage inefficiency can be tolerated.", e1);
            }
            writer.append(key, (Writable)vw);
        }
        writer.close();
    }

    private static Class<? extends Writable> peekAtSequenceFileForKeyType(Configuration conf, Path input) {
        try {
            SequenceFile.Reader reader = new SequenceFile.Reader(FileSystem.get((Configuration)conf), input, conf);
            return reader.getKeyClass();
        }
        catch (IOException ioe) {
            return Text.class;
        }
    }
}

