package org.wikibrain.sr.phrasesim;

import com.typesafe.config.Config;
import gnu.trove.map.TIntFloatMap;
import gnu.trove.map.TLongFloatMap;
import gnu.trove.map.hash.TIntDoubleHashMap;
import gnu.trove.map.hash.TIntFloatHashMap;
import gnu.trove.procedure.TIntFloatProcedure;
import gnu.trove.procedure.TLongFloatProcedure;
import gnu.trove.set.TIntSet;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.io.FileUtils;
import org.mapdb.DB;
import org.mapdb.DBMaker;
import org.mapdb.HTreeMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.lang.Language;
import org.wikibrain.core.lang.StringNormalizer;
import org.wikibrain.sr.SRMetric;
import org.wikibrain.sr.SRResult;
import org.wikibrain.sr.SRResultList;
import org.wikibrain.sr.dataset.Dataset;
import org.wikibrain.sr.normalize.IdentityNormalizer;
import org.wikibrain.sr.normalize.Normalizer;
import org.wikibrain.sr.normalize.PercentileNormalizer;
import org.wikibrain.sr.utils.Leaderboard;
import org.wikibrain.sr.vector.SparseVectorSRMetric;
import org.wikibrain.utils.ParallelForEach;
import org.wikibrain.utils.Procedure;
import org.wikibrain.utils.WpIOUtils;

/* loaded from: input_file:org/wikibrain/sr/phrasesim/KnownPhraseSim.class */
public class KnownPhraseSim implements SRMetric {
    private static final Logger LOGGER = LoggerFactory.getLogger(KnownPhraseSim.class);
    private final StringNormalizer stringNormalizer;
    private final HTreeMap<Object, Object> db;
    private final PhraseCreator creator;
    private final Language language;
    private final File dir;
    private final String name;
    private Normalizer scoreNormalizer;
    private ConcurrentHashMap<String, KnownPhrase> byPhrase;
    private ConcurrentHashMap<Integer, KnownPhrase> byId;
    private ConcurrentHashMap<Long, TIntFloatMap> invertedIndex;
    private CosimilarityMatrix cosim;
    private DB phraseDb;

    /* loaded from: input_file:org/wikibrain/sr/phrasesim/KnownPhraseSim$Provider.class */
    public static class Provider extends org.wikibrain.conf.Provider<SRMetric> {
        public Provider(Configurator configurator, Configuration configuration) throws ConfigurationException {
            super(configurator, configuration);
        }

        public Class getType() {
            return SRMetric.class;
        }

        public String getPath() {
            return "sr.metric.local";
        }

        public SRMetric get(String str, Config config, Map<String, String> map) throws ConfigurationException {
            if (!config.getString("type").equals("knownphrase")) {
                return null;
            }
            if (map == null || !map.containsKey("language")) {
                throw new IllegalArgumentException("Monolingual SR Metric requires 'language' runtime parameter");
            }
            Language byLangCode = Language.getByLangCode(map.get("language"));
            List stringList = config.getStringList("metrics");
            SparseVectorSRMetric[] sparseVectorSRMetricArr = new SparseVectorSRMetric[stringList.size()];
            for (int i = 0; i < stringList.size(); i++) {
                sparseVectorSRMetricArr[i] = (SparseVectorSRMetric) getConfigurator().get(SRMetric.class, (String) stringList.get(i), "language", byLangCode.getLangCode());
            }
            try {
                return new KnownPhraseSim(str, byLangCode, new EnsemblePhraseCreator(sparseVectorSRMetricArr, KnownPhraseSim.toPrimitive(config.getDoubleList("coefficients"))), FileUtils.getFile(new String[]{getConfig().getString(new String[]{"sr.metric.path"}), str, byLangCode.getLangCode()}), (StringNormalizer) getConfigurator().get(StringNormalizer.class, config.hasPath("stringnormalizer") ? config.getString("stringnormalizer") : null));
            } catch (IOException e) {
                throw new ConfigurationException(e);
            }
        }

        /* renamed from: get, reason: collision with other method in class */
        public /* bridge */ /* synthetic */ Object m42get(String str, Config config, Map map) throws ConfigurationException {
            return get(str, config, (Map<String, String>) map);
        }
    }

    public KnownPhraseSim(Language language, PhraseCreator phraseCreator, File file, StringNormalizer stringNormalizer) throws IOException {
        this("known-phrase-sim", language, phraseCreator, file, stringNormalizer);
    }

    public KnownPhraseSim(String str, Language language, PhraseCreator phraseCreator, File file, StringNormalizer stringNormalizer) throws IOException {
        this.scoreNormalizer = new IdentityNormalizer();
        this.invertedIndex = new ConcurrentHashMap<>();
        this.cosim = new CosimilarityMatrix();
        this.name = str;
        this.language = language;
        this.creator = phraseCreator;
        this.stringNormalizer = stringNormalizer;
        this.dir = file;
        this.dir.mkdirs();
        this.phraseDb = DBMaker.newFileDB(new File(file, "phrases.mapdb")).mmapFileEnable().transactionDisable().asyncWriteEnable().asyncWriteFlushDelay(100).make();
        this.db = this.phraseDb.getHashMap("phrases");
        readPhrases();
        readCosimilarity();
        File file2 = new File(file, "scoreNormalizer.bin");
        if (file2.isFile()) {
            this.scoreNormalizer = (Normalizer) WpIOUtils.readObjectFromFile(file2);
        }
    }

    @Override // org.wikibrain.sr.SRMetric
    public void read() {
        throw new UnsupportedOperationException("Metric cannot be re-read after creation");
    }

    @Override // org.wikibrain.sr.SRMetric
    public void write() throws IOException {
        flushCosimilarity();
    }

    private void readCosimilarity() throws IOException {
        File file = new File(this.dir, "cosimilarity.bin");
        try {
            this.cosim = (CosimilarityMatrix) WpIOUtils.readObjectFromFile(file);
        } catch (Exception e) {
            LOGGER.info("Reading cosim file " + file + " failed... rebuilding it from scratch");
            this.cosim = new CosimilarityMatrix();
        }
        final TIntSet completed = this.cosim.getCompleted();
        ParallelForEach.loop(this.byId.values(), new Procedure<KnownPhrase>() { // from class: org.wikibrain.sr.phrasesim.KnownPhraseSim.1
            public void call(KnownPhrase knownPhrase) throws Exception {
                if (completed.contains(knownPhrase.getId())) {
                    return;
                }
                KnownPhraseSim.this.cosim.update(knownPhrase.getId(), KnownPhraseSim.this.indexedMostSimilar(knownPhrase.getVector(), KnownPhraseSim.this.byId.size(), null));
            }
        });
    }

    public void flushCosimilarity() throws IOException {
        WpIOUtils.writeObjectToFile(new File(this.dir, "cosimilarity.bin"), this.cosim);
        this.db.getEngine().commit();
    }

    private void readPhrases() {
        this.byId = new ConcurrentHashMap<>();
        this.byPhrase = new ConcurrentHashMap<>();
        for (Map.Entry entry : this.db.entrySet()) {
            String str = (String) entry.getKey();
            KnownPhrase knownPhrase = (KnownPhrase) entry.getValue();
            if (!str.equals(knownPhrase.getNormalizedPhrase())) {
                throw new IllegalStateException();
            }
            this.byId.put(Integer.valueOf(knownPhrase.getId()), knownPhrase);
            this.byPhrase.put(knownPhrase.getNormalizedPhrase(), knownPhrase);
            Iterator<String> it = knownPhrase.getVersions().iterator();
            while (it.hasNext()) {
                this.byPhrase.put(it.next(), knownPhrase);
            }
            PhraseVector vector = knownPhrase.getVector();
            for (int i = 0; i < vector.ids.length; i++) {
                long j = vector.ids[i];
                float f = vector.vals[i];
                this.invertedIndex.putIfAbsent(Long.valueOf(j), new TIntFloatHashMap());
                synchronized (this.invertedIndex.get(Long.valueOf(j))) {
                    this.invertedIndex.get(Long.valueOf(j)).put(knownPhrase.getId(), f);
                }
            }
        }
    }

    public void addPhrase(String str, final int i) {
        KnownPhrase knownPhrase = new KnownPhrase(i, str, normalize(str));
        KnownPhrase putIfAbsent = this.byPhrase.putIfAbsent(knownPhrase.getNormalizedPhrase(), knownPhrase);
        if (putIfAbsent != null) {
            putIfAbsent.increment(str);
            this.db.put(knownPhrase.getNormalizedPhrase(), new KnownPhrase(putIfAbsent));
            return;
        }
        TLongFloatMap vector = this.creator.getVector(str);
        if (vector == null) {
            return;
        }
        knownPhrase.setVector(new PhraseVector(vector));
        this.byId.put(Integer.valueOf(i), knownPhrase);
        this.db.put(knownPhrase.getNormalizedPhrase(), new KnownPhrase(knownPhrase));
        vector.forEachEntry(new TLongFloatProcedure() { // from class: org.wikibrain.sr.phrasesim.KnownPhraseSim.2
            public boolean execute(long j, float f) {
                KnownPhraseSim.this.invertedIndex.putIfAbsent(Long.valueOf(j), new TIntFloatHashMap());
                synchronized (((TIntFloatMap) KnownPhraseSim.this.invertedIndex.get(Long.valueOf(j)))) {
                    ((TIntFloatMap) KnownPhraseSim.this.invertedIndex.get(Long.valueOf(j))).put(i, f);
                }
                return true;
            }
        });
        if (this.cosim != null) {
            this.cosim.update(i, indexedMostSimilar(knownPhrase.getVector(), this.byId.size(), null));
        }
    }

    public void rebuild() {
        throw new UnsupportedOperationException();
    }

    public void trainNormalizer() throws IOException {
        Normalizer normalizer = this.scoreNormalizer;
        this.scoreNormalizer = new IdentityNormalizer();
        try {
            ArrayList arrayList = new ArrayList(this.byId.keySet());
            Random random = new Random();
            PercentileNormalizer percentileNormalizer = new PercentileNormalizer();
            percentileNormalizer.setPower(10.0d);
            percentileNormalizer.setSampleSize(100000);
            for (int i = 0; i < 1000; i++) {
                Iterator<SRResult> it = mostSimilar(((Integer) arrayList.get(random.nextInt(arrayList.size()))).intValue(), arrayList.size()).iterator();
                while (it.hasNext()) {
                    percentileNormalizer.observe(it.next().getScore());
                }
            }
            percentileNormalizer.observationsFinished();
            WpIOUtils.writeObjectToFile(new File(this.dir, "scoreNormalizer.bin"), percentileNormalizer);
            normalizer = percentileNormalizer;
            this.scoreNormalizer = normalizer;
        } catch (Throwable th) {
            this.scoreNormalizer = normalizer;
            throw th;
        }
    }

    public String getPhrase(int i) {
        if (this.byId.containsKey(Integer.valueOf(i))) {
            return this.byId.get(Integer.valueOf(i)).getCanonicalVersion();
        }
        return null;
    }

    public Integer getId(String str) {
        KnownPhrase knownPhrase = this.byPhrase.get(normalize(str));
        if (knownPhrase == null) {
            return null;
        }
        return Integer.valueOf(knownPhrase.getId());
    }

    public String normalize(String str) {
        return this.stringNormalizer.normalize(this.language, str);
    }

    @Override // org.wikibrain.sr.SRMetric
    public SRResult similarity(int i, int i2, boolean z) {
        return new SRResult(this.cosim.similarity(i, i2));
    }

    @Override // org.wikibrain.sr.SRMetric
    public SRResult similarity(String str, String str2, boolean z) throws DaoException {
        Integer id = getId(str);
        Integer id2 = getId(str2);
        return (id == null || id2 == null) ? new SRResult(Double.NaN) : similarity(id.intValue(), id2.intValue(), z);
    }

    @Override // org.wikibrain.sr.SRMetric
    public SRResultList mostSimilar(String str, int i, TIntSet tIntSet) {
        Integer id = getId(str);
        if (id == null) {
            return null;
        }
        return mostSimilar(id.intValue(), i, tIntSet);
    }

    @Override // org.wikibrain.sr.SRMetric
    public SRResultList mostSimilar(String str, int i) {
        return mostSimilar(str, i, (TIntSet) null);
    }

    @Override // org.wikibrain.sr.SRMetric
    public SRResultList mostSimilar(int i, int i2, TIntSet tIntSet) {
        KnownPhrase knownPhrase = this.byId.get(Integer.valueOf(i));
        if (knownPhrase == null) {
            return null;
        }
        if (this.cosim != null) {
            return this.scoreNormalizer.normalize(this.cosim.mostSimilar(i, i2, tIntSet));
        }
        PhraseVector vector = knownPhrase.getVector();
        return (tIntSet == null || tIntSet.size() >= 10) ? indexedMostSimilar(vector, i2, tIntSet) : mostSimilar(vector, i2, tIntSet);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public SRResultList indexedMostSimilar(PhraseVector phraseVector, int i, TIntSet tIntSet) {
        final TIntDoubleHashMap tIntDoubleHashMap = new TIntDoubleHashMap(i * 5);
        for (int i2 = 0; i2 < phraseVector.ids.length; i2++) {
            long j = phraseVector.ids[i2];
            final float f = phraseVector.vals[i2];
            TIntFloatMap tIntFloatMap = this.invertedIndex.get(Long.valueOf(j));
            if (tIntFloatMap != null) {
                synchronized (tIntFloatMap) {
                    tIntFloatMap.forEachEntry(new TIntFloatProcedure() { // from class: org.wikibrain.sr.phrasesim.KnownPhraseSim.3
                        public boolean execute(int i3, float f2) {
                            tIntDoubleHashMap.adjustOrPutValue(i3, f2 * f, f2 * f);
                            return true;
                        }
                    });
                }
            }
        }
        Leaderboard leaderboard = new Leaderboard(i);
        double norm2 = phraseVector.norm2();
        for (int i3 : tIntDoubleHashMap.keys()) {
            leaderboard.tallyScore(i3, tIntDoubleHashMap.get(i3) / (norm2 * this.byId.get(Integer.valueOf(i3)).getVector().norm2()));
        }
        return leaderboard.getTop();
    }

    private SRResultList mostSimilar(PhraseVector phraseVector, int i, TIntSet tIntSet) {
        Leaderboard leaderboard = new Leaderboard(i);
        if (tIntSet != null) {
            for (int i2 : tIntSet.toArray()) {
                KnownPhrase knownPhrase = this.byId.get(Integer.valueOf(i2));
                if (knownPhrase != null) {
                    leaderboard.tallyScore(i2, phraseVector.cosineSim(knownPhrase.getVector()));
                }
            }
        } else {
            for (KnownPhrase knownPhrase2 : this.byId.values()) {
                leaderboard.tallyScore(knownPhrase2.getId(), phraseVector.cosineSim(knownPhrase2.getVector()));
            }
        }
        return leaderboard.getTop();
    }

    @Override // org.wikibrain.sr.SRMetric
    public SRResultList mostSimilar(int i, int i2) {
        return mostSimilar(i, i2, (TIntSet) null);
    }

    @Override // org.wikibrain.sr.SRMetric
    public double[][] cosimilarity(String[] strArr, String[] strArr2) {
        int[] iArr = new int[strArr.length];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = getId(strArr[i]).intValue();
        }
        int[] iArr2 = new int[strArr2.length];
        for (int i2 = 0; i2 < iArr2.length; i2++) {
            iArr2[i2] = getId(strArr2[i2]).intValue();
        }
        return cosimilarity(iArr, iArr2);
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    @Override // org.wikibrain.sr.SRMetric
    public double[][] cosimilarity(int[] iArr) throws DaoException {
        return new double[0];
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    @Override // org.wikibrain.sr.SRMetric
    public double[][] cosimilarity(String[] strArr) throws DaoException {
        return new double[0];
    }

    public float[] getPhraseVector(String str) {
        Integer id = getId(str);
        if (id == null) {
            return null;
        }
        return getPhraseVector(id.intValue());
    }

    public float[] getPhraseVector(int i) {
        return this.cosim.getVector(i);
    }

    @Override // org.wikibrain.sr.SRMetric
    public double[][] cosimilarity(int[] iArr, int[] iArr2) {
        double[][] dArr = new double[iArr.length][iArr2.length];
        if (this.cosim != null) {
            return this.cosim.cosimilarity(iArr, iArr2);
        }
        ArrayList arrayList = new ArrayList(iArr2.length);
        for (int i : iArr2) {
            KnownPhrase knownPhrase = this.byId.get(Integer.valueOf(i));
            arrayList.add(knownPhrase == null ? null : knownPhrase.getVector());
        }
        for (int i2 = 0; i2 < iArr.length; i2++) {
            KnownPhrase knownPhrase2 = this.byId.get(Integer.valueOf(iArr2[i2]));
            if (knownPhrase2 != null) {
                PhraseVector vector = knownPhrase2.getVector();
                for (int i3 = 0; i3 < iArr2.length; i3++) {
                    PhraseVector phraseVector = (PhraseVector) arrayList.get(i3);
                    if (phraseVector != null) {
                        dArr[i2][i3] = this.scoreNormalizer.normalize(vector.cosineSim(phraseVector));
                    }
                }
            }
        }
        return dArr;
    }

    @Override // org.wikibrain.sr.SRMetric
    public String getName() {
        return this.name;
    }

    @Override // org.wikibrain.sr.SRMetric
    public Language getLanguage() {
        return this.language;
    }

    @Override // org.wikibrain.sr.SRMetric
    public File getDataDir() {
        return this.dir;
    }

    @Override // org.wikibrain.sr.SRMetric
    public void setDataDir(File file) {
        throw new UnsupportedOperationException();
    }

    public Normalizer getScoreNormalizer() {
        return this.scoreNormalizer;
    }

    @Override // org.wikibrain.sr.SRMetric
    public Normalizer getMostSimilarNormalizer() {
        return this.scoreNormalizer;
    }

    @Override // org.wikibrain.sr.SRMetric
    public void setMostSimilarNormalizer(Normalizer normalizer) {
        throw new UnsupportedOperationException();
    }

    @Override // org.wikibrain.sr.SRMetric
    public Normalizer getSimilarityNormalizer() {
        return this.scoreNormalizer;
    }

    @Override // org.wikibrain.sr.SRMetric
    public void setSimilarityNormalizer(Normalizer normalizer) {
        throw new UnsupportedOperationException();
    }

    @Override // org.wikibrain.sr.SRMetric
    public void trainSimilarity(Dataset dataset) {
    }

    @Override // org.wikibrain.sr.SRMetric
    public void trainMostSimilar(Dataset dataset, int i, TIntSet tIntSet) {
    }

    @Override // org.wikibrain.sr.SRMetric
    public boolean similarityIsTrained() {
        return false;
    }

    @Override // org.wikibrain.sr.SRMetric
    public boolean mostSimilarIsTrained() {
        return false;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double[] toPrimitive(List<Double> list) {
        double[] dArr = new double[list.size()];
        for (int i = 0; i < list.size(); i++) {
            dArr[i] = list.get(i).doubleValue();
        }
        return dArr;
    }
}
