package org.wikibrain.sr.milnewitten;

import com.typesafe.config.Config;
import gnu.trove.map.hash.TIntDoubleHashMap;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Map;
import java.util.logging.Logger;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.dao.LocalPageDao;
import org.wikibrain.core.lang.Language;
import org.wikibrain.sr.BaseSRMetric;
import org.wikibrain.sr.SRMetric;
import org.wikibrain.sr.SRResult;
import org.wikibrain.sr.SRResultList;
import org.wikibrain.sr.dataset.Dataset;
import org.wikibrain.sr.disambig.Disambiguator;
import org.wikibrain.sr.utils.Leaderboard;

/* loaded from: input_file:org/wikibrain/sr/milnewitten/MilneWittenMetric.class */
public class MilneWittenMetric extends BaseSRMetric {
    private static final Logger LOG = Logger.getLogger(MilneWittenMetric.class.getName());
    private final SRMetric inlink;
    private final SRMetric outlink;
    private boolean trainSubmetrics;

    /* loaded from: input_file:org/wikibrain/sr/milnewitten/MilneWittenMetric$Provider.class */
    public static class Provider extends org.wikibrain.conf.Provider<SRMetric> {
        public Provider(Configurator configurator, Configuration configuration) throws ConfigurationException {
            super(configurator, configuration);
        }

        public Class getType() {
            return SRMetric.class;
        }

        public String getPath() {
            return "sr.metric.local";
        }

        public SRMetric get(String str, Config config, Map<String, String> map) throws ConfigurationException {
            if (!config.getString("type").equals("milnewitten")) {
                return null;
            }
            if (map == null || !map.containsKey("language")) {
                throw new IllegalArgumentException("Monolingual SR Metric requires 'language' runtime parameter");
            }
            Language byLangCode = Language.getByLangCode(map.get("language"));
            MilneWittenMetric milneWittenMetric = new MilneWittenMetric(str, byLangCode, (LocalPageDao) getConfigurator().get(LocalPageDao.class), (SRMetric) getConfigurator().get(SRMetric.class, config.getString("inlink"), "language", byLangCode.getLangCode()), (SRMetric) getConfigurator().get(SRMetric.class, config.getString("outlink"), "language", byLangCode.getLangCode()), (Disambiguator) getConfigurator().get(Disambiguator.class, config.getString("disambiguator"), "language", byLangCode.getLangCode()));
            MilneWittenMetric.configureBase(getConfigurator(), milneWittenMetric, config);
            return milneWittenMetric;
        }

        /* renamed from: get, reason: collision with other method in class */
        public /* bridge */ /* synthetic */ Object m26get(String str, Config config, Map map) throws ConfigurationException {
            return get(str, config, (Map<String, String>) map);
        }
    }

    public MilneWittenMetric(String str, Language language, LocalPageDao localPageDao, SRMetric sRMetric, SRMetric sRMetric2, Disambiguator disambiguator) {
        super(str, language, localPageDao, disambiguator);
        this.trainSubmetrics = true;
        this.inlink = sRMetric;
        this.outlink = sRMetric2;
    }

    @Override // org.wikibrain.sr.BaseSRMetric
    public BaseSRMetric.SRConfig getConfig() {
        BaseSRMetric.SRConfig sRConfig = new BaseSRMetric.SRConfig();
        sRConfig.maxScore = 1.1f;
        sRConfig.minScore = 0.0f;
        return sRConfig;
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public SRResult similarity(int i, int i2, boolean z) throws DaoException {
        SRResult similarity = this.inlink.similarity(i, i2, z);
        SRResult similarity2 = this.outlink.similarity(i, i2, z);
        if (similarity == null || similarity2 == null || !similarity.isValid() || !similarity2.isValid()) {
            return new SRResult(Double.NaN);
        }
        SRResult sRResult = new SRResult((0.5d * similarity.getScore()) + (0.5d * similarity2.getScore()));
        if (z) {
            ArrayList arrayList = new ArrayList();
            arrayList.addAll(similarity.getExplanations());
            arrayList.addAll(similarity2.getExplanations());
            sRResult.setExplanations(arrayList);
        }
        return sRResult;
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public double[][] cosimilarity(int[] iArr, int[] iArr2) throws DaoException {
        double[][] cosimilarity = this.inlink.cosimilarity(iArr, iArr2);
        double[][] cosimilarity2 = this.outlink.cosimilarity(iArr, iArr2);
        for (int i = 0; i < iArr.length; i++) {
            for (int i2 = 0; i2 < iArr2.length; i2++) {
                double d = cosimilarity[i][i2];
                double d2 = cosimilarity2[i][i2];
                if (Double.isNaN(d) || Double.isNaN(d2) || Double.isInfinite(d) || Double.isInfinite(d2)) {
                    cosimilarity[i][i2] = Double.NaN;
                } else {
                    cosimilarity[i][i2] = (d * 0.5d) + (d2 * 0.5d);
                }
            }
        }
        return cosimilarity;
    }

    public void setTrainSubmetrics(boolean z) {
        this.trainSubmetrics = z;
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public synchronized void trainSimilarity(Dataset dataset) throws DaoException {
        if (this.trainSubmetrics) {
            this.inlink.trainSimilarity(dataset);
            this.outlink.trainSimilarity(dataset);
        }
        super.trainSimilarity(dataset);
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public synchronized void trainMostSimilar(Dataset dataset, int i, TIntSet tIntSet) {
        if (this.trainSubmetrics) {
            this.inlink.trainMostSimilar(dataset, i, tIntSet);
            this.outlink.trainMostSimilar(dataset, i, tIntSet);
        }
        super.trainMostSimilar(dataset, i, tIntSet);
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public void write() throws IOException {
        this.inlink.write();
        this.outlink.write();
        super.write();
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public void read() throws IOException {
        this.inlink.read();
        this.outlink.read();
        super.read();
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public SRResultList mostSimilar(int i, int i2, TIntSet tIntSet) throws DaoException {
        SRResultList mostSimilar = this.inlink.mostSimilar(i, i2 * 2, tIntSet);
        TIntDoubleHashMap tIntDoubleHashMap = new TIntDoubleHashMap(i2 * 4);
        TIntHashSet tIntHashSet = new TIntHashSet();
        if (mostSimilar != null) {
            for (int i3 = 0; i3 < mostSimilar.numDocs(); i3++) {
                double score = mostSimilar.getScore(i3);
                if (!Double.isInfinite(score) && !Double.isNaN(score)) {
                    tIntDoubleHashMap.adjustOrPutValue(mostSimilar.getId(i3), score, score);
                    tIntHashSet.add(mostSimilar.getId(i3));
                }
            }
        }
        SRResultList mostSimilar2 = this.outlink.mostSimilar(i, i2 * 2, tIntSet);
        TIntHashSet tIntHashSet2 = new TIntHashSet();
        if (mostSimilar2 != null) {
            for (int i4 = 0; i4 < mostSimilar2.numDocs(); i4++) {
                double score2 = mostSimilar2.getScore(i4);
                if (!Double.isInfinite(score2) && !Double.isNaN(score2)) {
                    tIntDoubleHashMap.adjustOrPutValue(mostSimilar2.getId(i4), score2, score2);
                    tIntHashSet2.add(mostSimilar2.getId(i4));
                }
            }
        }
        double missingScore = mostSimilar == null ? 0.0d : mostSimilar.getMissingScore();
        double missingScore2 = mostSimilar2 == null ? 0.0d : mostSimilar2.getMissingScore();
        for (int i5 : tIntHashSet.toArray()) {
            if (!tIntHashSet2.contains(i5)) {
                tIntDoubleHashMap.adjustValue(i5, missingScore2);
            }
        }
        for (int i6 : tIntHashSet2.toArray()) {
            if (!tIntHashSet.contains(i6)) {
                tIntDoubleHashMap.adjustValue(i6, missingScore);
            }
        }
        Leaderboard leaderboard = new Leaderboard(i2);
        for (int i7 : tIntDoubleHashMap.keys()) {
            leaderboard.tallyScore(i7, tIntDoubleHashMap.get(i7));
        }
        return normalize(leaderboard.getTop());
    }
}
