package io.trino.execution.executor.scheduler;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import io.trino.annotation.NotThreadSafe;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

@NotThreadSafe
/* loaded from: input_file:io/trino/execution/executor/scheduler/SchedulingGroup.class */
final class SchedulingGroup<T> {
    private long weight;
    private final Map<T, Task> tasks = new HashMap();
    private final PriorityQueue<T> runnableQueue = new PriorityQueue<>();
    private final Set<T> blocked = new HashSet();
    private final PriorityQueue<T> baselineWeights = new PriorityQueue<>();
    private State state = State.BLOCKED;

    public void enqueue(T t, long j) {
        Task task = this.tasks.get(t);
        if (task == null) {
            task = new Task(baselineWeight());
            this.tasks.put(t, task);
        } else if (task.state() == State.BLOCKED) {
            this.blocked.remove(t);
            task.addWeight(baselineWeight());
        }
        this.weight -= task.uncommittedWeight();
        this.weight += j;
        task.commitWeight(j);
        task.setState(State.RUNNABLE);
        this.runnableQueue.add(t, task.weight());
        this.baselineWeights.addOrReplace(t, task.weight());
        updateState();
    }

    public T dequeue(long j) {
        Preconditions.checkArgument(this.state == State.RUNNABLE);
        T takeOrThrow = this.runnableQueue.takeOrThrow();
        Task task = this.tasks.get(takeOrThrow);
        task.setUncommittedWeight(j);
        task.setState(State.RUNNING);
        this.weight += j;
        this.baselineWeights.addOrReplace(takeOrThrow, task.weight());
        updateState();
        return takeOrThrow;
    }

    public void finish(T t) {
        Preconditions.checkArgument(this.tasks.containsKey(t), "Unknown task: %s", t);
        this.tasks.remove(t);
        this.blocked.remove(t);
        this.runnableQueue.removeIfPresent(t);
        this.baselineWeights.removeIfPresent(t);
        updateState();
    }

    public void block(T t, long j) {
        Preconditions.checkArgument(this.tasks.containsKey(t), "Unknown task: %s", t);
        Preconditions.checkArgument(!this.runnableQueue.contains(t), "Task is already in queue: %s", t);
        this.weight += j;
        Task task = this.tasks.get(t);
        task.commitWeight(j);
        task.setState(State.BLOCKED);
        task.addWeight(-baselineWeight());
        this.blocked.add(t);
        this.baselineWeights.remove(t);
        updateState();
    }

    public long baselineWeight() {
        if (this.baselineWeights.isEmpty()) {
            return 0L;
        }
        return this.baselineWeights.nextPriority();
    }

    public void addWeight(long j) {
        this.weight += j;
    }

    private void updateState() {
        if (this.blocked.size() == this.tasks.size()) {
            this.state = State.BLOCKED;
        } else if (this.runnableQueue.isEmpty()) {
            this.state = State.RUNNING;
        } else {
            this.state = State.RUNNABLE;
        }
    }

    public long weight() {
        return this.weight;
    }

    public Set<T> tasks() {
        return ImmutableSet.copyOf(this.tasks.keySet());
    }

    public State state() {
        return this.state;
    }

    public T peek() {
        return this.runnableQueue.peek();
    }

    public int runnableCount() {
        return this.runnableQueue.size();
    }

    public String toString() {
        String formatted;
        StringBuilder sb = new StringBuilder();
        for (Map.Entry<T, Task> entry : this.tasks.entrySet()) {
            T key = entry.getKey();
            Task value = entry.getValue();
            Object[] objArr = new Object[2];
            objArr[0] = key == peek() ? "=>" : "  ";
            objArr[1] = key;
            String formatted2 = "%s %s".formatted(objArr);
            switch (value.state()) {
                case BLOCKED:
                    formatted = "[BLOCKED, saved delta = %s]".formatted(Long.valueOf(value.weight()));
                    break;
                case RUNNABLE:
                    formatted = "[RUNNABLE, weight = %s]".formatted(Long.valueOf(value.weight()));
                    break;
                case RUNNING:
                    formatted = "[RUNNING, weight = %s, uncommitted = %s]".formatted(Long.valueOf(value.weight()), Long.valueOf(value.uncommittedWeight()));
                    break;
                default:
                    throw new IncompatibleClassChangeError();
            }
            sb.append(formatted2).append(" ").append(formatted).append("\n");
        }
        return sb.toString();
    }
}
