package org.wikibrain.sr;

import edu.emory.mathcs.backport.java.util.Arrays;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import java.io.File;
import java.io.IOException;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.logging.Logger;
import org.apache.commons.io.IOUtils;
import org.wikibrain.core.WikiBrainException;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.dao.DaoFilter;
import org.wikibrain.core.dao.UniversalPageDao;
import org.wikibrain.core.lang.LocalId;
import org.wikibrain.core.lang.LocalString;
import org.wikibrain.core.model.UniversalPage;
import org.wikibrain.matrix.SparseMatrixRow;
import org.wikibrain.sr.dataset.Dataset;
import org.wikibrain.sr.disambig.Disambiguator;
import org.wikibrain.sr.normalize.Normalizer;
import org.wikibrain.sr.pairwise.MostSimilarCache;
import org.wikibrain.sr.pairwise.PairwiseSimilarity;
import org.wikibrain.sr.utils.Leaderboard;
import org.wikibrain.sr.utils.SrNormalizers;
import org.wikibrain.utils.WpIOUtils;

/* loaded from: input_file:org/wikibrain/sr/BaseUniversalSRMetric.class */
public abstract class BaseUniversalSRMetric implements UniversalSRMetric {
    private static final Logger LOG = Logger.getLogger(BaseUniversalSRMetric.class.getName());
    protected UniversalPageDao universalPageDao;
    protected Disambiguator disambiguator;
    protected int algorithmId;
    private SrNormalizers normalizers = new SrNormalizers();
    private MostSimilarCache mostSimilarMatrices = null;

    public BaseUniversalSRMetric(Disambiguator disambiguator, UniversalPageDao universalPageDao, int i) {
        this.universalPageDao = universalPageDao;
        this.disambiguator = disambiguator;
        this.algorithmId = i;
    }

    public boolean hasCachedMostSimilarUniversal(int i) {
        try {
            return this.mostSimilarMatrices.getCosimilarityMatrix().getRow(i) != null;
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public SRResultList getCachedMostSimilarUniversal(int i, int i2, TIntSet tIntSet) {
        if (!hasCachedMostSimilarUniversal(i)) {
            return null;
        }
        try {
            SparseMatrixRow row = this.mostSimilarMatrices.getCosimilarityMatrix().getRow(i);
            Leaderboard leaderboard = new Leaderboard(i2);
            for (int i3 = 0; i3 < row.getNumCols(); i3++) {
                int colIndex = row.getColIndex(i3);
                float colValue = row.getColValue(i3);
                if (tIntSet == null || tIntSet.contains(colIndex)) {
                    leaderboard.tallyScore(colIndex, colValue);
                }
            }
            SRResultList top = leaderboard.getTop();
            top.sortDescending();
            return top;
        } catch (IOException e) {
            return null;
        }
    }

    @Override // org.wikibrain.sr.UniversalSRMetric
    public abstract SRResult similarity(UniversalPage universalPage, UniversalPage universalPage2, boolean z) throws DaoException;

    @Override // org.wikibrain.sr.UniversalSRMetric
    public SRResult similarity(LocalString localString, LocalString localString2, boolean z) throws DaoException {
        HashSet hashSet = new HashSet();
        hashSet.add(localString2);
        LocalId disambiguateTop = this.disambiguator.disambiguateTop(localString, hashSet);
        hashSet.clear();
        hashSet.add(localString);
        LocalId disambiguateTop2 = this.disambiguator.disambiguateTop(localString2, hashSet);
        if (disambiguateTop == null || disambiguateTop2 == null) {
            return new SRResult();
        }
        return similarity(this.universalPageDao.getById(this.universalPageDao.getUnivPageId(disambiguateTop.asLocalPage(), this.algorithmId), this.algorithmId), this.universalPageDao.getById(this.universalPageDao.getUnivPageId(disambiguateTop2.asLocalPage(), this.algorithmId), this.algorithmId), z);
    }

    @Override // org.wikibrain.sr.UniversalSRMetric
    public abstract SRResultList mostSimilar(UniversalPage universalPage, int i) throws DaoException;

    @Override // org.wikibrain.sr.UniversalSRMetric
    public abstract SRResultList mostSimilar(UniversalPage universalPage, int i, TIntSet tIntSet) throws DaoException;

    @Override // org.wikibrain.sr.UniversalSRMetric
    public SRResultList mostSimilar(LocalString localString, int i) throws DaoException {
        LocalId disambiguateTop = this.disambiguator.disambiguateTop(localString, (Set<LocalString>) null);
        if (disambiguateTop != null) {
            return mostSimilar(this.universalPageDao.getById(this.universalPageDao.getUnivPageId(disambiguateTop.asLocalPage(), this.algorithmId), this.algorithmId), i);
        }
        SRResultList sRResultList = new SRResultList(1);
        sRResultList.set(0, new SRResult());
        return sRResultList;
    }

    @Override // org.wikibrain.sr.UniversalSRMetric
    public SRResultList mostSimilar(LocalString localString, int i, TIntSet tIntSet) throws DaoException {
        LocalId disambiguateTop = this.disambiguator.disambiguateTop(localString, (Set<LocalString>) null);
        if (disambiguateTop != null) {
            return mostSimilar(this.universalPageDao.getById(this.universalPageDao.getUnivPageId(disambiguateTop.asLocalPage(), this.algorithmId), this.algorithmId), i, tIntSet);
        }
        SRResultList sRResultList = new SRResultList(1);
        sRResultList.set(0, new SRResult());
        return sRResultList;
    }

    protected void ensureSimilarityTrained() {
        if (!this.normalizers.getSimilarityNormalizer().isTrained()) {
            throw new IllegalStateException("Model default similarity has not been trained.");
        }
    }

    protected void ensureMostSimilarTrained() {
        if (!this.normalizers.getMostSimilarNormalizer().isTrained()) {
            throw new IllegalStateException("Model default mostSimilar has not been trained.");
        }
    }

    protected SRResult normalize(SRResult sRResult) {
        ensureSimilarityTrained();
        sRResult.score = this.normalizers.getSimilarityNormalizer().normalize(sRResult.score);
        return sRResult;
    }

    protected SRResultList normalize(SRResultList sRResultList) {
        ensureMostSimilarTrained();
        return this.normalizers.getMostSimilarNormalizer().normalize(sRResultList);
    }

    @Override // org.wikibrain.sr.UniversalSRMetric
    public void write(String str) throws IOException {
        File file = new File(str, getName());
        WpIOUtils.mkdirsQuietly(file);
        this.normalizers.write(file);
    }

    @Override // org.wikibrain.sr.UniversalSRMetric
    public void read(String str) throws IOException {
        File file = new File(str, getName());
        if (!file.isDirectory()) {
            LOG.warning("directory " + file + " does not exist; cannot read files");
        } else if (this.normalizers.hasReadableNormalizers(file)) {
            this.normalizers.read(file);
        }
    }

    @Override // org.wikibrain.sr.UniversalSRMetric
    public void trainSimilarity(Dataset dataset) throws DaoException {
        this.normalizers.trainSimilarity(this, dataset);
    }

    @Override // org.wikibrain.sr.UniversalSRMetric
    public void trainMostSimilar(Dataset dataset, int i, TIntSet tIntSet) throws DaoException {
        this.normalizers.trainMostSimilar(this, this.disambiguator, this.universalPageDao, this.algorithmId, dataset, tIntSet, i);
    }

    @Override // org.wikibrain.sr.UniversalSRMetric
    public void setMostSimilarNormalizer(Normalizer normalizer) {
        this.normalizers.setMostSimilarNormalizer(normalizer);
    }

    @Override // org.wikibrain.sr.UniversalSRMetric
    public void setSimilarityNormalizer(Normalizer normalizer) {
        this.normalizers.setSimilarityNormalizer(normalizer);
    }

    @Override // org.wikibrain.sr.UniversalSRMetric
    public double[][] cosimilarity(int[] iArr, int[] iArr2) throws IOException, DaoException {
        double[][] dArr = new double[iArr.length][iArr2.length];
        for (int i = 0; i < iArr.length; i++) {
            for (int i2 = 0; i2 < iArr2.length; i2++) {
                if (iArr[i] == iArr2[i2]) {
                    dArr[i][i2] = 1.0d;
                } else {
                    dArr[i][i2] = similarity(new UniversalPage(iArr[i], this.algorithmId), new UniversalPage(iArr2[i2], this.algorithmId), false).getScore();
                }
            }
        }
        return dArr;
    }

    @Override // org.wikibrain.sr.UniversalSRMetric
    public double[][] cosimilarity(LocalString[] localStringArr, LocalString[] localStringArr2) throws IOException, DaoException {
        int[] iArr = new int[localStringArr.length];
        int[] iArr2 = new int[localStringArr2.length];
        List<LocalId> disambiguateTop = this.disambiguator.disambiguateTop(Arrays.asList(localStringArr), new HashSet(Arrays.asList(localStringArr2)));
        List<LocalId> disambiguateTop2 = this.disambiguator.disambiguateTop(Arrays.asList(localStringArr2), new HashSet(Arrays.asList(localStringArr)));
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = this.universalPageDao.getUnivPageId(disambiguateTop.get(i).asLocalPage(), this.algorithmId);
        }
        for (int i2 = 0; i2 < iArr2.length; i2++) {
            iArr2[i2] = this.universalPageDao.getUnivPageId(disambiguateTop2.get(i2).asLocalPage(), this.algorithmId);
        }
        return cosimilarity(iArr, iArr2);
    }

    @Override // org.wikibrain.sr.UniversalSRMetric
    public double[][] cosimilarity(int[] iArr) throws IOException, DaoException {
        double[][] dArr = new double[iArr.length][iArr.length];
        for (int i = 0; i < iArr.length; i++) {
            dArr[i][i] = 1.0d;
        }
        for (int i2 = 0; i2 < iArr.length; i2++) {
            for (int i3 = i2 + 1; i3 < iArr.length; i3++) {
                dArr[i2][i3] = similarity(new UniversalPage(iArr[i2], 0), new UniversalPage(iArr[i3], 0), false).getScore();
            }
        }
        for (int i4 = 1; i4 < iArr.length; i4++) {
            for (int i5 = i4 - 1; i5 > -1; i5--) {
                dArr[i4][i5] = dArr[i5][i4];
            }
        }
        return dArr;
    }

    @Override // org.wikibrain.sr.UniversalSRMetric
    public double[][] cosimilarity(LocalString[] localStringArr) throws IOException, DaoException {
        int[] iArr = new int[localStringArr.length];
        List<LocalId> disambiguateTop = this.disambiguator.disambiguateTop(Arrays.asList(localStringArr), (Set<LocalString>) null);
        for (int i = 0; i < localStringArr.length; i++) {
            iArr[i] = this.universalPageDao.getUnivPageId(disambiguateTop.get(i).asLocalPage(), this.algorithmId);
        }
        return cosimilarity(iArr);
    }

    @Override // org.wikibrain.sr.UniversalSRMetric
    public int getAlgorithmId() {
        return this.algorithmId;
    }

    protected void writeCosimilarity(String str, int i, PairwiseSimilarity pairwiseSimilarity) throws IOException, DaoException, WikiBrainException {
        try {
            new MostSimilarCache(this, pairwiseSimilarity, new File(str, getName()));
            Iterable<UniversalPage> iterable = this.universalPageDao.get(new DaoFilter().setAlgorithmIds(this.algorithmId));
            TIntHashSet tIntHashSet = new TIntHashSet();
            for (UniversalPage universalPage : iterable) {
                if (universalPage != null) {
                    tIntHashSet.add(universalPage.getUnivId());
                }
            }
            throw new InterruptedException();
        } catch (InterruptedException e) {
            throw new RuntimeException();
        }
    }

    protected void readCosimilarity(String str, PairwiseSimilarity pairwiseSimilarity) throws IOException {
        if (this.mostSimilarMatrices != null) {
            IOUtils.closeQuietly(this.mostSimilarMatrices);
            this.mostSimilarMatrices = null;
        }
        MostSimilarCache mostSimilarCache = new MostSimilarCache(this, pairwiseSimilarity, new File(str, getName()));
        if (mostSimilarCache.hasReadableMatrices()) {
            mostSimilarCache.read();
        }
        this.mostSimilarMatrices = mostSimilarCache;
    }
}
