package org.wikibrain.sr.word2vec;

import com.typesafe.config.Config;
import gnu.trove.list.array.TCharArrayList;
import gnu.trove.map.TIntFloatMap;
import gnu.trove.map.TIntObjectMap;
import gnu.trove.map.hash.TIntFloatHashMap;
import gnu.trove.map.hash.TIntObjectHashMap;
import java.io.DataInputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.dao.LocalPageDao;
import org.wikibrain.core.lang.Language;
import org.wikibrain.core.model.LocalPage;
import org.wikibrain.sr.Explanation;
import org.wikibrain.sr.SRResult;
import org.wikibrain.sr.vector.VectorGenerator;
import org.wikibrain.utils.ObjectDb;
import org.wikibrain.utils.WpIOUtils;

/* loaded from: input_file:org/wikibrain/sr/word2vec/Word2VecGenerator.class */
public class Word2VecGenerator implements VectorGenerator {
    private static final Logger LOG = Logger.getLogger(Word2VecGenerator.class.getName());
    private final Language language;
    private final LocalPageDao localPageDao;
    private ObjectDb<float[]> phraseDb;
    private TIntObjectMap<float[]> articles;

    /* loaded from: input_file:org/wikibrain/sr/word2vec/Word2VecGenerator$Provider.class */
    public static class Provider extends org.wikibrain.conf.Provider<VectorGenerator> {
        public Provider(Configurator configurator, Configuration configuration) throws ConfigurationException {
            super(configurator, configuration);
        }

        public Class<VectorGenerator> getType() {
            return VectorGenerator.class;
        }

        public String getPath() {
            return "sr.metric.generator";
        }

        public VectorGenerator get(String str, Config config, Map<String, String> map) throws ConfigurationException {
            if (!config.getString("type").equals("word2vec")) {
                return null;
            }
            if (!map.containsKey("language")) {
                throw new IllegalArgumentException("Monolingual SR Metric requires 'language' runtime parameter");
            }
            Language byLangCode = Language.getByLangCode(map.get("language"));
            File modelFile = Word2VecGenerator.getModelFile(config.getString("modelDir"), byLangCode);
            if (!modelFile.isFile()) {
                throw new ConfigurationException("Path to word2vec model " + modelFile.getAbsolutePath() + " is not a file. Do you need to download or build the model?");
            }
            try {
                return new Word2VecGenerator(byLangCode, (LocalPageDao) getConfigurator().get(LocalPageDao.class), modelFile);
            } catch (IOException e) {
                throw new ConfigurationException(e);
            }
        }

        /* renamed from: get, reason: collision with other method in class */
        public /* bridge */ /* synthetic */ Object m63get(String str, Config config, Map map) throws ConfigurationException {
            return get(str, config, (Map<String, String>) map);
        }
    }

    public Word2VecGenerator(Language language, LocalPageDao localPageDao, File file) throws IOException {
        this.language = language;
        this.localPageDao = localPageDao;
        read(file);
    }

    public void read(File file) throws IOException {
        File file2 = new File(file.getAbsolutePath() + ".phrases");
        File file3 = new File(file.getAbsolutePath() + ".articles");
        if (!file2.exists() || !file3.exists() || file2.lastModified() < file.lastModified() || file3.lastModified() < file.lastModified()) {
            createWikiBrainModel(file, file2, file3);
            return;
        }
        LOG.info("phrase and article caches are up to date, loading them...");
        this.phraseDb = new ObjectDb<>(file2);
        this.articles = (TIntObjectMap) WpIOUtils.readObjectFromFile(file3);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v40, types: [float[], java.lang.Object, java.io.Serializable] */
    private void createWikiBrainModel(File file, File file2, File file3) throws IOException {
        FileUtils.deleteQuietly(file2);
        FileUtils.deleteQuietly(file3);
        this.phraseDb = new ObjectDb<>(file2, true);
        this.articles = new TIntObjectHashMap();
        DataInputStream dataInputStream = null;
        InputStream inputStream = null;
        try {
            inputStream = WpIOUtils.openInputStream(file);
            dataInputStream = new DataInputStream(inputStream);
            String str = "";
            while (true) {
                char read = (char) dataInputStream.read();
                if (read == '\n') {
                    break;
                } else {
                    str = str + read;
                }
            }
            String[] split = str.split(" ");
            int parseInt = Integer.parseInt(split[0]);
            int parseInt2 = Integer.parseInt(split[1]);
            LOG.info("preparing to read " + parseInt + " with length " + parseInt2 + " vectors");
            for (int i = 0; i < parseInt; i++) {
                String readString = readString(dataInputStream);
                if (i % 50000 == 0) {
                    LOG.info("Read word vector " + readString + " (" + i + " of " + parseInt + ")");
                }
                ?? r0 = new float[parseInt2];
                double d = 0.0d;
                for (int i2 = 0; i2 < parseInt2; i2++) {
                    d += r0 * r0;
                    r0[i2] = readFloat(dataInputStream);
                }
                double sqrt = Math.sqrt(d);
                for (int i3 = 0; i3 < parseInt2; i3++) {
                    r0[i3] = (float) (r0[r1] / sqrt);
                }
                if (readString.startsWith("/w/")) {
                    this.articles.put(Integer.valueOf(readString.split("/", 5)[3]).intValue(), (Object) r0);
                } else {
                    this.phraseDb.put(normalize(readString), (Serializable) r0);
                }
            }
            IOUtils.closeQuietly(inputStream);
            IOUtils.closeQuietly(dataInputStream);
            this.phraseDb.flush();
            WpIOUtils.writeObjectToFile(file3, this.articles);
        } catch (Throwable th) {
            IOUtils.closeQuietly(inputStream);
            IOUtils.closeQuietly(dataInputStream);
            throw th;
        }
    }

    private static String readString(DataInputStream dataInputStream) throws IOException {
        char c;
        TCharArrayList tCharArrayList = new TCharArrayList();
        while (true) {
            int read = dataInputStream.read();
            if (read >= 0 && (c = (char) read) != ' ') {
                if (c != '\n') {
                    tCharArrayList.add(c);
                }
            }
        }
        return new String(tCharArrayList.toArray());
    }

    private static float readFloat(InputStream inputStream) throws IOException {
        byte[] bArr = new byte[4];
        inputStream.read(bArr);
        return getFloat(bArr);
    }

    private static float getFloat(byte[] bArr) {
        return Float.intBitsToFloat(0 | ((bArr[0] & 255) << 0) | ((bArr[1] & 255) << 8) | ((bArr[2] & 255) << 16) | ((bArr[3] & 255) << 24));
    }

    @Override // org.wikibrain.sr.vector.VectorGenerator
    public TIntFloatMap getVector(int i) throws DaoException {
        float[] fArr = (float[]) this.articles.get(i);
        if (fArr == null) {
            return null;
        }
        TIntFloatHashMap tIntFloatHashMap = new TIntFloatHashMap(fArr.length);
        for (int i2 = 0; i2 < fArr.length; i2++) {
            tIntFloatHashMap.put(i2, fArr[i2]);
        }
        return tIntFloatHashMap;
    }

    @Override // org.wikibrain.sr.vector.VectorGenerator
    public TIntFloatMap getVector(String str) {
        try {
            float[] fArr = (float[]) this.phraseDb.get(normalize(str));
            if (fArr == null) {
                return null;
            }
            TIntFloatHashMap tIntFloatHashMap = new TIntFloatHashMap(fArr.length);
            for (int i = 0; i < fArr.length; i++) {
                tIntFloatHashMap.put(i, fArr[i]);
            }
            return tIntFloatHashMap;
        } catch (IOException e) {
            throw new RuntimeException(e);
        } catch (ClassNotFoundException e2) {
            throw new RuntimeException(e2);
        }
    }

    @Override // org.wikibrain.sr.vector.VectorGenerator
    public List<Explanation> getExplanations(LocalPage localPage, LocalPage localPage2, TIntFloatMap tIntFloatMap, TIntFloatMap tIntFloatMap2, SRResult sRResult) throws DaoException {
        return null;
    }

    private static String normalize(String str) {
        return str.replace('_', ' ').trim();
    }

    public static File getModelFile(String str, Language language) {
        return getModelFile(new File(str), language);
    }

    public static File getModelFile(File file, Language language) {
        return new File(file, language.getLangCode() + ".bin");
    }
}
