package io.trino.execution;

import com.google.common.base.MoreObjects;
import com.google.common.collect.Sets;
import io.airlift.log.Logger;
import io.trino.metadata.InternalNode;
import io.trino.util.FinalizerService;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.IntConsumer;
import javax.annotation.concurrent.ThreadSafe;
import javax.inject.Inject;

@ThreadSafe
/* loaded from: input_file:io/trino/execution/NodeTaskMap.class */
public class NodeTaskMap {
    private static final Logger log = Logger.get(NodeTaskMap.class);
    private final ConcurrentHashMap<InternalNode, NodeTasks> nodeTasksMap = new ConcurrentHashMap<>();
    private final FinalizerService finalizerService;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/execution/NodeTaskMap$NodeTasks.class */
    public static class NodeTasks {
        private final Set<RemoteTask> remoteTasks = Sets.newConcurrentHashSet();
        private final AtomicInteger nodeTotalPartitionedSplitCount = new AtomicInteger();
        private final FinalizerService finalizerService;

        /* JADX INFO: Access modifiers changed from: private */
        @ThreadSafe
        /* loaded from: input_file:io/trino/execution/NodeTaskMap$NodeTasks$TaskPartitionedSplitCountTracker.class */
        public class TaskPartitionedSplitCountTracker {
            private final TaskId taskId;
            private final AtomicInteger localPartitionedSplitCount = new AtomicInteger();

            public TaskPartitionedSplitCountTracker(TaskId taskId) {
                this.taskId = (TaskId) Objects.requireNonNull(taskId, "taskId is null");
            }

            public synchronized void setPartitionedSplitCount(int i) {
                if (i < 0) {
                    NodeTasks.this.nodeTotalPartitionedSplitCount.addAndGet(-this.localPartitionedSplitCount.getAndSet(0));
                    throw new IllegalArgumentException("partitionedSplitCount is negative");
                }
                NodeTasks.this.nodeTotalPartitionedSplitCount.addAndGet(i - this.localPartitionedSplitCount.getAndSet(i));
            }

            public void cleanup() {
                int andSet = this.localPartitionedSplitCount.getAndSet(0);
                if (andSet == 0) {
                    return;
                }
                NodeTaskMap.log.error("BUG! %s for %s leaked with %s partitioned splits.  Cleaning up so server can continue to function.", new Object[]{getClass().getName(), this.taskId, Integer.valueOf(andSet)});
                NodeTasks.this.nodeTotalPartitionedSplitCount.addAndGet(-andSet);
            }

            public String toString() {
                return MoreObjects.toStringHelper(this).add("taskId", this.taskId).add("splits", this.localPartitionedSplitCount).toString();
            }
        }

        public NodeTasks(FinalizerService finalizerService) {
            this.finalizerService = (FinalizerService) Objects.requireNonNull(finalizerService, "finalizerService is null");
        }

        private int getPartitionedSplitCount() {
            return this.nodeTotalPartitionedSplitCount.get();
        }

        private void addTask(RemoteTask remoteTask) {
            if (this.remoteTasks.add(remoteTask)) {
                remoteTask.addStateChangeListener(taskStatus -> {
                    if (taskStatus.getState().isDone()) {
                        this.remoteTasks.remove(remoteTask);
                    }
                });
                if (remoteTask.getTaskStatus().getState().isDone()) {
                    this.remoteTasks.remove(remoteTask);
                }
            }
        }

        public PartitionedSplitCountTracker createPartitionedSplitCountTracker(TaskId taskId) {
            Objects.requireNonNull(taskId, "taskId is null");
            TaskPartitionedSplitCountTracker taskPartitionedSplitCountTracker = new TaskPartitionedSplitCountTracker(taskId);
            Objects.requireNonNull(taskPartitionedSplitCountTracker);
            PartitionedSplitCountTracker partitionedSplitCountTracker = new PartitionedSplitCountTracker(taskPartitionedSplitCountTracker::setPartitionedSplitCount);
            FinalizerService finalizerService = this.finalizerService;
            Objects.requireNonNull(taskPartitionedSplitCountTracker);
            finalizerService.addFinalizer(partitionedSplitCountTracker, taskPartitionedSplitCountTracker::cleanup);
            return partitionedSplitCountTracker;
        }
    }

    /* loaded from: input_file:io/trino/execution/NodeTaskMap$PartitionedSplitCountTracker.class */
    public static class PartitionedSplitCountTracker {
        private final IntConsumer splitSetter;

        public PartitionedSplitCountTracker(IntConsumer intConsumer) {
            this.splitSetter = (IntConsumer) Objects.requireNonNull(intConsumer, "splitSetter is null");
        }

        public void setPartitionedSplitCount(int i) {
            this.splitSetter.accept(i);
        }

        public String toString() {
            return this.splitSetter.toString();
        }
    }

    @Inject
    public NodeTaskMap(FinalizerService finalizerService) {
        this.finalizerService = (FinalizerService) Objects.requireNonNull(finalizerService, "finalizerService is null");
    }

    public void addTask(InternalNode internalNode, RemoteTask remoteTask) {
        createOrGetNodeTasks(internalNode).addTask(remoteTask);
    }

    public int getPartitionedSplitsOnNode(InternalNode internalNode) {
        return createOrGetNodeTasks(internalNode).getPartitionedSplitCount();
    }

    public PartitionedSplitCountTracker createPartitionedSplitCountTracker(InternalNode internalNode, TaskId taskId) {
        return createOrGetNodeTasks(internalNode).createPartitionedSplitCountTracker(taskId);
    }

    private NodeTasks createOrGetNodeTasks(InternalNode internalNode) {
        NodeTasks nodeTasks = this.nodeTasksMap.get(internalNode);
        if (nodeTasks == null) {
            nodeTasks = addNodeTask(internalNode);
        }
        return nodeTasks;
    }

    private NodeTasks addNodeTask(InternalNode internalNode) {
        NodeTasks nodeTasks = new NodeTasks(this.finalizerService);
        NodeTasks putIfAbsent = this.nodeTasksMap.putIfAbsent(internalNode, nodeTasks);
        return putIfAbsent == null ? nodeTasks : putIfAbsent;
    }
}
