package org.wikibrain.sr.vector;

import com.typesafe.config.Config;
import gnu.trove.procedure.TIntProcedure;
import gnu.trove.set.TIntSet;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.dao.LocalPageDao;
import org.wikibrain.core.lang.Language;
import org.wikibrain.matrix.DenseMatrix;
import org.wikibrain.matrix.DenseMatrixRow;
import org.wikibrain.matrix.knn.KNNFinder;
import org.wikibrain.matrix.knn.Neighborhood;
import org.wikibrain.matrix.knn.RandomProjectionKNNFinder;
import org.wikibrain.sr.BaseSRMetric;
import org.wikibrain.sr.SRMetric;
import org.wikibrain.sr.SRResult;
import org.wikibrain.sr.SRResultList;
import org.wikibrain.sr.dataset.Dataset;
import org.wikibrain.sr.disambig.Disambiguator;
import org.wikibrain.sr.ensemble.EnsembleMetric;
import org.wikibrain.sr.utils.Leaderboard;
import org.wikibrain.sr.utils.SimUtils;

/* loaded from: input_file:org/wikibrain/sr/vector/DenseVectorSRMetric.class */
public class DenseVectorSRMetric extends BaseSRMetric {
    private static final Logger LOG = LoggerFactory.getLogger(DenseVectorSRMetric.class);
    protected final DenseVectorGenerator generator;
    protected final BaseSRMetric.SRConfig config;
    private DenseMatrix articleFeatures;
    private KNNFinder accelerator;
    private double acceleratorMultiplier;
    private int minAcceleratorCandidates;

    /* loaded from: input_file:org/wikibrain/sr/vector/DenseVectorSRMetric$Provider.class */
    public static class Provider extends org.wikibrain.conf.Provider<SRMetric> {
        public Provider(Configurator configurator, Configuration configuration) throws ConfigurationException {
            super(configurator, configuration);
        }

        public Class getType() {
            return SRMetric.class;
        }

        public String getPath() {
            return "sr.metric.local";
        }

        public SRMetric get(String str, Config config, Map<String, String> map) throws ConfigurationException {
            if (!config.getString("type").equals("densevector")) {
                return null;
            }
            if (map == null || !map.containsKey("language")) {
                throw new IllegalArgumentException("Monolingual requires 'language' runtime parameter.");
            }
            Language byLangCode = Language.getByLangCode(map.get("language"));
            HashMap hashMap = new HashMap();
            hashMap.put("language", byLangCode.getLangCode());
            DenseVectorSRMetric denseVectorSRMetric = new DenseVectorSRMetric(str, byLangCode, (LocalPageDao) getConfigurator().get(LocalPageDao.class, config.getString("pageDao")), (Disambiguator) getConfigurator().get(Disambiguator.class, config.getString("disambiguator"), "language", byLangCode.getLangCode()), (DenseVectorGenerator) getConfigurator().construct(DenseVectorGenerator.class, (String) null, config.getConfig("generator"), hashMap));
            DenseVectorSRMetric.configureBase(getConfigurator(), denseVectorSRMetric, config);
            return denseVectorSRMetric;
        }

        /* renamed from: get, reason: collision with other method in class */
        public /* bridge */ /* synthetic */ Object m52get(String str, Config config, Map map) throws ConfigurationException {
            return get(str, config, (Map<String, String>) map);
        }
    }

    public DenseVectorSRMetric(String str, Language language, LocalPageDao localPageDao, Disambiguator disambiguator, DenseVectorGenerator denseVectorGenerator) {
        super(str, language, localPageDao, disambiguator);
        this.acceleratorMultiplier = 100.0d;
        this.minAcceleratorCandidates = EnsembleMetric.MIN_SEARCH_DEPTH;
        this.generator = denseVectorGenerator;
        this.articleFeatures = denseVectorGenerator.getFeatureMatrix();
        if (this.articleFeatures == null) {
            throw new IllegalArgumentException();
        }
        this.config = new BaseSRMetric.SRConfig();
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public SRResult similarity(String str, String str2, boolean z) throws DaoException {
        float[] fArr = null;
        float[] fArr2 = null;
        try {
            fArr = this.generator.getVector(str);
            fArr2 = this.generator.getVector(str2);
        } catch (UnsupportedOperationException e) {
        }
        if (fArr == null || fArr2 == null) {
            return super.similarity(str, str2, z);
        }
        SRResult sRResult = new SRResult(SimUtils.cosineSimilarity(fArr, fArr2));
        if (z) {
            sRResult.setExplanations(this.generator.getExplanations(str, str2, fArr, fArr2, sRResult));
        }
        return normalize(sRResult);
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public SRResult similarity(int i, int i2, boolean z) throws DaoException {
        try {
            float[] pageVector = getPageVector(i);
            float[] pageVector2 = getPageVector(i2);
            SRResult sRResult = new SRResult(normalize(SimUtils.cosineSimilarity(pageVector, pageVector2)));
            if (z) {
                sRResult.setExplanations(this.generator.getExplanations(i, i2, pageVector, pageVector2, sRResult));
            }
            return sRResult;
        } catch (IOException e) {
            throw new DaoException(e);
        }
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public SRResultList mostSimilar(String str, int i, TIntSet tIntSet) throws DaoException {
        try {
            return mostSimilar(this.generator.getVector(str), i, tIntSet);
        } catch (IOException e) {
            throw new DaoException(e);
        } catch (UnsupportedOperationException e2) {
            return super.mostSimilar(str, i, tIntSet);
        }
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public SRResultList mostSimilar(int i, int i2, TIntSet tIntSet) throws DaoException {
        try {
            return mostSimilar(getPageVector(i), i2, tIntSet);
        } catch (IOException e) {
            throw new DaoException(e);
        }
    }

    public SRResultList mostSimilar(final float[] fArr, int i, TIntSet tIntSet) throws IOException {
        SRResultList top;
        if (fArr == null) {
            return new SRResultList(0);
        }
        if (this.accelerator == null) {
            final Leaderboard leaderboard = new Leaderboard(i);
            if (tIntSet == null) {
                Iterator it = this.articleFeatures.iterator();
                while (it.hasNext()) {
                    DenseMatrixRow denseMatrixRow = (DenseMatrixRow) it.next();
                    leaderboard.tallyScore(denseMatrixRow.getRowIndex(), SimUtils.cosineSimilarity(denseMatrixRow.getValues(), fArr));
                }
            } else {
                tIntSet.forEach(new TIntProcedure() { // from class: org.wikibrain.sr.vector.DenseVectorSRMetric.1
                    public boolean execute(int i2) {
                        try {
                            float[] pageVector = DenseVectorSRMetric.this.getPageVector(i2);
                            if (pageVector != null) {
                                leaderboard.tallyScore(i2, SimUtils.cosineSimilarity(pageVector, fArr));
                            }
                            return true;
                        } catch (Exception e) {
                            DenseVectorSRMetric.LOG.warn("similarity for " + i2 + " failed: ", e);
                            return true;
                        }
                    }
                });
            }
            top = leaderboard.getTop();
        } else {
            if (tIntSet != null) {
                throw new UnsupportedOperationException();
            }
            Neighborhood query = this.accelerator.query(fArr, i, (int) Math.max(this.minAcceleratorCandidates, i * this.acceleratorMultiplier), tIntSet);
            top = new SRResultList(query.size());
            for (int i2 = 0; i2 < query.size(); i2++) {
                top.set(i2, query.getId(i2), query.getScore(i2));
            }
        }
        return normalize(top);
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public void trainSimilarity(Dataset dataset) throws DaoException {
        super.trainSimilarity(dataset);
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public void trainMostSimilar(Dataset dataset, int i, TIntSet tIntSet) {
        try {
            RandomProjectionKNNFinder randomProjectionKNNFinder = new RandomProjectionKNNFinder(this.articleFeatures);
            randomProjectionKNNFinder.build();
            this.accelerator = randomProjectionKNNFinder;
            super.trainMostSimilar(dataset, i, tIntSet);
        } catch (IOException e) {
            throw new IllegalStateException("Unexpected exception: " + e);
        }
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public double[][] cosimilarity(int[] iArr) throws DaoException {
        return cosimilarity(iArr, iArr);
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public double[][] cosimilarity(String[] strArr) throws DaoException {
        return cosimilarity(strArr, strArr);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public double[][] cosimilarity(String[] strArr, String[] strArr2) throws DaoException {
        if (strArr.length == 0 || strArr2.length == 0) {
            return new double[strArr.length][strArr2.length];
        }
        float[] fArr = new float[strArr.length];
        float[] fArr2 = new float[strArr2.length];
        for (int i = 0; i < strArr.length; i++) {
            try {
                fArr[i] = this.generator.getVector(strArr[i]);
            } catch (UnsupportedOperationException e) {
                return super.cosimilarity(strArr, strArr2);
            }
        }
        for (int i2 = 0; i2 < strArr2.length; i2++) {
            fArr2[i2] = this.generator.getVector(strArr2[i2]);
        }
        double[][] dArr = new double[fArr.length][fArr2.length];
        for (int i3 = 0; i3 < fArr.length; i3++) {
            for (int i4 = 0; i4 < fArr2.length; i4++) {
                dArr[i3][i4] = normalize(SimUtils.cosineSimilarity(fArr[i3], fArr2[i4]));
            }
        }
        return dArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public double[][] cosimilarity(int[] iArr, int[] iArr2) throws DaoException {
        try {
            if (iArr.length == 0 || iArr2.length == 0) {
                return new double[iArr.length][iArr2.length];
            }
            float[] fArr = new float[iArr.length];
            float[] fArr2 = new float[iArr2.length];
            for (int i = 0; i < iArr.length; i++) {
                fArr[i] = getPageVector(iArr[i]);
            }
            for (int i2 = 0; i2 < iArr2.length; i2++) {
                fArr2[i2] = getPageVector(iArr2[i2]);
            }
            double[][] dArr = new double[fArr.length][fArr2.length];
            for (int i3 = 0; i3 < fArr.length; i3++) {
                for (int i4 = 0; i4 < fArr2.length; i4++) {
                    dArr[i3][i4] = normalize(SimUtils.cosineSimilarity(fArr[i3], fArr2[i4]));
                }
            }
            return dArr;
        } catch (IOException e) {
            throw new DaoException(e);
        }
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public void read() throws IOException {
        super.read();
        RandomProjectionKNNFinder randomProjectionKNNFinder = new RandomProjectionKNNFinder(this.articleFeatures);
        if (randomProjectionKNNFinder.load(new File(getDataDir(), "knn.bin"))) {
            this.accelerator = randomProjectionKNNFinder;
        }
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public void write() throws IOException {
        super.write();
        this.accelerator.save(new File(getDataDir(), "knn.bin"));
    }

    public float[] getPageVector(int i) throws IOException {
        if (this.articleFeatures == null) {
            try {
                return this.generator.getVector(i);
            } catch (DaoException e) {
                throw new IOException((Throwable) e);
            }
        }
        DenseMatrixRow row = this.articleFeatures.getRow(i);
        if (row == null) {
            return null;
        }
        return row.getValues();
    }

    public DenseVectorGenerator getGenerator() {
        return this.generator;
    }

    public void setAccelerator(KNNFinder kNNFinder) {
        this.accelerator = kNNFinder;
    }

    public void setAcceleratorMultiplier(double d) {
        this.acceleratorMultiplier = d;
    }

    public void setMinAcceleratorCandidates(int i) {
        this.minAcceleratorCandidates = i;
    }

    @Override // org.wikibrain.sr.BaseSRMetric
    public BaseSRMetric.SRConfig getConfig() {
        return this.config;
    }
}
