package org.wikibrain.sr.ensemble;

import com.typesafe.config.Config;
import gnu.trove.set.TIntSet;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.dao.DaoFilter;
import org.wikibrain.core.dao.LocalPageDao;
import org.wikibrain.core.lang.Language;
import org.wikibrain.core.lang.LocalId;
import org.wikibrain.core.lang.LocalString;
import org.wikibrain.sr.BaseSRMetric;
import org.wikibrain.sr.SRMetric;
import org.wikibrain.sr.SRResult;
import org.wikibrain.sr.SRResultList;
import org.wikibrain.sr.dataset.Dataset;
import org.wikibrain.sr.disambig.Disambiguator;
import org.wikibrain.sr.utils.KnownSim;
import org.wikibrain.utils.Function;
import org.wikibrain.utils.ParallelForEach;
import org.wikibrain.utils.Procedure;

/* loaded from: input_file:org/wikibrain/sr/ensemble/EnsembleMetric.class */
public class EnsembleMetric extends BaseSRMetric {
    private static final Logger LOG = LoggerFactory.getLogger(EnsembleMetric.class);
    public static final int MIN_SEARCH_DEPTH = 500;
    public static final int SEARCH_MULTIPLIER = 3;
    private List<SRMetric> metrics;
    private Ensemble ensemble;
    private boolean resolvePhrases;
    private boolean trainSubmetrics;

    /* loaded from: input_file:org/wikibrain/sr/ensemble/EnsembleMetric$Provider.class */
    public static class Provider extends org.wikibrain.conf.Provider<SRMetric> {
        public Provider(Configurator configurator, Configuration configuration) throws ConfigurationException {
            super(configurator, configuration);
        }

        public Class getType() {
            return SRMetric.class;
        }

        public String getPath() {
            return "sr.metric.local";
        }

        public SRMetric get(String str, Config config, Map<String, String> map) throws ConfigurationException {
            Ensemble evenEnsemble;
            if (!config.getString("type").equals("ensemble")) {
                return null;
            }
            if (map == null || !map.containsKey("language")) {
                throw new IllegalArgumentException("Monolingual SR Metric requires 'language' runtime parameter");
            }
            Language byLangCode = Language.getByLangCode(map.get("language"));
            if (!config.hasPath("metrics")) {
                throw new ConfigurationException("Ensemble metric has no base metrics to use.");
            }
            ArrayList arrayList = new ArrayList();
            Iterator it = config.getStringList("metrics").iterator();
            while (it.hasNext()) {
                arrayList.add(getConfigurator().get(SRMetric.class, (String) it.next(), "language", byLangCode.getLangCode()));
            }
            LocalPageDao localPageDao = (LocalPageDao) getConfigurator().get(LocalPageDao.class, config.getString("pageDao"));
            try {
                int count = localPageDao.getCount(DaoFilter.normalPageFilter(byLangCode));
                if (config.getString("ensemble").equals("linear")) {
                    evenEnsemble = new CorrelationEnsemble(arrayList.size(), count);
                } else {
                    if (!config.getString("ensemble").equals("even")) {
                        throw new ConfigurationException("I don't know how to do that ensemble.");
                    }
                    evenEnsemble = new EvenEnsemble();
                }
                EnsembleMetric ensembleMetric = new EnsembleMetric(str, byLangCode, arrayList, evenEnsemble, (Disambiguator) getConfigurator().get(Disambiguator.class, config.getString("disambiguator"), "language", byLangCode.getLangCode()), localPageDao);
                if (config.hasPath("resolvephrases")) {
                    ensembleMetric.setResolvePhrases(config.getBoolean("resolvephrases"));
                }
                EnsembleMetric.configureBase(getConfigurator(), ensembleMetric, config);
                return ensembleMetric;
            } catch (DaoException e) {
                throw new ConfigurationException(e);
            }
        }

        /* renamed from: get, reason: collision with other method in class */
        public /* bridge */ /* synthetic */ Object m18get(String str, Config config, Map map) throws ConfigurationException {
            return get(str, config, (Map<String, String>) map);
        }
    }

    public EnsembleMetric(String str, Language language, List<SRMetric> list, Ensemble ensemble, Disambiguator disambiguator, LocalPageDao localPageDao) {
        super(str, language, localPageDao, disambiguator);
        this.resolvePhrases = true;
        this.trainSubmetrics = true;
        this.metrics = list;
        this.ensemble = ensemble;
    }

    public List<SRMetric> getMetrics() {
        return this.metrics;
    }

    public void setResolvePhrases(boolean z) {
        this.resolvePhrases = z;
    }

    @Override // org.wikibrain.sr.BaseSRMetric
    public BaseSRMetric.SRConfig getConfig() {
        return new BaseSRMetric.SRConfig();
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public SRResult similarity(int i, int i2, boolean z) throws DaoException {
        ArrayList arrayList = new ArrayList();
        Iterator<SRMetric> it = this.metrics.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().similarity(i, i2, z));
        }
        return normalize(this.ensemble.predictSimilarity(arrayList));
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public SRResult similarity(String str, String str2, boolean z) throws DaoException {
        if (this.resolvePhrases) {
            return super.similarity(str, str2, z);
        }
        ArrayList arrayList = new ArrayList();
        Iterator<SRMetric> it = this.metrics.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().similarity(str, str2, z));
        }
        return normalize(this.ensemble.predictSimilarity(arrayList));
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public SRResultList mostSimilar(int i, int i2, TIntSet tIntSet) throws DaoException {
        SRResultList cachedMostSimilar = getCachedMostSimilar(i, i2, tIntSet);
        if (cachedMostSimilar != null) {
            return cachedMostSimilar;
        }
        ArrayList arrayList = new ArrayList();
        Iterator<SRMetric> it = this.metrics.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().mostSimilar(i, getMaxResults(i2), tIntSet));
        }
        return normalize(this.ensemble.predictMostSimilar(arrayList, i2, tIntSet));
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public SRResultList mostSimilar(String str, int i, TIntSet tIntSet) throws DaoException {
        if (this.resolvePhrases) {
            return super.mostSimilar(str, i, tIntSet);
        }
        ArrayList arrayList = new ArrayList();
        Iterator<SRMetric> it = this.metrics.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().mostSimilar(str, getMaxResults(i), tIntSet));
        }
        return normalize(this.ensemble.predictMostSimilar(arrayList, i, tIntSet));
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public void trainSimilarity(Dataset dataset) throws DaoException {
        if (this.trainSubmetrics) {
            Iterator<SRMetric> it = this.metrics.iterator();
            while (it.hasNext()) {
                it.next().trainSimilarity(dataset);
            }
        }
        final ArrayList arrayList = new ArrayList();
        ParallelForEach.loop(dataset.getData(), new Procedure<KnownSim>() { // from class: org.wikibrain.sr.ensemble.EnsembleMetric.1
            public void call(KnownSim knownSim) throws Exception {
                EnsembleSim ensembleSim = new EnsembleSim(knownSim);
                for (SRMetric sRMetric : EnsembleMetric.this.metrics) {
                    try {
                        SRResult similarity = sRMetric.similarity(knownSim.phrase1, knownSim.phrase2, false);
                        r10 = similarity != null ? similarity.getScore() : Double.NaN;
                    } catch (Exception e) {
                        EnsembleMetric.LOG.warn("Local sr metric " + sRMetric.getName() + " failed for " + knownSim, e);
                    }
                    ensembleSim.add(r10, 0);
                }
                arrayList.add(ensembleSim);
            }
        }, 100);
        this.ensemble.trainSimilarity(arrayList);
        super.trainSimilarity(dataset);
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public void trainMostSimilar(Dataset dataset, final int i, final TIntSet tIntSet) {
        if (getMostSimilarCache() != null) {
            clearMostSimilarCache();
        }
        if (this.trainSubmetrics) {
            Iterator<SRMetric> it = this.metrics.iterator();
            while (it.hasNext()) {
                it.next().trainMostSimilar(dataset, i, tIntSet);
            }
        }
        this.ensemble.trainMostSimilar(ParallelForEach.loop(dataset.getData(), new Function<KnownSim, EnsembleSim>() { // from class: org.wikibrain.sr.ensemble.EnsembleMetric.2
            public EnsembleSim call(KnownSim knownSim) throws DaoException {
                List<LocalId> disambiguateTop = EnsembleMetric.this.getDisambiguator().disambiguateTop(Arrays.asList(new LocalString(knownSim.language, knownSim.phrase1), new LocalString(knownSim.language, knownSim.phrase2)), (Set<LocalString>) null);
                if (disambiguateTop.isEmpty() || disambiguateTop.get(0).getId() <= 0) {
                    return null;
                }
                int id = disambiguateTop.get(0).getId();
                EnsembleSim ensembleSim = new EnsembleSim(knownSim);
                for (SRMetric sRMetric : EnsembleMetric.this.metrics) {
                    double d = Double.NaN;
                    int i2 = -1;
                    try {
                        try {
                            SRResultList mostSimilar = sRMetric.mostSimilar(id, EnsembleMetric.this.getMaxResults(i), tIntSet);
                            if (mostSimilar != null && mostSimilar.getIndexForId(disambiguateTop.get(1).getId()) >= 0) {
                                d = mostSimilar.getScore(mostSimilar.getIndexForId(disambiguateTop.get(1).getId()));
                                i2 = mostSimilar.getIndexForId(disambiguateTop.get(1).getId());
                            }
                            ensembleSim.add(d, i2);
                        } catch (Exception e) {
                            EnsembleMetric.LOG.warn("Local sr metric " + sRMetric.getName() + " failed for " + id, e);
                            ensembleSim.add(Double.NaN, -1);
                        }
                    } catch (Throwable th) {
                        ensembleSim.add(Double.NaN, -1);
                        throw th;
                    }
                }
                return ensembleSim;
            }
        }, 100));
        super.trainMostSimilar(dataset, i, tIntSet);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public int getMaxResults(int i) {
        return Math.max(MIN_SEARCH_DEPTH, i * 3);
    }

    public void setTrainSubmetrics(boolean z) {
        this.trainSubmetrics = z;
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public void write() throws IOException {
        super.write();
        this.ensemble.write(new File(getDataDir(), "ensemble").getAbsolutePath());
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public void read() throws IOException {
        super.read();
        this.ensemble.read(new File(getDataDir(), "ensemble").getAbsolutePath());
    }
}
