package edu.stanford.nlp.kbp.slotfilling.shallowdive;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.kbp.common.CollectionUtils;
import edu.stanford.nlp.kbp.common.KBPNew;
import edu.stanford.nlp.kbp.common.KBPair;
import edu.stanford.nlp.kbp.common.Maybe;
import edu.stanford.nlp.kbp.common.Pointer;
import edu.stanford.nlp.kbp.common.PostgresUtils;
import edu.stanford.nlp.kbp.common.Props;
import edu.stanford.nlp.kbp.common.SentenceGroup;
import edu.stanford.nlp.kbp.common.Utils;
import edu.stanford.nlp.kbp.slotfilling.classify.KBPDataset;
import edu.stanford.nlp.kbp.slotfilling.classify.RelationClassifier;
import edu.stanford.nlp.kbp.slotfilling.classify.TrainingStatistics;
import edu.stanford.nlp.kbp.slotfilling.ir.KBPIR;
import edu.stanford.nlp.kbp.slotfilling.train.KBPTrainer;
import edu.stanford.nlp.kbp.slotfilling.train.KryoDatumCache;
import edu.stanford.nlp.util.Factory;
import edu.stanford.nlp.util.IterableIterator;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Random;
import java.util.Set;

/* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/shallowdive/DatumOps.class */
public class DatumOps {
    private static Redwood.RedwoodChannels logger = Redwood.channels(new Object[]{"DatumOps"});
    public final KBPTrainer trainer;
    public final TextOps textOps;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/stanford/nlp/kbp/slotfilling/shallowdive/DatumOps$KBPDatum.class */
    public static class KBPDatum {
        public final SentenceGroup group;
        public Set<String> positiveLabels = new HashSet();
        public Set<String> negativeLabels = new HashSet();
        public Set<String> unknownLabels = new HashSet();

        public KBPDatum(SentenceGroup sentenceGroup) {
            this.group = sentenceGroup;
        }

        public void registerRelation(boolean z, String str) {
            if (z) {
                this.positiveLabels.add(str);
            } else {
                this.negativeLabels.add(str);
            }
        }
    }

    public DatumOps(TextOps textOps, KBPTrainer kBPTrainer) {
        this.textOps = textOps;
        this.trainer = kBPTrainer;
    }

    protected static KBPair mkKey(ResultSet resultSet) throws SQLException {
        return KBPNew.entName(resultSet.getString("entity_name")).entType(resultSet.getString("entity_type")).slotValue(resultSet.getString("slot_value")).slotType(resultSet.getString("slot_value_type")).KBPair();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static SentenceGroup mkDatum(ResultSet resultSet) throws SQLException {
        try {
            ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(resultSet.getBytes("datum"));
            SentenceGroup load = KryoDatumCache.load(byteArrayInputStream);
            byteArrayInputStream.close();
            return load;
        } catch (IOException e) {
            logger.log(new Object[]{e});
            return SentenceGroup.empty(mkKey(resultSet));
        } catch (ClassNotFoundException e2) {
            logger.log(new Object[]{e2});
            return SentenceGroup.empty(mkKey(resultSet));
        }
    }

    public KBPDataset<String, String> mkDataset(String str) {
        Pointer pointer = new Pointer();
        PostgresUtils.withConnection(str, connection -> {
            Statement createStatement = connection.createStatement();
            final ResultSet executeQuery = createStatement.executeQuery("SELECT d.did, d.entity_name, d.entity_type, d.slot_value, d.slot_value_type, d.datum, r.relation_name, r.truth FROM " + str + " d, " + str + "_relations r WHERE d.did = r.did ORDER BY d.entity_name, d.entity_type, d.slot_value, d.slot_value_type, r.truth DESC;");
            boolean autoCommit = connection.getAutoCommit();
            final Random random = new Random(42L);
            connection.setAutoCommit(false);
            createStatement.setFetchSize(100000);
            if (!executeQuery.next()) {
                throw new IllegalArgumentException("No datums in table");
            }
            pointer.set((Pointer) new IterableIterator(CollectionUtils.iteratorFromMaybeFactory(new Factory<Maybe<KBPDatum>>() { // from class: edu.stanford.nlp.kbp.slotfilling.shallowdive.DatumOps.1
                private KBPair currentKey;
                private KBPDatum currentDatum;
                private Set didsSeen = new HashSet();
                private boolean isDone = false;
                private int numPositives = 0;
                private int numNegatives = 0;

                {
                    this.currentKey = DatumOps.mkKey(executeQuery);
                    this.currentDatum = new KBPDatum(SentenceGroup.empty(this.currentKey));
                }

                /* renamed from: create, reason: merged with bridge method [inline-methods] */
                public Maybe<KBPDatum> m348create() {
                    int i;
                    try {
                        if (this.isDone) {
                            return null;
                        }
                        boolean z = executeQuery.getBoolean("truth");
                        boolean z2 = z;
                        if (!z && this.numNegatives < (i = (int) (this.numPositives * Props.TRAIN_NEGATIVES_RATIO))) {
                            z2 = random.nextDouble() >= Math.pow(0.75d, (double) (i - this.numNegatives));
                        }
                        if ((this.numPositives + this.numNegatives) % 100000 == 0) {
                            DatumOps.logger.log(new Object[]{"read " + (this.numPositives + this.numNegatives) + " sentence groups; [" + this.numPositives + " pos + " + this.numNegatives + " neg]; " + Utils.getMemoryUsage()});
                        }
                        if (z2) {
                            if (z) {
                                this.numPositives++;
                            } else {
                                this.numNegatives++;
                            }
                            this.didsSeen.add(Long.valueOf(executeQuery.getLong("did")));
                            this.currentDatum.group.merge(DatumOps.mkDatum(executeQuery));
                            this.currentDatum.registerRelation(executeQuery.getBoolean("truth"), executeQuery.getString("relation_name"));
                            while (executeQuery.next()) {
                                KBPair mkKey = DatumOps.mkKey(executeQuery);
                                if (!mkKey.equals(this.currentKey)) {
                                    this.currentKey = mkKey;
                                    this.didsSeen.clear();
                                    return Maybe.Just(this.currentDatum);
                                }
                                long j = executeQuery.getLong("did");
                                if (!this.didsSeen.contains(Long.valueOf(j))) {
                                    this.currentDatum.group.merge(DatumOps.mkDatum(executeQuery));
                                    this.didsSeen.add(Long.valueOf(j));
                                }
                                this.currentDatum.registerRelation(executeQuery.getBoolean("truth"), executeQuery.getString("relation_name"));
                            }
                        } else {
                            while (executeQuery.next()) {
                                KBPair mkKey2 = DatumOps.mkKey(executeQuery);
                                if (!mkKey2.equals(this.currentKey)) {
                                    this.currentKey = mkKey2;
                                    this.didsSeen.clear();
                                    return Maybe.Nothing();
                                }
                            }
                        }
                        this.isDone = true;
                        return this.currentDatum.group.isEmpty() ? Maybe.Nothing() : Maybe.Just(this.currentDatum);
                    } catch (SQLException e) {
                        throw new RuntimeException(e);
                    }
                }
            })));
            if (autoCommit != connection.getAutoCommit()) {
                connection.setAutoCommit(autoCommit);
            }
        });
        Redwood.Util.forceTrack("Creating dataset");
        KBPDataset<String, String> kBPDataset = new KBPDataset<>();
        Iterator it = ((IterableIterator) pointer.dereference().get()).iterator();
        while (it.hasNext()) {
            KBPDatum kBPDatum = (KBPDatum) it.next();
            kBPDataset.addDatum(kBPDatum.positiveLabels, kBPDatum.negativeLabels, kBPDatum.unknownLabels, kBPDatum.group, kBPDatum.group.sentenceGlossKeys, new Maybe[kBPDatum.group.size()]);
        }
        kBPDataset.applyFeatureCountThreshold(Props.FEATURE_COUNT_THRESHOLD);
        Redwood.Util.startTrack(new Object[]{"Dataset Info"});
        logger.log(new Object[]{Redwood.Util.BLUE, "                                size: " + kBPDataset.size()});
        logger.log(new Object[]{Redwood.Util.BLUE, "           number of feature classes: " + kBPDataset.numFeatures()});
        logger.log(new Object[]{Redwood.Util.BLUE, "                 number of relations: " + kBPDataset.numClasses()});
        Redwood.Util.endTrack("Dataset Info");
        Redwood.Util.endTrack("Creating Dataset");
        return kBPDataset;
    }

    public Pair<RelationClassifier, TrainingStatistics> train(String str, KBPIR kbpir) {
        Redwood.Util.forceTrack("Training");
        Pair<RelationClassifier, TrainingStatistics> trainOnData = this.trainer.trainOnData(mkDataset(str));
        try {
            logger.log(new Object[]{Redwood.Util.BOLD, Redwood.Util.BLUE, "saving model to " + Props.KBP_MODEL_PATH});
            ((RelationClassifier) trainOnData.first).save(Props.KBP_MODEL_PATH);
        } catch (IOException e) {
            logger.err(new Object[]{"Could not save model."});
            logger.fatal(new Object[]{e});
        }
        try {
            IOUtils.writeObjectToFile(trainOnData.second, Props.WORK_DIR.getPath() + File.separator + "train_statistics.ser.gz");
        } catch (IOException e2) {
            logger.err(new Object[]{e2});
        }
        Redwood.Util.endTrack("Training");
        return trainOnData;
    }
}
