package org.grouplens.lenskit.transform.normalize;

import it.unimi.dsi.fastutil.doubles.DoubleIterator;
import java.io.Serializable;
import java.util.Iterator;
import javax.inject.Inject;
import javax.inject.Provider;
import org.grouplens.grapht.annotation.DefaultProvider;
import org.grouplens.lenskit.baseline.MeanDamping;
import org.grouplens.lenskit.core.Shareable;
import org.grouplens.lenskit.core.Transient;
import org.grouplens.lenskit.cursors.Cursor;
import org.grouplens.lenskit.data.dao.DataAccessObject;
import org.grouplens.lenskit.data.event.Rating;
import org.grouplens.lenskit.data.pref.Preference;
import org.grouplens.lenskit.vectors.MutableSparseVector;
import org.grouplens.lenskit.vectors.SparseVector;
import org.grouplens.lenskit.vectors.VectorEntry;

@DefaultProvider(Builder.class)
@Shareable
/* loaded from: input_file:org/grouplens/lenskit/transform/normalize/MeanVarianceNormalizer.class */
public class MeanVarianceNormalizer extends AbstractVectorNormalizer implements Serializable {
    private static final long serialVersionUID = -7890335060797112954L;
    private final double damping;
    private final double globalVariance;

    /* loaded from: input_file:org/grouplens/lenskit/transform/normalize/MeanVarianceNormalizer$Builder.class */
    public static class Builder implements Provider<MeanVarianceNormalizer> {
        private final double damping;
        private final DataAccessObject dao;

        @Inject
        public Builder(@Transient DataAccessObject dataAccessObject, @MeanDamping double d) {
            this.dao = dataAccessObject;
            this.damping = d;
        }

        /* renamed from: get, reason: merged with bridge method [inline-methods] */
        public MeanVarianceNormalizer m62get() {
            double d = 0.0d;
            if (this.damping != 0.0d) {
                double d2 = 0.0d;
                Cursor events = this.dao.getEvents(Rating.class);
                int i = 0;
                Iterator it = events.fast().iterator();
                while (it.hasNext()) {
                    Preference preference = ((Rating) it.next()).getPreference();
                    if (preference != null) {
                        d2 += preference.getValue();
                        i++;
                    }
                }
                events.close();
                double d3 = d2 / i;
                Cursor events2 = this.dao.getEvents(Rating.class);
                double d4 = 0.0d;
                Iterator it2 = events2.fast().iterator();
                while (it2.hasNext()) {
                    Preference preference2 = ((Rating) it2.next()).getPreference();
                    if (preference2 != null) {
                        double value = d3 - preference2.getValue();
                        d4 += value * value;
                    }
                }
                events2.close();
                d = d4 / i;
            }
            return new MeanVarianceNormalizer(this.damping, d);
        }
    }

    /* loaded from: input_file:org/grouplens/lenskit/transform/normalize/MeanVarianceNormalizer$Transform.class */
    class Transform implements VectorTransformation {
        private final double mean;
        private final double stdev;

        public Transform(double d, double d2) {
            this.mean = d;
            this.stdev = d2;
        }

        @Override // org.grouplens.lenskit.transform.normalize.VectorTransformation
        public MutableSparseVector apply(MutableSparseVector mutableSparseVector) {
            for (VectorEntry vectorEntry : mutableSparseVector.fast()) {
                mutableSparseVector.set(vectorEntry.getKey(), this.stdev == 0.0d ? 0.0d : (vectorEntry.getValue() - this.mean) / this.stdev);
            }
            return mutableSparseVector;
        }

        @Override // org.grouplens.lenskit.transform.normalize.VectorTransformation
        public MutableSparseVector unapply(MutableSparseVector mutableSparseVector) {
            for (VectorEntry vectorEntry : mutableSparseVector.fast()) {
                mutableSparseVector.set(vectorEntry.getKey(), this.stdev == 0.0d ? this.mean : (vectorEntry.getValue() * this.stdev) + this.mean);
            }
            return mutableSparseVector;
        }
    }

    public MeanVarianceNormalizer() {
        this(0.0d, 0.0d);
    }

    public MeanVarianceNormalizer(double d, double d2) {
        this.damping = d;
        this.globalVariance = d2;
    }

    public double getDamping() {
        return this.damping;
    }

    public double getGlobalVariance() {
        return this.globalVariance;
    }

    @Override // org.grouplens.lenskit.transform.normalize.VectorNormalizer
    public VectorTransformation makeTransformation(SparseVector sparseVector) {
        if (sparseVector.isEmpty()) {
            return new IdentityVectorNormalizer().makeTransformation(sparseVector);
        }
        double mean = sparseVector.mean();
        double d = 0.0d;
        DoubleIterator it = sparseVector.values().iterator();
        while (it.hasNext()) {
            double nextDouble = it.nextDouble() - mean;
            d += nextDouble * nextDouble;
        }
        return new Transform(mean, Math.sqrt((d + (this.damping * this.globalVariance)) / (sparseVector.size() + this.damping)));
    }
}
