/*
 * Decompiled with CFR 0.152.
 */
package io.mxnet.caffetranslator;

import io.mxnet.caffetranslator.GeneratorOutput;
import io.mxnet.caffetranslator.Layer;
import java.util.ArrayList;
import java.util.List;
import org.stringtemplate.v4.ST;
import org.stringtemplate.v4.STErrorListener;
import org.stringtemplate.v4.STGroup;
import org.stringtemplate.v4.STGroupFile;
import org.stringtemplate.v4.STRawGroupDir;
import org.stringtemplate.v4.misc.STMessage;

public class GenerationHelper {
    protected final STGroup stGroupDir = new STRawGroupDir("templates");
    protected final STGroup stGroupFile = new STGroupFile("templates/symbols.stg");

    public GenerationHelper() {
        SuppressSTErrorsListener errListener = new SuppressSTErrorsListener();
        this.stGroupDir.setListener(errListener);
        this.stGroupFile.setListener(errListener);
    }

    public ST getTemplate(String name) {
        ST st = this.stGroupDir.getInstanceOf(name);
        if (st != null) {
            return st;
        }
        return this.stGroupFile.getInstanceOf(name);
    }

    public String generateVar(String varName, String symName, String lr_mult, String wd_mult, String init, List<Integer> shape) {
        ST st = this.getTemplate("var");
        st.add("var", varName);
        st.add("name", symName);
        st.add("lr_mult", lr_mult);
        st.add("wd_mult", wd_mult);
        st.add("init", init);
        st.add("shape", shape);
        return st.render();
    }

    public String getInit(String fillerType, String fillerValue) {
        String initializer;
        if (fillerType == null && fillerValue == null) {
            return null;
        }
        if (fillerType == null) {
            fillerType = "constant";
        }
        if (fillerValue == null) {
            fillerValue = "0";
        }
        switch (fillerType) {
            case "xavier": {
                initializer = "mx.initializer.Xavier()";
                break;
            }
            case "gaussian": {
                initializer = "mx.initializer.Normal()";
                break;
            }
            case "constant": {
                initializer = String.format("mx.initializer.Constant(%s)", fillerValue);
                break;
            }
            case "bilinear": {
                initializer = "mx.initializer.Bilinear()";
                break;
            }
            default: {
                initializer = "UnknownInitializer";
                System.err.println("Initializer " + fillerType + " not supported");
            }
        }
        return initializer;
    }

    public String getVarname(String name) {
        StringBuilder sb = new StringBuilder(name);
        for (int i = 0; i < sb.length(); ++i) {
            char ch = sb.charAt(i);
            if (Character.isLetter(ch) || Character.isDigit(ch) || ch == '_') continue;
            sb.replace(i, i + 1, "_");
        }
        return sb.toString();
    }

    public List<String> getVarNames(List<String> names) {
        ArrayList<String> list = new ArrayList<String>();
        for (String name : names) {
            list.add(this.getVarname(name));
        }
        return list;
    }

    public void fillNameDataAndVar(ST st, Layer layer) {
        st.add("name", layer.getName());
        st.add("data", this.getVarname(layer.getBottom()));
        st.add("var", this.getVarname(layer.getTop()));
    }

    public void simpleFillTemplate(ST st, String name, Layer layer, String key, String defaultValue, String ... altKeys) {
        String value = layer.getAttr(key);
        if (value == null) {
            String altKey;
            String[] stringArray = altKeys;
            int n = stringArray.length;
            for (int i = 0; i < n && (value = layer.getAttr(altKey = stringArray[i])) == null; ++i) {
            }
        }
        if (value == null && defaultValue != null) {
            value = defaultValue;
        }
        if (value == null) {
            System.err.println(String.format("Layer %s does not contain attribute %s or alternates", layer.getName(), key));
            value = "???";
        }
        st.add(name, value);
    }

    public GeneratorOutput makeGeneratorOutput(String code, int numLayersTranslated) {
        return new GeneratorOutput(code, numLayersTranslated);
    }

    public String initializeParam(String varname, int childIndex, String initializer) {
        StringBuilder out = new StringBuilder();
        out.append(String.format("param_initializer.add_param(%s.get_children()[%d].name, %s)", varname, childIndex, initializer));
        out.append(System.lineSeparator());
        return out.toString();
    }

    private class SuppressSTErrorsListener
    implements STErrorListener {
        private SuppressSTErrorsListener() {
        }

        @Override
        public void compileTimeError(STMessage msg) {
        }

        @Override
        public void runTimeError(STMessage msg) {
        }

        @Override
        public void IOError(STMessage msg) {
            throw new RuntimeException(msg.toString());
        }

        @Override
        public void internalError(STMessage msg) {
            throw new RuntimeException(msg.toString());
        }
    }
}

