package io.trino.execution.buffer;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.units.DataSize;
import io.trino.execution.StateMachine;
import io.trino.execution.buffer.OutputBuffers;
import io.trino.execution.buffer.SerializedPageReference;
import io.trino.memory.context.LocalMemoryContext;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;
import javax.annotation.concurrent.GuardedBy;

/* loaded from: input_file:io/trino/execution/buffer/BroadcastOutputBuffer.class */
public class BroadcastOutputBuffer implements OutputBuffer {
    private final String taskInstanceId;
    private final StateMachine<BufferState> state;
    private final OutputBufferMemoryManager memoryManager;
    private final Runnable notifyStatusChanged;

    @GuardedBy("this")
    private OutputBuffers outputBuffers = OutputBuffers.createInitialEmptyOutputBuffers(OutputBuffers.BufferType.BROADCAST);

    @GuardedBy("this")
    private final Map<OutputBuffers.OutputBufferId, ClientBuffer> buffers = new ConcurrentHashMap();

    @GuardedBy("this")
    private final List<SerializedPageReference> initialPagesForNewBuffers = new ArrayList();
    private final AtomicLong totalPagesAdded = new AtomicLong();
    private final AtomicLong totalRowsAdded = new AtomicLong();
    private final AtomicLong totalBufferedPages = new AtomicLong();
    private final AtomicBoolean hasBlockedBefore = new AtomicBoolean();
    private final SerializedPageReference.PagesReleasedListener onPagesReleased = (i, j) -> {
        Preconditions.checkState(this.totalBufferedPages.addAndGet((long) (-i)) >= 0);
        this.memoryManager.updateMemoryUsage(-j);
    };

    public BroadcastOutputBuffer(String str, StateMachine<BufferState> stateMachine, DataSize dataSize, Supplier<LocalMemoryContext> supplier, Executor executor, Runnable runnable) {
        this.taskInstanceId = (String) Objects.requireNonNull(str, "taskInstanceId is null");
        this.state = (StateMachine) Objects.requireNonNull(stateMachine, "state is null");
        this.memoryManager = new OutputBufferMemoryManager(((DataSize) Objects.requireNonNull(dataSize, "maxBufferSize is null")).toBytes(), (Supplier) Objects.requireNonNull(supplier, "systemMemoryContextSupplier is null"), (Executor) Objects.requireNonNull(executor, "notificationExecutor is null"));
        this.notifyStatusChanged = (Runnable) Objects.requireNonNull(runnable, "notifyStatusChanged is null");
    }

    @Override // io.trino.execution.buffer.OutputBuffer
    public void addStateChangeListener(StateMachine.StateChangeListener<BufferState> stateChangeListener) {
        this.state.addStateChangeListener(stateChangeListener);
    }

    @Override // io.trino.execution.buffer.OutputBuffer
    public boolean isFinished() {
        return this.state.get() == BufferState.FINISHED;
    }

    @Override // io.trino.execution.buffer.OutputBuffer
    public double getUtilization() {
        return this.memoryManager.getUtilization();
    }

    @Override // io.trino.execution.buffer.OutputBuffer
    public boolean isOverutilized() {
        return getUtilization() > 0.5d && this.state.get().canAddPages();
    }

    @Override // io.trino.execution.buffer.OutputBuffer
    public OutputBufferInfo getInfo() {
        BufferState bufferState = this.state.get();
        return new OutputBufferInfo("BROADCAST", bufferState, bufferState.canAddBuffers(), bufferState.canAddPages(), this.memoryManager.getBufferedBytes(), this.totalBufferedPages.get(), this.totalRowsAdded.get(), this.totalPagesAdded.get(), (List) this.buffers.values().stream().map((v0) -> {
            return v0.getInfo();
        }).collect(ImmutableList.toImmutableList()));
    }

    @Override // io.trino.execution.buffer.OutputBuffer
    public void setOutputBuffers(OutputBuffers outputBuffers) {
        Preconditions.checkState(!Thread.holdsLock(this), "Cannot set output buffers while holding a lock on this");
        Objects.requireNonNull(outputBuffers, "newOutputBuffers is null");
        synchronized (this) {
            BufferState bufferState = this.state.get();
            if (bufferState.isTerminal() || this.outputBuffers.getVersion() >= outputBuffers.getVersion()) {
                return;
            }
            this.outputBuffers.checkValidTransition(outputBuffers);
            this.outputBuffers = outputBuffers;
            for (Map.Entry<OutputBuffers.OutputBufferId, Integer> entry : this.outputBuffers.getBuffers().entrySet()) {
                if (!this.buffers.containsKey(entry.getKey())) {
                    ClientBuffer buffer = getBuffer(entry.getKey());
                    if (!bufferState.canAddPages()) {
                        buffer.setNoMorePages();
                    }
                }
            }
            if (this.outputBuffers.isNoMoreBufferIds()) {
                this.state.compareAndSet(BufferState.OPEN, BufferState.NO_MORE_BUFFERS);
                this.state.compareAndSet(BufferState.NO_MORE_PAGES, BufferState.FLUSHING);
            }
            if (!this.state.get().canAddBuffers()) {
                noMoreBuffers();
            }
            checkFlushComplete();
        }
    }

    @Override // io.trino.execution.buffer.OutputBuffer
    public ListenableFuture<Void> isFull() {
        return this.memoryManager.getBufferBlockedFuture();
    }

    @Override // io.trino.execution.buffer.OutputBuffer
    public void enqueue(List<SerializedPage> list) {
        Collection<ClientBuffer> safeGetBuffersSnapshot;
        Preconditions.checkState(!Thread.holdsLock(this), "Cannot enqueue pages while holding a lock on this");
        Objects.requireNonNull(list, "pages is null");
        if (this.state.get().canAddPages()) {
            ImmutableList.Builder builderWithExpectedSize = ImmutableList.builderWithExpectedSize(list.size());
            long j = 0;
            long j2 = 0;
            for (SerializedPage serializedPage : list) {
                j += serializedPage.getRetainedSizeInBytes();
                j2 += serializedPage.getPositionCount();
                builderWithExpectedSize.add(new SerializedPageReference(serializedPage, 1));
            }
            ImmutableList build = builderWithExpectedSize.build();
            this.totalRowsAdded.addAndGet(j2);
            this.totalPagesAdded.addAndGet(build.size());
            this.totalBufferedPages.addAndGet(build.size());
            this.memoryManager.updateMemoryUsage(j);
            synchronized (this) {
                if (this.state.get().canAddBuffers()) {
                    build.forEach((v0) -> {
                        v0.addReference();
                    });
                    this.initialPagesForNewBuffers.addAll(build);
                }
                safeGetBuffersSnapshot = safeGetBuffersSnapshot();
            }
            safeGetBuffersSnapshot.forEach(clientBuffer -> {
                clientBuffer.enqueuePages(build);
            });
            SerializedPageReference.dereferencePages(build, this.onPagesReleased);
            if (this.hasBlockedBefore.get() || !this.state.get().canAddBuffers() || isFull().isDone() || !this.hasBlockedBefore.compareAndSet(false, true)) {
                return;
            }
            this.notifyStatusChanged.run();
        }
    }

    @Override // io.trino.execution.buffer.OutputBuffer
    public void enqueue(int i, List<SerializedPage> list) {
        Preconditions.checkState(i == 0, "Expected partition number to be zero");
        enqueue(list);
    }

    @Override // io.trino.execution.buffer.OutputBuffer
    public ListenableFuture<BufferResult> get(OutputBuffers.OutputBufferId outputBufferId, long j, DataSize dataSize) {
        Preconditions.checkState(!Thread.holdsLock(this), "Cannot get pages while holding a lock on this");
        Objects.requireNonNull(outputBufferId, "outputBufferId is null");
        Preconditions.checkArgument(dataSize.toBytes() > 0, "maxSize must be at least 1 byte");
        return getBuffer(outputBufferId).getPages(j, dataSize);
    }

    @Override // io.trino.execution.buffer.OutputBuffer
    public void acknowledge(OutputBuffers.OutputBufferId outputBufferId, long j) {
        Preconditions.checkState(!Thread.holdsLock(this), "Cannot acknowledge pages while holding a lock on this");
        Objects.requireNonNull(outputBufferId, "bufferId is null");
        getBuffer(outputBufferId).acknowledgePages(j);
    }

    @Override // io.trino.execution.buffer.OutputBuffer
    public void abort(OutputBuffers.OutputBufferId outputBufferId) {
        Preconditions.checkState(!Thread.holdsLock(this), "Cannot abort while holding a lock on this");
        Objects.requireNonNull(outputBufferId, "bufferId is null");
        getBuffer(outputBufferId).destroy();
        checkFlushComplete();
    }

    @Override // io.trino.execution.buffer.OutputBuffer
    public void setNoMorePages() {
        Preconditions.checkState(!Thread.holdsLock(this), "Cannot set no more pages while holding a lock on this");
        this.state.compareAndSet(BufferState.OPEN, BufferState.NO_MORE_PAGES);
        this.state.compareAndSet(BufferState.NO_MORE_BUFFERS, BufferState.FLUSHING);
        this.memoryManager.setNoBlockOnFull();
        safeGetBuffersSnapshot().forEach((v0) -> {
            v0.setNoMorePages();
        });
        checkFlushComplete();
    }

    @Override // io.trino.execution.buffer.OutputBuffer
    public void destroy() {
        Preconditions.checkState(!Thread.holdsLock(this), "Cannot destroy while holding a lock on this");
        if (this.state.setIf(BufferState.FINISHED, bufferState -> {
            return !bufferState.isTerminal();
        })) {
            noMoreBuffers();
            safeGetBuffersSnapshot().forEach((v0) -> {
                v0.destroy();
            });
            this.memoryManager.setNoBlockOnFull();
            forceFreeMemory();
        }
    }

    @Override // io.trino.execution.buffer.OutputBuffer
    public void fail() {
        if (this.state.setIf(BufferState.FAILED, bufferState -> {
            return !bufferState.isTerminal();
        })) {
            this.memoryManager.setNoBlockOnFull();
            forceFreeMemory();
        }
    }

    @Override // io.trino.execution.buffer.OutputBuffer
    public long getPeakMemoryUsage() {
        return this.memoryManager.getPeakMemoryUsage();
    }

    @VisibleForTesting
    void forceFreeMemory() {
        this.memoryManager.close();
    }

    private synchronized ClientBuffer getBuffer(OutputBuffers.OutputBufferId outputBufferId) {
        ClientBuffer clientBuffer = this.buffers.get(outputBufferId);
        if (clientBuffer != null) {
            return clientBuffer;
        }
        BufferState bufferState = this.state.get();
        Preconditions.checkState(bufferState.canAddBuffers() || !this.outputBuffers.isNoMoreBufferIds(), "No more buffers already set");
        ClientBuffer clientBuffer2 = new ClientBuffer(this.taskInstanceId, outputBufferId, this.onPagesReleased);
        if (bufferState != BufferState.FAILED) {
            clientBuffer2.enqueuePages(this.initialPagesForNewBuffers);
            if (!bufferState.canAddPages()) {
                clientBuffer2.setNoMorePages();
            }
            if (bufferState == BufferState.FINISHED) {
                clientBuffer2.destroy();
            }
        }
        this.buffers.put(outputBufferId, clientBuffer2);
        return clientBuffer2;
    }

    private synchronized Collection<ClientBuffer> safeGetBuffersSnapshot() {
        return ImmutableList.copyOf(this.buffers.values());
    }

    private void noMoreBuffers() {
        ImmutableList copyOf;
        Preconditions.checkState(!Thread.holdsLock(this), "Cannot set no more buffers while holding a lock on this");
        synchronized (this) {
            copyOf = ImmutableList.copyOf(this.initialPagesForNewBuffers);
            this.initialPagesForNewBuffers.clear();
            if (this.outputBuffers.isNoMoreBufferIds()) {
                Sets.SetView difference = Sets.difference(this.buffers.keySet(), this.outputBuffers.getBuffers().keySet());
                Preconditions.checkState(difference.isEmpty(), "Final output buffers does not contain all created buffer ids: %s", difference);
            }
        }
        SerializedPageReference.dereferencePages(copyOf, this.onPagesReleased);
    }

    private void checkFlushComplete() {
        if ((this.state.get() == BufferState.FLUSHING || this.state.get() == BufferState.NO_MORE_BUFFERS) && safeGetBuffersSnapshot().stream().allMatch((v0) -> {
            return v0.isDestroyed();
        })) {
            destroy();
        }
    }

    @VisibleForTesting
    OutputBufferMemoryManager getMemoryManager() {
        return this.memoryManager;
    }
}
