package eus.ixa.ixa.pipe.ml;

import eus.ixa.ixa.pipe.ml.features.XMLFeatureDescriptor;
import eus.ixa.ixa.pipe.ml.parse.AncoraHeadRules;
import eus.ixa.ixa.pipe.ml.parse.HeadRules;
import eus.ixa.ixa.pipe.ml.parse.Parse;
import eus.ixa.ixa.pipe.ml.parse.ParseSampleStream;
import eus.ixa.ixa.pipe.ml.parse.ParserEvaluationMonitor;
import eus.ixa.ixa.pipe.ml.parse.ParserEvaluator;
import eus.ixa.ixa.pipe.ml.parse.ParserFactory;
import eus.ixa.ixa.pipe.ml.parse.ParserModel;
import eus.ixa.ixa.pipe.ml.parse.PennTreebankHeadRules;
import eus.ixa.ixa.pipe.ml.parse.ShiftReduceParser;
import eus.ixa.ixa.pipe.ml.resources.LoadModelResources;
import eus.ixa.ixa.pipe.ml.sequence.BilouCodec;
import eus.ixa.ixa.pipe.ml.sequence.BioCodec;
import eus.ixa.ixa.pipe.ml.sequence.SequenceLabelerCodec;
import eus.ixa.ixa.pipe.ml.sequence.SequenceLabelerFactory;
import eus.ixa.ixa.pipe.ml.sequence.SequenceLabelerModel;
import eus.ixa.ixa.pipe.ml.utils.Flags;
import eus.ixa.ixa.pipe.ml.utils.IOUtils;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.Charset;
import opennlp.tools.cmdline.TerminateToolException;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.TrainingParameters;
import opennlp.tools.util.model.ArtifactSerializer;

/* loaded from: input_file:eus/ixa/ixa/pipe/ml/ShiftReduceParserTrainer.class */
public class ShiftReduceParserTrainer {
    private final String lang;
    private final String trainData;
    private final String testData;
    private final ObjectStream<Parse> trainSamples;
    private final ObjectStream<Parse> testSamples;
    private final HeadRules rules;
    private ParserFactory parserFactory;
    private String sequenceCodec;
    private SequenceLabelerFactory taggerFactory;
    private SequenceLabelerFactory chunkerFactory;

    public ShiftReduceParserTrainer(TrainingParameters trainingParameters, TrainingParameters trainingParameters2, TrainingParameters trainingParameters3) throws IOException {
        this.lang = Flags.getLanguage(trainingParameters);
        this.trainData = (String) trainingParameters.getSettings().get("TrainSet");
        this.testData = (String) trainingParameters.getSettings().get("TestSet");
        this.trainSamples = getParseStream(this.trainData);
        this.testSamples = getParseStream(this.testData);
        this.rules = getHeadRules(trainingParameters);
        createParserFactory(trainingParameters);
        setTaggerFactory(createSequenceLabelerFactory(trainingParameters2));
        setChunkerFactory(createSequenceLabelerFactory(trainingParameters3));
    }

    public ShiftReduceParserTrainer(TrainingParameters trainingParameters, TrainingParameters trainingParameters2) throws IOException {
        this.lang = Flags.getLanguage(trainingParameters);
        this.trainData = (String) trainingParameters.getSettings().get("TrainSet");
        this.testData = (String) trainingParameters.getSettings().get("TestSet");
        this.trainSamples = getParseStream(this.trainData);
        this.testSamples = getParseStream(this.testData);
        this.rules = getHeadRules(trainingParameters);
        createParserFactory(trainingParameters);
        setChunkerFactory(createSequenceLabelerFactory(trainingParameters2));
    }

    public void createParserFactory(TrainingParameters trainingParameters) throws IOException {
        setParserFactory(ParserFactory.create(ParserFactory.class.getName(), ShiftReduceParser.buildDictionary(this.trainSamples, this.rules, trainingParameters), LoadModelResources.loadParseResources(trainingParameters)));
    }

    public SequenceLabelerFactory createSequenceLabelerFactory(TrainingParameters trainingParameters) throws IOException {
        SequenceLabelerCodec<String> instantiateSequenceCodec = SequenceLabelerFactory.instantiateSequenceCodec(getSequenceCodec());
        String createXMLFeatureDescriptor = XMLFeatureDescriptor.createXMLFeatureDescriptor(trainingParameters);
        System.err.println(createXMLFeatureDescriptor);
        return SequenceLabelerFactory.create(SequenceLabelerFactory.class.getName(), createXMLFeatureDescriptor.getBytes(Charset.forName("UTF-8")), LoadModelResources.loadSequenceResources(trainingParameters), instantiateSequenceCodec);
    }

    public final ParserModel train(TrainingParameters trainingParameters, TrainingParameters trainingParameters2, TrainingParameters trainingParameters3) {
        if (getParserFactory() == null) {
            throw new IllegalStateException("The ParserFactory must be instantiated!!");
        }
        if (getTaggerFactory() == null) {
            throw new IllegalStateException("The TaggerFactory must be instantiated!");
        }
        ParserModel parserModel = null;
        ParserEvaluator parserEvaluator = null;
        try {
            parserModel = ShiftReduceParser.train(this.lang, this.trainSamples, this.rules, trainingParameters, this.parserFactory, trainingParameters2, this.taggerFactory, trainingParameters3, this.chunkerFactory);
            parserEvaluator = new ParserEvaluator(new ShiftReduceParser(parserModel), new ParserEvaluationMonitor[0]);
            parserEvaluator.evaluate(this.testSamples);
        } catch (IOException e) {
            System.err.println("IO error while loading training and test sets!");
            e.printStackTrace();
            System.exit(1);
        }
        System.out.println("Final Result: \n" + parserEvaluator.getFMeasure());
        return parserModel;
    }

    public final ParserModel train(TrainingParameters trainingParameters, InputStream inputStream, TrainingParameters trainingParameters2) {
        if (getParserFactory() == null) {
            throw new IllegalStateException("The ParserFactory must be instantiated!!");
        }
        SequenceLabelerModel sequenceLabelerModel = null;
        try {
            sequenceLabelerModel = new SequenceLabelerModel(inputStream);
        } catch (IOException e) {
            e.printStackTrace();
        }
        ParserModel parserModel = null;
        ParserEvaluator parserEvaluator = null;
        try {
            parserModel = ShiftReduceParser.train(this.lang, this.trainSamples, this.rules, trainingParameters, this.parserFactory, sequenceLabelerModel, trainingParameters2, this.chunkerFactory);
            parserEvaluator = new ParserEvaluator(new ShiftReduceParser(parserModel), new ParserEvaluationMonitor[0]);
            parserEvaluator.evaluate(this.testSamples);
        } catch (IOException e2) {
            System.err.println("IO error while loading training and test sets!");
            e2.printStackTrace();
            System.exit(1);
        }
        System.out.println("Final Result: \n" + parserEvaluator.getFMeasure());
        return parserModel;
    }

    public static ObjectStream<Parse> getParseStream(String str) throws IOException {
        return new ParseSampleStream(IOUtils.readFileIntoMarkableStreamFactory(str));
    }

    public static HeadRules getHeadRules(TrainingParameters trainingParameters) throws IOException {
        ArtifactSerializer artifactSerializer = null;
        if (Flags.getLanguage(trainingParameters).equalsIgnoreCase("en")) {
            artifactSerializer = new PennTreebankHeadRules.PennTreebankHeadRulesSerializer();
        } else if (Flags.getLanguage(trainingParameters).equalsIgnoreCase("es")) {
            artifactSerializer = new AncoraHeadRules.AncoraHeadRulesSerializer();
        } else {
            System.err.println("HeadRules not suported for language " + Flags.getLanguage(trainingParameters) + "!!");
        }
        Object create = artifactSerializer.create(new FileInputStream(Flags.getHeadRulesFile(trainingParameters)));
        if (create instanceof HeadRules) {
            return (HeadRules) create;
        }
        throw new TerminateToolException(-1, "HeadRules Artifact Serializer must create an object of type HeadRules!");
    }

    public final SequenceLabelerFactory getTaggerFactory() {
        return this.taggerFactory;
    }

    public final SequenceLabelerFactory setTaggerFactory(SequenceLabelerFactory sequenceLabelerFactory) {
        this.taggerFactory = sequenceLabelerFactory;
        return this.taggerFactory;
    }

    public final SequenceLabelerFactory getChunkerFactory() {
        return this.chunkerFactory;
    }

    public final SequenceLabelerFactory setChunkerFactory(SequenceLabelerFactory sequenceLabelerFactory) {
        this.chunkerFactory = sequenceLabelerFactory;
        return this.chunkerFactory;
    }

    public final ParserFactory getParserFactory() {
        return this.parserFactory;
    }

    public final ParserFactory setParserFactory(ParserFactory parserFactory) {
        this.parserFactory = parserFactory;
        return parserFactory;
    }

    public final String getSequenceCodec() {
        String str = null;
        if ("BIO".equals(this.sequenceCodec)) {
            str = BioCodec.class.getName();
        } else if (Flags.DEFAULT_SEQUENCE_CODEC.equals(this.sequenceCodec)) {
            str = BilouCodec.class.getName();
        }
        return str;
    }

    public final void setSequenceCodec(String str) {
        this.sequenceCodec = str;
    }
}
