package org.wikibrain.sr.ensemble;

import com.typesafe.config.Config;
import gnu.trove.map.TIntFloatMap;
import gnu.trove.map.hash.TIntDoubleHashMap;
import gnu.trove.procedure.TIntDoubleProcedure;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.lang.Language;
import org.wikibrain.sr.SRMetric;
import org.wikibrain.sr.SRResult;
import org.wikibrain.sr.SRResultList;
import org.wikibrain.sr.dataset.Dataset;
import org.wikibrain.sr.normalize.Normalizer;
import org.wikibrain.sr.utils.Leaderboard;

/* loaded from: input_file:org/wikibrain/sr/ensemble/SimpleEnsembleMetric.class */
public class SimpleEnsembleMetric implements SRMetric {
    private static final Logger LOG = LoggerFactory.getLogger(SimpleEnsembleMetric.class);
    private final String name;
    private final Language language;
    private SubMetric[] metrics;
    private boolean trainSubmetrics = true;
    private double numCandidateMultiplier = 2.0d;

    /* loaded from: input_file:org/wikibrain/sr/ensemble/SimpleEnsembleMetric$Provider.class */
    public static class Provider extends org.wikibrain.conf.Provider<SRMetric> {
        public Provider(Configurator configurator, Configuration configuration) throws ConfigurationException {
            super(configurator, configuration);
        }

        public Class getType() {
            return SRMetric.class;
        }

        public String getPath() {
            return "sr.metric.local";
        }

        public SRMetric get(String str, Config config, Map<String, String> map) throws ConfigurationException {
            if (!config.getString("type").equals("simple-ensemble")) {
                return null;
            }
            if (map == null || !map.containsKey("language")) {
                throw new IllegalArgumentException("Monolingual SR Metric requires 'language' runtime parameter");
            }
            Language byLangCode = Language.getByLangCode(map.get("language"));
            if (!config.hasPath("metrics")) {
                throw new ConfigurationException("Ensemble metric has no base metrics to use.");
            }
            List stringList = config.getStringList("metrics");
            List doubleList = config.getDoubleList("coefficients");
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            for (int i = 0; i < stringList.size(); i++) {
                try {
                    arrayList.add(getConfigurator().get(SRMetric.class, (String) stringList.get(i), "language", byLangCode.getLangCode()));
                    arrayList2.add(doubleList.get(i));
                } catch (Exception e) {
                    SimpleEnsembleMetric.LOG.error("Loading of metric " + ((String) stringList.get(i)) + " failed. Skipping it! Error:", e);
                }
            }
            return new SimpleEnsembleMetric(str, byLangCode, arrayList, arrayList2);
        }

        /* renamed from: get, reason: collision with other method in class */
        public /* bridge */ /* synthetic */ Object m24get(String str, Config config, Map map) throws ConfigurationException {
            return get(str, config, (Map<String, String>) map);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/wikibrain/sr/ensemble/SimpleEnsembleMetric$SubMetric.class */
    public class SubMetric {
        SRMetric metric;
        double coefficient;

        private SubMetric() {
        }
    }

    public SimpleEnsembleMetric(String str, Language language, List<SRMetric> list, List<Double> list2) {
        if (list.size() != list2.size()) {
            throw new IllegalArgumentException();
        }
        if (list.isEmpty()) {
            throw new IllegalArgumentException("Must supply at least one metric to the simple ensemble.");
        }
        this.metrics = new SubMetric[list.size()];
        for (int i = 0; i < list.size(); i++) {
            this.metrics[i] = new SubMetric();
            this.metrics[i].metric = list.get(i);
            this.metrics[i].coefficient = list2.get(i).doubleValue();
        }
        this.name = str;
        this.language = language;
    }

    @Override // org.wikibrain.sr.SRMetric
    public String getName() {
        return this.name;
    }

    @Override // org.wikibrain.sr.SRMetric
    public Language getLanguage() {
        return this.language;
    }

    @Override // org.wikibrain.sr.SRMetric
    public SRResult similarity(int i, int i2, boolean z) throws DaoException {
        double d = 0.0d;
        double d2 = 0.0d;
        for (SubMetric subMetric : this.metrics) {
            SRResult similarity = subMetric.metric.similarity(i, i2, false);
            if (similarity != null && similarity.isValid()) {
                d += subMetric.coefficient * similarity.getScore();
                d2 += subMetric.coefficient;
            }
        }
        return new SRResult(d2 > 0.0d ? d / d2 : Double.NaN);
    }

    @Override // org.wikibrain.sr.SRMetric
    public SRResult similarity(String str, String str2, boolean z) throws DaoException {
        double d = 0.0d;
        double d2 = 0.0d;
        for (SubMetric subMetric : this.metrics) {
            SRResult similarity = subMetric.metric.similarity(str, str2, false);
            if (similarity != null && similarity.isValid()) {
                d += subMetric.coefficient * similarity.getScore();
                d2 += subMetric.coefficient;
            }
        }
        return new SRResult(d2 > 0.0d ? d / d2 : Double.NaN);
    }

    @Override // org.wikibrain.sr.SRMetric
    public SRResultList mostSimilar(int i, int i2) throws DaoException {
        return mostSimilar(i, i2, (TIntSet) null);
    }

    @Override // org.wikibrain.sr.SRMetric
    public SRResultList mostSimilar(int i, int i2, TIntSet tIntSet) throws DaoException {
        TIntHashSet tIntHashSet = new TIntHashSet();
        for (SubMetric subMetric : this.metrics) {
            SRResultList mostSimilar = subMetric.metric.mostSimilar(i, (int) (i2 * this.numCandidateMultiplier), tIntSet);
            if (mostSimilar != null) {
                Iterator<SRResult> it = mostSimilar.iterator();
                while (it.hasNext()) {
                    tIntHashSet.add(it.next().getId());
                }
            }
        }
        int[] array = tIntHashSet.toArray();
        double[][] cosimilarity = cosimilarity(new int[]{i}, array);
        Leaderboard leaderboard = new Leaderboard(i2);
        for (int i3 = 0; i3 < array.length; i3++) {
            leaderboard.tallyScore(array[i3], cosimilarity[0][i3]);
        }
        return leaderboard.getTop();
    }

    @Override // org.wikibrain.sr.SRMetric
    public SRResultList mostSimilar(String str, int i) throws DaoException {
        return mostSimilar(str, i, (TIntSet) null);
    }

    @Override // org.wikibrain.sr.SRMetric
    public SRResultList mostSimilar(String str, int i, TIntSet tIntSet) throws DaoException {
        TIntSet tIntHashSet = new TIntHashSet();
        for (SubMetric subMetric : this.metrics) {
            SRResultList mostSimilar = subMetric.metric.mostSimilar(str, (int) (i * this.numCandidateMultiplier), tIntSet);
            if (mostSimilar != null) {
                Iterator<SRResult> it = mostSimilar.iterator();
                while (it.hasNext()) {
                    tIntHashSet.add(it.next().getId());
                }
            }
        }
        TIntDoubleHashMap tIntDoubleHashMap = new TIntDoubleHashMap();
        for (SubMetric subMetric2 : this.metrics) {
            SRResultList mostSimilar2 = subMetric2.metric.mostSimilar(str, (int) Math.ceil(tIntHashSet.size() * 0.8d), tIntHashSet);
            if (mostSimilar2 != null && mostSimilar2.numDocs() > 0) {
                TIntFloatMap asTroveMap = mostSimilar2.asTroveMap();
                double score = mostSimilar2.getScore(mostSimilar2.numDocs() - 1) * 0.99d;
                for (int i2 : asTroveMap.keys()) {
                    double d = score;
                    if (asTroveMap.containsKey(i2)) {
                        d = asTroveMap.get(i2);
                        if (Double.isInfinite(d) || Double.isNaN(d)) {
                            d = score;
                        }
                    }
                    double d2 = d * subMetric2.coefficient;
                    tIntDoubleHashMap.adjustOrPutValue(i2, d2, d2);
                }
            }
        }
        final Leaderboard leaderboard = new Leaderboard(i);
        tIntDoubleHashMap.forEachEntry(new TIntDoubleProcedure() { // from class: org.wikibrain.sr.ensemble.SimpleEnsembleMetric.1
            public boolean execute(int i3, double d3) {
                leaderboard.tallyScore(i3, d3);
                return true;
            }
        });
        return leaderboard.getTop();
    }

    @Override // org.wikibrain.sr.SRMetric
    public double[][] cosimilarity(int[] iArr, int[] iArr2) throws DaoException {
        double[][] dArr = new double[iArr.length][iArr2.length];
        for (SubMetric subMetric : this.metrics) {
            double[][] cosimilarity = subMetric.metric.cosimilarity(iArr, iArr2);
            for (int i = 0; i < iArr.length; i++) {
                for (int i2 = 0; i2 < iArr2.length; i2++) {
                    double d = cosimilarity[i][i2];
                    if (!Double.isNaN(d) && !Double.isInfinite(d)) {
                        double[] dArr2 = dArr[i];
                        int i3 = i2;
                        dArr2[i3] = dArr2[i3] + (d * subMetric.coefficient);
                    }
                }
            }
        }
        return dArr;
    }

    @Override // org.wikibrain.sr.SRMetric
    public double[][] cosimilarity(String[] strArr, String[] strArr2) throws DaoException {
        double[][] dArr = new double[strArr.length][strArr2.length];
        for (SubMetric subMetric : this.metrics) {
            double[][] cosimilarity = subMetric.metric.cosimilarity(strArr, strArr2);
            for (int i = 0; i < strArr.length; i++) {
                for (int i2 = 0; i2 < strArr2.length; i2++) {
                    double d = cosimilarity[i][i2];
                    if (!Double.isNaN(d) && !Double.isInfinite(d)) {
                        double[] dArr2 = dArr[i];
                        int i3 = i2;
                        dArr2[i3] = dArr2[i3] + (d * subMetric.coefficient);
                    }
                }
            }
        }
        return dArr;
    }

    @Override // org.wikibrain.sr.SRMetric
    public double[][] cosimilarity(int[] iArr) throws DaoException {
        return cosimilarity(iArr, iArr);
    }

    @Override // org.wikibrain.sr.SRMetric
    public double[][] cosimilarity(String[] strArr) throws DaoException {
        return cosimilarity(strArr, strArr);
    }

    public void setTrainSubmetrics(boolean z) {
        this.trainSubmetrics = z;
    }

    @Override // org.wikibrain.sr.SRMetric
    public Normalizer getMostSimilarNormalizer() {
        return null;
    }

    @Override // org.wikibrain.sr.SRMetric
    public void setMostSimilarNormalizer(Normalizer normalizer) {
    }

    @Override // org.wikibrain.sr.SRMetric
    public Normalizer getSimilarityNormalizer() {
        return null;
    }

    @Override // org.wikibrain.sr.SRMetric
    public void setSimilarityNormalizer(Normalizer normalizer) {
    }

    @Override // org.wikibrain.sr.SRMetric
    public File getDataDir() {
        return null;
    }

    @Override // org.wikibrain.sr.SRMetric
    public void setDataDir(File file) {
    }

    @Override // org.wikibrain.sr.SRMetric
    public void write() throws IOException {
    }

    @Override // org.wikibrain.sr.SRMetric
    public void read() throws IOException {
    }

    @Override // org.wikibrain.sr.SRMetric
    public void trainSimilarity(Dataset dataset) throws DaoException {
        if (this.trainSubmetrics) {
            for (SubMetric subMetric : this.metrics) {
                subMetric.metric.trainSimilarity(dataset);
            }
        }
    }

    @Override // org.wikibrain.sr.SRMetric
    public void trainMostSimilar(Dataset dataset, int i, TIntSet tIntSet) {
        if (this.trainSubmetrics) {
            for (SubMetric subMetric : this.metrics) {
                subMetric.metric.trainMostSimilar(dataset, i, tIntSet);
            }
        }
    }

    @Override // org.wikibrain.sr.SRMetric
    public boolean similarityIsTrained() {
        if (!this.trainSubmetrics) {
            return true;
        }
        for (SubMetric subMetric : this.metrics) {
            if (!subMetric.metric.similarityIsTrained()) {
                return false;
            }
        }
        return true;
    }

    @Override // org.wikibrain.sr.SRMetric
    public boolean mostSimilarIsTrained() {
        if (!this.trainSubmetrics) {
            return true;
        }
        for (SubMetric subMetric : this.metrics) {
            if (!subMetric.metric.mostSimilarIsTrained()) {
                return false;
            }
        }
        return true;
    }

    public void setNumCandidateMultiplier(double d) {
        this.numCandidateMultiplier = d;
    }
}
