package org.wikibrain.sr;

import com.typesafe.config.Config;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.logging.Logger;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.commons.cli.PosixParser;
import org.apache.commons.io.FileUtils;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.DefaultOptionBuilder;
import org.wikibrain.core.WikiBrainException;
import org.wikibrain.core.cmd.Env;
import org.wikibrain.core.cmd.EnvBuilder;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.dao.LocalLinkDao;
import org.wikibrain.core.dao.LocalPageDao;
import org.wikibrain.core.lang.Language;
import org.wikibrain.phrases.LinkProbabilityDao;
import org.wikibrain.sr.dataset.Dataset;
import org.wikibrain.sr.dataset.DatasetDao;
import org.wikibrain.sr.ensemble.EnsembleMetric;
import org.wikibrain.sr.esa.SRConceptSpaceGenerator;
import org.wikibrain.sr.milnewitten.MilneWittenMetric;
import org.wikibrain.sr.wikify.Corpus;
import org.wikibrain.sr.word2vec.Word2VecGenerator;
import org.wikibrain.sr.word2vec.Word2VecTrainer;
import org.wikibrain.utils.WpIOUtils;

/* loaded from: input_file:org/wikibrain/sr/SRBuilder.class */
public class SRBuilder {
    private static final Logger LOG = Logger.getLogger(SRBuilder.class.getName());
    private final Env env;
    private final Configuration config;
    private Language language;
    private final File srDir;
    private String metricName;
    private List<String> datasetNames;
    private boolean deleteExistingData = true;
    private int maxResults = 500;
    private boolean buildCosimilarity = false;
    private TIntSet rowIds = null;
    private TIntSet colIds = null;
    private boolean skipBuiltMetrics = false;
    private TIntSet validMostSimilarIds = null;
    private Mode mode = Mode.BOTH;

    /* loaded from: input_file:org/wikibrain/sr/SRBuilder$Mode.class */
    public enum Mode {
        SIMILARITY,
        MOSTSIMILAR,
        BOTH
    }

    public SRBuilder(Env env, String str) throws ConfigurationException {
        this.metricName = null;
        this.env = env;
        this.language = env.getLanguages().getDefaultLanguage();
        this.config = env.getConfiguration();
        this.srDir = new File(this.config.get().getString("sr.metric.path"));
        this.datasetNames = this.config.get().getStringList("sr.dataset.defaultsets");
        this.metricName = env.getConfigurator().resolveComponentName(SRMetric.class, str);
        if (this.srDir.isDirectory()) {
            return;
        }
        this.srDir.mkdirs();
    }

    public synchronized SRMetric getMetric() throws ConfigurationException {
        return getMetric(this.metricName);
    }

    public synchronized SRMetric getMetric(String str) throws ConfigurationException {
        return (SRMetric) this.env.getConfigurator().get(SRMetric.class, str, "language", this.language.getLangCode());
    }

    public void build() throws ConfigurationException, DaoException, IOException, WikiBrainException {
        if (this.deleteExistingData) {
            deleteDataDirectories();
        }
        buildConceptsIfNecessary();
        LOG.info("building metric " + this.metricName);
        Iterator<String> it = getSubmetrics(this.metricName).iterator();
        while (it.hasNext()) {
            initMetric(it.next());
        }
        Iterator<String> it2 = getSubmetrics(this.metricName).iterator();
        while (it2.hasNext()) {
            buildMetric(it2.next());
        }
    }

    public void deleteDataDirectories() throws ConfigurationException {
        Iterator<String> it = getSubmetrics(this.metricName).iterator();
        while (it.hasNext()) {
            File file = FileUtils.getFile(this.srDir, new String[]{it.next(), this.language.getLangCode()});
            if (file.exists()) {
                LOG.info("deleting metric directory " + file);
                FileUtils.deleteQuietly(file);
            }
        }
    }

    public List<String> getSubmetrics(String str) throws ConfigurationException {
        String metricType = getMetricType(str);
        Config metricConfig = getMetricConfig(str);
        ArrayList<String> arrayList = new ArrayList();
        if (metricType.equals("ensemble")) {
            for (String str2 : metricConfig.getStringList("metrics")) {
                arrayList.addAll(getSubmetrics(str2));
                arrayList.add(str2);
            }
        } else if (metricType.equals("vector.mostsimilarconcepts")) {
            arrayList.addAll(getSubmetrics(metricConfig.getString("generator.basemetric")));
        } else if (metricType.equals("milnewitten")) {
            arrayList.add(metricConfig.getString("inlink"));
            arrayList.add(metricConfig.getString("outlink"));
        }
        arrayList.add(str);
        ArrayList arrayList2 = new ArrayList();
        for (String str3 : arrayList) {
            if (!arrayList2.contains(str3)) {
                arrayList2.add(str3);
            }
        }
        return arrayList2;
    }

    public void initMetric(String str) throws ConfigurationException {
        String metricType = getMetricType(str);
        if (metricType.equals("ensemble")) {
            ((EnsembleMetric) getMetric(str)).setTrainSubmetrics(false);
            return;
        }
        if (!metricType.equals("vector.mostsimilarconcepts")) {
            if (metricType.equals("milnewitten")) {
                ((MilneWittenMetric) getMetric(str)).setTrainSubmetrics(false);
            }
        } else if (this.mode == Mode.SIMILARITY) {
            LOG.warning("metric " + str + " of type " + metricType + " requires mostSimilar... training BOTH");
            this.mode = Mode.BOTH;
        }
    }

    public void buildMetric(String str) throws ConfigurationException, DaoException, IOException {
        LOG.info("building component metric " + str);
        if (getMetricType(str).equals("vector.word2vec")) {
            initWord2Vec(str);
        }
        Dataset dataset = getDataset();
        SRMetric metric = getMetric(str);
        if (metric instanceof BaseSRMetric) {
            ((BaseSRMetric) metric).setBuildMostSimilarCache(this.buildCosimilarity);
        }
        if (this.mode == Mode.SIMILARITY || this.mode == Mode.BOTH) {
            if (this.skipBuiltMetrics && metric.similarityIsTrained()) {
                LOG.info("metric " + str + " similarity() is already trained... skipping");
            } else {
                metric.trainSimilarity(dataset);
            }
        }
        if (this.mode == Mode.MOSTSIMILAR || this.mode == Mode.BOTH) {
            if (this.skipBuiltMetrics && metric.mostSimilarIsTrained()) {
                LOG.info("metric " + str + " mostSimilar() is already trained... skipping");
            } else {
                Config metricConfig = getMetricConfig(str);
                int i = this.maxResults * 2;
                TIntSet tIntSet = this.validMostSimilarIds;
                if (metricConfig.hasPath("maxResults")) {
                    i = metricConfig.getInt("maxResults");
                }
                if (metricConfig.hasPath("mostSimilarConcepts")) {
                    tIntSet = readIds(String.format("%s/%s.txt", metricConfig.getString("mostSimilarConcepts"), metric.getLanguage().getLangCode()));
                }
                metric.trainMostSimilar(dataset, i, tIntSet);
            }
        }
        metric.write();
    }

    private void initWord2Vec(String str) throws ConfigurationException, IOException, DaoException {
        LinkProbabilityDao linkProbabilityDao = (LinkProbabilityDao) this.env.getConfigurator().get(LinkProbabilityDao.class);
        if (!linkProbabilityDao.isBuilt()) {
            linkProbabilityDao.build();
        }
        Config config = getMetricConfig(str).getConfig("generator");
        Corpus corpus = null;
        if (!config.getString("corpus").equals("NONE")) {
            corpus = (Corpus) this.env.getConfigurator().get(Corpus.class, config.getString("corpus"), "language", this.language.getLangCode());
            if (!corpus.exists()) {
                corpus.create();
            }
        }
        File modelFile = Word2VecGenerator.getModelFile(config.getString("modelDir"), this.language);
        if (modelFile.isFile()) {
            return;
        }
        if (corpus == null) {
            throw new ConfigurationException("word2vec metric " + str + " cannot build or find model!configuration has no corpus, but model not found at " + modelFile + ".");
        }
        Word2VecTrainer word2VecTrainer = new Word2VecTrainer((LocalPageDao) this.env.getConfigurator().get(LocalPageDao.class), this.language);
        word2VecTrainer.train(corpus.getDirectory());
        word2VecTrainer.save(modelFile);
    }

    private void setValidMostSimilarIdsFromFile(String str) throws IOException {
        setValidMostSimilarIds(readIds(str));
    }

    public void setValidMostSimilarIds(TIntSet tIntSet) {
        this.validMostSimilarIds = tIntSet;
    }

    private void buildConceptsIfNecessary() throws IOException, ConfigurationException, DaoException {
        boolean z = false;
        Iterator<String> it = getSubmetrics(this.metricName).iterator();
        while (it.hasNext()) {
            String metricType = getMetricType(it.next());
            if (metricType.equals("vector.esa") || metricType.equals("vector.mostsimilarconcepts")) {
                z = true;
            }
        }
        if (z) {
            File file = FileUtils.getFile(new String[]{this.env.getConfiguration().get().getString("sr.concepts.path"), this.language.getLangCode() + ".txt"});
            file.getParentFile().mkdirs();
            if (!file.isFile() || FileUtils.readLines(file).size() <= 1) {
                LOG.info("building concept file " + file.getAbsolutePath() + " for " + this.metricName);
                new SRConceptSpaceGenerator(this.language, (LocalLinkDao) this.env.getConfigurator().get(LocalLinkDao.class), (LocalPageDao) this.env.getConfigurator().get(LocalPageDao.class)).writeConcepts(file);
                LOG.info("finished creating concept file " + file.getAbsolutePath() + " with " + FileUtils.readLines(file).size() + " lines");
            }
        }
    }

    public Dataset getDataset() throws ConfigurationException, DaoException {
        DatasetDao datasetDao = (DatasetDao) this.env.getConfigurator().get(DatasetDao.class);
        ArrayList arrayList = new ArrayList();
        Iterator<String> it = this.datasetNames.iterator();
        while (it.hasNext()) {
            arrayList.addAll(datasetDao.getDatasetOrGroup(this.language, it.next()));
        }
        return new Dataset(arrayList);
    }

    public String getMetricType() throws ConfigurationException {
        return getMetricType(this.metricName);
    }

    public String getMetricType(String str) throws ConfigurationException {
        Config metricConfig = getMetricConfig(str);
        String string = metricConfig.getString("type");
        if (string.equals("vector")) {
            string = string + "." + metricConfig.getString("generator.type");
        }
        return string;
    }

    public Config getMetricConfig() throws ConfigurationException {
        return getMetricConfig(this.metricName);
    }

    public Config getMetricConfig(String str) throws ConfigurationException {
        return this.env.getConfigurator().getConfig(SRMetric.class, str);
    }

    public void setRowIdsFromFile(String str) throws IOException {
        this.rowIds = readIds(str);
    }

    public void setColIdsFromFile(String str) throws IOException {
        this.colIds = readIds(str);
    }

    public void setDatasetNames(List<String> list) {
        this.datasetNames = list;
    }

    public void setBuildCosimilarity(boolean z) {
        this.buildCosimilarity = z;
    }

    public void setMaxResults(int i) {
        this.maxResults = i;
    }

    public void setRowIds(TIntSet tIntSet) {
        this.rowIds = tIntSet;
    }

    public void setColIds(TIntSet tIntSet) {
        this.colIds = tIntSet;
    }

    public void setMode(Mode mode) {
        this.mode = mode;
    }

    public void setDeleteExistingData(boolean z) {
        this.deleteExistingData = z;
    }

    public void setSkipBuiltMetrics(boolean z) {
        this.skipBuiltMetrics = z;
    }

    public void setLanguage(Language language) {
        this.language = language;
    }

    private static TIntSet readIds(String str) throws IOException {
        TIntHashSet tIntHashSet = new TIntHashSet();
        BufferedReader openBufferedReader = WpIOUtils.openBufferedReader(new File(str));
        while (true) {
            String readLine = openBufferedReader.readLine();
            if (readLine == null) {
                openBufferedReader.close();
                return tIntHashSet;
            }
            tIntHashSet.add(Integer.valueOf(readLine.trim()).intValue());
        }
    }

    public static void main(String[] strArr) throws ConfigurationException, IOException, WikiBrainException, DaoException {
        Options options = new Options();
        options.addOption(new DefaultOptionBuilder().hasArg().withLongOpt("max-results").withDescription("maximum number of results").create("r"));
        options.addOption(new DefaultOptionBuilder().hasArgs().withLongOpt("gold").withDescription("the set of gold standard datasets to train on").create("g"));
        options.addOption(new DefaultOptionBuilder().hasArg().withLongOpt("delete").withDescription("delete all existing SR data for the metric and its submetrics (true or false, default is true)").create("d"));
        options.addOption(new DefaultOptionBuilder().hasArg().withLongOpt("metric").withDescription("set a local metric").create("m"));
        options.addOption(new DefaultOptionBuilder().hasArg().withLongOpt("rowids").withDescription("page ids for rows of cosimilarity matrices (implies -s)").create("p"));
        options.addOption(new DefaultOptionBuilder().hasArg().withLongOpt("colids").withDescription("page ids for columns of cosimilarity matrices (implies -s)").create("q"));
        options.addOption(new DefaultOptionBuilder().withLongOpt("cosimilarity").withDescription("build cosimilarity matrices").create("s"));
        options.addOption(new DefaultOptionBuilder().withLongOpt("mode").hasArg().withDescription("mode: similarity, mostsimilar, or both").create("o"));
        options.addOption(new DefaultOptionBuilder().withLongOpt("validMostSimilarIds").withDescription("Set valid most similar ids").create("y"));
        options.addOption(new DefaultOptionBuilder().withLongOpt("skip-built").withDescription("Don't rebuild already built bmetrics (implies -d false)").create("k"));
        EnvBuilder.addStandardOptions(options);
        try {
            CommandLine parse = new PosixParser().parse(options, strArr);
            SRBuilder sRBuilder = new SRBuilder(new EnvBuilder(parse).build(), parse.hasOption("m") ? parse.getOptionValue("m") : null);
            if (parse.hasOption("g")) {
                sRBuilder.setDatasetNames(Arrays.asList(parse.getOptionValues("g")));
            }
            if (parse.hasOption("p")) {
                sRBuilder.setRowIdsFromFile(parse.getOptionValue("p"));
                sRBuilder.setBuildCosimilarity(true);
            }
            if (parse.hasOption("q")) {
                sRBuilder.setColIdsFromFile(parse.getOptionValue("q"));
                sRBuilder.setBuildCosimilarity(true);
            }
            if (parse.hasOption("y")) {
                sRBuilder.setValidMostSimilarIdsFromFile(parse.getOptionValue("y"));
            }
            if (parse.hasOption("s")) {
                sRBuilder.setBuildCosimilarity(true);
            }
            if (parse.hasOption("k")) {
                sRBuilder.setSkipBuiltMetrics(true);
                sRBuilder.setDeleteExistingData(false);
            }
            if (parse.hasOption("d")) {
                sRBuilder.setDeleteExistingData(Boolean.valueOf(parse.getOptionValue("d")).booleanValue());
            }
            if (parse.hasOption("o")) {
                sRBuilder.setMode(Mode.valueOf(parse.getOptionValue("o").toUpperCase()));
            }
            if (parse.hasOption("l")) {
                sRBuilder.setLanguage(Language.getByLangCode(parse.getOptionValue("l")));
            }
            if (parse.hasOption("r")) {
                sRBuilder.setMaxResults(Integer.valueOf(parse.getOptionValue("r")).intValue());
            }
            sRBuilder.build();
        } catch (ParseException e) {
            System.err.println("Invalid option usage: " + e.getMessage());
            new HelpFormatter().printHelp("SRBuilder", options);
        }
    }
}
