package org.wikibrain.sr;

import com.typesafe.config.Config;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang.ArrayUtils;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.core.WikiBrainException;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.dao.DaoFilter;
import org.wikibrain.core.dao.LocalPageDao;
import org.wikibrain.core.lang.Language;
import org.wikibrain.core.lang.LocalId;
import org.wikibrain.core.lang.LocalString;
import org.wikibrain.core.model.LocalPage;
import org.wikibrain.core.model.NameSpace;
import org.wikibrain.matrix.SparseMatrix;
import org.wikibrain.matrix.SparseMatrixRow;
import org.wikibrain.matrix.SparseMatrixWriter;
import org.wikibrain.matrix.ValueConf;
import org.wikibrain.sr.dataset.Dataset;
import org.wikibrain.sr.disambig.Disambiguator;
import org.wikibrain.sr.disambig.SimilarityDisambiguator;
import org.wikibrain.sr.normalize.IdentityNormalizer;
import org.wikibrain.sr.normalize.Normalizer;
import org.wikibrain.sr.utils.Leaderboard;
import org.wikibrain.sr.utils.SrNormalizers;
import org.wikibrain.utils.ParallelForEach;
import org.wikibrain.utils.Procedure;
import org.wikibrain.utils.WpIOUtils;
import org.wikibrain.utils.WpThreadUtils;

/* loaded from: input_file:org/wikibrain/sr/BaseSRMetric.class */
public abstract class BaseSRMetric implements SRMetric {
    private static Logger LOG = Logger.getLogger(BaseSRMetric.class.getName());
    private final String name;
    private final Language language;
    private File dataDir;
    private Disambiguator disambiguator;
    private LocalPageDao localPageDao;
    private boolean shouldReadNormalizers = true;
    private boolean buildMostSimilarCache = false;
    private SparseMatrix mostSimilarCache = null;
    private TIntSet mostSimilarCacheRowIds = null;
    private int numSenses = 5;
    private SrNormalizers normalizers = new SrNormalizers();

    /* loaded from: input_file:org/wikibrain/sr/BaseSRMetric$SRConfig.class */
    public static class SRConfig {
        public float minScore = -1.1f;
        public float maxScore = 1.1f;
    }

    public BaseSRMetric(String str, Language language, LocalPageDao localPageDao, Disambiguator disambiguator) {
        this.name = str;
        this.language = language;
        this.disambiguator = disambiguator;
        this.localPageDao = localPageDao;
    }

    public abstract SRConfig getConfig();

    @Override // org.wikibrain.sr.SRMetric
    public File getDataDir() {
        return this.dataDir;
    }

    @Override // org.wikibrain.sr.SRMetric
    public String getName() {
        return this.name;
    }

    @Override // org.wikibrain.sr.SRMetric
    public void setDataDir(File file) {
        this.dataDir = file;
    }

    @Override // org.wikibrain.sr.SRMetric
    public void setMostSimilarNormalizer(Normalizer normalizer) {
        this.normalizers.setMostSimilarNormalizer(normalizer);
    }

    @Override // org.wikibrain.sr.SRMetric
    public void setSimilarityNormalizer(Normalizer normalizer) {
        this.normalizers.setSimilarityNormalizer(normalizer);
    }

    @Override // org.wikibrain.sr.SRMetric
    public boolean similarityIsTrained() {
        return this.normalizers.getSimilarityNormalizer().isTrained();
    }

    @Override // org.wikibrain.sr.SRMetric
    public boolean mostSimilarIsTrained() {
        return this.normalizers.getMostSimilarNormalizer().isTrained();
    }

    protected void ensureSimilarityTrained() {
        if (!similarityIsTrained()) {
            throw new IllegalStateException("Model similarity has not been trained.");
        }
    }

    protected void ensureMostSimilarTrained() {
        if (!mostSimilarIsTrained()) {
            throw new IllegalStateException("Model mostSimilar has not been trained.");
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public SRResult normalize(SRResult sRResult) {
        sRResult.score = normalize(sRResult.score);
        return sRResult;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public SRResultList normalize(SRResultList sRResultList) {
        ensureMostSimilarTrained();
        return this.normalizers.getMostSimilarNormalizer().normalize(sRResultList);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double normalize(double d) {
        ensureSimilarityTrained();
        return this.normalizers.getSimilarityNormalizer().normalize(d);
    }

    @Override // org.wikibrain.sr.SRMetric
    public void write() throws IOException {
        WpIOUtils.mkdirsQuietly(this.dataDir);
        this.normalizers.write(this.dataDir);
    }

    public void setReadNormalizers(boolean z) {
        this.shouldReadNormalizers = z;
    }

    @Override // org.wikibrain.sr.SRMetric
    public void read() throws IOException {
        if (!this.dataDir.isDirectory()) {
            LOG.warning("directory " + this.dataDir + " does not exist; cannot read files");
            return;
        }
        if (this.shouldReadNormalizers && this.normalizers.hasReadableNormalizers(this.dataDir)) {
            this.normalizers.read(this.dataDir);
        }
        IOUtils.closeQuietly(this.mostSimilarCache);
        if (getMostSimilarMatrixPath().isFile()) {
            this.mostSimilarCache = new SparseMatrix(getMostSimilarMatrixPath());
        }
    }

    @Override // org.wikibrain.sr.SRMetric
    public synchronized void trainSimilarity(Dataset dataset) throws DaoException {
        if (!dataset.getLanguage().equals(getLanguage())) {
            throw new IllegalArgumentException("SR metric has language " + getLanguage() + " but dataset has language " + dataset.getLanguage());
        }
        this.normalizers.trainSimilarity(this, dataset);
    }

    @Override // org.wikibrain.sr.SRMetric
    public synchronized void trainMostSimilar(Dataset dataset, int i, TIntSet tIntSet) {
        if (!dataset.getLanguage().equals(getLanguage())) {
            throw new IllegalArgumentException("SR metric has language " + getLanguage() + " but dataset has language " + dataset.getLanguage());
        }
        this.normalizers.trainMostSimilar(this, this.disambiguator, dataset, tIntSet, i);
        try {
            if (this.buildMostSimilarCache) {
                writeMostSimilarCache(i, this.mostSimilarCacheRowIds, tIntSet);
            }
        } catch (Exception e) {
            LOG.log(Level.SEVERE, "writing most similar cache failed:", (Throwable) e);
        }
    }

    @Override // org.wikibrain.sr.SRMetric
    public abstract SRResult similarity(int i, int i2, boolean z) throws DaoException;

    @Override // org.wikibrain.sr.SRMetric
    public SRResult similarity(String str, String str2, boolean z) throws DaoException {
        Language language = getLanguage();
        List<LocalId> disambiguateTop = this.disambiguator.disambiguateTop(Arrays.asList(new LocalString(language, str), new LocalString(language, str2)), (Set<LocalString>) null);
        return (disambiguateTop.get(0) == null || disambiguateTop.get(1) == null) ? new SRResult() : similarity(disambiguateTop.get(0).getId(), disambiguateTop.get(1).getId(), z);
    }

    private void debugSimilarityDisambiguator(List<LocalString> list) throws DaoException {
        List<LocalId> disambiguateTop;
        String str = null;
        boolean z = true;
        StringBuffer stringBuffer = new StringBuffer("results for " + list.get(0).getString() + ", " + list.get(1).getString() + "\n");
        for (SimilarityDisambiguator.Criteria criteria : SimilarityDisambiguator.Criteria.values()) {
            if (criteria != SimilarityDisambiguator.Criteria.SIMILARITY) {
                synchronized (this.disambiguator) {
                    ((SimilarityDisambiguator) this.disambiguator).setCriteria(criteria);
                    disambiguateTop = this.disambiguator.disambiguateTop(list, (Set<LocalString>) null);
                }
                String localPage = disambiguateTop.get(0) == null ? "null" : this.localPageDao.getById(this.language, disambiguateTop.get(0).getId()).toString();
                String localPage2 = disambiguateTop.get(1) == null ? "null" : this.localPageDao.getById(this.language, disambiguateTop.get(1).getId()).toString();
                stringBuffer.append("\t" + criteria + ": " + localPage + ",  " + localPage2 + "\n");
                if (str == null) {
                    str = localPage + localPage2;
                }
                if (!str.equals(localPage + localPage2)) {
                    z = false;
                }
            }
        }
        if (z) {
            return;
        }
        System.out.println(stringBuffer.toString());
    }

    @Override // org.wikibrain.sr.SRMetric
    public SRResultList mostSimilar(int i, int i2) throws DaoException {
        return mostSimilar(i, i2, (TIntSet) null);
    }

    @Override // org.wikibrain.sr.SRMetric
    public abstract SRResultList mostSimilar(int i, int i2, TIntSet tIntSet) throws DaoException;

    @Override // org.wikibrain.sr.SRMetric
    public SRResultList mostSimilar(String str, int i) throws DaoException {
        LocalId disambiguateTop = this.disambiguator.disambiguateTop(new LocalString(getLanguage(), str), (Set<LocalString>) null);
        if (disambiguateTop != null) {
            return mostSimilar(disambiguateTop.getId(), i);
        }
        SRResultList sRResultList = new SRResultList(1);
        sRResultList.set(0, new SRResult());
        return sRResultList;
    }

    @Override // org.wikibrain.sr.SRMetric
    public SRResultList mostSimilar(String str, int i, TIntSet tIntSet) throws DaoException {
        LocalId disambiguateTop = this.disambiguator.disambiguateTop(new LocalString(getLanguage(), str), (Set<LocalString>) null);
        if (disambiguateTop != null) {
            return mostSimilar(disambiguateTop.getId(), i, tIntSet);
        }
        SRResultList sRResultList = new SRResultList(1);
        sRResultList.set(0, new SRResult());
        return sRResultList;
    }

    @Override // org.wikibrain.sr.SRMetric
    public double[][] cosimilarity(int[] iArr, int[] iArr2) throws DaoException {
        double[][] dArr = new double[iArr.length][iArr2.length];
        for (int i = 0; i < iArr.length; i++) {
            for (int i2 = 0; i2 < iArr2.length; i2++) {
                if (iArr[i] == iArr2[i2]) {
                    dArr[i][i2] = normalize(1.0d);
                } else {
                    dArr[i][i2] = similarity(iArr[i], iArr2[i2], false).getScore();
                }
            }
        }
        return dArr;
    }

    @Override // org.wikibrain.sr.SRMetric
    public double[][] cosimilarity(String[] strArr, String[] strArr2) throws DaoException {
        double[][] dArr = new double[strArr.length][strArr2.length];
        for (int i = 0; i < strArr.length; i++) {
            for (int i2 = 0; i2 < strArr2.length; i2++) {
                if (strArr[i].equals(strArr2[i2])) {
                    dArr[i][i2] = normalize(1.0d);
                } else {
                    dArr[i][i2] = similarity(strArr[i], strArr2[i2], false).getScore();
                }
            }
        }
        return dArr;
    }

    @Override // org.wikibrain.sr.SRMetric
    public double[][] cosimilarity(int[] iArr) throws DaoException {
        double[][] dArr = new double[iArr.length][iArr.length];
        for (int i = 0; i < iArr.length; i++) {
            dArr[i][i] = normalize(1.0d);
        }
        for (int i2 = 0; i2 < iArr.length; i2++) {
            for (int i3 = i2 + 1; i3 < iArr.length; i3++) {
                dArr[i2][i3] = similarity(iArr[i2], iArr[i3], false).getScore();
                dArr[i3][i2] = dArr[i2][i3];
            }
        }
        return dArr;
    }

    @Override // org.wikibrain.sr.SRMetric
    public double[][] cosimilarity(String[] strArr) throws DaoException {
        int[] iArr = new int[strArr.length];
        ArrayList arrayList = new ArrayList();
        for (String str : strArr) {
            arrayList.add(new LocalString(getLanguage(), str));
        }
        List<LocalId> disambiguateTop = this.disambiguator.disambiguateTop(arrayList, (Set<LocalString>) null);
        for (int i = 0; i < strArr.length; i++) {
            iArr[i] = disambiguateTop.get(i).getId();
        }
        return cosimilarity(iArr);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public SRResultList getCachedMostSimilar(int i, int i2, TIntSet tIntSet) throws DaoException {
        if (this.mostSimilarCache == null) {
            return null;
        }
        try {
            SparseMatrixRow row = this.mostSimilarCache.getRow(i);
            if (row == null || row.getNumCols() < i2) {
                return null;
            }
            Leaderboard leaderboard = new Leaderboard(i2);
            for (int i3 = 0; i3 < row.getNumCols(); i3++) {
                int colIndex = row.getColIndex(i3);
                if (tIntSet == null || tIntSet.contains(colIndex)) {
                    leaderboard.tallyScore(colIndex, row.getColValue(i3));
                }
            }
            SRResultList top = leaderboard.getTop();
            if (top.numDocs() < i2) {
                return null;
            }
            top.sortDescending();
            return top;
        } catch (IOException e) {
            throw new DaoException(e);
        }
    }

    public void writeMostSimilarCache(int i) throws IOException, DaoException, WikiBrainException {
        writeMostSimilarCache(i, null, null);
    }

    public void writeMostSimilarCache(final int i, TIntSet tIntSet, TIntSet tIntSet2) throws IOException, DaoException, WikiBrainException {
        TIntSet tIntSet3 = null;
        if (tIntSet == null || tIntSet2 == null) {
            Iterable<LocalPage> iterable = this.localPageDao.get(new DaoFilter().setLanguages(getLanguage()).setNameSpaces(NameSpace.ARTICLE).setDisambig(false).setRedirect(false));
            tIntSet3 = new TIntHashSet();
            for (LocalPage localPage : iterable) {
                if (localPage != null) {
                    tIntSet3.add(localPage.getLocalId());
                }
            }
        }
        if (tIntSet == null) {
            tIntSet = tIntSet3;
        }
        if (tIntSet2 == null) {
            tIntSet2 = tIntSet3;
        }
        getDataDir().mkdirs();
        IOUtils.closeQuietly(this.mostSimilarCache);
        SRConfig config = getConfig();
        final AtomicInteger atomicInteger = new AtomicInteger();
        final AtomicLong atomicLong = new AtomicLong();
        final SparseMatrixWriter sparseMatrixWriter = new SparseMatrixWriter(getMostSimilarMatrixPath(), new ValueConf(config.minScore, config.maxScore));
        final TIntHashSet tIntHashSet = tIntSet2 == null ? null : new TIntHashSet(tIntSet2);
        Normalizer similarityNormalizer = getSimilarityNormalizer();
        Normalizer mostSimilarNormalizer = getMostSimilarNormalizer();
        setMostSimilarNormalizer(new IdentityNormalizer());
        setSimilarityNormalizer(new IdentityNormalizer());
        try {
            ParallelForEach.loop(Arrays.asList(ArrayUtils.toObject(tIntSet.toArray())), WpThreadUtils.getMaxThreads(), new Procedure<Integer>() { // from class: org.wikibrain.sr.BaseSRMetric.1
                public void call(Integer num) throws IOException, DaoException {
                    BaseSRMetric.this.writeSim(sparseMatrixWriter, num, tIntHashSet, i, atomicInteger, atomicLong);
                }
            }, Integer.MAX_VALUE);
            setSimilarityNormalizer(similarityNormalizer);
            setMostSimilarNormalizer(mostSimilarNormalizer);
            LOG.info("wrote " + atomicLong.get() + " non-zero similarity cells");
            sparseMatrixWriter.finish();
            this.mostSimilarCache = new SparseMatrix(getMostSimilarMatrixPath());
        } catch (Throwable th) {
            setSimilarityNormalizer(similarityNormalizer);
            setMostSimilarNormalizer(mostSimilarNormalizer);
            throw th;
        }
    }

    protected File getMostSimilarMatrixPath() {
        return new File(getDataDir(), "mostSimilar.matrix");
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void writeSim(SparseMatrixWriter sparseMatrixWriter, Integer num, TIntSet tIntSet, int i, AtomicInteger atomicInteger, AtomicLong atomicLong) throws IOException, DaoException {
        if (atomicInteger.incrementAndGet() % 10000 == 0) {
            LOG.info("finding matches for page " + atomicInteger.get());
        }
        SRResultList mostSimilar = mostSimilar(num.intValue(), i, tIntSet);
        if (mostSimilar != null) {
            int[] ids = mostSimilar.getIds();
            atomicLong.getAndIncrement();
            sparseMatrixWriter.writeRow(new SparseMatrixRow(sparseMatrixWriter.getValueConf(), num.intValue(), ids, mostSimilar.getScoresAsFloat()));
        }
    }

    @Override // org.wikibrain.sr.SRMetric
    public Language getLanguage() {
        return this.language;
    }

    public Disambiguator getDisambiguator() {
        return this.disambiguator;
    }

    public LocalPageDao getLocalPageDao() {
        return this.localPageDao;
    }

    @Override // org.wikibrain.sr.SRMetric
    public Normalizer getMostSimilarNormalizer() {
        return this.normalizers.getMostSimilarNormalizer();
    }

    @Override // org.wikibrain.sr.SRMetric
    public Normalizer getSimilarityNormalizer() {
        return this.normalizers.getSimilarityNormalizer();
    }

    public SparseMatrix getMostSimilarCache() {
        return this.mostSimilarCache;
    }

    public void clearMostSimilarCache() {
        IOUtils.closeQuietly(this.mostSimilarCache);
        FileUtils.deleteQuietly(getMostSimilarMatrixPath());
        this.mostSimilarCache = null;
    }

    public void setBuildMostSimilarCache(boolean z) {
        this.buildMostSimilarCache = z;
    }

    public void setMostSimilarCacheRowIds(TIntSet tIntSet) {
        this.mostSimilarCacheRowIds = tIntSet;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static void configureBase(Configurator configurator, BaseSRMetric baseSRMetric, Config config) throws ConfigurationException {
        Config config2 = configurator.getConf().get();
        baseSRMetric.setDataDir(FileUtils.getFile(new File(config2.getString("sr.metric.path")), new String[]{baseSRMetric.getName(), baseSRMetric.getLanguage().getLangCode()}));
        baseSRMetric.setSimilarityNormalizer((Normalizer) configurator.get(Normalizer.class, config.getString("similaritynormalizer")));
        baseSRMetric.setMostSimilarNormalizer((Normalizer) configurator.get(Normalizer.class, config.getString("mostsimilarnormalizer")));
        if (config2.getBoolean("sr.metric.training")) {
            baseSRMetric.setReadNormalizers(false);
        }
        if (config.hasPath("buildMostSimilarCache")) {
            baseSRMetric.setBuildMostSimilarCache(config.getBoolean("buildMostSimilarCache"));
        }
        try {
            baseSRMetric.read();
            LOG.info("finished base configuration of metric " + baseSRMetric.getName());
        } catch (IOException e) {
            throw new ConfigurationException(e);
        }
    }
}
