package org.deeplearning4j.rl4j.policy;

import java.beans.ConstructorProperties;
import java.util.Random;
import org.deeplearning4j.rl4j.learning.StepCountable;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/rl4j/policy/EpsGreedy.class */
public class EpsGreedy<O extends Encodable, A, AS extends ActionSpace<A>> extends Policy<O, A> {
    private final Logger log = LoggerFactory.getLogger("EpsGreedy");
    private final Policy<O, A> policy;
    private final MDP<O, A, AS> mdp;
    private final int updateStart;
    private final int epsilonNbStep;
    private final Random rd;
    private final float minEpsilon;
    private final StepCountable learning;

    @Override // org.deeplearning4j.rl4j.policy.Policy
    public A nextAction(INDArray iNDArray) {
        float epsilon = getEpsilon();
        if (this.learning.getStepCounter() % 500 == 1) {
            this.log.info("EP: " + epsilon + " " + this.learning.getStepCounter());
        }
        return this.rd.nextFloat() > epsilon ? this.policy.nextAction(iNDArray) : (A) this.mdp.getActionSpace().randomAction();
    }

    public float getEpsilon() {
        return Math.min(1.0f, Math.max(this.minEpsilon, 1.0f - (((this.learning.getStepCounter() - this.updateStart) * 1.0f) / this.epsilonNbStep)));
    }

    @ConstructorProperties({"policy", "mdp", "updateStart", "epsilonNbStep", "rd", "minEpsilon", "learning"})
    public EpsGreedy(Policy<O, A> policy, MDP<O, A, AS> mdp, int i, int i2, Random random, float f, StepCountable stepCountable) {
        this.policy = policy;
        this.mdp = mdp;
        this.updateStart = i;
        this.epsilonNbStep = i2;
        this.rd = random;
        this.minEpsilon = f;
        this.learning = stepCountable;
    }
}
