/*
 * Decompiled with CFR 0.152.
 */
package io.mxnet.caffetranslator;

import io.mxnet.caffetranslator.CaffePrototxtLexer;
import io.mxnet.caffetranslator.CaffePrototxtParser;
import io.mxnet.caffetranslator.CreateModelListener;
import io.mxnet.caffetranslator.GenerationHelper;
import io.mxnet.caffetranslator.GeneratorOutput;
import io.mxnet.caffetranslator.Layer;
import io.mxnet.caffetranslator.MLModel;
import io.mxnet.caffetranslator.Optimizer;
import io.mxnet.caffetranslator.Solver;
import io.mxnet.caffetranslator.SymbolGenerator;
import io.mxnet.caffetranslator.SymbolGeneratorFactory;
import io.mxnet.caffetranslator.Utils;
import io.mxnet.caffetranslator.generators.AccuracyMetricsGenerator;
import io.mxnet.caffetranslator.generators.BatchNormGenerator;
import io.mxnet.caffetranslator.generators.ConcatGenerator;
import io.mxnet.caffetranslator.generators.ConvolutionGenerator;
import io.mxnet.caffetranslator.generators.DeconvolutionGenerator;
import io.mxnet.caffetranslator.generators.DropoutGenerator;
import io.mxnet.caffetranslator.generators.EltwiseGenerator;
import io.mxnet.caffetranslator.generators.FCGenerator;
import io.mxnet.caffetranslator.generators.FlattenGenerator;
import io.mxnet.caffetranslator.generators.PermuteGenerator;
import io.mxnet.caffetranslator.generators.PluginIntLayerGenerator;
import io.mxnet.caffetranslator.generators.PluginLossGenerator;
import io.mxnet.caffetranslator.generators.PoolingGenerator;
import io.mxnet.caffetranslator.generators.PowerGenerator;
import io.mxnet.caffetranslator.generators.ReluGenerator;
import io.mxnet.caffetranslator.generators.ScaleGenerator;
import io.mxnet.caffetranslator.generators.SoftmaxOutputGenerator;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import org.antlr.v4.runtime.CharStream;
import org.antlr.v4.runtime.CharStreams;
import org.antlr.v4.runtime.CommonTokenStream;
import org.stringtemplate.v4.ST;
import org.stringtemplate.v4.STGroup;
import org.stringtemplate.v4.STRawGroupDir;

public class Converter {
    private final String trainPrototxt;
    private final String solverPrototxt;
    private final MLModel mlModel;
    private final STGroup stGroup;
    private final SymbolGeneratorFactory generators;
    private final String NL;
    private final GenerationHelper gh;
    private String paramsFilePath;
    private Solver solver;

    Converter(String trainPrototxt, String solverPrototxt) {
        this.trainPrototxt = trainPrototxt;
        this.solverPrototxt = solverPrototxt;
        this.mlModel = new MLModel();
        this.stGroup = new STRawGroupDir("templates");
        this.generators = SymbolGeneratorFactory.getInstance();
        this.NL = System.getProperty("line.separator");
        this.gh = new GenerationHelper();
        this.addGenerators();
    }

    private void addGenerators() {
        this.generators.addGenerator("Convolution", new ConvolutionGenerator());
        this.generators.addGenerator("Deconvolution", new DeconvolutionGenerator());
        this.generators.addGenerator("Pooling", new PoolingGenerator());
        this.generators.addGenerator("InnerProduct", new FCGenerator());
        this.generators.addGenerator("ReLU", new ReluGenerator());
        this.generators.addGenerator("SoftmaxWithLoss", new SoftmaxOutputGenerator());
        this.generators.addGenerator("PluginIntLayerGenerator", new PluginIntLayerGenerator());
        this.generators.addGenerator("CaffePluginLossLayer", new PluginLossGenerator());
        this.generators.addGenerator("Permute", new PermuteGenerator());
        this.generators.addGenerator("Concat", new ConcatGenerator());
        this.generators.addGenerator("BatchNorm", new BatchNormGenerator());
        this.generators.addGenerator("Power", new PowerGenerator());
        this.generators.addGenerator("Eltwise", new EltwiseGenerator());
        this.generators.addGenerator("Flatten", new FlattenGenerator());
        this.generators.addGenerator("Dropout", new DropoutGenerator());
        this.generators.addGenerator("Scale", new ScaleGenerator());
    }

    public boolean parseTrainingPrototxt() {
        CharStream cs = null;
        try {
            FileInputStream fis = new FileInputStream(new File(this.trainPrototxt));
            cs = CharStreams.fromStream(fis, StandardCharsets.UTF_8);
        }
        catch (IOException e) {
            System.err.println("Unable to read prototxt: " + this.trainPrototxt);
            return false;
        }
        CaffePrototxtLexer lexer = new CaffePrototxtLexer(cs);
        CommonTokenStream tokens = new CommonTokenStream(lexer);
        CaffePrototxtParser parser = new CaffePrototxtParser(tokens);
        CreateModelListener modelCreator = new CreateModelListener(parser, this.mlModel);
        parser.addParseListener(modelCreator);
        parser.prototxt();
        return true;
    }

    public boolean parseSolverPrototxt() {
        this.solver = new Solver(this.solverPrototxt);
        return this.solver.parsePrototxt();
    }

    public String generateMXNetCode() {
        if (!this.parseTrainingPrototxt()) {
            return "";
        }
        if (!this.parseSolverPrototxt()) {
            return "";
        }
        StringBuilder code = new StringBuilder();
        code.append(this.generateImports());
        code.append(System.lineSeparator());
        code.append(this.generateLogger());
        code.append(System.lineSeparator());
        code.append(this.generateParamInitializer());
        code.append(System.lineSeparator());
        code.append(this.generateMetricsClasses());
        code.append(System.lineSeparator());
        if (this.paramsFilePath != null) {
            code.append(this.generateParamsLoader());
            code.append(System.lineSeparator());
        }
        code.append((CharSequence)this.generateIterators());
        code.append((CharSequence)this.generateInputVars());
        List<Layer> layers = this.mlModel.getNonDataLayers();
        int layerIndex = 0;
        while (layerIndex < layers.size()) {
            Layer layer = layers.get(layerIndex);
            SymbolGenerator generator = this.generators.getGenerator(layer.getType());
            if (generator == null) {
                if (layer.getType().toLowerCase().endsWith("loss") && !layer.getType().equalsIgnoreCase("Accuracy")) {
                    generator = this.generators.getGenerator("CaffePluginLossLayer");
                } else if (!layer.getType().equalsIgnoreCase("Accuracy")) {
                    generator = this.generators.getGenerator("PluginIntLayerGenerator");
                }
            }
            if (generator != null) {
                GeneratorOutput out = generator.generate(layer, this.mlModel);
                String segment = out.code;
                code.append(segment);
                code.append(this.NL);
                layerIndex += out.numLayersTranslated;
                continue;
            }
            ++layerIndex;
        }
        String loss = this.getLoss(this.mlModel, code);
        String evalMetric = this.generateValidationMetrics(this.mlModel);
        code.append(evalMetric);
        String runner = this.generateRunner(loss);
        code.append(runner);
        return code.toString();
    }

    private String generateLogger() {
        ST st = this.gh.getTemplate("logging");
        st.add("name", this.mlModel.getName());
        return st.render();
    }

    private String generateRunner(String loss) {
        ST st = this.gh.getTemplate("runner");
        st.add("max_iter", this.solver.getProperty("max_iter"));
        st.add("stepsize", this.solver.getProperty("stepsize"));
        st.add("snapshot", this.solver.getProperty("snapshot"));
        st.add("test_interval", this.solver.getProperty("test_interval"));
        st.add("test_iter", this.solver.getProperty("test_iter"));
        st.add("snapshot_prefix", this.solver.getProperty("snapshot_prefix"));
        st.add("train_data_itr", this.getIteratorName("TRAIN"));
        st.add("test_data_itr", this.getIteratorName("TEST"));
        String context = this.solver.getProperty("solver_mode", "cpu").toLowerCase();
        context = String.format("mx.%s()", context);
        st.add("ctx", context);
        st.add("loss", loss);
        st.add("data_names", this.getDataNames());
        st.add("label_names", this.getLabelNames());
        st.add("init_params", this.generateInitializer());
        st.add("init_optimizer", this.generateOptimizer());
        st.add("gamma", this.solver.getProperty("gamma"));
        st.add("power", this.solver.getProperty("power"));
        st.add("lr_update", this.generateLRUpdate());
        return st.render();
    }

    private String generateParamInitializer() {
        return this.gh.getTemplate("param_initializer").render();
    }

    private String generateMetricsClasses() {
        ST st = this.gh.getTemplate("metrics_classes");
        String display = this.solver.getProperty("display");
        String average_loss = this.solver.getProperty("average_loss");
        if (display != null) {
            st.add("display", display);
        }
        if (average_loss != null) {
            st.add("average_loss", average_loss);
        }
        return st.render();
    }

    private String generateParamsLoader() {
        return this.gh.getTemplate("params_loader").render();
    }

    private String getLoss(MLModel model, StringBuilder out) {
        ArrayList<String> losses = new ArrayList<String>();
        for (Layer layer : model.getLayerList()) {
            if (!layer.getType().toLowerCase().endsWith("loss")) continue;
            losses.add(this.gh.getVarname(layer.getTop()));
        }
        if (losses.size() == 1) {
            return (String)losses.get(0);
        }
        if (losses.size() > 1) {
            String loss_var = "combined_loss";
            ST st = this.gh.getTemplate("group");
            st.add("var", loss_var);
            st.add("symbols", losses);
            out.append(st.render());
            return loss_var;
        }
        System.err.println("No loss found");
        return "unknown_loss";
    }

    private String generateLRUpdate() {
        String code;
        String lrPolicy;
        switch (lrPolicy = this.solver.getProperty("lr_policy", "fixed").toLowerCase()) {
            case "fixed": {
                code = "";
                break;
            }
            case "multistep": {
                ST st = this.gh.getTemplate("lrpolicy_multistep");
                st.add("steps", this.solver.getProperties("stepvalue"));
                code = st.render();
                break;
            }
            case "step": 
            case "exp": 
            case "inv": 
            case "poly": 
            case "sigmoid": {
                ST st = this.gh.getTemplate("lrpolicy_" + lrPolicy);
                code = st.render();
                break;
            }
            default: {
                String message = "Unknown lr_policy: " + lrPolicy;
                System.err.println(message);
                code = "# " + message + System.lineSeparator();
            }
        }
        return Utils.indent(code, 2, true, 4);
    }

    private String generateValidationMetrics(MLModel mlModel) {
        return new AccuracyMetricsGenerator().generate(mlModel);
    }

    private String generateOptimizer() {
        Optimizer optimizer = new Optimizer(this.solver);
        return optimizer.generateInitCode();
    }

    private String generateInitializer() {
        ST st = this.gh.getTemplate("init_params");
        st.add("params_file", this.paramsFilePath);
        return st.render();
    }

    private String generateImports() {
        return this.gh.getTemplate("imports").render();
    }

    private StringBuilder generateIterators() {
        StringBuilder code = new StringBuilder();
        for (Layer layer : this.mlModel.getDataLayers()) {
            String iterator = this.generateIterator(layer);
            code.append(iterator);
        }
        return code;
    }

    private String getIteratorName(String phase) {
        for (Layer layer : this.mlModel.getDataLayers()) {
            String layerPhase = layer.getAttr("include.phase", phase);
            if (!phase.equalsIgnoreCase(layerPhase)) continue;
            return layerPhase.toLowerCase() + "_" + layer.getName() + "_itr";
        }
        return null;
    }

    private List<String> getDataNames() {
        return this.getDataNames(0);
    }

    private List<String> getLabelNames() {
        return this.getDataNames(1);
    }

    private List<String> getDataNames(int topIndex) {
        ArrayList<String> dataList = new ArrayList<String>();
        for (Layer layer : this.mlModel.getDataLayers()) {
            String dataName;
            if (!layer.getAttr("include.phase").equalsIgnoreCase("train") || (dataName = layer.getTops().get(topIndex)) == null) continue;
            dataList.add(String.format("'%s'", dataName));
        }
        return dataList;
    }

    private StringBuilder generateInputVars() {
        StringBuilder code = new StringBuilder();
        HashSet<String> tops = new HashSet<String>();
        for (Layer layer : this.mlModel.getDataLayers()) {
            for (String top : layer.getTops()) {
                tops.add(top);
            }
        }
        for (String top : tops) {
            code.append(this.gh.generateVar(this.gh.getVarname(top), top, null, null, null, null));
        }
        code.append(System.lineSeparator());
        return code;
    }

    private String generateIterator(Layer layer) {
        String iteratorName = layer.getAttr("include.phase");
        iteratorName = iteratorName.toLowerCase();
        iteratorName = iteratorName + "_" + layer.getName() + "_itr";
        ST st = this.stGroup.getInstanceOf("iterator");
        String prototxt = layer.getPrototxt();
        prototxt = prototxt.replace("\r", "");
        prototxt = prototxt.replace("\n", " \\\n");
        prototxt = "'" + prototxt + "'";
        prototxt = Utils.indent(prototxt, 1, true, 4);
        st.add("iter_name", iteratorName);
        st.add("prototxt", prototxt);
        String dataName = "???";
        if (layer.getTops().size() >= 1) {
            dataName = layer.getTops().get(0);
        } else {
            System.err.println(String.format("Data layer %s doesn't have data", layer.getName()));
        }
        st.add("data_name", dataName);
        String labelName = "???";
        if (layer.getTops().size() >= 1) {
            labelName = layer.getTops().get(1);
        } else {
            System.err.println(String.format("Data layer %s doesn't have label", layer.getName()));
        }
        st.add("label_name", labelName);
        if (layer.hasAttr("data_param.num_examples")) {
            st.add("num_examples", layer.getAttr("data_param.num_examples"));
        }
        return st.render();
    }

    public void setParamsFilePath(String paramsFilePath) {
        this.paramsFilePath = paramsFilePath;
    }
}

