package org.wikibrain.sr.vector;

import com.typesafe.config.Config;
import gnu.trove.map.TIntFloatMap;
import gnu.trove.set.TIntSet;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.dao.LocalPageDao;
import org.wikibrain.core.lang.Language;
import org.wikibrain.sr.SRMetric;
import org.wikibrain.sr.SRResult;
import org.wikibrain.sr.SRResultList;
import org.wikibrain.sr.disambig.Disambiguator;

/* loaded from: input_file:org/wikibrain/sr/vector/FancyPhraseVectorBasedSRMetric.class */
public class FancyPhraseVectorBasedSRMetric extends VectorBasedSRMetric {
    private final PhraseVectorCreator phraseVectorCreator;
    private PhraseMode phraseMode;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/wikibrain/sr/vector/FancyPhraseVectorBasedSRMetric$PhraseMode.class */
    public enum PhraseMode {
        GENERATOR,
        CREATOR,
        BOTH,
        NONE
    }

    /* loaded from: input_file:org/wikibrain/sr/vector/FancyPhraseVectorBasedSRMetric$Provider.class */
    public static class Provider extends org.wikibrain.conf.Provider<SRMetric> {
        public Provider(Configurator configurator, Configuration configuration) throws ConfigurationException {
            super(configurator, configuration);
        }

        public Class getType() {
            return SRMetric.class;
        }

        public String getPath() {
            return "sr.metric.local";
        }

        public SRMetric get(String str, Config config, Map<String, String> map) throws ConfigurationException {
            if (!config.getString("type").equals("fancyphrasevector")) {
                return null;
            }
            if (map == null || !map.containsKey("language")) {
                throw new IllegalArgumentException("Monolingual requires 'language' runtime parameter.");
            }
            Language byLangCode = Language.getByLangCode(map.get("language"));
            HashMap hashMap = new HashMap();
            hashMap.put("language", byLangCode.getLangCode());
            FancyPhraseVectorBasedSRMetric fancyPhraseVectorBasedSRMetric = new FancyPhraseVectorBasedSRMetric(str, byLangCode, (LocalPageDao) getConfigurator().get(LocalPageDao.class, config.getString("pageDao")), (Disambiguator) getConfigurator().get(Disambiguator.class, config.getString("disambiguator"), "language", byLangCode.getLangCode()), (VectorGenerator) getConfigurator().construct(VectorGenerator.class, (String) null, config.getConfig("generator"), hashMap), (VectorSimilarity) getConfigurator().construct(VectorSimilarity.class, (String) null, config.getConfig("similarity"), hashMap), (PhraseVectorCreator) getConfigurator().construct(PhraseVectorCreator.class, (String) null, config.getConfig("phrases"), (Map) null));
            if (config.hasPath("phraseMode")) {
                fancyPhraseVectorBasedSRMetric.setPhraseMode(PhraseMode.valueOf(config.getString("phraseMode").toUpperCase()));
            }
            FancyPhraseVectorBasedSRMetric.configureBase(getConfigurator(), fancyPhraseVectorBasedSRMetric, config);
            return fancyPhraseVectorBasedSRMetric;
        }

        /* renamed from: get, reason: collision with other method in class */
        public /* bridge */ /* synthetic */ Object m44get(String str, Config config, Map map) throws ConfigurationException {
            return get(str, config, (Map<String, String>) map);
        }
    }

    public FancyPhraseVectorBasedSRMetric(String str, Language language, LocalPageDao localPageDao, Disambiguator disambiguator, VectorGenerator vectorGenerator, VectorSimilarity vectorSimilarity, PhraseVectorCreator phraseVectorCreator) {
        super(str, language, localPageDao, disambiguator, vectorGenerator, vectorSimilarity);
        this.phraseMode = PhraseMode.BOTH;
        this.phraseVectorCreator = phraseVectorCreator;
        phraseVectorCreator.setMetric(this);
    }

    @Override // org.wikibrain.sr.vector.VectorBasedSRMetric, org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public SRResult similarity(String str, String str2, boolean z) throws DaoException {
        if (this.phraseMode == PhraseMode.NONE) {
            return super.similarity(str, str2, z);
        }
        TIntFloatMap tIntFloatMap = null;
        TIntFloatMap tIntFloatMap2 = null;
        if (this.phraseMode == PhraseMode.BOTH || this.phraseMode == PhraseMode.GENERATOR) {
            try {
                tIntFloatMap = this.generator.getVector(str);
                tIntFloatMap2 = this.generator.getVector(str2);
            } catch (UnsupportedOperationException e) {
            }
        }
        if ((tIntFloatMap == null || tIntFloatMap2 == null) && (this.phraseMode == PhraseMode.BOTH || this.phraseMode == PhraseMode.CREATOR)) {
            if (this.phraseVectorCreator == null) {
                throw new IllegalStateException("phraseMode is " + this.phraseMode + " but phraseVectorCreator is null");
            }
            TIntFloatMap[] phraseVectors = this.phraseVectorCreator.getPhraseVectors(str, str2);
            if (phraseVectors != null) {
                tIntFloatMap = phraseVectors[0];
                tIntFloatMap2 = phraseVectors[1];
            }
        }
        if (tIntFloatMap == null || tIntFloatMap2 == null) {
            return super.similarity(str, str2, z);
        }
        SRResult sRResult = new SRResult(this.similarity.similarity(tIntFloatMap, tIntFloatMap2));
        if (z) {
            sRResult.setExplanations(this.generator.getExplanations(str, str2, tIntFloatMap, tIntFloatMap2, sRResult));
        }
        return normalize(sRResult);
    }

    @Override // org.wikibrain.sr.vector.VectorBasedSRMetric, org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public SRResultList mostSimilar(String str, int i, TIntSet tIntSet) throws DaoException {
        if (this.phraseMode == PhraseMode.NONE) {
            return super.mostSimilar(str, i, tIntSet);
        }
        TIntFloatMap tIntFloatMap = null;
        if (this.phraseMode == PhraseMode.BOTH || this.phraseMode == PhraseMode.GENERATOR) {
            try {
                tIntFloatMap = this.generator.getVector(str);
            } catch (UnsupportedOperationException e) {
            }
        }
        if (tIntFloatMap == null && (this.phraseMode == PhraseMode.BOTH || this.phraseMode == PhraseMode.CREATOR)) {
            if (this.phraseVectorCreator == null) {
                throw new IllegalStateException("phraseMode is " + this.phraseMode + " but phraseVectorCreator is null");
            }
            tIntFloatMap = this.phraseVectorCreator.getPhraseVector(str);
        }
        if (tIntFloatMap == null) {
            return super.mostSimilar(str, i, tIntSet);
        }
        try {
            return this.similarity.mostSimilar(tIntFloatMap, i, tIntSet);
        } catch (IOException e2) {
            throw new DaoException(e2);
        }
    }

    @Override // org.wikibrain.sr.vector.VectorBasedSRMetric, org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public double[][] cosimilarity(String[] strArr, String[] strArr2) throws DaoException {
        if (strArr.length == 0 || strArr2.length == 0) {
            return new double[strArr.length][strArr2.length];
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        try {
            HashMap hashMap = new HashMap();
            for (String str : (String[]) ArrayUtils.addAll(strArr, strArr2)) {
                if (!hashMap.containsKey(str)) {
                    hashMap.put(str, this.generator.getVector(str));
                }
            }
            for (String str2 : strArr) {
                arrayList.add(hashMap.get(str2));
            }
            for (String str3 : strArr2) {
                arrayList2.add(hashMap.get(str3));
            }
        } catch (UnsupportedOperationException e) {
        }
        if (arrayList.isEmpty() || arrayList2.isEmpty()) {
            ArrayList arrayList3 = new ArrayList();
            for (String str4 : (String[]) ArrayUtils.addAll(strArr, strArr2)) {
                if (!arrayList3.contains(str4)) {
                    arrayList3.add(str4);
                }
            }
            TIntFloatMap[] phraseVectors = this.phraseVectorCreator.getPhraseVectors((String[]) arrayList3.toArray(new String[0]));
            for (String str5 : strArr) {
                int indexOf = arrayList3.indexOf(str5);
                if (indexOf < 0) {
                    throw new IllegalStateException();
                }
                arrayList.add(phraseVectors[indexOf]);
            }
            for (String str6 : strArr2) {
                int indexOf2 = arrayList3.indexOf(str6);
                if (indexOf2 < 0) {
                    throw new IllegalStateException();
                }
                arrayList2.add(phraseVectors[indexOf2]);
            }
        }
        return cosimilarity(arrayList, arrayList2);
    }

    public void setPhraseMode(PhraseMode phraseMode) {
        this.phraseMode = phraseMode;
    }
}
