package org.nd4j.parameterserver.distributed.training;

import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.ServiceLoader;
import lombok.NonNull;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.nd4j.parameterserver.distributed.logic.Storage;
import org.nd4j.parameterserver.distributed.logic.completion.Clipboard;
import org.nd4j.parameterserver.distributed.messages.TrainingMessage;
import org.nd4j.parameterserver.distributed.transport.Transport;

@Deprecated
/* loaded from: input_file:org/nd4j/parameterserver/distributed/training/TrainerProvider.class */
public class TrainerProvider {
    private static final TrainerProvider INSTANCE = new TrainerProvider();
    protected Map<String, TrainingDriver<?>> trainers = new HashMap();
    protected VoidConfiguration voidConfiguration;
    protected Transport transport;
    protected Clipboard clipboard;
    protected Storage storage;

    private TrainerProvider() {
        loadProviders();
    }

    public static TrainerProvider getInstance() {
        return INSTANCE;
    }

    protected void loadProviders() {
        Iterator it = ServiceLoader.load(TrainingDriver.class).iterator();
        while (it.hasNext()) {
            TrainingDriver<?> trainingDriver = (TrainingDriver) it.next();
            this.trainers.put(trainingDriver.targetMessageClass(), trainingDriver);
        }
        if (this.trainers.size() < 1) {
            throw new ND4JIllegalStateException("No TrainingDrivers were found via ServiceLoader mechanism");
        }
    }

    public void init(@NonNull VoidConfiguration voidConfiguration, @NonNull Transport transport, @NonNull Storage storage, @NonNull Clipboard clipboard) {
        if (voidConfiguration == null) {
            throw new NullPointerException("voidConfiguration is marked @NonNull but is null");
        }
        if (transport == null) {
            throw new NullPointerException("transport is marked @NonNull but is null");
        }
        if (storage == null) {
            throw new NullPointerException("storage is marked @NonNull but is null");
        }
        if (clipboard == null) {
            throw new NullPointerException("clipboard is marked @NonNull but is null");
        }
        this.voidConfiguration = voidConfiguration;
        this.transport = transport;
        this.clipboard = clipboard;
        this.storage = storage;
        Iterator<TrainingDriver<?>> it = this.trainers.values().iterator();
        while (it.hasNext()) {
            it.next().init(voidConfiguration, transport, storage, clipboard);
        }
    }

    protected <T extends TrainingMessage> TrainingDriver<T> getTrainer(T t) {
        TrainingDriver<T> trainingDriver = (TrainingDriver) this.trainers.get(t.getClass().getSimpleName());
        if (trainingDriver == null) {
            throw new ND4JIllegalStateException("Can't find trainer for [" + t.getClass().getSimpleName() + "]");
        }
        return trainingDriver;
    }

    public <T extends TrainingMessage> void doTraining(T t) {
        getTrainer(t).startTraining(t);
    }
}
