package org.wikibrain.sr.vector;

import com.typesafe.config.Config;
import gnu.trove.map.TIntFloatMap;
import gnu.trove.map.hash.TIntObjectHashMap;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.dao.DaoFilter;
import org.wikibrain.core.dao.LocalPageDao;
import org.wikibrain.core.lang.Language;
import org.wikibrain.core.model.LocalPage;
import org.wikibrain.core.model.NameSpace;
import org.wikibrain.matrix.MatrixRow;
import org.wikibrain.matrix.SparseMatrix;
import org.wikibrain.matrix.SparseMatrixRow;
import org.wikibrain.matrix.SparseMatrixTransposer;
import org.wikibrain.matrix.SparseMatrixWriter;
import org.wikibrain.matrix.ValueConf;
import org.wikibrain.sr.BaseSRMetric;
import org.wikibrain.sr.SRMetric;
import org.wikibrain.sr.SRResult;
import org.wikibrain.sr.SRResultList;
import org.wikibrain.sr.dataset.Dataset;
import org.wikibrain.sr.disambig.Disambiguator;
import org.wikibrain.utils.ParallelForEach;
import org.wikibrain.utils.Procedure;
import org.wikibrain.utils.WbArrayUtils;
import org.wikibrain.utils.WpThreadUtils;

/* loaded from: input_file:org/wikibrain/sr/vector/VectorBasedSRMetric.class */
public class VectorBasedSRMetric extends BaseSRMetric {
    private static final Logger LOG = Logger.getLogger(VectorBasedSRMetric.class.getName());
    protected final VectorGenerator generator;
    protected final VectorSimilarity similarity;
    protected final BaseSRMetric.SRConfig config;
    private FeatureFilter featureFilter;
    private SparseMatrix featureMatrix;
    private SparseMatrix transposeMatrix;

    /* loaded from: input_file:org/wikibrain/sr/vector/VectorBasedSRMetric$Provider.class */
    public static class Provider extends org.wikibrain.conf.Provider<SRMetric> {
        public Provider(Configurator configurator, Configuration configuration) throws ConfigurationException {
            super(configurator, configuration);
        }

        public Class getType() {
            return SRMetric.class;
        }

        public String getPath() {
            return "sr.metric.local";
        }

        public SRMetric get(String str, Config config, Map<String, String> map) throws ConfigurationException {
            if (!config.getString("type").equals("vector")) {
                return null;
            }
            if (map == null || !map.containsKey("language")) {
                throw new IllegalArgumentException("Monolingual requires 'language' runtime parameter.");
            }
            Language byLangCode = Language.getByLangCode(map.get("language"));
            HashMap hashMap = new HashMap();
            hashMap.put("language", byLangCode.getLangCode());
            VectorBasedSRMetric vectorBasedSRMetric = new VectorBasedSRMetric(str, byLangCode, (LocalPageDao) getConfigurator().get(LocalPageDao.class, config.getString("pageDao")), (Disambiguator) getConfigurator().get(Disambiguator.class, config.getString("disambiguator"), "language", byLangCode.getLangCode()), (VectorGenerator) getConfigurator().construct(VectorGenerator.class, (String) null, config.getConfig("generator"), hashMap), (VectorSimilarity) getConfigurator().construct(VectorSimilarity.class, (String) null, config.getConfig("similarity"), hashMap));
            VectorBasedSRMetric.configureBase(getConfigurator(), vectorBasedSRMetric, config);
            return vectorBasedSRMetric;
        }

        /* renamed from: get, reason: collision with other method in class */
        public /* bridge */ /* synthetic */ Object m54get(String str, Config config, Map map) throws ConfigurationException {
            return get(str, config, (Map<String, String>) map);
        }
    }

    public VectorBasedSRMetric(String str, Language language, LocalPageDao localPageDao, Disambiguator disambiguator, VectorGenerator vectorGenerator, VectorSimilarity vectorSimilarity) {
        super(str, language, localPageDao, disambiguator);
        this.featureFilter = null;
        this.generator = vectorGenerator;
        this.similarity = vectorSimilarity;
        this.config = new BaseSRMetric.SRConfig();
        this.config.minScore = (float) vectorSimilarity.getMinValue();
        this.config.maxScore = (float) vectorSimilarity.getMaxValue();
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public SRResult similarity(String str, String str2, boolean z) throws DaoException {
        if (this.featureFilter != null) {
            throw new UnsupportedOperationException();
        }
        TIntFloatMap tIntFloatMap = null;
        TIntFloatMap tIntFloatMap2 = null;
        try {
            tIntFloatMap = this.generator.getVector(str);
            tIntFloatMap2 = this.generator.getVector(str2);
        } catch (UnsupportedOperationException e) {
        }
        if (tIntFloatMap == null || tIntFloatMap2 == null) {
            return super.similarity(str, str2, z);
        }
        SRResult sRResult = new SRResult(this.similarity.similarity(tIntFloatMap, tIntFloatMap2));
        if (z) {
            sRResult.setExplanations(this.generator.getExplanations(str, str2, tIntFloatMap, tIntFloatMap2, sRResult));
        }
        return normalize(sRResult);
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public SRResult similarity(int i, int i2, boolean z) throws DaoException {
        try {
            if (!hasFeatureMatrix()) {
                TIntFloatMap pageVector = getPageVector(i);
                TIntFloatMap pageVector2 = getPageVector(i2);
                if (pageVector == null || pageVector2 == null) {
                    return null;
                }
                return normalize(new SRResult(this.similarity.similarity(pageVector, pageVector2)));
            }
            SparseMatrixRow row = this.featureMatrix.getRow(i);
            SparseMatrixRow row2 = this.featureMatrix.getRow(i2);
            if (row == null || row2 == null) {
                return null;
            }
            if (this.featureFilter != null) {
                row = this.featureFilter.filter(i, row);
                row2 = this.featureFilter.filter(i2, row2);
            }
            SRResult sRResult = new SRResult(this.similarity.similarity((MatrixRow) row, (MatrixRow) row2));
            if (z) {
                sRResult.setExplanations(this.generator.getExplanations(i, i2, row.asTroveMap(), row2.asTroveMap(), sRResult));
            }
            return normalize(sRResult);
        } catch (IOException e) {
            throw new DaoException(e);
        }
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public SRResultList mostSimilar(String str, int i, TIntSet tIntSet) throws DaoException {
        if (this.featureFilter != null) {
            throw new UnsupportedOperationException();
        }
        TIntFloatMap tIntFloatMap = null;
        try {
            tIntFloatMap = this.generator.getVector(str);
        } catch (UnsupportedOperationException e) {
        }
        if (tIntFloatMap == null) {
            return super.mostSimilar(str, i, tIntSet);
        }
        try {
            return this.similarity.mostSimilar(tIntFloatMap, i, tIntSet);
        } catch (IOException e2) {
            throw new DaoException(e2);
        }
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public SRResultList mostSimilar(int i, int i2, TIntSet tIntSet) throws DaoException {
        if (this.featureFilter != null) {
            throw new UnsupportedOperationException();
        }
        try {
            TIntFloatMap pageVector = getPageVector(i);
            if (pageVector == null) {
                return null;
            }
            return this.similarity.mostSimilar(pageVector, i2, tIntSet);
        } catch (IOException e) {
            throw new DaoException(e);
        }
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public void trainSimilarity(Dataset dataset) throws DaoException {
        super.trainSimilarity(dataset);
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public void trainMostSimilar(Dataset dataset, int i, TIntSet tIntSet) {
        try {
            buildFeatureAndTransposeMatrices(tIntSet);
            super.trainMostSimilar(dataset, i, tIntSet);
        } catch (IOException e) {
            LOG.log(Level.SEVERE, "training failed", (Throwable) e);
            throw new RuntimeException(e);
        }
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public double[][] cosimilarity(int[] iArr) throws DaoException {
        return cosimilarity(iArr, iArr);
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public double[][] cosimilarity(String[] strArr) throws DaoException {
        return cosimilarity(strArr, strArr);
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public double[][] cosimilarity(String[] strArr, String[] strArr2) throws DaoException {
        if (this.featureFilter != null) {
            throw new UnsupportedOperationException();
        }
        if (strArr.length == 0 || strArr2.length == 0) {
            return new double[strArr.length][strArr2.length];
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        try {
            HashMap hashMap = new HashMap();
            for (String str : (String[]) ArrayUtils.addAll(strArr, strArr2)) {
                if (!hashMap.containsKey(str)) {
                    hashMap.put(str, this.generator.getVector(str));
                }
            }
            for (String str2 : strArr) {
                arrayList.add(hashMap.get(str2));
            }
            for (String str3 : strArr2) {
                arrayList2.add(hashMap.get(str3));
            }
        } catch (UnsupportedOperationException e) {
        }
        return (arrayList.isEmpty() || arrayList2.isEmpty()) ? super.cosimilarity(strArr, strArr2) : cosimilarity(arrayList, arrayList2);
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public double[][] cosimilarity(int[] iArr, int[] iArr2) throws DaoException {
        if (!hasFeatureMatrix()) {
            HashMap hashMap = new HashMap();
            for (int i : ArrayUtils.addAll(iArr2, iArr)) {
                if (!hashMap.containsKey(Integer.valueOf(i))) {
                    try {
                        hashMap.put(Integer.valueOf(i), getPageVector(i));
                    } catch (IOException e) {
                        throw new DaoException(e);
                    }
                }
            }
            ArrayList arrayList = new ArrayList();
            for (int i2 : iArr) {
                arrayList.add(hashMap.get(Integer.valueOf(i2)));
            }
            ArrayList arrayList2 = new ArrayList();
            for (int i3 : iArr2) {
                arrayList2.add(hashMap.get(Integer.valueOf(i3)));
            }
            return cosimilarity(arrayList, arrayList2);
        }
        TIntObjectHashMap tIntObjectHashMap = new TIntObjectHashMap(iArr.length + iArr2.length);
        for (int i4 : ArrayUtils.addAll(iArr, iArr2)) {
            if (!tIntObjectHashMap.containsKey(i4)) {
                try {
                    SparseMatrixRow row = this.featureMatrix.getRow(i4);
                    if (row != null) {
                        if (this.featureFilter != null) {
                            row = this.featureFilter.filter(i4, row);
                        }
                        tIntObjectHashMap.put(i4, row);
                    }
                } catch (IOException e2) {
                    throw new DaoException(e2);
                }
            }
        }
        double[][] dArr = new double[iArr.length][iArr2.length];
        for (int i5 = 0; i5 < iArr.length; i5++) {
            MatrixRow matrixRow = (SparseMatrixRow) tIntObjectHashMap.get(iArr[i5]);
            if (matrixRow != null) {
                for (int i6 = 0; i6 < iArr2.length; i6++) {
                    MatrixRow matrixRow2 = (SparseMatrixRow) tIntObjectHashMap.get(iArr2[i6]);
                    if (matrixRow2 != null) {
                        dArr[i5][i6] = normalize(this.similarity.similarity(matrixRow, matrixRow2));
                    }
                }
            }
        }
        return dArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double[][] cosimilarity(List<TIntFloatMap> list, List<TIntFloatMap> list2) {
        if (this.featureFilter != null) {
            throw new UnsupportedOperationException();
        }
        double[][] dArr = new double[list.size()][list2.size()];
        for (int i = 0; i < list.size(); i++) {
            for (int i2 = 0; i2 < list2.size(); i2++) {
                dArr[i][i2] = normalize(this.similarity.similarity(list.get(i), list2.get(i2)));
            }
        }
        return dArr;
    }

    public synchronized void buildFeatureAndTransposeMatrices(TIntSet tIntSet) throws IOException {
        if (tIntSet == null) {
            tIntSet = getAllPageIds();
        }
        IOUtils.closeQuietly(this.featureMatrix);
        IOUtils.closeQuietly(this.transposeMatrix);
        this.featureMatrix = null;
        this.transposeMatrix = null;
        getDataDir().mkdirs();
        final SparseMatrixWriter sparseMatrixWriter = new SparseMatrixWriter(getFeatureMatrixPath(), new ValueConf((float) this.similarity.getMinValue(), (float) this.similarity.getMaxValue()));
        ParallelForEach.loop(WbArrayUtils.toList(tIntSet.toArray()), WpThreadUtils.getMaxThreads(), new Procedure<Integer>() { // from class: org.wikibrain.sr.vector.VectorBasedSRMetric.1
            public void call(Integer num) throws IOException {
                TIntFloatMap pageVector = VectorBasedSRMetric.this.getPageVector(num.intValue());
                if (pageVector == null || pageVector.isEmpty()) {
                    return;
                }
                sparseMatrixWriter.writeRow(new SparseMatrixRow(sparseMatrixWriter.getValueConf(), num.intValue(), pageVector));
            }
        }, 10000);
        sparseMatrixWriter.finish();
        this.featureMatrix = new SparseMatrix(getFeatureMatrixPath());
        getDataDir().mkdirs();
        new SparseMatrixTransposer(this.featureMatrix, getTransposeMatrixPath()).transpose();
        this.transposeMatrix = new SparseMatrix(getTransposeMatrixPath());
        this.similarity.setMatrices(this.featureMatrix, this.transposeMatrix, getDataDir());
    }

    private TIntSet getAllPageIds() throws IOException {
        DaoFilter nameSpaces = new DaoFilter().setLanguages(getLanguage()).setDisambig(false).setRedirect(false).setNameSpaces(NameSpace.ARTICLE);
        TIntHashSet tIntHashSet = new TIntHashSet();
        try {
            Iterator it = getLocalPageDao().get(nameSpaces).iterator();
            while (it.hasNext()) {
                tIntHashSet.add(((LocalPage) it.next()).getLocalId());
            }
            return tIntHashSet;
        } catch (DaoException e) {
            throw new IOException((Throwable) e);
        }
    }

    protected File getFeatureMatrixPath() {
        return new File(getDataDir(), "feature.matrix");
    }

    protected File getTransposeMatrixPath() {
        return new File(getDataDir(), "featureTranspose.matrix");
    }

    @Override // org.wikibrain.sr.BaseSRMetric, org.wikibrain.sr.SRMetric
    public void read() throws IOException {
        super.read();
        if (getFeatureMatrixPath().isFile() && getTransposeMatrixPath().isFile()) {
            IOUtils.closeQuietly(this.featureMatrix);
            IOUtils.closeQuietly(this.transposeMatrix);
            this.featureMatrix = new SparseMatrix(getFeatureMatrixPath());
            this.transposeMatrix = new SparseMatrix(getTransposeMatrixPath());
            this.similarity.setMatrices(this.featureMatrix, this.transposeMatrix, getDataDir());
        }
    }

    public TIntFloatMap getPageVector(int i) throws IOException {
        if (!hasFeatureMatrix()) {
            try {
                return this.featureFilter != null ? this.featureFilter.filter(i, this.generator.getVector(i)) : this.generator.getVector(i);
            } catch (DaoException e) {
                throw new IOException((Throwable) e);
            }
        }
        SparseMatrixRow row = this.featureMatrix.getRow(i);
        if (row == null) {
            return null;
        }
        return this.featureFilter != null ? this.featureFilter.filter(i, (TIntFloatMap) row.asTroveMap()) : row.asTroveMap();
    }

    protected boolean hasFeatureMatrix() {
        return this.featureMatrix != null && this.featureMatrix.getNumRows() > 0;
    }

    protected boolean hasTransposeMatrix() {
        return this.transposeMatrix != null && this.transposeMatrix.getNumRows() > 0;
    }

    public VectorGenerator getGenerator() {
        return this.generator;
    }

    public VectorSimilarity getSimilarity() {
        return this.similarity;
    }

    public void setFeatureFilter(FeatureFilter featureFilter) {
        this.featureFilter = featureFilter;
    }

    @Override // org.wikibrain.sr.BaseSRMetric
    public BaseSRMetric.SRConfig getConfig() {
        return this.config;
    }
}
