package eus.ixa.ixa.pipe.pos.eval;

import eus.ixa.ixa.pipe.pos.train.BaselineFactory;
import eus.ixa.ixa.pipe.pos.train.Flags;
import eus.ixa.ixa.pipe.pos.train.InputOutputUtils;
import java.io.File;
import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import opennlp.tools.cmdline.postag.POSEvaluationErrorListener;
import opennlp.tools.cmdline.postag.POSTaggerFineGrainedReportListener;
import opennlp.tools.postag.POSSample;
import opennlp.tools.postag.POSTaggerCrossValidator;
import opennlp.tools.postag.POSTaggerEvaluationMonitor;
import opennlp.tools.postag.POSTaggerFactory;
import opennlp.tools.postag.WordTagSampleStream;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.TrainingParameters;
import opennlp.tools.util.eval.EvaluationMonitor;

/* loaded from: input_file:eus/ixa/ixa/pipe/pos/eval/CrossValidator.class */
public class CrossValidator {
    private final String lang;
    private final ObjectStream<POSSample> trainSamples;
    private final int dictCutOff;
    private final int folds;
    private POSTaggerFactory posTaggerFactory;
    private final List<EvaluationMonitor<POSSample>> listeners = new LinkedList();
    POSTaggerFineGrainedReportListener detailedListener;

    public CrossValidator(TrainingParameters trainingParameters) throws IOException {
        this.lang = Flags.getLanguage(trainingParameters);
        this.trainSamples = new WordTagSampleStream(InputOutputUtils.readFileIntoMarkableStreamFactory(Flags.getDataSet("TrainSet", trainingParameters)));
        this.dictCutOff = Flags.getAutoDictFeatures(trainingParameters).intValue();
        this.folds = Flags.getFolds(trainingParameters).intValue();
        createPOSFactory(trainingParameters);
        getEvalListeners(trainingParameters);
    }

    private void createPOSFactory(TrainingParameters trainingParameters) {
        if (Flags.getFeatureSet(trainingParameters).equalsIgnoreCase("Opennlp")) {
            this.posTaggerFactory = new POSTaggerFactory();
        } else {
            this.posTaggerFactory = new BaselineFactory();
        }
    }

    private void getEvalListeners(TrainingParameters trainingParameters) {
        if (trainingParameters.getSettings().get("EvaluationType").equalsIgnoreCase("error")) {
            this.listeners.add(new POSEvaluationErrorListener());
        }
        if (trainingParameters.getSettings().get("EvaluationType").equalsIgnoreCase("detailed")) {
            this.detailedListener = new POSTaggerFineGrainedReportListener();
            this.listeners.add(this.detailedListener);
        }
    }

    public final void crossValidate(TrainingParameters trainingParameters) {
        POSTaggerCrossValidator pOSTaggerCrossValidator = null;
        try {
            try {
                pOSTaggerCrossValidator = getPOSTaggerCrossValidator(trainingParameters);
                pOSTaggerCrossValidator.evaluate(this.trainSamples, this.folds);
                try {
                    this.trainSamples.close();
                } catch (IOException e) {
                    System.err.println("IO error with the train samples!");
                }
            } catch (Throwable th) {
                try {
                    this.trainSamples.close();
                } catch (IOException e2) {
                    System.err.println("IO error with the train samples!");
                }
                throw th;
            }
        } catch (IOException e3) {
            System.err.println("IO error while loading training set!");
            e3.printStackTrace();
            System.exit(1);
            try {
                this.trainSamples.close();
            } catch (IOException e4) {
                System.err.println("IO error with the train samples!");
            }
        }
        if (this.detailedListener == null) {
            System.out.println(pOSTaggerCrossValidator.getWordAccuracy());
        } else {
            System.out.println(pOSTaggerCrossValidator.getWordAccuracy());
        }
    }

    private POSTaggerCrossValidator getPOSTaggerCrossValidator(TrainingParameters trainingParameters) {
        File file = new File(Flags.getDictionaryFeatures(trainingParameters));
        if (this.posTaggerFactory == null) {
            throw new IllegalStateException("You must create the POSTaggerFactory features!");
        }
        return file.getName().equals("off") ? this.dictCutOff == -1 ? new POSTaggerCrossValidator(this.lang, trainingParameters, null, null, null, this.posTaggerFactory.getClass().getName(), (POSTaggerEvaluationMonitor[]) this.listeners.toArray(new POSTaggerEvaluationMonitor[this.listeners.size()])) : new POSTaggerCrossValidator(this.lang, trainingParameters, null, null, Integer.valueOf(this.dictCutOff), this.posTaggerFactory.getClass().getName(), (POSTaggerEvaluationMonitor[]) this.listeners.toArray(new POSTaggerEvaluationMonitor[this.listeners.size()])) : this.dictCutOff == -1 ? new POSTaggerCrossValidator(this.lang, trainingParameters, file, null, null, this.posTaggerFactory.getClass().getName(), (POSTaggerEvaluationMonitor[]) this.listeners.toArray(new POSTaggerEvaluationMonitor[this.listeners.size()])) : new POSTaggerCrossValidator(this.lang, trainingParameters, file, null, Integer.valueOf(this.dictCutOff), this.posTaggerFactory.getClass().getName(), (POSTaggerEvaluationMonitor[]) this.listeners.toArray(new POSTaggerEvaluationMonitor[this.listeners.size()]));
    }
}
