/*
 * Decompiled with CFR 0.152.
 */
package org.anchoranalysis.inference.concurrency;

import java.util.Optional;
import java.util.concurrent.PriorityBlockingQueue;
import org.anchoranalysis.core.functional.FunctionalIterate;
import org.anchoranalysis.core.functional.checked.CheckedFunction;
import org.anchoranalysis.core.log.Logger;
import org.anchoranalysis.inference.InferenceModel;
import org.anchoranalysis.inference.concurrency.ConcurrencyPlan;
import org.anchoranalysis.inference.concurrency.ConcurrentModel;
import org.anchoranalysis.inference.concurrency.ConcurrentModelException;
import org.anchoranalysis.inference.concurrency.CreateModelFailedException;
import org.anchoranalysis.inference.concurrency.CreateModelForPool;
import org.anchoranalysis.inference.concurrency.GPUMessageLogger;
import org.anchoranalysis.inference.concurrency.WithPriority;

public class ConcurrentModelPool<T extends InferenceModel>
implements AutoCloseable {
    private PriorityBlockingQueue<WithPriority<ConcurrentModel<T>>> queue;
    private final CreateModelForPool<T> createModel;

    public ConcurrentModelPool(ConcurrencyPlan plan, CreateModelForPool<T> createModel, Logger logger) throws CreateModelFailedException {
        this.createModel = createModel;
        this.queue = new PriorityBlockingQueue();
        int gpusAdded = this.addNumberModels(plan.numberGPUs(), true, createModel);
        GPUMessageLogger.maybeLog(plan.numberGPUs(), gpusAdded, logger.messageLogger());
        this.addNumberModels(plan.numberCPUs() - gpusAdded, false, createModel);
    }

    public <S> S executeOrWait(CheckedFunction<ConcurrentModel<T>, S, ConcurrentModelException> functionToExecute) throws Throwable {
        while (true) {
            WithPriority<ConcurrentModel<T>> model = this.getOrWait();
            try {
                Object returnValue = functionToExecute.apply(model.get());
                this.giveBack(model);
                return (S)returnValue;
            }
            catch (ConcurrentModelException e) {
                if (model.isGPU()) {
                    this.addNumberModels(1, false, this.createModel);
                    continue;
                }
                throw e.getCause();
            }
            break;
        }
    }

    @Override
    public void close() throws Exception {
        for (WithPriority<ConcurrentModel<T>> model : this.queue) {
            model.get().getModel().close();
        }
    }

    private WithPriority<ConcurrentModel<T>> getOrWait() throws InterruptedException {
        return this.queue.take();
    }

    private void giveBack(WithPriority<ConcurrentModel<T>> model) {
        this.queue.put(model);
    }

    private int addNumberModels(int numberModels, boolean useGPU, CreateModelForPool<T> createModel) throws CreateModelFailedException {
        return FunctionalIterate.repeatCountSuccessful((int)numberModels, () -> this.addModelCatchGPUException(useGPU, createModel));
    }

    private boolean addModelCatchGPUException(boolean useGPU, CreateModelForPool<T> createModel) throws CreateModelFailedException {
        if (useGPU) {
            try {
                return this.createAdd(true, createModel);
            }
            catch (CreateModelFailedException e) {
                return false;
            }
        }
        return this.createAdd(false, createModel);
    }

    private boolean createAdd(boolean useGPU, CreateModelForPool<T> createModel) throws CreateModelFailedException {
        Optional<ConcurrentModel<T>> model = createModel.create(useGPU);
        if (model.isPresent()) {
            WithPriority<ConcurrentModel<T>> priority = new WithPriority<ConcurrentModel<T>>(model.get(), useGPU);
            this.queue.add(priority);
            return true;
        }
        return false;
    }
}

