package org.wikibrain.sr.word2vec;

import com.typesafe.config.Config;
import gnu.trove.map.TLongIntMap;
import gnu.trove.map.hash.TLongIntHashMap;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.DataInputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.dao.LocalPageDao;
import org.wikibrain.core.lang.Language;
import org.wikibrain.matrix.DenseMatrix;
import org.wikibrain.matrix.DenseMatrixRow;
import org.wikibrain.matrix.DenseMatrixWriter;
import org.wikibrain.matrix.ValueConf;
import org.wikibrain.sr.Explanation;
import org.wikibrain.sr.SRResult;
import org.wikibrain.sr.vector.DenseVectorGenerator;
import org.wikibrain.utils.WpIOUtils;

/* loaded from: input_file:org/wikibrain/sr/word2vec/Word2VecGenerator.class */
public class Word2VecGenerator implements DenseVectorGenerator {
    private static final Logger LOG = LoggerFactory.getLogger(Word2VecGenerator.class);
    private final Language language;
    private final LocalPageDao localPageDao;
    private final File path;
    private TLongIntMap phraseIds;
    private DenseMatrix phraseMatrix;
    private DenseMatrix articleMatrix;

    /* loaded from: input_file:org/wikibrain/sr/word2vec/Word2VecGenerator$Provider.class */
    public static class Provider extends org.wikibrain.conf.Provider<DenseVectorGenerator> {
        public Provider(Configurator configurator, Configuration configuration) throws ConfigurationException {
            super(configurator, configuration);
        }

        public Class<DenseVectorGenerator> getType() {
            return DenseVectorGenerator.class;
        }

        public String getPath() {
            return "sr.metric.densegenerator";
        }

        public DenseVectorGenerator get(String str, Config config, Map<String, String> map) throws ConfigurationException {
            if (!config.getString("type").equals("word2vec")) {
                return null;
            }
            if (!map.containsKey("language")) {
                throw new IllegalArgumentException("Monolingual SR Metric requires 'language' runtime parameter");
            }
            Language byLangCode = Language.getByLangCode(map.get("language"));
            File modelFile = Word2VecGenerator.getModelFile(config.getString("modelDir"), byLangCode);
            if (!modelFile.isFile()) {
                throw new ConfigurationException("Path to word2vec model " + modelFile.getAbsolutePath() + " is not a file. Do you need to download or build the model?");
            }
            try {
                return new Word2VecGenerator(byLangCode, (LocalPageDao) getConfigurator().get(LocalPageDao.class), modelFile);
            } catch (IOException e) {
                throw new ConfigurationException(e);
            }
        }

        /* renamed from: get, reason: collision with other method in class */
        public /* bridge */ /* synthetic */ Object m83get(String str, Config config, Map map) throws ConfigurationException {
            return get(str, config, (Map<String, String>) map);
        }
    }

    public Word2VecGenerator(Language language, LocalPageDao localPageDao, File file) throws IOException {
        this.language = language;
        this.localPageDao = localPageDao;
        this.path = file;
        read();
    }

    public void read() throws IOException {
        if (!getArticleMatrixPath().exists() || !getPhraseMatrixPath().exists() || !getPhraseIdPath().exists() || getPhraseMatrixPath().lastModified() < this.path.lastModified() || getArticleMatrixPath().lastModified() < this.path.lastModified()) {
            createWikiBrainModel();
            return;
        }
        LOG.info("phrase and article caches are up to date, loading them...");
        this.phraseMatrix = new DenseMatrix(getPhraseMatrixPath());
        this.articleMatrix = new DenseMatrix(getArticleMatrixPath());
        readPhraseIds();
    }

    private void readPhraseIds() throws IOException {
        BufferedReader openBufferedReader = WpIOUtils.openBufferedReader(getPhraseIdPath());
        try {
            this.phraseIds = new TLongIntHashMap();
            while (true) {
                String readLine = openBufferedReader.readLine();
                if (readLine == null) {
                    return;
                }
                String[] split = readLine.split("\t", 2);
                this.phraseIds.put(hashWord(split[1].trim()), Integer.parseInt(split[0]));
            }
        } finally {
            IOUtils.closeQuietly(openBufferedReader);
        }
    }

    private void createWikiBrainModel() throws IOException {
        FileUtils.deleteQuietly(getPhraseIdPath());
        FileUtils.deleteQuietly(getPhraseMatrixPath());
        FileUtils.deleteQuietly(getArticleMatrixPath());
        ValueConf valueConf = new ValueConf();
        BufferedWriter openWriter = WpIOUtils.openWriter(getPhraseIdPath());
        DenseMatrixWriter denseMatrixWriter = new DenseMatrixWriter(getPhraseMatrixPath(), valueConf);
        DenseMatrixWriter denseMatrixWriter2 = new DenseMatrixWriter(getArticleMatrixPath(), valueConf);
        DataInputStream dataInputStream = null;
        InputStream inputStream = null;
        try {
            inputStream = WpIOUtils.openInputStream(this.path);
            dataInputStream = new DataInputStream(inputStream);
            String str = "";
            while (true) {
                char read = (char) dataInputStream.read();
                if (read == '\n') {
                    break;
                } else {
                    str = str + read;
                }
            }
            String[] split = str.split(" ");
            int parseInt = Integer.parseInt(split[0]);
            int parseInt2 = Integer.parseInt(split[1]);
            LOG.info("preparing to read " + parseInt + " with length " + parseInt2 + " vectors");
            int[] iArr = new int[parseInt2];
            for (int i = 0; i < parseInt2; i++) {
                iArr[i] = i;
            }
            int i2 = 0;
            int i3 = 0;
            for (int i4 = 0; i4 < parseInt; i4++) {
                String readString = readString(dataInputStream);
                if (i4 % 5000 == 0) {
                    LOG.info("Read word vector " + readString + " (" + i4 + " of " + parseInt + ")");
                }
                float[] fArr = new float[parseInt2];
                double d = 0.0d;
                for (int i5 = 0; i5 < parseInt2; i5++) {
                    d += r0 * r0;
                    fArr[i5] = readFloat(dataInputStream);
                }
                double sqrt = Math.sqrt(d);
                for (int i6 = 0; i6 < parseInt2; i6++) {
                    fArr[i6] = (float) (fArr[r1] / sqrt);
                }
                if (readString.startsWith("/w/")) {
                    int intValue = Integer.valueOf(readString.split("/", 5)[3]).intValue();
                    if (intValue >= 0) {
                        denseMatrixWriter2.writeRow(new DenseMatrixRow(valueConf, intValue, iArr, fArr));
                        i3++;
                    }
                } else {
                    String replace = readString.replace('\t', ' ').replace('\n', ' ');
                    denseMatrixWriter.writeRow(new DenseMatrixRow(valueConf, i2, iArr, fArr));
                    openWriter.write(i2 + "\t" + replace + "\n");
                    i2++;
                }
            }
            if (i2 == 0) {
                denseMatrixWriter.writeRow(new DenseMatrixRow(valueConf, 0, iArr, new float[parseInt2]));
            }
            if (i3 == 0) {
                denseMatrixWriter2.writeRow(new DenseMatrixRow(valueConf, 0, iArr, new float[parseInt2]));
            }
            IOUtils.closeQuietly(inputStream);
            IOUtils.closeQuietly(dataInputStream);
            IOUtils.closeQuietly(openWriter);
            denseMatrixWriter.finish();
            denseMatrixWriter2.finish();
            this.phraseMatrix = new DenseMatrix(getPhraseMatrixPath());
            this.articleMatrix = new DenseMatrix(getArticleMatrixPath());
            readPhraseIds();
        } catch (Throwable th) {
            IOUtils.closeQuietly(inputStream);
            IOUtils.closeQuietly(dataInputStream);
            throw th;
        }
    }

    private File getPhraseMatrixPath() {
        return new File(this.path.getAbsolutePath() + ".phrases.matrix");
    }

    private File getArticleMatrixPath() {
        return new File(this.path.getAbsolutePath() + ".articles.matrix");
    }

    private File getPhraseIdPath() {
        return new File(this.path.getAbsolutePath() + ".phrases.txt");
    }

    /* JADX WARN: Code restructure failed: missing block: B:19:0x0055, code lost:
    
        return new java.lang.String(r0.toArray(), "UTF-8");
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    private static java.lang.String readString(java.io.DataInputStream r5) throws java.io.IOException {
        /*
            gnu.trove.list.array.TByteArrayList r0 = new gnu.trove.list.array.TByteArrayList
            r1 = r0
            r1.<init>()
            r6 = r0
        L8:
            r0 = r5
            int r0 = r0.read()
            r7 = r0
            r0 = r7
            r1 = -1
            if (r0 != r1) goto L15
            goto L46
        L15:
            r0 = r7
            if (r0 < 0) goto L20
            r0 = r7
            r1 = 255(0xff, float:3.57E-43)
            if (r0 <= r1) goto L28
        L20:
            java.lang.IllegalStateException r0 = new java.lang.IllegalStateException
            r1 = r0
            r1.<init>()
            throw r0
        L28:
            r0 = r7
            char r0 = (char) r0
            r8 = r0
            r0 = r8
            r1 = 32
            if (r0 != r1) goto L34
            goto L46
        L34:
            r0 = r8
            r1 = 10
            if (r0 == r1) goto L43
            r0 = r6
            r1 = r7
            byte r1 = (byte) r1
            boolean r0 = r0.add(r1)
        L43:
            goto L8
        L46:
            java.lang.String r0 = new java.lang.String
            r1 = r0
            r2 = r6
            byte[] r2 = r2.toArray()
            java.lang.String r3 = "UTF-8"
            r1.<init>(r2, r3)
            return r0
        */
        throw new UnsupportedOperationException("Method not decompiled: org.wikibrain.sr.word2vec.Word2VecGenerator.readString(java.io.DataInputStream):java.lang.String");
    }

    private static float readFloat(InputStream inputStream) throws IOException {
        byte[] bArr = new byte[4];
        inputStream.read(bArr);
        return getFloat(bArr);
    }

    private static float getFloat(byte[] bArr) {
        return Float.intBitsToFloat(0 | ((bArr[0] & 255) << 0) | ((bArr[1] & 255) << 8) | ((bArr[2] & 255) << 16) | ((bArr[3] & 255) << 24));
    }

    @Override // org.wikibrain.sr.vector.DenseVectorGenerator
    public DenseMatrix getFeatureMatrix() {
        return this.articleMatrix;
    }

    @Override // org.wikibrain.sr.vector.DenseVectorGenerator
    public float[] getVector(int i) throws DaoException {
        try {
            DenseMatrixRow row = this.articleMatrix.getRow(i);
            if (row == null) {
                return null;
            }
            return row.getValues();
        } catch (IOException e) {
            throw new DaoException(e);
        }
    }

    @Override // org.wikibrain.sr.vector.DenseVectorGenerator
    public float[] getVector(String str) {
        try {
            long hashWord = hashWord(str);
            if (!this.phraseIds.containsKey(hashWord)) {
                return null;
            }
            DenseMatrixRow row = this.phraseMatrix.getRow(this.phraseIds.get(hashWord));
            if (row == null) {
                return null;
            }
            return row.getValues();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private static long hashWord(String str) {
        return Word2VecUtils.hashWord(normalize(str));
    }

    @Override // org.wikibrain.sr.vector.DenseVectorGenerator
    public List<Explanation> getExplanations(String str, String str2, float[] fArr, float[] fArr2, SRResult sRResult) throws DaoException {
        throw new UnsupportedOperationException();
    }

    @Override // org.wikibrain.sr.vector.DenseVectorGenerator
    public List<Explanation> getExplanations(int i, int i2, float[] fArr, float[] fArr2, SRResult sRResult) throws DaoException {
        return null;
    }

    private static String normalize(String str) {
        return str.replace('_', ' ').trim();
    }

    public static File getModelFile(String str, Language language) {
        return getModelFile(new File(str), language);
    }

    public static File getModelFile(File file, Language language) {
        return new File(file, language.getLangCode() + ".bin");
    }

    public static void main(String[] strArr) throws IOException {
        new Word2VecGenerator(null, null, new File("/Users/a558989/Projects/wikibrain/base-bh/dat/word2vecRaw/bh.bin"));
    }
}
