/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.cf.taste.hadoop.als;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.hadoop.mapreduce.lib.partition.HashPartitioner;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
import org.apache.mahout.cf.taste.hadoop.als.FeatureVectorWithRatingWritable;
import org.apache.mahout.cf.taste.hadoop.als.IndexedVarIntWritable;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.VarIntWritable;
import org.apache.mahout.math.VarLongWritable;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.als.AlternateLeastSquaresSolver;

public class ParallelALSFactorizationJob
extends AbstractJob {
    static final String NUM_FEATURES = ParallelALSFactorizationJob.class.getName() + ".numFeatures";
    static final String LAMBDA = ParallelALSFactorizationJob.class.getName() + ".lambda";
    static final String MAP_TRANSPOSED = ParallelALSFactorizationJob.class.getName() + ".mapTransposed";
    static final String STEP_ONE = "fixMcomputeU";
    static final String STEP_TWO = "fixUcomputeM";
    private String tempDir;

    public static void main(String[] args) throws Exception {
        ToolRunner.run((Tool)new ParallelALSFactorizationJob(), (String[])args);
    }

    public int run(String[] args) throws Exception {
        this.addInputOption();
        this.addOutputOption();
        this.addOption("lambda", "l", "", true);
        this.addOption("numFeatures", "f", "", true);
        this.addOption("numIterations", "i", "", true);
        Map<String, String> parsedArgs = this.parseArguments(args);
        if (parsedArgs == null) {
            return -1;
        }
        int numFeatures = Integer.parseInt(parsedArgs.get("--numFeatures"));
        int numIterations = Integer.parseInt(parsedArgs.get("--numIterations"));
        double lambda = Double.parseDouble(parsedArgs.get("--lambda"));
        this.tempDir = parsedArgs.get("--tempDir");
        Job itemRatings = this.prepareJob(this.getInputPath(), this.pathToItemRatings(), TextInputFormat.class, PrefsToRatingsMapper.class, VarIntWritable.class, FeatureVectorWithRatingWritable.class, Reducer.class, VarIntWritable.class, FeatureVectorWithRatingWritable.class, SequenceFileOutputFormat.class);
        itemRatings.waitForCompletion(true);
        Job userRatings = this.prepareJob(this.getInputPath(), this.pathToUserRatings(), TextInputFormat.class, PrefsToRatingsMapper.class, VarIntWritable.class, FeatureVectorWithRatingWritable.class, Reducer.class, VarIntWritable.class, FeatureVectorWithRatingWritable.class, SequenceFileOutputFormat.class);
        userRatings.getConfiguration().setBoolean(MAP_TRANSPOSED, Boolean.TRUE.booleanValue());
        userRatings.waitForCompletion(true);
        Job initializeM = this.prepareJob(this.getInputPath(), this.pathToM(-1), TextInputFormat.class, ItemIDRatingMapper.class, VarLongWritable.class, FloatWritable.class, InitializeMReducer.class, VarIntWritable.class, FeatureVectorWithRatingWritable.class, SequenceFileOutputFormat.class);
        initializeM.getConfiguration().setInt(NUM_FEATURES, numFeatures);
        initializeM.waitForCompletion(true);
        for (int n = 0; n < numIterations; ++n) {
            this.iterate(n, numFeatures, lambda);
        }
        Job uAsMatrix = this.prepareJob(this.pathToU(numIterations - 1), new Path(this.getOutputPath(), "U"), SequenceFileInputFormat.class, ToMatrixMapper.class, IntWritable.class, VectorWritable.class, Reducer.class, IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class);
        uAsMatrix.waitForCompletion(true);
        Job mAsMatrix = this.prepareJob(this.pathToM(numIterations - 1), new Path(this.getOutputPath(), "M"), SequenceFileInputFormat.class, ToMatrixMapper.class, IntWritable.class, VectorWritable.class, Reducer.class, IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class);
        mAsMatrix.waitForCompletion(true);
        return 0;
    }

    private void iterate(int currentIteration, int numFeatures, double lambda) throws IOException, ClassNotFoundException, InterruptedException {
        this.joinAndSolve(this.pathToM(currentIteration - 1), this.pathToItemRatings(), this.pathToU(currentIteration), numFeatures, lambda, currentIteration, STEP_ONE);
        this.joinAndSolve(this.pathToU(currentIteration), this.pathToUserRatings(), this.pathToM(currentIteration), numFeatures, lambda, currentIteration, STEP_TWO);
    }

    private void joinAndSolve(Path featureMatrix, Path ratingMatrix, Path outputPath, int numFeatures, double lambda, int currentIteration, String step) throws IOException, ClassNotFoundException, InterruptedException {
        Path joinPath = new Path(ratingMatrix.toString() + ',' + featureMatrix);
        Path featureVectorWithRatingPath = this.joinAndSolvePath(currentIteration, step);
        Job joinToFeatureVectorWithRating = this.prepareJob(joinPath, featureVectorWithRatingPath, SequenceFileInputFormat.class, Mapper.class, VarIntWritable.class, FeatureVectorWithRatingWritable.class, JoinFeatureVectorAndRatingsReducer.class, IndexedVarIntWritable.class, FeatureVectorWithRatingWritable.class, SequenceFileOutputFormat.class);
        joinToFeatureVectorWithRating.waitForCompletion(true);
        Job solve = this.prepareJob(featureVectorWithRatingPath, outputPath, SequenceFileInputFormat.class, Mapper.class, IndexedVarIntWritable.class, FeatureVectorWithRatingWritable.class, SolvingReducer.class, VarIntWritable.class, FeatureVectorWithRatingWritable.class, SequenceFileOutputFormat.class);
        Configuration solveConf = solve.getConfiguration();
        solve.setPartitionerClass(HashPartitioner.class);
        solve.setGroupingComparatorClass(IndexedVarIntWritable.GroupingComparator.class);
        solveConf.setInt(NUM_FEATURES, numFeatures);
        solveConf.set(LAMBDA, String.valueOf(lambda));
        solve.waitForCompletion(true);
    }

    private Path joinAndSolvePath(int currentIteration, String step) {
        return new Path(this.tempDir, "joinAndSolve-" + currentIteration + '-' + step);
    }

    private Path pathToM(int iteration) {
        return new Path(this.tempDir, "M-" + iteration);
    }

    private Path pathToU(int iteration) {
        return new Path(this.tempDir, "U-" + iteration);
    }

    private Path pathToItemRatings() {
        return new Path(this.tempDir, "itemsAsFeatureWithRatingWritable");
    }

    private Path pathToUserRatings() {
        return new Path(this.tempDir, "usersAsFeatureWithRatingWritable");
    }

    static class InitializeMReducer
    extends Reducer<VarLongWritable, FloatWritable, VarIntWritable, FeatureVectorWithRatingWritable> {
        private int numFeatures;
        private static final Random random = RandomUtils.getRandom();

        InitializeMReducer() {
        }

        protected void setup(Reducer.Context ctx) throws IOException, InterruptedException {
            super.setup(ctx);
            this.numFeatures = ctx.getConfiguration().getInt(NUM_FEATURES, -1);
            if (this.numFeatures < 1) {
                throw new IllegalStateException("numFeatures was not set correctly!");
            }
        }

        protected void reduce(VarLongWritable itemID, Iterable<FloatWritable> ratings, Reducer.Context ctx) throws IOException, InterruptedException {
            FullRunningAverage averageRating = new FullRunningAverage();
            for (FloatWritable rating : ratings) {
                averageRating.addDatum(rating.get());
            }
            int itemIDIndex = TasteHadoopUtils.idToIndex(itemID.get());
            DenseVector columnOfM = new DenseVector(this.numFeatures);
            columnOfM.setQuick(0, averageRating.getAverage());
            for (int n = 1; n < this.numFeatures; ++n) {
                columnOfM.setQuick(n, random.nextDouble());
            }
            ctx.write((Object)new VarIntWritable(itemIDIndex), (Object)new FeatureVectorWithRatingWritable(itemIDIndex, (Vector)columnOfM));
        }
    }

    static class ItemIDRatingMapper
    extends Mapper<LongWritable, Text, VarLongWritable, FloatWritable> {
        ItemIDRatingMapper() {
        }

        protected void map(LongWritable key, Text value, Mapper.Context ctx) throws IOException, InterruptedException {
            String[] tokens = TasteHadoopUtils.splitPrefTokens(value.toString());
            ctx.write((Object)new VarLongWritable(Long.parseLong(tokens[1])), (Object)new FloatWritable(Float.parseFloat(tokens[2])));
        }
    }

    static class SolvingReducer
    extends Reducer<IndexedVarIntWritable, FeatureVectorWithRatingWritable, VarIntWritable, FeatureVectorWithRatingWritable> {
        private int numFeatures;
        private double lambda;
        private AlternateLeastSquaresSolver solver;

        SolvingReducer() {
        }

        protected void setup(Reducer.Context ctx) throws IOException, InterruptedException {
            super.setup(ctx);
            this.lambda = Double.parseDouble(ctx.getConfiguration().get(LAMBDA));
            this.numFeatures = ctx.getConfiguration().getInt(NUM_FEATURES, -1);
            if (this.numFeatures < 1) {
                throw new IllegalStateException("numFeatures was not set correctly!");
            }
            this.solver = new AlternateLeastSquaresSolver();
        }

        protected void reduce(IndexedVarIntWritable key, Iterable<FeatureVectorWithRatingWritable> values, Reducer.Context ctx) throws IOException, InterruptedException {
            ArrayList<Vector> UorMColumns = new ArrayList<Vector>();
            RandomAccessSparseVector ratingVector = new RandomAccessSparseVector(Integer.MAX_VALUE);
            int n = 0;
            for (FeatureVectorWithRatingWritable value : values) {
                ratingVector.setQuick(n++, (double)value.getRating().floatValue());
                UorMColumns.add(value.getFeatureVector());
            }
            Vector uiOrmj = this.solver.solve(UorMColumns, (Vector)new SequentialAccessSparseVector((Vector)ratingVector), this.lambda, this.numFeatures);
            ctx.write((Object)new VarIntWritable(key.getValue()), (Object)new FeatureVectorWithRatingWritable(key.getValue(), uiOrmj));
        }
    }

    static class JoinFeatureVectorAndRatingsReducer
    extends Reducer<VarIntWritable, FeatureVectorWithRatingWritable, IndexedVarIntWritable, FeatureVectorWithRatingWritable> {
        JoinFeatureVectorAndRatingsReducer() {
        }

        protected void reduce(VarIntWritable id, Iterable<FeatureVectorWithRatingWritable> values, Reducer.Context ctx) throws IOException, InterruptedException {
            Vector featureVector = null;
            HashMap<Integer, Float> ratings = new HashMap<Integer, Float>();
            for (FeatureVectorWithRatingWritable featureVectorWithRatingWritable : values) {
                if (featureVectorWithRatingWritable.getFeatureVector() == null) {
                    ratings.put(featureVectorWithRatingWritable.getIDIndex(), featureVectorWithRatingWritable.getRating());
                    continue;
                }
                featureVector = featureVectorWithRatingWritable.getFeatureVector().clone();
            }
            if (featureVector == null || ratings.isEmpty()) {
                throw new IllegalStateException("Unable to join data for " + id);
            }
            for (Map.Entry entry : ratings.entrySet()) {
                ctx.write((Object)new IndexedVarIntWritable((Integer)entry.getKey(), id.get()), (Object)new FeatureVectorWithRatingWritable(id.get(), (Float)entry.getValue(), featureVector));
            }
        }
    }

    static class PrefsToRatingsMapper
    extends Mapper<LongWritable, Text, VarIntWritable, FeatureVectorWithRatingWritable> {
        private boolean transpose;

        PrefsToRatingsMapper() {
        }

        protected void setup(Mapper.Context ctx) throws IOException, InterruptedException {
            this.transpose = ctx.getConfiguration().getBoolean(MAP_TRANSPOSED, false);
        }

        protected void map(LongWritable offset, Text line, Mapper.Context ctx) throws IOException, InterruptedException {
            String[] tokens = TasteHadoopUtils.splitPrefTokens(line.toString());
            int keyIDIndex = TasteHadoopUtils.idToIndex(Long.parseLong(tokens[this.transpose ? 0 : 1]));
            int valueIDIndex = TasteHadoopUtils.idToIndex(Long.parseLong(tokens[this.transpose ? 1 : 0]));
            float rating = Float.parseFloat(tokens[2]);
            ctx.write((Object)new VarIntWritable(keyIDIndex), (Object)new FeatureVectorWithRatingWritable(valueIDIndex, rating));
        }
    }

    static class ToMatrixMapper
    extends Mapper<VarIntWritable, FeatureVectorWithRatingWritable, IntWritable, VectorWritable> {
        ToMatrixMapper() {
        }

        protected void map(VarIntWritable key, FeatureVectorWithRatingWritable value, Mapper.Context ctx) throws IOException, InterruptedException {
            ctx.write((Object)new IntWritable(key.get()), (Object)new VectorWritable(value.getFeatureVector()));
        }
    }
}

