package org.nd4j.autodiff.samediff.ops;

import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.GRUCellOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMCellOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMLayerOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.SRUCellOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.SRULayerOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;

/* loaded from: input_file:org/nd4j/autodiff/samediff/ops/SDRNN.class */
public class SDRNN extends SDOps {
    public SDRNN(SameDiff sameDiff) {
        super(sameDiff);
    }

    public GRUCellOutputs gru(@NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, @NonNull GRUWeights gRUWeights) {
        if (sDVariable == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("hLast is marked @NonNull but is null");
        }
        if (gRUWeights == null) {
            throw new NullPointerException("weights is marked @NonNull but is null");
        }
        return new GRUCellOutputs(new GRUCell(this.sd, sDVariable, sDVariable2, gRUWeights).outputVariables());
    }

    public GRUCellOutputs gru(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, @NonNull GRUWeights gRUWeights) {
        if (sDVariable == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("hLast is marked @NonNull but is null");
        }
        if (gRUWeights == null) {
            throw new NullPointerException("weights is marked @NonNull but is null");
        }
        return new GRUCellOutputs(new GRUCell(this.sd, sDVariable, sDVariable2, gRUWeights).outputVariables(str));
    }

    public LSTMCellOutputs lstmCell(@NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, @NonNull SDVariable sDVariable3, LSTMWeights lSTMWeights, LSTMConfiguration lSTMConfiguration) {
        if (sDVariable == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("cLast is marked @NonNull but is null");
        }
        if (sDVariable3 == null) {
            throw new NullPointerException("yLast is marked @NonNull but is null");
        }
        return new LSTMCellOutputs(new LSTMBlockCell(this.sd, sDVariable, sDVariable2, sDVariable3, lSTMWeights, lSTMConfiguration).outputVariables());
    }

    public LSTMCellOutputs lstmCell(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, @NonNull SDVariable sDVariable3, @NonNull LSTMWeights lSTMWeights, @NonNull LSTMConfiguration lSTMConfiguration) {
        if (sDVariable == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("cLast is marked @NonNull but is null");
        }
        if (sDVariable3 == null) {
            throw new NullPointerException("yLast is marked @NonNull but is null");
        }
        if (lSTMWeights == null) {
            throw new NullPointerException("weights is marked @NonNull but is null");
        }
        if (lSTMConfiguration == null) {
            throw new NullPointerException("config is marked @NonNull but is null");
        }
        return new LSTMCellOutputs(new LSTMBlockCell(this.sd, sDVariable, sDVariable2, sDVariable3, lSTMWeights, lSTMConfiguration).outputVariables(str));
    }

    public LSTMLayerOutputs lstmLayer(@NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, @NonNull SDVariable sDVariable3, @NonNull SDVariable sDVariable4, @NonNull LSTMWeights lSTMWeights, @NonNull LSTMConfiguration lSTMConfiguration) {
        if (sDVariable == null) {
            throw new NullPointerException("maxTSLength is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (sDVariable3 == null) {
            throw new NullPointerException("cLast is marked @NonNull but is null");
        }
        if (sDVariable4 == null) {
            throw new NullPointerException("yLast is marked @NonNull but is null");
        }
        if (lSTMWeights == null) {
            throw new NullPointerException("weights is marked @NonNull but is null");
        }
        if (lSTMConfiguration == null) {
            throw new NullPointerException("config is marked @NonNull but is null");
        }
        return new LSTMLayerOutputs(new LSTMLayer(this.sd, sDVariable, sDVariable2, sDVariable3, sDVariable4, lSTMWeights, lSTMConfiguration).outputVariables(), lSTMConfiguration.getDataFormat());
    }

    public LSTMLayerOutputs lstmLayer(int i, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, @NonNull SDVariable sDVariable3, @NonNull LSTMWeights lSTMWeights, @NonNull LSTMConfiguration lSTMConfiguration) {
        if (sDVariable == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("cLast is marked @NonNull but is null");
        }
        if (sDVariable3 == null) {
            throw new NullPointerException("yLast is marked @NonNull but is null");
        }
        if (lSTMWeights == null) {
            throw new NullPointerException("weights is marked @NonNull but is null");
        }
        if (lSTMConfiguration == null) {
            throw new NullPointerException("config is marked @NonNull but is null");
        }
        return lstmLayer(this.sd.scalar("lstm_max_ts_length", i), sDVariable, sDVariable2, sDVariable3, lSTMWeights, lSTMConfiguration);
    }

    public LSTMLayerOutputs lstmLayer(String str, int i, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, @NonNull SDVariable sDVariable3, @NonNull LSTMWeights lSTMWeights, @NonNull LSTMConfiguration lSTMConfiguration) {
        if (sDVariable == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("cLast is marked @NonNull but is null");
        }
        if (sDVariable3 == null) {
            throw new NullPointerException("yLast is marked @NonNull but is null");
        }
        if (lSTMWeights == null) {
            throw new NullPointerException("weights is marked @NonNull but is null");
        }
        if (lSTMConfiguration == null) {
            throw new NullPointerException("config is marked @NonNull but is null");
        }
        return str != null ? lstmLayer(str, this.sd.scalar(this.sd.generateDistinctCustomVariableName(str + "_max_ts_length"), i), sDVariable, sDVariable2, sDVariable3, lSTMWeights, lSTMConfiguration) : lstmLayer(i, sDVariable, sDVariable2, sDVariable3, lSTMWeights, lSTMConfiguration);
    }

    public LSTMLayerOutputs lstmLayer(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, @NonNull SDVariable sDVariable3, @NonNull SDVariable sDVariable4, @NonNull LSTMWeights lSTMWeights, @NonNull LSTMConfiguration lSTMConfiguration) {
        if (sDVariable == null) {
            throw new NullPointerException("maxTSLength is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (sDVariable3 == null) {
            throw new NullPointerException("cLast is marked @NonNull but is null");
        }
        if (sDVariable4 == null) {
            throw new NullPointerException("yLast is marked @NonNull but is null");
        }
        if (lSTMWeights == null) {
            throw new NullPointerException("weights is marked @NonNull but is null");
        }
        if (lSTMConfiguration == null) {
            throw new NullPointerException("config is marked @NonNull but is null");
        }
        return new LSTMLayerOutputs(new LSTMLayer(this.sd, sDVariable, sDVariable2, sDVariable3, sDVariable4, lSTMWeights, lSTMConfiguration).outputVariables(str), lSTMConfiguration.getDataFormat());
    }

    public SRUCellOutputs sruCell(@NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, @NonNull SRUWeights sRUWeights) {
        if (sDVariable == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("cLast is marked @NonNull but is null");
        }
        if (sRUWeights == null) {
            throw new NullPointerException("weights is marked @NonNull but is null");
        }
        return new SRUCellOutputs(new SRUCell(this.sd, sDVariable, sDVariable2, sRUWeights).outputVariables());
    }

    public SRUCellOutputs sruCell(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, @NonNull SRUWeights sRUWeights) {
        if (sDVariable == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("cLast is marked @NonNull but is null");
        }
        if (sRUWeights == null) {
            throw new NullPointerException("weights is marked @NonNull but is null");
        }
        return new SRUCellOutputs(new SRUCell(this.sd, sDVariable, sDVariable2, sRUWeights).outputVariables(str));
    }

    public SRULayerOutputs sru(@NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, @NonNull SRUWeights sRUWeights) {
        if (sDVariable == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("initialC is marked @NonNull but is null");
        }
        if (sRUWeights == null) {
            throw new NullPointerException("weights is marked @NonNull but is null");
        }
        return sru(sDVariable, sDVariable2, (SDVariable) null, sRUWeights);
    }

    public SRULayerOutputs sru(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, @NonNull SRUWeights sRUWeights) {
        if (sDVariable == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("initialC is marked @NonNull but is null");
        }
        if (sRUWeights == null) {
            throw new NullPointerException("weights is marked @NonNull but is null");
        }
        return sru(str, sDVariable, sDVariable2, null, sRUWeights);
    }

    public SRULayerOutputs sru(@NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, SDVariable sDVariable3, @NonNull SRUWeights sRUWeights) {
        if (sDVariable == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("initialC is marked @NonNull but is null");
        }
        if (sRUWeights == null) {
            throw new NullPointerException("weights is marked @NonNull but is null");
        }
        return new SRULayerOutputs(new SRU(this.sd, sDVariable, sDVariable2, sDVariable3, sRUWeights).outputVariables());
    }

    public SRULayerOutputs sru(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, SDVariable sDVariable3, @NonNull SRUWeights sRUWeights) {
        if (sDVariable == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("initialC is marked @NonNull but is null");
        }
        if (sRUWeights == null) {
            throw new NullPointerException("weights is marked @NonNull but is null");
        }
        return new SRULayerOutputs(new SRU(this.sd, sDVariable, sDVariable2, sDVariable3, sRUWeights).outputVariables(str));
    }
}
