package org.wikibrain.sr.pairwise;

import gnu.trove.map.TIntDoubleMap;
import gnu.trove.map.TIntFloatMap;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import java.io.Closeable;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.logging.Logger;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.wikibrain.core.WikiBrainException;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.matrix.MatrixRow;
import org.wikibrain.matrix.SparseMatrix;
import org.wikibrain.matrix.SparseMatrixRow;
import org.wikibrain.matrix.SparseMatrixTransposer;
import org.wikibrain.matrix.SparseMatrixWriter;
import org.wikibrain.matrix.ValueConf;
import org.wikibrain.sr.MonolingualSRMetric;
import org.wikibrain.sr.SRResultList;
import org.wikibrain.sr.UniversalSRMetric;
import org.wikibrain.sr.normalize.IdentityNormalizer;
import org.wikibrain.sr.normalize.Normalizer;
import org.wikibrain.sr.utils.Leaderboard;
import org.wikibrain.utils.ParallelForEach;
import org.wikibrain.utils.Procedure;

/* loaded from: input_file:org/wikibrain/sr/pairwise/MostSimilarCache.class */
public class MostSimilarCache implements Closeable {
    private static final Logger LOG = Logger.getLogger(MostSimilarCache.class.getName());
    public static final String COSIMILARITY_MATRIX = "cosimilarityMatrix";
    public static final String FEATURE_TRANSPOSE_MATRIX = "featureTransposeMatrix";
    public static final String FEATURE_MATRIX = "featureMatrix";
    private final PairwiseSimilarity similarity;
    private final File dir;
    private MonolingualSRMetric monoSr;
    private UniversalSRMetric universalSr;
    private SparseMatrix featureMatrix;
    private SparseMatrix featureTransposeMatrix;
    private SparseMatrix cosimilarityMatrix;

    public MostSimilarCache(MonolingualSRMetric monolingualSRMetric, File file) {
        this(monolingualSRMetric, (PairwiseSimilarity) null, file);
    }

    public MostSimilarCache(MonolingualSRMetric monolingualSRMetric, PairwiseSimilarity pairwiseSimilarity, File file) {
        this.featureMatrix = null;
        this.featureTransposeMatrix = null;
        this.cosimilarityMatrix = null;
        this.monoSr = monolingualSRMetric;
        this.similarity = pairwiseSimilarity;
        this.dir = file;
        if (this.dir.isDirectory()) {
            return;
        }
        FileUtils.deleteQuietly(file);
        file.mkdirs();
    }

    public MostSimilarCache(UniversalSRMetric universalSRMetric, PairwiseSimilarity pairwiseSimilarity, File file) {
        this.featureMatrix = null;
        this.featureTransposeMatrix = null;
        this.cosimilarityMatrix = null;
        throw new UnsupportedOperationException();
    }

    public boolean hasCachedMostSimilarVectors() {
        return this.featureMatrix != null && this.featureMatrix.getNumRows() > 0 && this.featureTransposeMatrix != null && this.featureTransposeMatrix.getNumRows() > 0;
    }

    public void clear() {
        close();
        FileUtils.deleteQuietly(getFeatureMatrixPath());
        FileUtils.deleteQuietly(getFeatureTransposeMatrixPath());
        FileUtils.deleteQuietly(getCosimilarityMatrixPath());
    }

    public void read() throws IOException {
        if (hasAllReadableMatrices()) {
            this.featureMatrix = readMatrix(FEATURE_MATRIX);
            this.featureTransposeMatrix = readMatrix(FEATURE_TRANSPOSE_MATRIX);
            this.cosimilarityMatrix = readMatrix(COSIMILARITY_MATRIX);
        } else if (hasJustReadableCosimilarity()) {
            this.featureMatrix = null;
            this.featureTransposeMatrix = null;
            this.cosimilarityMatrix = readMatrix(COSIMILARITY_MATRIX);
        } else {
            if (!hasJustReadableFeatureAndTranspose()) {
                throw new IOException("No readable matrices");
            }
            this.featureMatrix = readMatrix(FEATURE_MATRIX);
            this.featureTransposeMatrix = readMatrix(FEATURE_TRANSPOSE_MATRIX);
            this.cosimilarityMatrix = null;
        }
    }

    public boolean hasReadableMatrices() {
        return hasAllReadableMatrices() || hasJustReadableCosimilarity() || hasJustReadableFeatureAndTranspose();
    }

    protected File getChildFile(String str) {
        return new File(this.dir, str);
    }

    protected boolean hasChildFile(String str) {
        return getChildFile(str).isFile();
    }

    protected boolean hasAllReadableMatrices() {
        return hasChildFile(FEATURE_MATRIX) && hasChildFile(FEATURE_TRANSPOSE_MATRIX) && hasChildFile(COSIMILARITY_MATRIX);
    }

    protected boolean hasJustReadableFeatureAndTranspose() {
        return hasChildFile(FEATURE_MATRIX) && hasChildFile(FEATURE_TRANSPOSE_MATRIX) && !hasChildFile(COSIMILARITY_MATRIX);
    }

    protected boolean hasJustReadableCosimilarity() {
        return (hasChildFile(FEATURE_MATRIX) || hasChildFile(FEATURE_TRANSPOSE_MATRIX) || !hasChildFile(COSIMILARITY_MATRIX)) ? false : true;
    }

    protected File getFeatureMatrixPath() {
        return getChildFile(FEATURE_MATRIX);
    }

    protected File getFeatureTransposeMatrixPath() {
        return getChildFile(FEATURE_TRANSPOSE_MATRIX);
    }

    protected File getCosimilarityMatrixPath() {
        return getChildFile(COSIMILARITY_MATRIX);
    }

    public SparseMatrix getFeatureMatrix() {
        return this.featureMatrix;
    }

    public SparseMatrix getFeatureTransposeMatrix() {
        return this.featureTransposeMatrix;
    }

    public SparseMatrix getCosimilarityMatrix() {
        return this.cosimilarityMatrix;
    }

    public SRResultList mostSimilar(int i, int i2, TIntSet tIntSet) throws IOException, DaoException {
        SparseMatrixRow row;
        SRResultList rowToResultList;
        System.currentTimeMillis();
        if (this.cosimilarityMatrix != null && (row = this.cosimilarityMatrix.getRow(i)) != null && row.getNumCols() >= i2 && (rowToResultList = rowToResultList(row, i2, tIntSet)) != null && rowToResultList.numDocs() >= i2) {
            return normalize(rowToResultList, i2);
        }
        if (this.similarity == null || this.featureMatrix == null || this.featureTransposeMatrix == null) {
            return null;
        }
        return normalize(this.similarity.mostSimilar(this, i, i2, tIntSet), i2);
    }

    public SRResultList mostSimilar(TIntFloatMap tIntFloatMap, int i, TIntSet tIntSet) throws IOException, DaoException {
        if (this.similarity == null || this.featureMatrix == null || this.featureTransposeMatrix == null) {
            return null;
        }
        return normalize(this.similarity.mostSimilar(this, tIntFloatMap, i, tIntSet), i);
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() {
        IOUtils.closeQuietly(this.featureMatrix);
        IOUtils.closeQuietly(this.featureTransposeMatrix);
        IOUtils.closeQuietly(this.cosimilarityMatrix);
        this.featureMatrix = null;
        this.featureTransposeMatrix = null;
        this.cosimilarityMatrix = null;
    }

    public void writeFeatureAndTransposeMatrix(int[] iArr, int i) throws WikiBrainException, InterruptedException, IOException {
        ensureDataDirectoryExists();
        final SparseMatrixWriter sparseMatrixWriter = new SparseMatrixWriter(getFeatureMatrixPath(), new ValueConf());
        ParallelForEach.loop(intArrayToList(iArr), i, new Procedure<Integer>() { // from class: org.wikibrain.sr.pairwise.MostSimilarCache.1
            public void call(Integer num) throws IOException, DaoException, WikiBrainException {
                MostSimilarCache.this.writeFeatureVector(sparseMatrixWriter, num);
            }
        }, 10000);
        sparseMatrixWriter.finish();
        IOUtils.closeQuietly(this.featureMatrix);
        this.featureMatrix = readMatrix(FEATURE_MATRIX);
        new SparseMatrixTransposer(this.featureMatrix, getFeatureTransposeMatrixPath()).transpose();
        IOUtils.closeQuietly(this.featureTransposeMatrix);
        this.featureTransposeMatrix = readMatrix(FEATURE_TRANSPOSE_MATRIX);
    }

    public void writeCosimilarity(int[] iArr, int[] iArr2, final int i, int i2) throws IOException, InterruptedException {
        ensureDataDirectoryExists();
        IOUtils.closeQuietly(this.cosimilarityMatrix);
        this.cosimilarityMatrix = null;
        final AtomicInteger atomicInteger = new AtomicInteger();
        final AtomicLong atomicLong = new AtomicLong();
        final SparseMatrixWriter sparseMatrixWriter = new SparseMatrixWriter(getCosimilarityMatrixPath(), this.similarity == null ? new ValueConf() : new ValueConf((float) this.similarity.getMinValue(), (float) this.similarity.getMaxValue()));
        final TIntHashSet tIntHashSet = iArr2 == null ? null : new TIntHashSet(iArr2);
        Normalizer similarityNormalizer = this.monoSr.getSimilarityNormalizer();
        Normalizer mostSimilarNormalizer = this.monoSr.getMostSimilarNormalizer();
        this.monoSr.setMostSimilarNormalizer(new IdentityNormalizer());
        this.monoSr.setSimilarityNormalizer(new IdentityNormalizer());
        try {
            ParallelForEach.loop(intArrayToList(iArr), i2, new Procedure<Integer>() { // from class: org.wikibrain.sr.pairwise.MostSimilarCache.2
                public void call(Integer num) throws IOException, DaoException {
                    MostSimilarCache.this.writeSim(sparseMatrixWriter, num, tIntHashSet, i, atomicInteger, atomicLong);
                }
            }, Integer.MAX_VALUE);
            this.monoSr.setSimilarityNormalizer(similarityNormalizer);
            this.monoSr.setMostSimilarNormalizer(mostSimilarNormalizer);
            LOG.info("wrote " + atomicLong.get() + " non-zero similarity cells");
            sparseMatrixWriter.finish();
            this.cosimilarityMatrix = readMatrix(COSIMILARITY_MATRIX);
        } catch (Throwable th) {
            this.monoSr.setSimilarityNormalizer(similarityNormalizer);
            this.monoSr.setMostSimilarNormalizer(mostSimilarNormalizer);
            throw th;
        }
    }

    private SparseMatrix readMatrix(String str) throws IOException {
        return new SparseMatrix(getChildFile(str));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void writeSim(SparseMatrixWriter sparseMatrixWriter, Integer num, TIntSet tIntSet, int i, AtomicInteger atomicInteger, AtomicLong atomicLong) throws IOException, DaoException {
        if (atomicInteger.incrementAndGet() % 10000 == 0) {
            LOG.info("finding matches for page " + atomicInteger.get());
        }
        SRResultList mostSimilar = (this.similarity == null || this.featureMatrix == null || this.featureTransposeMatrix == null) ? this.monoSr.mostSimilar(num.intValue(), i, tIntSet) : this.similarity.mostSimilar(this, num.intValue(), i, tIntSet);
        if (mostSimilar != null) {
            int[] ids = mostSimilar.getIds();
            atomicLong.getAndIncrement();
            sparseMatrixWriter.writeRow(new SparseMatrixRow(sparseMatrixWriter.getValueConf(), num.intValue(), ids, mostSimilar.getScoresAsFloat()));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void writeFeatureVector(SparseMatrixWriter sparseMatrixWriter, Integer num) throws WikiBrainException {
        try {
            if (this.monoSr != null) {
                throw new UnsupportedOperationException();
            }
            if (this.universalSr == null) {
                throw new IllegalStateException("SRFeatureMatrixWriter does not have a local or universal metric defined.");
            }
            TIntDoubleMap vector = this.universalSr.getVector(num.intValue());
            if (vector == null || vector.isEmpty()) {
                return;
            }
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            for (int i : vector.keys()) {
                linkedHashMap.put(Integer.valueOf(i), Float.valueOf((float) vector.get(i)));
            }
            try {
                sparseMatrixWriter.writeRow(new SparseMatrixRow(sparseMatrixWriter.getValueConf(), num.intValue(), linkedHashMap));
            } catch (IOException e) {
                throw new WikiBrainException(e);
            }
        } catch (DaoException e2) {
            throw new WikiBrainException(e2);
        }
    }

    private SRResultList rowToResultList(MatrixRow matrixRow, int i, TIntSet tIntSet) {
        Leaderboard leaderboard = new Leaderboard(i);
        for (int i2 = 0; i2 < matrixRow.getNumCols(); i2++) {
            int colIndex = matrixRow.getColIndex(i2);
            if (tIntSet == null || tIntSet.contains(colIndex)) {
                leaderboard.tallyScore(colIndex, matrixRow.getColValue(i2));
            }
        }
        SRResultList top = leaderboard.getTop();
        top.sortDescending();
        return top;
    }

    private List<Integer> intArrayToList(int[] iArr) {
        ArrayList arrayList = new ArrayList();
        for (int i : iArr) {
            arrayList.add(Integer.valueOf(i));
        }
        return arrayList;
    }

    private void ensureDataDirectoryExists() {
        if (this.dir.isDirectory()) {
            return;
        }
        this.dir.mkdirs();
    }

    private SRResultList normalize(SRResultList sRResultList, int i) {
        if (sRResultList == null) {
            return null;
        }
        sRResultList.sortDescending();
        SRResultList normalize = this.monoSr.getMostSimilarNormalizer().normalize(sRResultList);
        if (normalize.numDocs() > i) {
            normalize.truncate(i);
        }
        return normalize;
    }

    public boolean hasFeatureAndTransposeMatrices() {
        return this.featureMatrix != null && this.featureMatrix.getNumRows() > 0 && this.featureTransposeMatrix != null && this.featureTransposeMatrix.getNumRows() > 0;
    }

    public boolean hasCosimilarityMatrix() {
        return this.cosimilarityMatrix != null && this.cosimilarityMatrix.getNumRows() > 0;
    }
}
