package io.trino.server;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.reflect.TypeToken;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
import io.airlift.concurrent.BoundedExecutor;
import io.airlift.concurrent.MoreFutures;
import io.airlift.jaxrs.AsyncResponseHandler;
import io.airlift.log.Logger;
import io.airlift.slice.Slice;
import io.airlift.stats.TimeStat;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import io.trino.Session;
import io.trino.TrinoMediaTypes;
import io.trino.execution.FailureInjector;
import io.trino.execution.SqlTaskManager;
import io.trino.execution.TaskId;
import io.trino.execution.TaskInfo;
import io.trino.execution.TaskState;
import io.trino.execution.buffer.BufferResult;
import io.trino.execution.buffer.OutputBuffers;
import io.trino.metadata.SessionPropertyManager;
import io.trino.server.security.ResourceSecurity;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import javax.inject.Inject;
import javax.ws.rs.Consumes;
import javax.ws.rs.DELETE;
import javax.ws.rs.DefaultValue;
import javax.ws.rs.GET;
import javax.ws.rs.HeaderParam;
import javax.ws.rs.POST;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.QueryParam;
import javax.ws.rs.container.AsyncResponse;
import javax.ws.rs.container.Suspended;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.GenericEntity;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.UriInfo;
import org.weakref.jmx.Managed;
import org.weakref.jmx.Nested;

@Path("/v1/task")
/* loaded from: input_file:io/trino/server/TaskResource.class */
public class TaskResource {
    private static final Logger log = Logger.get(TaskResource.class);
    private static final Duration ADDITIONAL_WAIT_TIME = new Duration(5.0d, TimeUnit.SECONDS);
    private static final Duration DEFAULT_MAX_WAIT_TIME = new Duration(2.0d, TimeUnit.SECONDS);
    private final SqlTaskManager taskManager;
    private final SessionPropertyManager sessionPropertyManager;
    private final Executor responseExecutor;
    private final ScheduledExecutorService timeoutExecutor;
    private final FailureInjector failureInjector;
    private final TimeStat readFromOutputBufferTime = new TimeStat();
    private final TimeStat resultsRequestTime = new TimeStat();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/server/TaskResource$RequestType.class */
    public enum RequestType {
        CREATE_OR_UPDATE_TASK(true),
        GET_TASK_INFO(true),
        GET_TASK_STATUS(true),
        ACKNOWLEDGE_AND_GET_NEW_DYNAMIC_FILTER_DOMAINS(true),
        GET_RESULTS(false),
        DESTROY_RESULTS(false);

        private final boolean taskManagement;

        RequestType(boolean z) {
            this.taskManagement = z;
        }

        public boolean isTaskManagement() {
            return this.taskManagement;
        }
    }

    @Inject
    public TaskResource(SqlTaskManager sqlTaskManager, SessionPropertyManager sessionPropertyManager, @ForAsyncHttp BoundedExecutor boundedExecutor, @ForAsyncHttp ScheduledExecutorService scheduledExecutorService, FailureInjector failureInjector) {
        this.taskManager = (SqlTaskManager) Objects.requireNonNull(sqlTaskManager, "taskManager is null");
        this.sessionPropertyManager = (SessionPropertyManager) Objects.requireNonNull(sessionPropertyManager, "sessionPropertyManager is null");
        this.responseExecutor = (Executor) Objects.requireNonNull(boundedExecutor, "responseExecutor is null");
        this.timeoutExecutor = (ScheduledExecutorService) Objects.requireNonNull(scheduledExecutorService, "timeoutExecutor is null");
        this.failureInjector = (FailureInjector) Objects.requireNonNull(failureInjector, "failureInjector is null");
    }

    @ResourceSecurity(ResourceSecurity.AccessType.INTERNAL_ONLY)
    @GET
    @Produces({"application/json"})
    public List<TaskInfo> getAllTaskInfo(@Context UriInfo uriInfo) {
        ImmutableList allTaskInfo = this.taskManager.getAllTaskInfo();
        if (shouldSummarize(uriInfo)) {
            allTaskInfo = ImmutableList.copyOf(Iterables.transform(allTaskInfo, (v0) -> {
                return v0.summarize();
            }));
        }
        return allTaskInfo;
    }

    @Path("{taskId}")
    @Consumes({"application/json"})
    @ResourceSecurity(ResourceSecurity.AccessType.INTERNAL_ONLY)
    @POST
    @Produces({"application/json"})
    public void createOrUpdateTask(@PathParam("taskId") TaskId taskId, TaskUpdateRequest taskUpdateRequest, @Context UriInfo uriInfo, @Suspended AsyncResponse asyncResponse) {
        Objects.requireNonNull(taskUpdateRequest, "taskUpdateRequest is null");
        Session session = taskUpdateRequest.getSession().toSession(this.sessionPropertyManager, taskUpdateRequest.getExtraCredentials());
        if (injectFailure(session.getTraceToken(), taskId, RequestType.CREATE_OR_UPDATE_TASK, asyncResponse)) {
            return;
        }
        TaskInfo updateTask = this.taskManager.updateTask(session, taskId, taskUpdateRequest.getFragment(), taskUpdateRequest.getSplitAssignments(), taskUpdateRequest.getOutputIds(), taskUpdateRequest.getDynamicFilterDomains());
        if (shouldSummarize(uriInfo)) {
            updateTask = updateTask.summarize();
        }
        asyncResponse.resume(Response.ok().entity(updateTask).build());
    }

    @GET
    @Path("{taskId}")
    @ResourceSecurity(ResourceSecurity.AccessType.INTERNAL_ONLY)
    @Produces({"application/json"})
    public void getTaskInfo(@PathParam("taskId") TaskId taskId, @HeaderParam("X-Trino-Current-Version") Long l, @HeaderParam("X-Trino-Max-Wait") Duration duration, @Context UriInfo uriInfo, @Suspended AsyncResponse asyncResponse) {
        Objects.requireNonNull(taskId, "taskId is null");
        if (injectFailure(this.taskManager.getTraceToken(taskId), taskId, RequestType.GET_TASK_INFO, asyncResponse)) {
            return;
        }
        if (l == null || duration == null) {
            TaskInfo taskInfo = this.taskManager.getTaskInfo(taskId);
            if (shouldSummarize(uriInfo)) {
                taskInfo = taskInfo.summarize();
            }
            asyncResponse.resume(taskInfo);
            return;
        }
        ListenableFuture addTimeout = MoreFutures.addTimeout(this.taskManager.getTaskInfo(taskId, l.longValue()), () -> {
            return this.taskManager.getTaskInfo(taskId);
        }, randomizeWaitTime(duration), this.timeoutExecutor);
        if (shouldSummarize(uriInfo)) {
            addTimeout = Futures.transform(addTimeout, (v0) -> {
                return v0.summarize();
            }, MoreExecutors.directExecutor());
        }
        AsyncResponseHandler.bindAsyncResponse(asyncResponse, addTimeout, this.responseExecutor).withTimeout(new Duration(r0.toMillis() + ADDITIONAL_WAIT_TIME.toMillis(), TimeUnit.MILLISECONDS));
    }

    @GET
    @Path("{taskId}/status")
    @ResourceSecurity(ResourceSecurity.AccessType.INTERNAL_ONLY)
    @Produces({"application/json"})
    public void getTaskStatus(@PathParam("taskId") TaskId taskId, @HeaderParam("X-Trino-Current-Version") Long l, @HeaderParam("X-Trino-Max-Wait") Duration duration, @Context UriInfo uriInfo, @Suspended AsyncResponse asyncResponse) {
        Objects.requireNonNull(taskId, "taskId is null");
        if (injectFailure(this.taskManager.getTraceToken(taskId), taskId, RequestType.GET_TASK_STATUS, asyncResponse)) {
            return;
        }
        if (l == null || duration == null) {
            asyncResponse.resume(this.taskManager.getTaskStatus(taskId));
            return;
        }
        AsyncResponseHandler.bindAsyncResponse(asyncResponse, MoreFutures.addTimeout(this.taskManager.getTaskStatus(taskId, l.longValue()), () -> {
            return this.taskManager.getTaskStatus(taskId);
        }, randomizeWaitTime(duration), this.timeoutExecutor), this.responseExecutor).withTimeout(new Duration(r0.toMillis() + ADDITIONAL_WAIT_TIME.toMillis(), TimeUnit.MILLISECONDS));
    }

    @GET
    @Path("{taskId}/dynamicfilters")
    @ResourceSecurity(ResourceSecurity.AccessType.INTERNAL_ONLY)
    @Produces({"application/json"})
    public void acknowledgeAndGetNewDynamicFilterDomains(@PathParam("taskId") TaskId taskId, @HeaderParam("X-Trino-Current-Version") Long l, @Context UriInfo uriInfo, @Suspended AsyncResponse asyncResponse) {
        Objects.requireNonNull(taskId, "taskId is null");
        Objects.requireNonNull(l, "currentDynamicFiltersVersion is null");
        if (injectFailure(this.taskManager.getTraceToken(taskId), taskId, RequestType.ACKNOWLEDGE_AND_GET_NEW_DYNAMIC_FILTER_DOMAINS, asyncResponse)) {
            return;
        }
        asyncResponse.resume(this.taskManager.acknowledgeAndGetNewDynamicFilterDomains(taskId, l.longValue()));
    }

    @Path("{taskId}")
    @DELETE
    @ResourceSecurity(ResourceSecurity.AccessType.INTERNAL_ONLY)
    @Produces({"application/json"})
    public TaskInfo deleteTask(@PathParam("taskId") TaskId taskId, @QueryParam("abort") @DefaultValue("true") boolean z, @Context UriInfo uriInfo) {
        Objects.requireNonNull(taskId, "taskId is null");
        TaskInfo abortTask = z ? this.taskManager.abortTask(taskId) : this.taskManager.cancelTask(taskId);
        if (shouldSummarize(uriInfo)) {
            abortTask = abortTask.summarize();
        }
        return abortTask;
    }

    @Path("{taskId}/fail")
    @Consumes({"application/json"})
    @ResourceSecurity(ResourceSecurity.AccessType.INTERNAL_ONLY)
    @POST
    @Produces({"application/json"})
    public TaskInfo failTask(@PathParam("taskId") TaskId taskId, FailTaskRequest failTaskRequest, @Context UriInfo uriInfo) {
        Objects.requireNonNull(taskId, "taskId is null");
        Objects.requireNonNull(failTaskRequest, "failTaskRequest is null");
        return this.taskManager.failTask(taskId, failTaskRequest.getFailureInfo().toException());
    }

    @GET
    @Path("{taskId}/results/{bufferId}/{token}")
    @ResourceSecurity(ResourceSecurity.AccessType.INTERNAL_ONLY)
    @Produces({TrinoMediaTypes.TRINO_PAGES})
    public void getResults(@PathParam("taskId") TaskId taskId, @PathParam("bufferId") OutputBuffers.OutputBufferId outputBufferId, @PathParam("token") long j, @HeaderParam("X-Trino-Max-Size") DataSize dataSize, @Suspended AsyncResponse asyncResponse) {
        Objects.requireNonNull(taskId, "taskId is null");
        Objects.requireNonNull(outputBufferId, "bufferId is null");
        if (injectFailure(this.taskManager.getTraceToken(taskId), taskId, RequestType.GET_RESULTS, asyncResponse)) {
            return;
        }
        TaskState state = this.taskManager.getTaskStatus(taskId).getState();
        boolean z = state == TaskState.ABORTED || state == TaskState.FAILED;
        long nanoTime = System.nanoTime();
        ListenableFuture transform = Futures.transform(MoreFutures.addTimeout(this.taskManager.getTaskResults(taskId, outputBufferId, j, dataSize), () -> {
            return BufferResult.emptyResults(this.taskManager.getTaskInstanceId(taskId), j, false);
        }, randomizeWaitTime(DEFAULT_MAX_WAIT_TIME), this.timeoutExecutor), bufferResult -> {
            Response.Status status;
            List<Slice> serializedPages = bufferResult.getSerializedPages();
            GenericEntity genericEntity = null;
            if (serializedPages.isEmpty()) {
                status = Response.Status.NO_CONTENT;
            } else {
                genericEntity = new GenericEntity(serializedPages, new TypeToken<List<Slice>>() { // from class: io.trino.server.TaskResource.1
                }.getType());
                status = Response.Status.OK;
            }
            return Response.status(status).entity(genericEntity).header(InternalHeaders.TRINO_TASK_INSTANCE_ID, bufferResult.getTaskInstanceId()).header(InternalHeaders.TRINO_PAGE_TOKEN, Long.valueOf(bufferResult.getToken())).header(InternalHeaders.TRINO_PAGE_NEXT_TOKEN, Long.valueOf(bufferResult.getNextToken())).header(InternalHeaders.TRINO_BUFFER_COMPLETE, Boolean.valueOf(bufferResult.isBufferComplete())).header(InternalHeaders.TRINO_TASK_FAILED, Boolean.valueOf(z)).build();
        }, MoreExecutors.directExecutor());
        AsyncResponseHandler.bindAsyncResponse(asyncResponse, transform, this.responseExecutor).withTimeout(new Duration(r0.toMillis() + ADDITIONAL_WAIT_TIME.toMillis(), TimeUnit.MILLISECONDS), Response.status(Response.Status.NO_CONTENT).header(InternalHeaders.TRINO_TASK_INSTANCE_ID, this.taskManager.getTaskInstanceId(taskId)).header(InternalHeaders.TRINO_PAGE_TOKEN, Long.valueOf(j)).header(InternalHeaders.TRINO_PAGE_NEXT_TOKEN, Long.valueOf(j)).header(InternalHeaders.TRINO_BUFFER_COMPLETE, false).header(InternalHeaders.TRINO_TASK_FAILED, Boolean.valueOf(z)).build());
        transform.addListener(() -> {
            this.readFromOutputBufferTime.add(Duration.nanosSince(nanoTime));
        }, MoreExecutors.directExecutor());
        asyncResponse.register(th -> {
            this.resultsRequestTime.add(Duration.nanosSince(nanoTime));
        });
    }

    @ResourceSecurity(ResourceSecurity.AccessType.INTERNAL_ONLY)
    @GET
    @Path("{taskId}/results/{bufferId}/{token}/acknowledge")
    public void acknowledgeResults(@PathParam("taskId") TaskId taskId, @PathParam("bufferId") OutputBuffers.OutputBufferId outputBufferId, @PathParam("token") long j) {
        Objects.requireNonNull(taskId, "taskId is null");
        Objects.requireNonNull(outputBufferId, "bufferId is null");
        this.taskManager.acknowledgeTaskResults(taskId, outputBufferId, j);
    }

    @ResourceSecurity(ResourceSecurity.AccessType.INTERNAL_ONLY)
    @Path("{taskId}/results/{bufferId}")
    @DELETE
    public void destroyTaskResults(@PathParam("taskId") TaskId taskId, @PathParam("bufferId") OutputBuffers.OutputBufferId outputBufferId, @Context UriInfo uriInfo, @Suspended AsyncResponse asyncResponse) {
        Objects.requireNonNull(taskId, "taskId is null");
        Objects.requireNonNull(outputBufferId, "bufferId is null");
        if (injectFailure(this.taskManager.getTraceToken(taskId), taskId, RequestType.DESTROY_RESULTS, asyncResponse)) {
            return;
        }
        this.taskManager.destroyTaskResults(taskId, outputBufferId);
        asyncResponse.resume(Response.noContent().build());
    }

    private boolean injectFailure(Optional<String> optional, TaskId taskId, RequestType requestType, AsyncResponse asyncResponse) {
        if (optional.isEmpty()) {
            return false;
        }
        Optional<FailureInjector.InjectedFailure> injectedFailure = this.failureInjector.getInjectedFailure(optional.get(), taskId.getStageId().getId(), taskId.getPartitionId(), taskId.getAttemptId());
        if (injectedFailure.isEmpty()) {
            return false;
        }
        FailureInjector.InjectedFailure injectedFailure2 = injectedFailure.get();
        Duration requestTimeout = this.failureInjector.getRequestTimeout();
        switch (injectedFailure2.getInjectedFailureType()) {
            case TASK_MANAGEMENT_REQUEST_FAILURE:
                if (!requestType.isTaskManagement()) {
                    return false;
                }
                log.info("Failing %s request for task %s", new Object[]{requestType, taskId});
                asyncResponse.resume(Response.serverError().build());
                return true;
            case TASK_MANAGEMENT_REQUEST_TIMEOUT:
                if (!requestType.isTaskManagement()) {
                    return false;
                }
                log.info("Timing out %s request for task %s", new Object[]{requestType, taskId});
                asyncResponse.setTimeout(requestTimeout.toMillis(), TimeUnit.MILLISECONDS);
                return true;
            case TASK_GET_RESULTS_REQUEST_FAILURE:
                if (requestType.isTaskManagement()) {
                    return false;
                }
                log.info("Failing %s request for task %s", new Object[]{requestType, taskId});
                asyncResponse.resume(Response.serverError().build());
                return true;
            case TASK_GET_RESULTS_REQUEST_TIMEOUT:
                if (requestType.isTaskManagement()) {
                    return false;
                }
                log.info("Timing out %s request for task %s", new Object[]{requestType, taskId});
                asyncResponse.setTimeout(requestTimeout.toMillis(), TimeUnit.MILLISECONDS);
                return true;
            case TASK_FAILURE:
                log.info("Injecting failure for task %s at %s", new Object[]{taskId, requestType});
                this.taskManager.failTask(taskId, injectedFailure.get().getTaskFailureException());
                return false;
            default:
                throw new IllegalArgumentException("unexpected failure type: " + injectedFailure2.getInjectedFailureType());
        }
    }

    @Managed
    @Nested
    public TimeStat getReadFromOutputBufferTime() {
        return this.readFromOutputBufferTime;
    }

    @Managed
    @Nested
    public TimeStat getResultsRequestTime() {
        return this.resultsRequestTime;
    }

    private static boolean shouldSummarize(UriInfo uriInfo) {
        return uriInfo.getQueryParameters().containsKey("summarize");
    }

    private static Duration randomizeWaitTime(Duration duration) {
        long millis = duration.toMillis() / 2;
        return new Duration(millis + ThreadLocalRandom.current().nextLong(millis), TimeUnit.MILLISECONDS);
    }
}
