package io.trino.execution.scheduler.policy;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Multimap;
import com.google.common.graph.EndpointPair;
import com.google.common.util.concurrent.ListenableFuture;
import io.trino.execution.DynamicFilterConfig;
import io.trino.execution.ExecutionFailureInfo;
import io.trino.execution.RemoteTask;
import io.trino.execution.StageId;
import io.trino.execution.StateMachine;
import io.trino.execution.TaskId;
import io.trino.execution.TaskStatus;
import io.trino.execution.scheduler.StageExecution;
import io.trino.execution.scheduler.TaskLifecycleListener;
import io.trino.metadata.FunctionManager;
import io.trino.metadata.InternalNode;
import io.trino.metadata.MetadataManager;
import io.trino.metadata.Split;
import io.trino.server.DynamicFilterService;
import io.trino.spi.QueryId;
import io.trino.spi.type.TypeOperators;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.sql.planner.plan.PlanNodeId;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import org.assertj.core.api.Assertions;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/execution/scheduler/policy/TestPhasedExecutionSchedule.class */
public class TestPhasedExecutionSchedule {
    private final DynamicFilterService dynamicFilterService = new DynamicFilterService(MetadataManager.createTestMetadataManager(), FunctionManager.createTestingFunctionManager(), new TypeOperators(), new DynamicFilterConfig());

    /* loaded from: input_file:io/trino/execution/scheduler/policy/TestPhasedExecutionSchedule$TestingStageExecution.class */
    private static class TestingStageExecution implements StageExecution {
        private final PlanFragment fragment;
        private StateMachine.StateChangeListener<StageExecution.State> stateChangeListener;
        private boolean anyTaskBlocked;
        private StageExecution.State state = StageExecution.State.SCHEDULING;

        public TestingStageExecution(PlanFragment planFragment) {
            this.fragment = (PlanFragment) Objects.requireNonNull(planFragment, "fragment is null");
        }

        public PlanFragment getFragment() {
            return this.fragment;
        }

        public boolean isAnyTaskBlocked() {
            return this.anyTaskBlocked;
        }

        public void setAnyTaskBlocked(boolean z) {
            this.anyTaskBlocked = z;
        }

        public void setState(StageExecution.State state) {
            this.state = state;
            if (this.stateChangeListener != null) {
                this.stateChangeListener.stateChanged(state);
            }
        }

        public StageExecution.State getState() {
            return this.state;
        }

        public void addStateChangeListener(StateMachine.StateChangeListener<StageExecution.State> stateChangeListener) {
            this.stateChangeListener = (StateMachine.StateChangeListener) Objects.requireNonNull(stateChangeListener, "stateChangeListener is null");
        }

        public StageId getStageId() {
            return new StageId(new QueryId("id"), 0);
        }

        public int getAttemptId() {
            throw new UnsupportedOperationException();
        }

        public void beginScheduling() {
            throw new UnsupportedOperationException();
        }

        public void transitionToSchedulingSplits() {
            throw new UnsupportedOperationException();
        }

        public TaskLifecycleListener getTaskLifecycleListener() {
            throw new UnsupportedOperationException();
        }

        public void schedulingComplete() {
            throw new UnsupportedOperationException();
        }

        public void schedulingComplete(PlanNodeId planNodeId) {
            throw new UnsupportedOperationException();
        }

        public void cancel() {
            throw new UnsupportedOperationException();
        }

        public void abort() {
            throw new UnsupportedOperationException();
        }

        public void recordGetSplitTime(long j) {
            throw new UnsupportedOperationException();
        }

        public Optional<RemoteTask> scheduleTask(InternalNode internalNode, int i, Multimap<PlanNodeId, Split> multimap) {
            throw new UnsupportedOperationException();
        }

        public void failTask(TaskId taskId, Throwable th) {
            throw new UnsupportedOperationException();
        }

        public List<RemoteTask> getAllTasks() {
            throw new UnsupportedOperationException();
        }

        public List<TaskStatus> getTaskStatuses() {
            throw new UnsupportedOperationException();
        }

        public Optional<ExecutionFailureInfo> getFailureCause() {
            throw new UnsupportedOperationException();
        }
    }

    @Test
    public void testPartitionedJoin() {
        PlanFragment createTableScanPlanFragment = PlanUtils.createTableScanPlanFragment("build");
        PlanFragment createTableScanPlanFragment2 = PlanUtils.createTableScanPlanFragment("probe");
        PlanFragment createJoinPlanFragment = PlanUtils.createJoinPlanFragment(JoinNode.Type.INNER, JoinNode.DistributionType.PARTITIONED, "join", createTableScanPlanFragment, createTableScanPlanFragment2);
        TestingStageExecution testingStageExecution = new TestingStageExecution(createTableScanPlanFragment);
        TestingStageExecution testingStageExecution2 = new TestingStageExecution(createTableScanPlanFragment2);
        TestingStageExecution testingStageExecution3 = new TestingStageExecution(createJoinPlanFragment);
        PhasedExecutionSchedule forStages = PhasedExecutionSchedule.forStages(ImmutableSet.of(testingStageExecution, testingStageExecution2, testingStageExecution3), this.dynamicFilterService);
        Assertions.assertThat(forStages.getSortedFragments()).containsExactly(new PlanFragmentId[]{createTableScanPlanFragment.getId(), createTableScanPlanFragment2.getId(), createJoinPlanFragment.getId()});
        Assertions.assertThat(forStages.getFragmentDependency().edges()).containsExactlyInAnyOrder(new EndpointPair[]{EndpointPair.ordered(createTableScanPlanFragment.getId(), createTableScanPlanFragment2.getId())});
        Assertions.assertThat(getSchedulingFragments(forStages)).containsExactly(new PlanFragmentId[]{createTableScanPlanFragment.getId(), createJoinPlanFragment.getId()});
        ListenableFuture listenableFuture = (ListenableFuture) forStages.getRescheduleFuture().orElseThrow();
        Assertions.assertThat(listenableFuture).isNotDone();
        testingStageExecution.setState(StageExecution.State.FLUSHING);
        Assertions.assertThat(listenableFuture).isDone();
        forStages.schedule();
        Assertions.assertThat(getSchedulingFragments(forStages)).containsExactly(new PlanFragmentId[]{createJoinPlanFragment.getId(), createTableScanPlanFragment2.getId()});
        ListenableFuture listenableFuture2 = (ListenableFuture) forStages.getRescheduleFuture().orElseThrow();
        Assertions.assertThat(listenableFuture2).isNotDone();
        testingStageExecution2.setState(StageExecution.State.FINISHED);
        Assertions.assertThat(listenableFuture2).isNotDone();
        testingStageExecution3.setState(StageExecution.State.FINISHED);
        forStages.schedule();
        Assertions.assertThat(getSchedulingFragments(forStages)).isEmpty();
        Assertions.assertThat(forStages.isFinished()).isTrue();
    }

    @Test
    public void testBroadcastSourceJoin() {
        PlanFragment createTableScanPlanFragment = PlanUtils.createTableScanPlanFragment("build");
        PlanFragment createBroadcastJoinPlanFragment = PlanUtils.createBroadcastJoinPlanFragment("probe", createTableScanPlanFragment);
        TestingStageExecution testingStageExecution = new TestingStageExecution(createTableScanPlanFragment);
        PhasedExecutionSchedule forStages = PhasedExecutionSchedule.forStages(ImmutableSet.of(new TestingStageExecution(createBroadcastJoinPlanFragment), testingStageExecution), this.dynamicFilterService);
        Assertions.assertThat(forStages.getSortedFragments()).containsExactly(new PlanFragmentId[]{createTableScanPlanFragment.getId(), createBroadcastJoinPlanFragment.getId()});
        Assertions.assertThat(forStages.getFragmentDependency().edges()).containsExactlyInAnyOrder(new EndpointPair[]{EndpointPair.ordered(createTableScanPlanFragment.getId(), createBroadcastJoinPlanFragment.getId())});
        Assertions.assertThat(getSchedulingFragments(forStages)).containsExactly(new PlanFragmentId[]{createTableScanPlanFragment.getId()});
        testingStageExecution.setAnyTaskBlocked(true);
        forStages.schedule();
        Assertions.assertThat(getSchedulingFragments(forStages)).containsExactly(new PlanFragmentId[]{createTableScanPlanFragment.getId(), createBroadcastJoinPlanFragment.getId()});
    }

    @Test
    public void testAggregation() {
        PlanFragment createTableScanPlanFragment = PlanUtils.createTableScanPlanFragment("probe");
        PlanFragment createAggregationFragment = PlanUtils.createAggregationFragment("aggregation", createTableScanPlanFragment);
        PlanFragment createTableScanPlanFragment2 = PlanUtils.createTableScanPlanFragment("build");
        PlanFragment createJoinPlanFragment = PlanUtils.createJoinPlanFragment(JoinNode.Type.INNER, JoinNode.DistributionType.REPLICATED, "join", createTableScanPlanFragment2, createAggregationFragment);
        PhasedExecutionSchedule forStages = PhasedExecutionSchedule.forStages(ImmutableSet.of(new TestingStageExecution(createTableScanPlanFragment), new TestingStageExecution(createAggregationFragment), new TestingStageExecution(createTableScanPlanFragment2), new TestingStageExecution(createJoinPlanFragment)), this.dynamicFilterService);
        Assertions.assertThat(forStages.getSortedFragments()).containsExactly(new PlanFragmentId[]{createTableScanPlanFragment2.getId(), createTableScanPlanFragment.getId(), createAggregationFragment.getId(), createJoinPlanFragment.getId()});
        Assertions.assertThat(forStages.getFragmentDependency().edges()).containsExactly(new EndpointPair[]{EndpointPair.ordered(createTableScanPlanFragment2.getId(), createJoinPlanFragment.getId())});
        Assertions.assertThat(getSchedulingFragments(forStages)).containsExactly(new PlanFragmentId[]{createTableScanPlanFragment2.getId(), createTableScanPlanFragment.getId(), createAggregationFragment.getId()});
    }

    @Test
    public void testDependentStageAbortedBeforeStarted() {
        PlanFragment createTableScanPlanFragment = PlanUtils.createTableScanPlanFragment("probe");
        PlanFragment createAggregationFragment = PlanUtils.createAggregationFragment("aggregation", createTableScanPlanFragment);
        PlanFragment createTableScanPlanFragment2 = PlanUtils.createTableScanPlanFragment("build");
        PlanFragment createJoinPlanFragment = PlanUtils.createJoinPlanFragment(JoinNode.Type.INNER, JoinNode.DistributionType.REPLICATED, "join", createTableScanPlanFragment2, createAggregationFragment);
        TestingStageExecution testingStageExecution = new TestingStageExecution(createTableScanPlanFragment);
        TestingStageExecution testingStageExecution2 = new TestingStageExecution(createAggregationFragment);
        TestingStageExecution testingStageExecution3 = new TestingStageExecution(createTableScanPlanFragment2);
        TestingStageExecution testingStageExecution4 = new TestingStageExecution(createJoinPlanFragment);
        PhasedExecutionSchedule forStages = PhasedExecutionSchedule.forStages(ImmutableSet.of(testingStageExecution, testingStageExecution2, testingStageExecution3, testingStageExecution4), this.dynamicFilterService);
        Assertions.assertThat(forStages.getSortedFragments()).containsExactly(new PlanFragmentId[]{createTableScanPlanFragment2.getId(), createTableScanPlanFragment.getId(), createAggregationFragment.getId(), createJoinPlanFragment.getId()});
        Assertions.assertThat(forStages.getFragmentDependency().edges()).containsExactly(new EndpointPair[]{EndpointPair.ordered(createTableScanPlanFragment2.getId(), createJoinPlanFragment.getId())});
        Assertions.assertThat(getSchedulingFragments(forStages)).containsExactly(new PlanFragmentId[]{createTableScanPlanFragment2.getId(), createTableScanPlanFragment.getId(), createAggregationFragment.getId()});
        testingStageExecution4.setState(StageExecution.State.ABORTED);
        testingStageExecution3.setState(StageExecution.State.FINISHED);
        testingStageExecution2.setState(StageExecution.State.FINISHED);
        testingStageExecution.setState(StageExecution.State.FINISHED);
        forStages.schedule();
        Assertions.assertThat(forStages.isFinished()).isTrue();
    }

    @Test
    public void testStageWithBroadcastAndPartitionedJoin() {
        PlanFragment createTableScanPlanFragment = PlanUtils.createTableScanPlanFragment("broadcast_build");
        PlanFragment createTableScanPlanFragment2 = PlanUtils.createTableScanPlanFragment("partitioned_build");
        PlanFragment createTableScanPlanFragment3 = PlanUtils.createTableScanPlanFragment("probe");
        PlanFragment createBroadcastAndPartitionedJoinPlanFragment = PlanUtils.createBroadcastAndPartitionedJoinPlanFragment("join", createTableScanPlanFragment, createTableScanPlanFragment2, createTableScanPlanFragment3);
        TestingStageExecution testingStageExecution = new TestingStageExecution(createTableScanPlanFragment);
        TestingStageExecution testingStageExecution2 = new TestingStageExecution(createTableScanPlanFragment2);
        PhasedExecutionSchedule forStages = PhasedExecutionSchedule.forStages(ImmutableSet.of(testingStageExecution, testingStageExecution2, new TestingStageExecution(createTableScanPlanFragment3), new TestingStageExecution(createBroadcastAndPartitionedJoinPlanFragment)), this.dynamicFilterService);
        Assertions.assertThat(forStages.getFragmentDependency().edges()).containsExactlyInAnyOrder(new EndpointPair[]{EndpointPair.ordered(createTableScanPlanFragment.getId(), createTableScanPlanFragment3.getId()), EndpointPair.ordered(createTableScanPlanFragment2.getId(), createTableScanPlanFragment3.getId()), EndpointPair.ordered(createTableScanPlanFragment.getId(), createBroadcastAndPartitionedJoinPlanFragment.getId())});
        Assertions.assertThat(getSchedulingFragments(forStages)).containsExactly(new PlanFragmentId[]{createTableScanPlanFragment2.getId(), createTableScanPlanFragment.getId(), createBroadcastAndPartitionedJoinPlanFragment.getId()});
        testingStageExecution.setState(StageExecution.State.FLUSHING);
        forStages.schedule();
        Assertions.assertThat(getSchedulingFragments(forStages)).containsExactly(new PlanFragmentId[]{createTableScanPlanFragment2.getId(), createBroadcastAndPartitionedJoinPlanFragment.getId()});
        testingStageExecution2.setState(StageExecution.State.FLUSHING);
        forStages.schedule();
        Assertions.assertThat(getSchedulingFragments(forStages)).containsExactly(new PlanFragmentId[]{createBroadcastAndPartitionedJoinPlanFragment.getId(), createTableScanPlanFragment3.getId()});
    }

    @Test
    public void testSourceStageBroadcastJoinWithPartitionedJoinBuildSide() {
        PlanFragment createTableScanPlanFragment = PlanUtils.createTableScanPlanFragment("nested_join_build");
        PlanFragment createTableScanPlanFragment2 = PlanUtils.createTableScanPlanFragment("nested_join_probe");
        PlanFragment createJoinPlanFragment = PlanUtils.createJoinPlanFragment(JoinNode.Type.INNER, "nested_join", createTableScanPlanFragment, createTableScanPlanFragment2);
        PlanFragment createBroadcastJoinPlanFragment = PlanUtils.createBroadcastJoinPlanFragment("probe", createJoinPlanFragment);
        TestingStageExecution testingStageExecution = new TestingStageExecution(createTableScanPlanFragment);
        TestingStageExecution testingStageExecution2 = new TestingStageExecution(createTableScanPlanFragment2);
        TestingStageExecution testingStageExecution3 = new TestingStageExecution(createJoinPlanFragment);
        PhasedExecutionSchedule forStages = PhasedExecutionSchedule.forStages(ImmutableSet.of(testingStageExecution, testingStageExecution2, testingStageExecution3, new TestingStageExecution(createBroadcastJoinPlanFragment)), this.dynamicFilterService);
        Assertions.assertThat(forStages.getFragmentDependency().edges()).containsExactlyInAnyOrder(new EndpointPair[]{EndpointPair.ordered(createTableScanPlanFragment.getId(), createBroadcastJoinPlanFragment.getId()), EndpointPair.ordered(createJoinPlanFragment.getId(), createBroadcastJoinPlanFragment.getId()), EndpointPair.ordered(createTableScanPlanFragment.getId(), createTableScanPlanFragment2.getId()), EndpointPair.ordered(createTableScanPlanFragment2.getId(), createBroadcastJoinPlanFragment.getId())});
        Assertions.assertThat(getSchedulingFragments(forStages)).containsExactly(new PlanFragmentId[]{createTableScanPlanFragment.getId(), createJoinPlanFragment.getId()});
        testingStageExecution3.setState(StageExecution.State.SCHEDULED);
        testingStageExecution.setState(StageExecution.State.FINISHED);
        forStages.schedule();
        Assertions.assertThat(getSchedulingFragments(forStages)).containsExactly(new PlanFragmentId[]{createTableScanPlanFragment2.getId()});
        testingStageExecution3.setAnyTaskBlocked(true);
        forStages.schedule();
        Assertions.assertThat(getSchedulingFragments(forStages)).containsExactly(new PlanFragmentId[]{createTableScanPlanFragment2.getId(), createBroadcastJoinPlanFragment.getId()});
        testingStageExecution2.setState(StageExecution.State.FINISHED);
        forStages.schedule();
        Assertions.assertThat(getSchedulingFragments(forStages)).containsExactly(new PlanFragmentId[]{createBroadcastJoinPlanFragment.getId()});
    }

    private Set<PlanFragmentId> getSchedulingFragments(PhasedExecutionSchedule phasedExecutionSchedule) {
        return (Set) phasedExecutionSchedule.getSchedulingStages().stream().map(stageExecution -> {
            return stageExecution.getFragment().getId();
        }).collect(ImmutableSet.toImmutableSet());
    }
}
