package org.grouplens.lenskit.baseline;

import it.unimi.dsi.fastutil.longs.Long2DoubleMap;
import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap;
import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap;
import it.unimi.dsi.fastutil.longs.LongIterator;
import it.unimi.dsi.fastutil.longs.LongSortedSet;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import org.grouplens.lenskit.RecommenderComponentBuilder;
import org.grouplens.lenskit.data.pref.Preference;
import org.grouplens.lenskit.data.vector.MutableSparseVector;
import org.grouplens.lenskit.data.vector.UserVector;
import org.grouplens.lenskit.params.MeanSmoothing;
import org.grouplens.lenskit.params.meta.Built;
import org.grouplens.lenskit.util.CollectionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Built
/* loaded from: input_file:org/grouplens/lenskit/baseline/ItemMeanPredictor.class */
public class ItemMeanPredictor implements BaselinePredictor {
    private static final long serialVersionUID = 1;
    private static final Logger logger = LoggerFactory.getLogger(ItemMeanPredictor.class);
    private final Long2DoubleMap itemMeans;
    protected final double globalMean;

    /* loaded from: input_file:org/grouplens/lenskit/baseline/ItemMeanPredictor$Builder.class */
    public static class Builder extends RecommenderComponentBuilder<ItemMeanPredictor> {
        private double damping = 0.0d;

        @MeanSmoothing
        public void setDamping(double d) {
            this.damping = d;
        }

        @Override // org.grouplens.lenskit.Builder
        public ItemMeanPredictor build() {
            Long2DoubleOpenHashMap long2DoubleOpenHashMap = new Long2DoubleOpenHashMap();
            return new ItemMeanPredictor(long2DoubleOpenHashMap, ItemMeanPredictor.computeItemAverages(this.snapshot.getRatings().fastIterator(), this.damping, long2DoubleOpenHashMap));
        }
    }

    public ItemMeanPredictor(Long2DoubleMap long2DoubleMap, double d) {
        if (long2DoubleMap instanceof Serializable) {
            this.itemMeans = long2DoubleMap;
        } else {
            this.itemMeans = new Long2DoubleOpenHashMap(long2DoubleMap);
        }
        this.globalMean = d;
    }

    public static double computeItemAverages(Iterator<? extends Preference> it, double d, Long2DoubleMap long2DoubleMap) {
        double d2 = 0.0d;
        int i = 0;
        long2DoubleMap.defaultReturnValue(0.0d);
        Long2IntOpenHashMap long2IntOpenHashMap = new Long2IntOpenHashMap();
        long2IntOpenHashMap.defaultReturnValue(0);
        while (it.hasNext()) {
            Preference next = it.next();
            long itemId = next.getItemId();
            double value = next.getValue();
            d2 += value;
            i++;
            long2DoubleMap.put(itemId, value + long2DoubleMap.get(itemId));
            long2IntOpenHashMap.put(itemId, 1 + long2IntOpenHashMap.get(itemId));
        }
        double d3 = i > 0 ? d2 / i : 0.0d;
        logger.debug("Computed global mean {} for {} items", Double.valueOf(d3), Integer.valueOf(long2DoubleMap.size()));
        logger.debug("Computing item means, smoothing={}", Double.valueOf(d));
        LongIterator it2 = long2IntOpenHashMap.keySet().iterator();
        while (it2.hasNext()) {
            long nextLong = it2.nextLong();
            double d4 = long2IntOpenHashMap.get(nextLong) + d;
            double d5 = long2DoubleMap.get(nextLong) + (d * d3);
            double d6 = 0.0d;
            if (d4 > 0.0d) {
                d6 = (d5 / d4) - d3;
            }
            long2DoubleMap.put(nextLong, d6);
        }
        return d3;
    }

    @Override // org.grouplens.lenskit.baseline.BaselinePredictor
    public MutableSparseVector predict(UserVector userVector, Collection<Long> collection) {
        long[] longArray = CollectionUtils.fastCollection(collection).toLongArray();
        if (!(collection instanceof LongSortedSet)) {
            Arrays.sort(longArray);
        }
        double[] dArr = new double[longArray.length];
        for (int i = 0; i < longArray.length; i++) {
            dArr[i] = getItemMean(longArray[i]);
        }
        return MutableSparseVector.wrap(longArray, dArr);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double getItemMean(long j) {
        return this.globalMean + this.itemMeans.get(j);
    }
}
