package org.deeplearning4j.arbiter;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.deeplearning4j.arbiter.BaseNetworkSpace;
import org.deeplearning4j.arbiter.layers.LayerSpace;
import org.deeplearning4j.arbiter.layers.fixed.FixedLayerSpace;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.api.TaskCreatorProvider;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper;
import org.deeplearning4j.arbiter.optimize.serde.jackson.YamlMapper;
import org.deeplearning4j.arbiter.task.ComputationGraphTaskCreator;
import org.deeplearning4j.arbiter.util.LeafUtils;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import org.nd4j.shade.jackson.annotation.JsonTypeName;

@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY, property = "@class")
@JsonTypeName("ComputationGraphSpace")
/* loaded from: input_file:org/deeplearning4j/arbiter/ComputationGraphSpace.class */
public class ComputationGraphSpace extends BaseNetworkSpace<GraphConfiguration> {

    @JsonProperty
    protected List<BaseNetworkSpace.LayerConf> layerSpaces;

    @JsonProperty
    protected List<VertexConf> vertices;

    @JsonProperty
    protected String[] networkInputs;

    @JsonProperty
    protected String[] networkOutputs;

    @JsonProperty
    protected ParameterSpace<InputType[]> inputTypes;

    @JsonProperty
    protected int numParameters;

    @JsonProperty
    protected WorkspaceMode trainingWorkspaceMode;

    @JsonProperty
    protected WorkspaceMode inferenceWorkspaceMode;
    protected EarlyStoppingConfiguration<ComputationGraph> earlyStoppingConfiguration;

    /* loaded from: input_file:org/deeplearning4j/arbiter/ComputationGraphSpace$Builder.class */
    public static class Builder extends BaseNetworkSpace.Builder<Builder> {
        protected List<BaseNetworkSpace.LayerConf> layerList = new ArrayList();
        protected List<VertexConf> vertexList = new ArrayList();
        protected EarlyStoppingConfiguration<ComputationGraph> earlyStoppingConfiguration;
        protected String[] networkInputs;
        protected String[] networkOutputs;
        protected ParameterSpace<InputType[]> inputTypes;
        protected WorkspaceMode trainingWorkspaceMode;
        protected WorkspaceMode inferenceWorkspaceMode;

        public Builder earlyStoppingConfiguration(EarlyStoppingConfiguration<ComputationGraph> earlyStoppingConfiguration) {
            this.earlyStoppingConfiguration = earlyStoppingConfiguration;
            return this;
        }

        public Builder layer(String str, LayerSpace<? extends Layer> layerSpace, String... strArr) {
            return addLayer(str, layerSpace, strArr);
        }

        public Builder layer(String str, LayerSpace<? extends Layer> layerSpace, InputPreProcessor inputPreProcessor, String... strArr) {
            return addLayer(str, layerSpace, inputPreProcessor, strArr);
        }

        public Builder layer(String str, Layer layer, String... strArr) {
            return layer(str, new FixedLayerSpace(layer), strArr);
        }

        public Builder addLayer(String str, LayerSpace<? extends Layer> layerSpace, String... strArr) {
            this.layerList.add(new BaseNetworkSpace.LayerConf(layerSpace, str, strArr, new FixedValue(1), false, null));
            return this;
        }

        public Builder addLayer(String str, LayerSpace<? extends Layer> layerSpace, InputPreProcessor inputPreProcessor, String... strArr) {
            this.layerList.add(new BaseNetworkSpace.LayerConf(layerSpace, str, strArr, new FixedValue(1), false, inputPreProcessor));
            return this;
        }

        public Builder addVertex(String str, GraphVertex graphVertex, String... strArr) {
            this.vertexList.add(new VertexConf(graphVertex, str, strArr));
            return this;
        }

        public Builder addInputs(String... strArr) {
            this.networkInputs = strArr;
            return this;
        }

        public Builder setOutputs(String... strArr) {
            this.networkOutputs = strArr;
            return this;
        }

        public Builder setInputTypes(InputType... inputTypeArr) {
            return setInputTypes((ParameterSpace<InputType[]>) new FixedValue(inputTypeArr));
        }

        public Builder setInputTypes(ParameterSpace<InputType[]> parameterSpace) {
            this.inputTypes = parameterSpace;
            return this;
        }

        public Builder trainingWorkspaceMode(WorkspaceMode workspaceMode) {
            this.trainingWorkspaceMode = workspaceMode;
            return this;
        }

        public Builder inferenceWorkspaceMode(WorkspaceMode workspaceMode) {
            this.inferenceWorkspaceMode = workspaceMode;
            return this;
        }

        @Override // org.deeplearning4j.arbiter.BaseNetworkSpace.Builder
        public ComputationGraphSpace build() {
            return new ComputationGraphSpace(this);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/deeplearning4j/arbiter/ComputationGraphSpace$VertexConf.class */
    public static class VertexConf {
        protected GraphVertex graphVertex;
        protected String vertexName;
        protected String[] inputs;

        public VertexConf(GraphVertex graphVertex, String str, String[] strArr) {
            this.graphVertex = graphVertex;
            this.vertexName = str;
            this.inputs = strArr;
        }

        public GraphVertex getGraphVertex() {
            return this.graphVertex;
        }

        public String getVertexName() {
            return this.vertexName;
        }

        public String[] getInputs() {
            return this.inputs;
        }

        public void setGraphVertex(GraphVertex graphVertex) {
            this.graphVertex = graphVertex;
        }

        public void setVertexName(String str) {
            this.vertexName = str;
        }

        public void setInputs(String[] strArr) {
            this.inputs = strArr;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof VertexConf)) {
                return false;
            }
            VertexConf vertexConf = (VertexConf) obj;
            if (!vertexConf.canEqual(this)) {
                return false;
            }
            GraphVertex graphVertex = getGraphVertex();
            GraphVertex graphVertex2 = vertexConf.getGraphVertex();
            if (graphVertex == null) {
                if (graphVertex2 != null) {
                    return false;
                }
            } else if (!graphVertex.equals(graphVertex2)) {
                return false;
            }
            String vertexName = getVertexName();
            String vertexName2 = vertexConf.getVertexName();
            if (vertexName == null) {
                if (vertexName2 != null) {
                    return false;
                }
            } else if (!vertexName.equals(vertexName2)) {
                return false;
            }
            return Arrays.deepEquals(getInputs(), vertexConf.getInputs());
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof VertexConf;
        }

        public int hashCode() {
            GraphVertex graphVertex = getGraphVertex();
            int hashCode = (1 * 59) + (graphVertex == null ? 43 : graphVertex.hashCode());
            String vertexName = getVertexName();
            return (((hashCode * 59) + (vertexName == null ? 43 : vertexName.hashCode())) * 59) + Arrays.deepHashCode(getInputs());
        }

        public String toString() {
            return "ComputationGraphSpace.VertexConf(graphVertex=" + getGraphVertex() + ", vertexName=" + getVertexName() + ", inputs=" + Arrays.deepToString(getInputs()) + ")";
        }

        public VertexConf() {
        }
    }

    protected ComputationGraphSpace(Builder builder) {
        super(builder);
        this.layerSpaces = new ArrayList();
        this.vertices = new ArrayList();
        this.earlyStoppingConfiguration = builder.earlyStoppingConfiguration;
        this.layerSpaces = builder.layerList;
        this.vertices = builder.vertexList;
        this.networkInputs = builder.networkInputs;
        this.networkOutputs = builder.networkOutputs;
        this.inputTypes = builder.inputTypes;
        this.trainingWorkspaceMode = builder.trainingWorkspaceMode;
        this.inferenceWorkspaceMode = builder.inferenceWorkspaceMode;
        Iterator it = LeafUtils.getUniqueObjects(collectLeaves()).iterator();
        while (it.hasNext()) {
            this.numParameters += ((ParameterSpace) it.next()).numParameters();
        }
    }

    /* renamed from: getValue, reason: merged with bridge method [inline-methods] */
    public GraphConfiguration m2getValue(double[] dArr) {
        ComputationGraphConfiguration.GraphBuilder graphBuilder = randomGlobalConf(dArr).graphBuilder();
        graphBuilder.addInputs(this.networkInputs);
        graphBuilder.setOutputs(this.networkOutputs);
        if (this.inputTypes != null) {
            graphBuilder.setInputTypes((InputType[]) this.inputTypes.getValue(dArr));
        }
        for (BaseNetworkSpace.LayerConf layerConf : this.layerSpaces) {
            graphBuilder.addLayer(layerConf.getLayerName(), (Layer) layerConf.layerSpace.getValue(dArr), layerConf.getPreProcessor(), layerConf.getInputs());
        }
        for (VertexConf vertexConf : this.vertices) {
            graphBuilder.addVertex(vertexConf.getVertexName(), vertexConf.getGraphVertex(), vertexConf.getInputs());
        }
        if (this.backprop != null) {
            graphBuilder.backprop(((Boolean) this.backprop.getValue(dArr)).booleanValue());
        }
        if (this.pretrain != null) {
            graphBuilder.pretrain(((Boolean) this.pretrain.getValue(dArr)).booleanValue());
        }
        if (this.backpropType != null) {
            graphBuilder.backpropType((BackpropType) this.backpropType.getValue(dArr));
        }
        if (this.tbpttFwdLength != null) {
            graphBuilder.tBPTTForwardLength(((Integer) this.tbpttFwdLength.getValue(dArr)).intValue());
        }
        if (this.tbpttBwdLength != null) {
            graphBuilder.tBPTTBackwardLength(((Integer) this.tbpttBwdLength.getValue(dArr)).intValue());
        }
        ComputationGraphConfiguration build = graphBuilder.build();
        if (this.trainingWorkspaceMode != null) {
            build.setTrainingWorkspaceMode(this.trainingWorkspaceMode);
        }
        if (this.inferenceWorkspaceMode != null) {
            build.setInferenceWorkspaceMode(this.inferenceWorkspaceMode);
        }
        return new GraphConfiguration(build, this.earlyStoppingConfiguration, Integer.valueOf(this.numEpochs));
    }

    public int numParameters() {
        return this.numParameters;
    }

    @Override // org.deeplearning4j.arbiter.BaseNetworkSpace
    public List<ParameterSpace> collectLeaves() {
        List<ParameterSpace> collectLeaves = super.collectLeaves();
        Iterator<BaseNetworkSpace.LayerConf> it = this.layerSpaces.iterator();
        while (it.hasNext()) {
            collectLeaves.addAll(it.next().layerSpace.collectLeaves());
        }
        if (this.inputTypes != null) {
            collectLeaves.add(this.inputTypes);
        }
        return collectLeaves;
    }

    @Override // org.deeplearning4j.arbiter.BaseNetworkSpace
    public String toString() {
        StringBuilder sb = new StringBuilder(super.toString());
        for (BaseNetworkSpace.LayerConf layerConf : this.layerSpaces) {
            sb.append("Layer config: \"").append(layerConf.layerName).append("\", ").append(layerConf.layerSpace).append(", inputs: ").append(layerConf.inputs == null ? "[]" : Arrays.toString(layerConf.inputs)).append("\n");
        }
        for (VertexConf vertexConf : this.vertices) {
            sb.append("GraphVertex: \"").append(vertexConf.vertexName).append("\", ").append(vertexConf.graphVertex).append(", inputs: ").append(vertexConf.inputs == null ? "[]" : Arrays.toString(vertexConf.inputs)).append("\n");
        }
        if (this.earlyStoppingConfiguration != null) {
            sb.append("Early stopping configuration:").append(this.earlyStoppingConfiguration.toString()).append("\n");
        } else {
            sb.append("Training # epochs:").append(this.numEpochs).append("\n");
        }
        if (this.inputTypes != null) {
            sb.append("Input types: ").append(this.inputTypes).append("\n");
        }
        return sb.toString();
    }

    public static ComputationGraphSpace fromJson(String str) {
        try {
            return (ComputationGraphSpace) JsonMapper.getMapper().readValue(str, ComputationGraphSpace.class);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public static ComputationGraphSpace fromYaml(String str) {
        try {
            return (ComputationGraphSpace) YamlMapper.getMapper().readValue(str, ComputationGraphSpace.class);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private ComputationGraphSpace() {
        this.layerSpaces = new ArrayList();
        this.vertices = new ArrayList();
    }

    @Override // org.deeplearning4j.arbiter.BaseNetworkSpace
    public List<BaseNetworkSpace.LayerConf> getLayerSpaces() {
        return this.layerSpaces;
    }

    public List<VertexConf> getVertices() {
        return this.vertices;
    }

    public String[] getNetworkInputs() {
        return this.networkInputs;
    }

    public String[] getNetworkOutputs() {
        return this.networkOutputs;
    }

    public ParameterSpace<InputType[]> getInputTypes() {
        return this.inputTypes;
    }

    public int getNumParameters() {
        return this.numParameters;
    }

    public WorkspaceMode getTrainingWorkspaceMode() {
        return this.trainingWorkspaceMode;
    }

    public WorkspaceMode getInferenceWorkspaceMode() {
        return this.inferenceWorkspaceMode;
    }

    public EarlyStoppingConfiguration<ComputationGraph> getEarlyStoppingConfiguration() {
        return this.earlyStoppingConfiguration;
    }

    @Override // org.deeplearning4j.arbiter.BaseNetworkSpace
    public void setLayerSpaces(List<BaseNetworkSpace.LayerConf> list) {
        this.layerSpaces = list;
    }

    public void setVertices(List<VertexConf> list) {
        this.vertices = list;
    }

    public void setNetworkInputs(String[] strArr) {
        this.networkInputs = strArr;
    }

    public void setNetworkOutputs(String[] strArr) {
        this.networkOutputs = strArr;
    }

    public void setInputTypes(ParameterSpace<InputType[]> parameterSpace) {
        this.inputTypes = parameterSpace;
    }

    public void setNumParameters(int i) {
        this.numParameters = i;
    }

    public void setTrainingWorkspaceMode(WorkspaceMode workspaceMode) {
        this.trainingWorkspaceMode = workspaceMode;
    }

    public void setInferenceWorkspaceMode(WorkspaceMode workspaceMode) {
        this.inferenceWorkspaceMode = workspaceMode;
    }

    public void setEarlyStoppingConfiguration(EarlyStoppingConfiguration<ComputationGraph> earlyStoppingConfiguration) {
        this.earlyStoppingConfiguration = earlyStoppingConfiguration;
    }

    @Override // org.deeplearning4j.arbiter.BaseNetworkSpace
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ComputationGraphSpace)) {
            return false;
        }
        ComputationGraphSpace computationGraphSpace = (ComputationGraphSpace) obj;
        if (!computationGraphSpace.canEqual(this) || !super.equals(obj)) {
            return false;
        }
        List<BaseNetworkSpace.LayerConf> layerSpaces = getLayerSpaces();
        List<BaseNetworkSpace.LayerConf> layerSpaces2 = computationGraphSpace.getLayerSpaces();
        if (layerSpaces == null) {
            if (layerSpaces2 != null) {
                return false;
            }
        } else if (!layerSpaces.equals(layerSpaces2)) {
            return false;
        }
        List<VertexConf> vertices = getVertices();
        List<VertexConf> vertices2 = computationGraphSpace.getVertices();
        if (vertices == null) {
            if (vertices2 != null) {
                return false;
            }
        } else if (!vertices.equals(vertices2)) {
            return false;
        }
        if (!Arrays.deepEquals(getNetworkInputs(), computationGraphSpace.getNetworkInputs()) || !Arrays.deepEquals(getNetworkOutputs(), computationGraphSpace.getNetworkOutputs())) {
            return false;
        }
        ParameterSpace<InputType[]> inputTypes = getInputTypes();
        ParameterSpace<InputType[]> inputTypes2 = computationGraphSpace.getInputTypes();
        if (inputTypes == null) {
            if (inputTypes2 != null) {
                return false;
            }
        } else if (!inputTypes.equals(inputTypes2)) {
            return false;
        }
        if (getNumParameters() != computationGraphSpace.getNumParameters()) {
            return false;
        }
        WorkspaceMode trainingWorkspaceMode = getTrainingWorkspaceMode();
        WorkspaceMode trainingWorkspaceMode2 = computationGraphSpace.getTrainingWorkspaceMode();
        if (trainingWorkspaceMode == null) {
            if (trainingWorkspaceMode2 != null) {
                return false;
            }
        } else if (!trainingWorkspaceMode.equals(trainingWorkspaceMode2)) {
            return false;
        }
        WorkspaceMode inferenceWorkspaceMode = getInferenceWorkspaceMode();
        WorkspaceMode inferenceWorkspaceMode2 = computationGraphSpace.getInferenceWorkspaceMode();
        if (inferenceWorkspaceMode == null) {
            if (inferenceWorkspaceMode2 != null) {
                return false;
            }
        } else if (!inferenceWorkspaceMode.equals(inferenceWorkspaceMode2)) {
            return false;
        }
        EarlyStoppingConfiguration<ComputationGraph> earlyStoppingConfiguration = getEarlyStoppingConfiguration();
        EarlyStoppingConfiguration<ComputationGraph> earlyStoppingConfiguration2 = computationGraphSpace.getEarlyStoppingConfiguration();
        return earlyStoppingConfiguration == null ? earlyStoppingConfiguration2 == null : earlyStoppingConfiguration.equals(earlyStoppingConfiguration2);
    }

    @Override // org.deeplearning4j.arbiter.BaseNetworkSpace
    protected boolean canEqual(Object obj) {
        return obj instanceof ComputationGraphSpace;
    }

    @Override // org.deeplearning4j.arbiter.BaseNetworkSpace
    public int hashCode() {
        int hashCode = super.hashCode();
        List<BaseNetworkSpace.LayerConf> layerSpaces = getLayerSpaces();
        int hashCode2 = (hashCode * 59) + (layerSpaces == null ? 43 : layerSpaces.hashCode());
        List<VertexConf> vertices = getVertices();
        int hashCode3 = (((((hashCode2 * 59) + (vertices == null ? 43 : vertices.hashCode())) * 59) + Arrays.deepHashCode(getNetworkInputs())) * 59) + Arrays.deepHashCode(getNetworkOutputs());
        ParameterSpace<InputType[]> inputTypes = getInputTypes();
        int hashCode4 = (((hashCode3 * 59) + (inputTypes == null ? 43 : inputTypes.hashCode())) * 59) + getNumParameters();
        WorkspaceMode trainingWorkspaceMode = getTrainingWorkspaceMode();
        int hashCode5 = (hashCode4 * 59) + (trainingWorkspaceMode == null ? 43 : trainingWorkspaceMode.hashCode());
        WorkspaceMode inferenceWorkspaceMode = getInferenceWorkspaceMode();
        int hashCode6 = (hashCode5 * 59) + (inferenceWorkspaceMode == null ? 43 : inferenceWorkspaceMode.hashCode());
        EarlyStoppingConfiguration<ComputationGraph> earlyStoppingConfiguration = getEarlyStoppingConfiguration();
        return (hashCode6 * 59) + (earlyStoppingConfiguration == null ? 43 : earlyStoppingConfiguration.hashCode());
    }

    static {
        TaskCreatorProvider.registerDefaultTaskCreatorClass(ComputationGraphSpace.class, ComputationGraphTaskCreator.class);
    }
}
