/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.pytorch.engine;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.engine.Engine;
import ai.djl.engine.EngineException;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.SymbolBlock;
import ai.djl.pytorch.engine.PtGradientCollector;
import ai.djl.pytorch.engine.PtModel;
import ai.djl.pytorch.engine.PtNDManager;
import ai.djl.pytorch.engine.PtSymbolBlock;
import ai.djl.pytorch.jni.JniUtils;
import ai.djl.pytorch.jni.LibUtils;
import ai.djl.training.GradientCollector;
import ai.djl.util.RandomUtils;
import ai.djl.util.Utils;
import java.io.FileNotFoundException;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class PtEngine
extends Engine {
    private static final Logger logger = LoggerFactory.getLogger(PtEngine.class);
    public static final String ENGINE_NAME = "PyTorch";
    static final int RANK = 2;

    private PtEngine() {
    }

    static Engine newInstance() {
        try {
            LibUtils.loadLibrary();
            JniUtils.setGradMode(false);
            if (Integer.getInteger("ai.djl.pytorch.num_interop_threads") != null) {
                JniUtils.setNumInteropThreads(Integer.getInteger("ai.djl.pytorch.num_interop_threads"));
            }
            if (Integer.getInteger("ai.djl.pytorch.num_threads") != null) {
                JniUtils.setNumThreads(Integer.getInteger("ai.djl.pytorch.num_threads"));
            }
            if (Boolean.getBoolean("ai.djl.pytorch.cudnn_benchmark")) {
                JniUtils.setBenchmarkCuDNN(true);
            }
            logger.info("Number of inter-op threads is " + JniUtils.getNumInteropThreads());
            logger.info("Number of intra-op threads is " + JniUtils.getNumThreads());
            String paths = Utils.getEnvOrSystemProperty((String)"PYTORCH_EXTRA_LIBRARY_PATH");
            if (paths != null) {
                String[] files;
                for (String file : files = paths.split(",")) {
                    Path path = Paths.get(file, new String[0]);
                    if (Files.notExists(path, new LinkOption[0])) {
                        throw new FileNotFoundException("PyTorch extra Library not found: " + file);
                    }
                    System.load(path.toAbsolutePath().toString());
                }
            }
            return new PtEngine();
        }
        catch (Throwable t) {
            throw new EngineException("Failed to load PyTorch native library", t);
        }
    }

    public Engine getAlternativeEngine() {
        return null;
    }

    public String getEngineName() {
        return ENGINE_NAME;
    }

    public int getRank() {
        return 2;
    }

    public String getVersion() {
        return LibUtils.getVersion();
    }

    public boolean hasCapability(String capability) {
        return JniUtils.getFeatures().contains(capability);
    }

    public SymbolBlock newSymbolBlock(NDManager manager) {
        return new PtSymbolBlock((PtNDManager)manager);
    }

    public Model newModel(String name, Device device) {
        return new PtModel(name, device);
    }

    public NDManager newBaseManager() {
        return PtNDManager.getSystemManager().newSubManager();
    }

    public NDManager newBaseManager(Device device) {
        return PtNDManager.getSystemManager().newSubManager(device);
    }

    public GradientCollector newGradientCollector() {
        return new PtGradientCollector();
    }

    public void setRandomSeed(int seed) {
        super.setRandomSeed(seed);
        JniUtils.setSeed(seed);
        RandomUtils.RANDOM.setSeed(seed);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append(this.getEngineName()).append(':').append(this.getVersion()).append(", capabilities: [\n");
        for (String feature : JniUtils.getFeatures()) {
            sb.append("\t").append(feature).append(",\n");
        }
        sb.append("]\nPyTorch Library: ").append(LibUtils.getLibtorchPath());
        return sb.toString();
    }
}

