package org.wikibrain.sr.vector;

import com.typesafe.config.Config;
import gnu.trove.map.TIntFloatMap;
import gnu.trove.map.hash.TIntDoubleHashMap;
import gnu.trove.map.hash.TIntFloatHashMap;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import java.io.File;
import java.io.IOException;
import java.util.Iterator;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.matrix.MatrixRow;
import org.wikibrain.matrix.SparseMatrix;
import org.wikibrain.matrix.SparseMatrixRow;
import org.wikibrain.sr.SRResultList;
import org.wikibrain.sr.utils.Leaderboard;
import org.wikibrain.sr.utils.SimUtils;
import org.wikibrain.utils.WpIOUtils;

/* loaded from: input_file:org/wikibrain/sr/vector/CosineSimilarity.class */
public class CosineSimilarity implements VectorSimilarity {
    private static final Logger LOG = LoggerFactory.getLogger(CosineSimilarity.class);
    private TIntFloatHashMap lengths = new TIntFloatHashMap();
    private TIntSet idsInResults = new TIntHashSet();
    private int maxResults = -1;
    private SparseMatrix features;
    private SparseMatrix transpose;

    /* loaded from: input_file:org/wikibrain/sr/vector/CosineSimilarity$Provider.class */
    public static class Provider extends org.wikibrain.conf.Provider<VectorSimilarity> {
        public Provider(Configurator configurator, Configuration configuration) throws ConfigurationException {
            super(configurator, configuration);
        }

        public Class getType() {
            return VectorSimilarity.class;
        }

        public String getPath() {
            return "sr.metric.similarity";
        }

        public VectorSimilarity get(String str, Config config, Map<String, String> map) throws ConfigurationException {
            if (config.getString("type").equals("cosine")) {
                return new CosineSimilarity();
            }
            return null;
        }

        /* renamed from: get, reason: collision with other method in class */
        public /* bridge */ /* synthetic */ Object m53get(String str, Config config, Map map) throws ConfigurationException {
            return get(str, config, (Map<String, String>) map);
        }
    }

    @Override // org.wikibrain.sr.vector.VectorSimilarity
    public synchronized void setMatrices(SparseMatrix sparseMatrix, SparseMatrix sparseMatrix2, File file) throws IOException {
        this.features = sparseMatrix;
        this.transpose = sparseMatrix2;
        File file2 = new File(file, "cosineSimilarity-ids.bin");
        File file3 = new File(file, "cosineSimilarity-lengths.bin");
        File file4 = new File(file, "cosineSimilarity-maxResults.bin");
        if (file3.exists() && file3.lastModified() >= sparseMatrix.lastModified() && file2.exists() && file2.lastModified() >= sparseMatrix2.lastModified()) {
            LOG.info("reading matrix information from cache");
            this.lengths = (TIntFloatHashMap) WpIOUtils.readObjectFromFile(file3);
            this.idsInResults = (TIntSet) WpIOUtils.readObjectFromFile(file2);
            this.maxResults = ((Integer) WpIOUtils.readObjectFromFile(file4)).intValue();
            return;
        }
        LOG.info("building cached matrix information");
        this.lengths.clear();
        this.idsInResults.clear();
        this.maxResults = 0;
        Iterator it = sparseMatrix.iterator();
        while (it.hasNext()) {
            SparseMatrixRow sparseMatrixRow = (SparseMatrixRow) it.next();
            this.lengths.put(sparseMatrixRow.getRowIndex(), (float) sparseMatrixRow.getNorm());
            this.maxResults = Math.max(this.maxResults, sparseMatrixRow.getNumCols());
        }
        this.idsInResults.addAll(sparseMatrix2.getRowIds());
        WpIOUtils.writeObjectToFile(file3, this.lengths);
        WpIOUtils.writeObjectToFile(file2, this.idsInResults);
        WpIOUtils.writeObjectToFile(file4, Integer.valueOf(this.maxResults));
    }

    @Override // org.wikibrain.sr.vector.VectorSimilarity
    public double similarity(MatrixRow matrixRow, MatrixRow matrixRow2) {
        return SimUtils.cosineSimilarity(matrixRow, matrixRow2);
    }

    @Override // org.wikibrain.sr.vector.VectorSimilarity
    public double similarity(TIntFloatMap tIntFloatMap, TIntFloatMap tIntFloatMap2) {
        return SimUtils.cosineSimilarity(tIntFloatMap, tIntFloatMap2);
    }

    @Override // org.wikibrain.sr.vector.VectorSimilarity
    public SRResultList mostSimilar(TIntFloatMap tIntFloatMap, int i, TIntSet tIntSet) throws IOException {
        return (tIntSet == null || tIntSet.size() >= 10000) ? mostSimilarWithInvertedIndex(tIntFloatMap, i, tIntSet) : mostSimilarWithRegularIndex(tIntFloatMap, i, tIntSet);
    }

    private SRResultList mostSimilarWithRegularIndex(TIntFloatMap tIntFloatMap, int i, TIntSet tIntSet) throws IOException {
        Leaderboard leaderboard = new Leaderboard(i);
        double norm = norm(tIntFloatMap);
        for (int i2 : tIntSet.toArray()) {
            SparseMatrixRow row = this.features.getRow(i2);
            if (row != null) {
                double d = 0.0d;
                for (int i3 = 0; i3 < row.getNumCols(); i3++) {
                    if (tIntFloatMap.get(row.getColIndex(i3)) > 0.0f) {
                        d += r0 + row.getColValue(i3);
                    }
                }
                leaderboard.tallyScore(i2, d / (this.lengths.get(i2) * norm));
            }
        }
        return leaderboard.getTop();
    }

    private SRResultList mostSimilarWithInvertedIndex(TIntFloatMap tIntFloatMap, int i, TIntSet tIntSet) throws IOException {
        TIntDoubleHashMap tIntDoubleHashMap = new TIntDoubleHashMap(Math.max(100000, i * 5));
        for (int i2 : tIntFloatMap.keys()) {
            float f = tIntFloatMap.get(i2);
            SparseMatrixRow row = this.transpose.getRow(i2);
            if (row != null) {
                int numCols = row.getNumCols();
                for (int i3 = 0; i3 < numCols; i3++) {
                    int colIndex = row.getColIndex(i3);
                    if (tIntSet == null || tIntSet.contains(colIndex)) {
                        float colValue = row.getColValue(i3);
                        tIntDoubleHashMap.adjustOrPutValue(colIndex, f * colValue, f * colValue);
                    }
                }
            }
        }
        Leaderboard leaderboard = new Leaderboard(i);
        double norm = norm(tIntFloatMap);
        for (int i4 : tIntDoubleHashMap.keys()) {
            leaderboard.tallyScore(i4, tIntDoubleHashMap.get(i4) / (this.lengths.get(i4) * norm));
        }
        return leaderboard.getTop();
    }

    @Override // org.wikibrain.sr.vector.VectorSimilarity
    public double getMinValue() {
        return -1.0d;
    }

    @Override // org.wikibrain.sr.vector.VectorSimilarity
    public double getMaxValue() {
        return 1.0d;
    }

    private double norm(TIntFloatMap tIntFloatMap) {
        double d = 0.0d;
        for (float f : tIntFloatMap.values()) {
            d += f * f;
        }
        return Math.sqrt(d);
    }
}
