/*
 * Decompiled with CFR 0.152.
 */
package dev.yavuztas.junit;

import dev.yavuztas.junit.ConcurrentTest;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Method;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.InvocationInterceptor;
import org.junit.jupiter.api.extension.ReflectiveInvocationContext;
import org.junit.platform.commons.util.AnnotationUtils;
import org.junit.platform.commons.util.ClassUtils;
import org.junit.platform.commons.util.Preconditions;
import org.junit.platform.commons.util.ReflectionUtils;

public class ConcurrentExtension
implements InvocationInterceptor {
    private int globalThreadCount;

    public static ConcurrentExtension withGlobalThreadCount(int threadCount) {
        ConcurrentExtension instance = new ConcurrentExtension();
        instance.globalThreadCount = threadCount;
        return instance;
    }

    public void interceptTestMethod(InvocationInterceptor.Invocation<Void> invocation, ReflectiveInvocationContext<Method> invocationContext, ExtensionContext extensionContext) throws Throwable {
        Method testMethod = (Method)invocationContext.getExecutable();
        Optional annotation = AnnotationUtils.findAnnotation((AnnotatedElement)testMethod, ConcurrentTest.class);
        if (!annotation.isPresent()) {
            invocation.proceed();
            return;
        }
        ConcurrentTest concurrentTest = (ConcurrentTest)annotation.get();
        Throwable[] exception = new Throwable[1];
        int threadCount = this.threadCount(concurrentTest, testMethod);
        ExecutorService executorService = Executors.newFixedThreadPool(threadCount);
        for (int i = 0; i < threadCount; ++i) {
            CompletableFuture.runAsync(() -> {
                try {
                    if (concurrentTest.printInfo()) {
                        this.printInfo(testMethod);
                    }
                    ReflectionUtils.invokeMethod((Method)testMethod, invocationContext.getTarget().orElse(null), (Object[])invocationContext.getArguments().toArray());
                }
                catch (Throwable t) {
                    exception[0] = t;
                }
            }, executorService);
        }
        this.awaitTerminationAfterShutdown(executorService, this.timeout(invocationContext.getTargetClass(), testMethod));
        if (exception[0] != null) {
            throw exception[0];
        }
        invocation.skip();
    }

    private void awaitTerminationAfterShutdown(ExecutorService threadPool, Timeout timeout) {
        threadPool.shutdown();
        try {
            if (!threadPool.awaitTermination(timeout != null ? timeout.value() : Long.MAX_VALUE, timeout != null ? timeout.unit() : TimeUnit.NANOSECONDS)) {
                threadPool.shutdownNow();
            }
        }
        catch (InterruptedException ex) {
            threadPool.shutdownNow();
            Thread.currentThread().interrupt();
        }
    }

    private void printInfo(Method testMethod) {
        String message = String.format("Thread#%s - %s(%s)", Thread.currentThread().getId(), testMethod.getName(), ClassUtils.nullSafeToString(Class::getSimpleName, (Class[])testMethod.getParameterTypes()));
        System.out.println(message);
    }

    private int threadCount(ConcurrentTest concurrent, Method method) {
        int count = concurrent.count();
        Preconditions.condition((count > 0 ? 1 : 0) != 0, () -> String.format("Configuration error: @ConcurrentTest on method [%s] must be declared with a positive 'count'.", method));
        return !concurrent.overrideGlobal() && this.globalThreadCount > 0 ? this.globalThreadCount : count;
    }

    private Timeout timeout(Class<?> clazz, Method method) {
        Optional methodTimeout = AnnotationUtils.findAnnotation((AnnotatedElement)method, Timeout.class);
        if (!methodTimeout.isPresent()) {
            return AnnotationUtils.findAnnotation(clazz, Timeout.class).orElse(null);
        }
        return (Timeout)methodTimeout.get();
    }
}

