package org.wikibrain.sr.utils;

import gnu.trove.set.TIntSet;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.dao.UniversalPageDao;
import org.wikibrain.core.lang.LocalId;
import org.wikibrain.core.lang.LocalString;
import org.wikibrain.core.model.UniversalPage;
import org.wikibrain.sr.MonolingualSRMetric;
import org.wikibrain.sr.SRResult;
import org.wikibrain.sr.SRResultList;
import org.wikibrain.sr.UniversalSRMetric;
import org.wikibrain.sr.dataset.Dataset;
import org.wikibrain.sr.disambig.Disambiguator;
import org.wikibrain.sr.normalize.IdentityNormalizer;
import org.wikibrain.sr.normalize.Normalizer;
import org.wikibrain.utils.ParallelForEach;
import org.wikibrain.utils.Procedure;

/* loaded from: input_file:org/wikibrain/sr/utils/SrNormalizers.class */
public class SrNormalizers {
    private static final Logger LOG = Logger.getLogger(SrNormalizers.class.getName());
    public static final String SIMILARITY_NORMALIZER = "similarityNormalizer";
    public static final String MOST_SIMILAR_NORMALIZER = "mostSimilarNormalizer";
    private Normalizer mostSimilarNormalizer = new IdentityNormalizer();
    private Normalizer similarityNormalizer = new IdentityNormalizer();

    public Normalizer getMostSimilarNormalizer() {
        return this.mostSimilarNormalizer;
    }

    public Normalizer getSimilarityNormalizer() {
        return this.similarityNormalizer;
    }

    public void setMostSimilarNormalizer(Normalizer normalizer) {
        this.mostSimilarNormalizer = normalizer;
    }

    public void setSimilarityNormalizer(Normalizer normalizer) {
        this.similarityNormalizer = normalizer;
    }

    public void clear(File file) {
        FileUtils.deleteQuietly(new File(file, MOST_SIMILAR_NORMALIZER));
        FileUtils.deleteQuietly(new File(file, SIMILARITY_NORMALIZER));
    }

    public boolean hasReadableNormalizers(File file) {
        return isValidNormalizer(file, MOST_SIMILAR_NORMALIZER) && isValidNormalizer(file, SIMILARITY_NORMALIZER);
    }

    public void read(File file) throws IOException {
        this.mostSimilarNormalizer = readNormalizer(file, MOST_SIMILAR_NORMALIZER);
        this.similarityNormalizer = readNormalizer(file, SIMILARITY_NORMALIZER);
    }

    public void write(File file) throws IOException {
        writeNormalizer(file, MOST_SIMILAR_NORMALIZER, this.mostSimilarNormalizer);
        writeNormalizer(file, SIMILARITY_NORMALIZER, this.similarityNormalizer);
    }

    private boolean isValidNormalizer(File file, String str) {
        File file2 = new File(file, str);
        if (!file2.isFile()) {
            return false;
        }
        try {
            return readNormalizer(file, str).isTrained();
        } catch (IOException e) {
            LOG.log(Level.WARNING, "Failed to load normalizer at " + file2.getAbsolutePath() + ". Setting it to be invalid. Traceback:", (Throwable) e);
            return false;
        }
    }

    public void trainSimilarity(final MonolingualSRMetric monolingualSRMetric, Dataset dataset) {
        if (this.similarityNormalizer instanceof IdentityNormalizer) {
            return;
        }
        if (!dataset.getLanguage().equals(monolingualSRMetric.getLanguage())) {
            throw new IllegalArgumentException("SR metric has language " + monolingualSRMetric.getLanguage() + " but dataset has language " + dataset.getLanguage());
        }
        final Normalizer normalizer = this.similarityNormalizer;
        this.similarityNormalizer = new IdentityNormalizer();
        try {
            normalizer.reset();
            ParallelForEach.loop(dataset.getData(), new Procedure<KnownSim>() { // from class: org.wikibrain.sr.utils.SrNormalizers.1
                public void call(KnownSim knownSim) throws IOException, DaoException {
                    knownSim.maybeSwap();
                    SRResult similarity = monolingualSRMetric.similarity(knownSim.phrase1, knownSim.phrase2, false);
                    normalizer.observe(similarity == null ? Double.NaN : similarity.getScore(), knownSim.similarity);
                }
            }, 100);
            normalizer.observationsFinished();
            LOG.info("trained similarity normalizer: " + normalizer.dump());
            this.similarityNormalizer = normalizer;
        } catch (Throwable th) {
            this.similarityNormalizer = normalizer;
            throw th;
        }
    }

    public void trainSimilarity(final UniversalSRMetric universalSRMetric, Dataset dataset) {
        if (this.similarityNormalizer instanceof IdentityNormalizer) {
            return;
        }
        final Normalizer normalizer = this.similarityNormalizer;
        this.similarityNormalizer = new IdentityNormalizer();
        try {
            normalizer.reset();
            ParallelForEach.loop(dataset.getData(), new Procedure<KnownSim>() { // from class: org.wikibrain.sr.utils.SrNormalizers.2
                public void call(KnownSim knownSim) throws IOException, DaoException {
                    knownSim.maybeSwap();
                    normalizer.observe(universalSRMetric.similarity(new LocalString(knownSim.language, knownSim.phrase1), new LocalString(knownSim.language, knownSim.phrase2), false).getScore(), knownSim.similarity);
                }
            }, 100);
            normalizer.observationsFinished();
            LOG.info("trained similarity normalizer: " + normalizer.dump());
            this.similarityNormalizer = normalizer;
        } catch (Throwable th) {
            this.similarityNormalizer = normalizer;
            throw th;
        }
    }

    public void trainMostSimilar(final MonolingualSRMetric monolingualSRMetric, final Disambiguator disambiguator, Dataset dataset, final TIntSet tIntSet, final int i) {
        if (this.similarityNormalizer instanceof IdentityNormalizer) {
            return;
        }
        if (!dataset.getLanguage().equals(monolingualSRMetric.getLanguage())) {
            throw new IllegalArgumentException("SR metric has language " + monolingualSRMetric.getLanguage() + " but dataset has language " + dataset.getLanguage());
        }
        final Normalizer normalizer = this.mostSimilarNormalizer;
        this.mostSimilarNormalizer = new IdentityNormalizer();
        try {
            normalizer.reset();
            ParallelForEach.loop(dataset.getData(), new Procedure<KnownSim>() { // from class: org.wikibrain.sr.utils.SrNormalizers.3
                public void call(KnownSim knownSim) throws IOException, DaoException {
                    knownSim.maybeSwap();
                    ArrayList arrayList = new ArrayList();
                    arrayList.add(new LocalString(knownSim.language, knownSim.phrase1));
                    arrayList.add(new LocalString(knownSim.language, knownSim.phrase2));
                    List<LocalId> disambiguateTop = disambiguator.disambiguateTop(arrayList, (Set<LocalString>) null);
                    if (disambiguateTop == null || disambiguateTop.size() != 2 || disambiguateTop.get(0) == null || disambiguateTop.get(1) == null) {
                        return;
                    }
                    LocalId localId = disambiguateTop.get(0);
                    LocalId localId2 = disambiguateTop.get(1);
                    SRResultList mostSimilar = monolingualSRMetric.mostSimilar(localId.getId(), i, tIntSet);
                    if (mostSimilar != null) {
                        normalizer.observe(mostSimilar, mostSimilar.getIndexForId(localId2.getId()), knownSim.similarity);
                    }
                }
            }, 100);
            normalizer.observationsFinished();
            LOG.info("trained most similar normalizer: " + normalizer.dump());
            this.mostSimilarNormalizer = normalizer;
        } catch (Throwable th) {
            this.mostSimilarNormalizer = normalizer;
            throw th;
        }
    }

    public void trainMostSimilar(final UniversalSRMetric universalSRMetric, final Disambiguator disambiguator, final UniversalPageDao universalPageDao, final int i, Dataset dataset, final TIntSet tIntSet, final int i2) {
        if (this.similarityNormalizer instanceof IdentityNormalizer) {
            return;
        }
        final Normalizer normalizer = this.mostSimilarNormalizer;
        this.mostSimilarNormalizer = new IdentityNormalizer();
        try {
            normalizer.reset();
            ParallelForEach.loop(dataset.getData(), new Procedure<KnownSim>() { // from class: org.wikibrain.sr.utils.SrNormalizers.4
                public void call(KnownSim knownSim) throws IOException, DaoException {
                    SRResultList mostSimilar;
                    knownSim.maybeSwap();
                    ArrayList arrayList = new ArrayList();
                    arrayList.add(new LocalString(knownSim.language, knownSim.phrase1));
                    arrayList.add(new LocalString(knownSim.language, knownSim.phrase2));
                    List<LocalId> disambiguateTop = disambiguator.disambiguateTop(arrayList, (Set<LocalString>) null);
                    if (disambiguateTop == null || disambiguateTop.size() != 2) {
                        return;
                    }
                    int univPageId = universalPageDao.getUnivPageId(disambiguateTop.get(0).asLocalPage(), i);
                    int univPageId2 = universalPageDao.getUnivPageId(disambiguateTop.get(1).asLocalPage(), i);
                    UniversalPage byId = universalPageDao.getById(univPageId, i);
                    if (byId == null || (mostSimilar = universalSRMetric.mostSimilar(byId, i2, tIntSet)) == null) {
                        return;
                    }
                    normalizer.observe(mostSimilar, mostSimilar.getIndexForId(univPageId2), knownSim.similarity);
                }
            }, 100);
            normalizer.observationsFinished();
            LOG.info("trained most similar normalizer: " + normalizer.dump());
            this.mostSimilarNormalizer = normalizer;
        } catch (Throwable th) {
            this.mostSimilarNormalizer = normalizer;
            throw th;
        }
    }

    private Normalizer readNormalizer(File file, String str) throws IOException {
        ObjectInputStream objectInputStream = null;
        try {
            try {
                try {
                    objectInputStream = new ObjectInputStream(new FileInputStream(new File(file, str)));
                    Normalizer normalizer = (Normalizer) objectInputStream.readObject();
                    if (objectInputStream != null) {
                        IOUtils.closeQuietly(objectInputStream);
                    }
                    return normalizer;
                } catch (ClassNotFoundException e) {
                    throw new IllegalStateException(e);
                }
            } catch (FileNotFoundException e2) {
                throw new IllegalStateException(e2);
            }
        } catch (Throwable th) {
            if (objectInputStream != null) {
                IOUtils.closeQuietly(objectInputStream);
            }
            throw th;
        }
    }

    private void writeNormalizer(File file, String str, Normalizer normalizer) throws IOException {
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(new File(file, str)));
        objectOutputStream.writeObject(normalizer);
        objectOutputStream.flush();
        objectOutputStream.close();
    }
}
