/*
 * Decompiled with CFR 0.152.
 */
package io.cdap.mmds.plugin;

import com.google.common.base.Joiner;
import com.google.common.collect.Sets;
import io.cdap.cdap.api.annotation.Description;
import io.cdap.cdap.api.annotation.Name;
import io.cdap.cdap.api.annotation.Plugin;
import io.cdap.cdap.api.data.format.StructuredRecord;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.cdap.api.dataset.lib.FileSet;
import io.cdap.cdap.api.dataset.lib.IndexedTable;
import io.cdap.cdap.api.spark.sql.DataFrames;
import io.cdap.cdap.etl.api.PipelineConfigurer;
import io.cdap.cdap.etl.api.StageConfigurer;
import io.cdap.cdap.etl.api.batch.SparkCompute;
import io.cdap.cdap.etl.api.batch.SparkExecutionPluginContext;
import io.cdap.mmds.api.AlgorithmType;
import io.cdap.mmds.api.Modeler;
import io.cdap.mmds.data.ModelKey;
import io.cdap.mmds.data.ModelMeta;
import io.cdap.mmds.data.ModelTable;
import io.cdap.mmds.modeler.Modelers;
import io.cdap.mmds.modeler.feature.FeatureGenerator;
import io.cdap.mmds.modeler.feature.FeatureGeneratorPredictor;
import io.cdap.mmds.plugin.PredictorConf;
import io.cdap.mmds.plugin.RecordToRow;
import io.cdap.mmds.plugin.RowToRecord;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import javax.annotation.Nullable;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.feature.IndexToString;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.StructType;

@Plugin(type="sparkcompute")
@Name(value="MLPredictor")
@Description(value="Uses a deployed model to add a prediction field to incoming records.")
public class MLPredictor
extends SparkCompute<StructuredRecord, StructuredRecord> {
    private final PredictorConf conf;
    private String featuregenPath;
    private String modelPath;
    private String targetIndexPath;
    private Schema inputSchema;
    private Schema outputSchema;
    private Schema.Type predictionType;
    private FeatureGenerator featureGenerator;
    private Modeler modeler;

    public MLPredictor(PredictorConf conf) {
        this.conf = conf;
    }

    public void configurePipeline(PipelineConfigurer pipelineConfigurer) throws IllegalArgumentException {
        StageConfigurer stageConfigurer = pipelineConfigurer.getStageConfigurer();
        Schema inputSchema = stageConfigurer.getInputSchema();
        if (inputSchema == null) {
            throw new IllegalArgumentException("ML Predictor cannot be used with a null input schema. Please connect it to stages that have a set output schema.");
        }
        this.conf.validate(inputSchema);
        stageConfigurer.setOutputSchema(this.conf.getOutputSchema());
    }

    public void initialize(SparkExecutionPluginContext context) throws Exception {
        this.inputSchema = context.getInputSchema();
        this.conf.validate(this.inputSchema);
        this.outputSchema = this.conf.getOutputSchema();
        Schema predictionSchema = this.outputSchema.getField(this.conf.getPredictionField()).getSchema();
        predictionSchema = predictionSchema.isNullable() ? predictionSchema.getNonNullable() : predictionSchema;
        this.predictionType = predictionSchema.getType();
        IndexedTable modelTable = (IndexedTable)context.getDataset("experiment_model_meta");
        ModelTable modelMetaTable = new ModelTable(modelTable);
        ModelKey key = new ModelKey(this.conf.getExperimentID(), this.conf.getModelID());
        ModelMeta meta = modelMetaTable.get(key);
        if (meta == null) {
            throw new IllegalArgumentException(String.format("Could not find model '%s' in experiment '%s'.", this.conf.getModelID(), this.conf.getExperimentID()));
        }
        this.modeler = Modelers.getModeler((String)meta.getAlgorithm());
        if (this.modeler == null) {
            throw new IllegalArgumentException(String.format("Model '%s' in experiment '%s' uses unknown algorithm '%s'", this.conf.getModelID(), this.conf.getExperimentID(), meta.getAlgorithm()));
        }
        if (this.modeler.getAlgorithm().getType() == AlgorithmType.REGRESSION && this.predictionType == Schema.Type.STRING) {
            throw new IllegalArgumentException(String.format("Invalid getType for prediction field '%s'. Model '%s' in experiment '%s' is a regression model, which only supports double predictions.", this.conf.getPredictionField(), this.conf.getModelID(), this.conf.getExperimentID()));
        }
        HashSet featureSet = new HashSet(meta.getFeatures());
        HashSet<String> inputFields = new HashSet<String>();
        for (Schema.Field field : this.inputSchema.getFields()) {
            inputFields.add(field.getName());
        }
        Sets.SetView missingFeatures = Sets.difference(featureSet, inputFields);
        if (!missingFeatures.isEmpty()) {
            throw new IllegalArgumentException(String.format("Input is missing feature fields %s.", Joiner.on((char)',').join((Iterable)missingFeatures)));
        }
        FileSet modelFiles = (FileSet)context.getDataset("experiment_model_components");
        this.featuregenPath = this.getComponentPath(modelFiles, "featuregen");
        if (this.featuregenPath == null) {
            throw new IllegalArgumentException(String.format("Could not find feature generation data for model '%s' in experiment '%s'. Please verify that the same model and model meta datasets used to train the model are used here.", this.conf.getModelID(), this.conf.getExperimentID()));
        }
        this.featureGenerator = new FeatureGeneratorPredictor(meta.getFeatures(), meta.getCategoricalFeatures(), this.featuregenPath);
        this.modelPath = this.getComponentPath(modelFiles, "model");
        if (this.modelPath == null) {
            throw new IllegalArgumentException(String.format("Could not find the files for model '%s' in experiment '%s'. Please verify that the model was successfully trained.", this.conf.getModelID(), this.conf.getExperimentID()));
        }
        this.targetIndexPath = this.getComponentPath(modelFiles, "targetindices");
        if (this.targetIndexPath == null && this.modeler.getAlgorithm().getType() == AlgorithmType.CLASSIFICATION && this.predictionType == Schema.Type.STRING) {
            throw new IllegalArgumentException(String.format("Could not find target index data for model '%s' in experiment '%s'. Please change the prediction field type to double.", this.conf.getModelID(), this.conf.getExperimentID()));
        }
    }

    public JavaRDD<StructuredRecord> transform(SparkExecutionPluginContext sparkExecutionPluginContext, JavaRDD<StructuredRecord> javaRDD) throws Exception {
        PredictionModel model = this.modeler.loadPredictor(this.modelPath);
        StructType rowType = (StructType)DataFrames.toDataType((Schema)this.inputSchema);
        JavaRDD rowRDD = javaRDD.map((Function)new RecordToRow(rowType));
        SQLContext sqlContext = new SQLContext(sparkExecutionPluginContext.getSparkContext().sc());
        Dataset rawData = sqlContext.createDataFrame(rowRDD, rowType);
        HashSet featureSet = new HashSet(this.featureGenerator.getFeatures());
        ArrayList<String> extraFields = new ArrayList<String>();
        for (Schema.Field outputField : this.outputSchema.getFields()) {
            String outputFieldName = outputField.getName();
            if (this.conf.getPredictionField().equals(outputFieldName) || featureSet.contains(outputFieldName)) continue;
            extraFields.add(outputFieldName);
        }
        Dataset featureData = this.featureGenerator.generateFeatures(rawData, extraFields);
        Dataset predictions = model.transform(featureData);
        if (this.modeler.getAlgorithm().getType() == AlgorithmType.CLASSIFICATION && this.predictionType == Schema.Type.STRING) {
            StringIndexerModel indexerModel = StringIndexerModel.load((String)this.targetIndexPath);
            String[] labels = indexerModel.labels();
            IndexToString reverseIndex = new IndexToString().setLabels(labels).setInputCol("_prediction").setOutputCol(this.conf.getPredictionField());
            predictions = reverseIndex.transform(predictions);
        } else {
            predictions = predictions.withColumnRenamed("_prediction", this.conf.getPredictionField());
        }
        Column[] cols = new Column[this.outputSchema.getFields().size()];
        int i = 0;
        for (Schema.Field outputField : this.outputSchema.getFields()) {
            cols[i] = new Column(outputField.getName());
            ++i;
        }
        predictions = predictions.select(cols);
        JavaRDD output = predictions.toJavaRDD().map((Function)new RowToRecord(this.outputSchema));
        return output;
    }

    @Nullable
    private String getComponentPath(FileSet modelFiles, String component) throws IOException {
        return modelFiles.getLocation(this.conf.getExperimentID()).append(this.conf.getModelID()).append(component).toURI().getPath();
    }
}

