package org.wikibrain.sr.dataset;

import com.typesafe.config.Config;
import com.typesafe.config.ConfigValue;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.StringEscapeUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.lang.Language;
import org.wikibrain.core.lang.LanguageSet;
import org.wikibrain.core.lang.LocalId;
import org.wikibrain.core.lang.LocalString;
import org.wikibrain.sr.disambig.Disambiguator;
import org.wikibrain.sr.utils.KnownSim;
import org.wikibrain.utils.WpIOUtils;

/* loaded from: input_file:org/wikibrain/sr/dataset/DatasetDao.class */
public class DatasetDao {
    private static final Logger LOG = LoggerFactory.getLogger(Dataset.class);
    public static final String RESOURCE_DATSET = "/datasets";
    public static final String RESOURCE_DATASET_INFO = "/datasets/info.tsv";
    private final Collection<Info> info;
    private Map<String, List<String>> groups = new HashMap();
    private boolean normalize = true;
    private boolean resolvePhrases = false;
    private Disambiguator disambiguator = null;

    /* loaded from: input_file:org/wikibrain/sr/dataset/DatasetDao$Info.class */
    public static class Info {
        private String name;
        private LanguageSet languages;

        public Info(String str, LanguageSet languageSet) {
            this.name = str;
            this.languages = languageSet;
        }

        public String getName() {
            return this.name;
        }

        public LanguageSet getLanguages() {
            return this.languages;
        }
    }

    /* loaded from: input_file:org/wikibrain/sr/dataset/DatasetDao$Provider.class */
    public static class Provider extends org.wikibrain.conf.Provider<DatasetDao> {
        public Provider(Configurator configurator, Configuration configuration) throws ConfigurationException {
            super(configurator, configuration);
        }

        public Class<DatasetDao> getType() {
            return DatasetDao.class;
        }

        public String getPath() {
            return "sr.dataset.dao";
        }

        public DatasetDao get(String str, Config config, Map<String, String> map) throws ConfigurationException {
            if (!config.getString("type").equals("resource")) {
                return null;
            }
            DatasetDao datasetDao = new DatasetDao();
            if (config.hasPath("normalize")) {
                datasetDao.setNormalize(config.getBoolean("normalize"));
            }
            if (config.hasPath("disambig")) {
                datasetDao.setDisambiguator((Disambiguator) getConfigurator().get(Disambiguator.class, config.getString("disambig")));
            }
            if (config.hasPath("resolvePhrases")) {
                datasetDao.setResolvePhrases(config.getBoolean("resolvePhrases"));
            }
            HashMap hashMap = new HashMap();
            for (Map.Entry entry : getConfig().get().getConfig("sr.dataset.groups").entrySet()) {
                hashMap.put(entry.getKey(), (List) ((ConfigValue) entry.getValue()).unwrapped());
            }
            datasetDao.setGroups(hashMap);
            return datasetDao;
        }

        /* renamed from: get, reason: collision with other method in class */
        public /* bridge */ /* synthetic */ Object m10get(String str, Config config, Map map) throws ConfigurationException {
            return get(str, config, (Map<String, String>) map);
        }
    }

    public DatasetDao() {
        try {
            this.info = readInfos();
        } catch (DaoException e) {
            throw new RuntimeException((Throwable) e);
        }
    }

    public DatasetDao(Collection<Info> collection) {
        this.info = collection;
    }

    public void setNormalize(boolean z) {
        this.normalize = z;
    }

    public List<Dataset> getAllInLanguage(Language language) throws DaoException {
        ArrayList arrayList = new ArrayList();
        for (Info info : this.info) {
            if (info.getLanguages().containsLanguage(language)) {
                arrayList.add(get(language, info.getName()));
            }
        }
        return arrayList;
    }

    public Dataset read(Language language, File file) throws DaoException {
        try {
            return read(file.getName(), language, WpIOUtils.openBufferedReader(file));
        } catch (IOException e) {
            throw new DaoException(e);
        }
    }

    public Dataset get(Language language, String str) throws DaoException {
        if (this.groups.containsKey(str)) {
            ArrayList arrayList = new ArrayList();
            Iterator<String> it = this.groups.get(str).iterator();
            while (it.hasNext()) {
                arrayList.add(get(language, it.next()));
            }
            return new Dataset(str, arrayList);
        }
        if (str.contains("/") || str.contains("\\")) {
            throw new DaoException("get() reads a dataset by name for a jar. Try read() instead?");
        }
        Info info = getInfo(str);
        if (info == null) {
            throw new DaoException("no dataset with name '" + str + "'");
        }
        if (!info.languages.containsLanguage(language)) {
            throw new DaoException("dataset '" + str + "' does not support language " + language);
        }
        try {
            return read(str, language, WpIOUtils.openResource("/datasets/" + str));
        } catch (IOException e) {
            throw new DaoException(e);
        }
    }

    public boolean isGroup(String str) {
        return this.groups.containsKey(str);
    }

    public List<Dataset> getGroup(Language language, String str) throws DaoException {
        ArrayList arrayList = new ArrayList();
        Iterator<String> it = this.groups.get(str).iterator();
        while (it.hasNext()) {
            arrayList.add(get(language, it.next()));
        }
        return arrayList;
    }

    public List<Dataset> getDatasetOrGroup(Language language, String str) throws DaoException {
        return isGroup(str) ? getGroup(language, str) : Arrays.asList(get(language, str));
    }

    public Info getInfo(String str) {
        for (Info info : this.info) {
            if (info.name.equalsIgnoreCase(str)) {
                return info;
            }
        }
        return null;
    }

    public void setDisambiguator(Disambiguator disambiguator) {
        this.disambiguator = disambiguator;
        this.resolvePhrases = true;
    }

    public void setResolvePhrases(boolean z) {
        this.resolvePhrases = z;
        if (z && this.disambiguator == null) {
            throw new IllegalStateException("resolve phrases et to true, but no disambiguator specified.");
        }
    }

    public void setGroups(Map<String, List<String>> map) {
        this.groups = map;
    }

    protected Dataset read(String str, Language language, BufferedReader bufferedReader) throws DaoException {
        ArrayList arrayList = new ArrayList();
        try {
            String str2 = str.toLowerCase().endsWith("csv") ? "," : "\t";
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    bufferedReader.close();
                    Dataset dataset = new Dataset(str, language, arrayList);
                    if (this.normalize) {
                        dataset.normalize();
                    }
                    return dataset;
                }
                String[] split = readLine.split(str2);
                if (split.length < 3) {
                    throw new DaoException("Invalid line in dataset file " + str + ": '" + StringEscapeUtils.escapeJava(readLine) + "'");
                }
                KnownSim knownSim = new KnownSim(split[0], split[1], Double.valueOf(split[2]).doubleValue(), language);
                if (this.resolvePhrases) {
                    LocalId disambiguateTop = this.disambiguator.disambiguateTop(new LocalString(language, knownSim.phrase1), (Set<LocalString>) null);
                    LocalId disambiguateTop2 = this.disambiguator.disambiguateTop(new LocalString(language, knownSim.phrase2), (Set<LocalString>) null);
                    if (disambiguateTop != null) {
                        knownSim.wpId1 = disambiguateTop.getId();
                    }
                    if (disambiguateTop2 != null) {
                        knownSim.wpId2 = disambiguateTop2.getId();
                    }
                }
                arrayList.add(knownSim);
            }
        } catch (IOException e) {
            throw new DaoException(e);
        }
    }

    public void write(Dataset dataset, File file) throws DaoException {
        try {
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(file));
            for (KnownSim knownSim : dataset.getData()) {
                bufferedWriter.write(knownSim.phrase1 + "\t" + knownSim.phrase2 + "\t" + knownSim.similarity + "\n");
            }
            bufferedWriter.flush();
            bufferedWriter.close();
        } catch (IOException e) {
            throw new DaoException(e);
        }
    }

    public static Collection<Info> readInfos() throws DaoException {
        try {
            return readInfos(WpIOUtils.openResource(RESOURCE_DATASET_INFO));
        } catch (IOException e) {
            throw new DaoException(e);
        }
    }

    public static Collection<Info> readInfos(BufferedReader bufferedReader) throws DaoException {
        try {
            ArrayList arrayList = new ArrayList();
            while (true) {
                try {
                    String readLine = bufferedReader.readLine();
                    if (readLine == null) {
                        return arrayList;
                    }
                    String[] split = readLine.trim().split("\t");
                    arrayList.add(new Info(split[0], new LanguageSet(split[1])));
                } catch (IOException e) {
                    throw new DaoException(e);
                }
            }
        } finally {
            IOUtils.closeQuietly(bufferedReader);
        }
    }
}
