/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.cf.taste.impl.recommender.svd;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.apache.mahout.cf.taste.impl.recommender.svd.AbstractFactorizer;
import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization;
import org.apache.mahout.cf.taste.impl.recommender.svd.SVDPreference;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.model.Preference;
import org.apache.mahout.common.RandomUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class ExpectationMaximizationSVDFactorizer
extends AbstractFactorizer {
    private static final Logger log = LoggerFactory.getLogger(ExpectationMaximizationSVDFactorizer.class);
    private final double learningRate;
    private final double preventOverfitting;
    private final int numFeatures;
    private final int numIterations;
    private final double randomNoise;
    private double[][] leftVectors;
    private double[][] rightVectors;
    private final DataModel dataModel;
    private List<SVDPreference> cachedPreferences;
    private double defaultValue;
    private double interval;

    public ExpectationMaximizationSVDFactorizer(DataModel dataModel, int numFeatures, int numIterations) throws TasteException {
        this(dataModel, numFeatures, 0.005, 0.02, 0.005, numIterations);
    }

    public ExpectationMaximizationSVDFactorizer(DataModel dataModel, int numFeatures, double learningRate, double preventOverfitting, double randomNoise, int numIterations) throws TasteException {
        super(dataModel);
        this.dataModel = dataModel;
        this.numFeatures = numFeatures;
        this.numIterations = numIterations;
        this.learningRate = learningRate;
        this.preventOverfitting = preventOverfitting;
        this.randomNoise = randomNoise;
    }

    @Override
    public Factorization factorize() throws TasteException {
        Random random = RandomUtils.getRandom();
        this.leftVectors = new double[this.dataModel.getNumUsers()][this.numFeatures];
        this.rightVectors = new double[this.dataModel.getNumItems()][this.numFeatures];
        double average = this.getAveragePreference();
        double prefInterval = this.dataModel.getMaxPreference() - this.dataModel.getMinPreference();
        this.defaultValue = Math.sqrt((average - prefInterval * 0.1) / (double)this.numFeatures);
        this.interval = prefInterval * 0.1 / (double)this.numFeatures;
        for (int feature = 0; feature < this.numFeatures; ++feature) {
            for (int userIndex = 0; userIndex < this.dataModel.getNumUsers(); ++userIndex) {
                this.leftVectors[userIndex][feature] = this.defaultValue + (random.nextDouble() - 0.5) * this.interval * this.randomNoise;
            }
            for (int itemIndex = 0; itemIndex < this.dataModel.getNumItems(); ++itemIndex) {
                this.rightVectors[itemIndex][feature] = this.defaultValue + (random.nextDouble() - 0.5) * this.interval * this.randomNoise;
            }
        }
        this.cachedPreferences = new ArrayList<SVDPreference>(this.dataModel.getNumUsers());
        this.cachePreferences();
        double rmse = this.dataModel.getMaxPreference() - this.dataModel.getMinPreference();
        for (int ii = 0; ii < this.numFeatures; ++ii) {
            Collections.shuffle(this.cachedPreferences, random);
            for (int i = 0; i < this.numIterations; ++i) {
                double err = 0.0;
                for (SVDPreference pref : this.cachedPreferences) {
                    int useridx = this.userIndex(pref.getUserID());
                    int itemidx = this.itemIndex(pref.getItemID());
                    err += Math.pow(this.train(useridx, itemidx, ii, pref), 2.0);
                }
                rmse = Math.sqrt(err / (double)this.cachedPreferences.size());
            }
            if (ii < this.numFeatures - 1) {
                for (SVDPreference pref : this.cachedPreferences) {
                    int useridx = this.userIndex(pref.getUserID());
                    int itemidx = this.itemIndex(pref.getItemID());
                    this.buildCache(useridx, itemidx, ii, pref);
                }
            }
            log.info("Finished training feature {} with RMSE {}.", (Object)ii, (Object)rmse);
        }
        return this.createFactorization(this.leftVectors, this.rightVectors);
    }

    double getAveragePreference() throws TasteException {
        FullRunningAverage average = new FullRunningAverage();
        LongPrimitiveIterator it = this.dataModel.getUserIDs();
        while (it.hasNext()) {
            for (Preference pref : this.dataModel.getPreferencesFromUser(it.nextLong())) {
                average.addDatum(pref.getValue());
            }
        }
        return average.getAverage();
    }

    private double train(int i, int j, int f, SVDPreference pref) {
        double[] leftVectorI = this.leftVectors[i];
        double[] rightVectorJ = this.rightVectors[j];
        double prediction = this.predictRating(i, j, f, pref, true);
        double err = (double)pref.getValue() - prediction;
        int n = f;
        leftVectorI[n] = leftVectorI[n] + this.learningRate * (err * rightVectorJ[f] - this.preventOverfitting * leftVectorI[f]);
        int n2 = f;
        rightVectorJ[n2] = rightVectorJ[n2] + this.learningRate * (err * leftVectorI[f] - this.preventOverfitting * rightVectorJ[f]);
        return err;
    }

    private void buildCache(int i, int j, int k, SVDPreference pref) {
        pref.setCache(this.predictRating(i, j, k, pref, false));
    }

    private double predictRating(int i, int j, int f, SVDPreference pref, boolean trailing) {
        float minPreference = this.dataModel.getMinPreference();
        float maxPreference = this.dataModel.getMaxPreference();
        double sum = pref.getCache();
        sum += this.leftVectors[i][f] * this.rightVectors[j][f];
        if (trailing) {
            if ((sum += (double)(this.numFeatures - f - 1) * (this.defaultValue + this.interval) * (this.defaultValue + this.interval)) > (double)maxPreference) {
                sum = maxPreference;
            } else if (sum < (double)minPreference) {
                sum = minPreference;
            }
        }
        return sum;
    }

    private void cachePreferences() throws TasteException {
        this.cachedPreferences.clear();
        LongPrimitiveIterator it = this.dataModel.getUserIDs();
        while (it.hasNext()) {
            for (Preference pref : this.dataModel.getPreferencesFromUser(it.nextLong())) {
                this.cachedPreferences.add(new SVDPreference(pref.getUserID(), pref.getItemID(), pref.getValue(), 0.0));
            }
        }
    }
}

