package io.trino.execution.scheduler.policy;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Ordering;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import io.trino.execution.scheduler.StageExecution;
import io.trino.server.DynamicFilterService;
import io.trino.spi.QueryId;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.IndexJoinNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanVisitor;
import io.trino.sql.planner.plan.RemoteSourceNode;
import io.trino.sql.planner.plan.SemiJoinNode;
import io.trino.sql.planner.plan.SpatialJoinNode;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Stream;
import org.jgrapht.DirectedGraph;
import org.jgrapht.EdgeFactory;
import org.jgrapht.alg.StrongConnectivityInspector;
import org.jgrapht.graph.DefaultDirectedGraph;
import oshi.annotation.concurrent.GuardedBy;

/* loaded from: input_file:io/trino/execution/scheduler/policy/PhasedExecutionSchedule.class */
public class PhasedExecutionSchedule implements ExecutionSchedule {
    private final Map<PlanFragmentId, StageExecution> stagesByFragmentId;
    private final DynamicFilterService dynamicFilterService;
    private Ordering<PlanFragmentId> fragmentOrdering;
    private final List<PlanFragmentId> sortedFragments = new ArrayList();
    private final Set<StageExecution> activeStages = new LinkedHashSet();

    @GuardedBy("this")
    private SettableFuture<Void> rescheduleFuture = SettableFuture.create();
    private final DirectedGraph<PlanFragmentId, FragmentsEdge> fragmentDependency = new DefaultDirectedGraph(new FragmentsEdgeFactory());
    private final DirectedGraph<PlanFragmentId, FragmentsEdge> fragmentTopology = new DefaultDirectedGraph(new FragmentsEdgeFactory());

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/execution/scheduler/policy/PhasedExecutionSchedule$FragmentSubGraph.class */
    public static class FragmentSubGraph {
        private final Set<PlanFragmentId> upstreamFragments;
        private final Set<PlanFragmentId> lazyUpstreamFragments;
        private final boolean currentFragmentLazy;

        public FragmentSubGraph(Set<PlanFragmentId> set, Set<PlanFragmentId> set2, boolean z) {
            this.upstreamFragments = (Set) Objects.requireNonNull(set, "upstreamFragments is null");
            this.lazyUpstreamFragments = (Set) Objects.requireNonNull(set2, "lazyUpstreamFragments is null");
            this.currentFragmentLazy = z;
        }

        public Set<PlanFragmentId> getUpstreamFragments() {
            return this.upstreamFragments;
        }

        public Set<PlanFragmentId> getLazyUpstreamFragments() {
            return this.lazyUpstreamFragments;
        }

        public boolean isCurrentFragmentLazy() {
            return this.currentFragmentLazy;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @VisibleForTesting
    /* loaded from: input_file:io/trino/execution/scheduler/policy/PhasedExecutionSchedule$FragmentsEdge.class */
    public static class FragmentsEdge {
        private final PlanFragmentId source;
        private final PlanFragmentId target;

        public FragmentsEdge(PlanFragmentId planFragmentId, PlanFragmentId planFragmentId2) {
            this.source = (PlanFragmentId) Objects.requireNonNull(planFragmentId, "source is null");
            this.target = (PlanFragmentId) Objects.requireNonNull(planFragmentId2, "target is null");
        }

        public PlanFragmentId getSource() {
            return this.source;
        }

        public PlanFragmentId getTarget() {
            return this.target;
        }

        public String toString() {
            return MoreObjects.toStringHelper(this).add("source", this.source).add("target", this.target).toString();
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            FragmentsEdge fragmentsEdge = (FragmentsEdge) obj;
            return this.source.equals(fragmentsEdge.source) && this.target.equals(fragmentsEdge.target);
        }

        public int hashCode() {
            return Objects.hash(this.source, this.target);
        }
    }

    /* loaded from: input_file:io/trino/execution/scheduler/policy/PhasedExecutionSchedule$FragmentsEdgeFactory.class */
    private static class FragmentsEdgeFactory implements EdgeFactory<PlanFragmentId, FragmentsEdge> {
        private FragmentsEdgeFactory() {
        }

        public FragmentsEdge createEdge(PlanFragmentId planFragmentId, PlanFragmentId planFragmentId2) {
            return new FragmentsEdge(planFragmentId, planFragmentId2);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/execution/scheduler/policy/PhasedExecutionSchedule$Visitor.class */
    public class Visitor extends PlanVisitor<FragmentSubGraph, PlanFragmentId> {
        private final QueryId queryId;
        private final Map<PlanFragmentId, PlanFragment> fragments;
        private final ImmutableSet.Builder<PlanFragmentId> nonLazyFragments = ImmutableSet.builder();
        private final Map<PlanFragmentId, FragmentSubGraph> fragmentSubGraphs = new HashMap();

        public Visitor(QueryId queryId, Collection<PlanFragment> collection) {
            this.queryId = queryId;
            this.fragments = (Map) ((Collection) Objects.requireNonNull(collection, "fragments is null")).stream().collect(ImmutableMap.toImmutableMap((v0) -> {
                return v0.getId();
            }, Function.identity()));
        }

        public Set<PlanFragmentId> getNonLazyFragments() {
            return this.nonLazyFragments.build();
        }

        public void processAllFragments() {
            this.fragments.forEach((planFragmentId, planFragment) -> {
                PhasedExecutionSchedule.this.fragmentDependency.addVertex(planFragmentId);
                PhasedExecutionSchedule.this.fragmentTopology.addVertex(planFragmentId);
            });
            Set set = (Set) this.fragments.values().stream().map((v0) -> {
                return v0.getRemoteSourceNodes();
            }).flatMap((v0) -> {
                return v0.stream();
            }).map((v0) -> {
                return v0.getSourceFragmentIds();
            }).flatMap((v0) -> {
                return v0.stream();
            }).collect(ImmutableSet.toImmutableSet());
            this.fragments.keySet().stream().filter(planFragmentId2 -> {
                return !set.contains(planFragmentId2);
            }).forEach(this::processFragment);
        }

        public FragmentSubGraph processFragment(PlanFragmentId planFragmentId) {
            if (this.fragmentSubGraphs.containsKey(planFragmentId)) {
                return this.fragmentSubGraphs.get(planFragmentId);
            }
            FragmentSubGraph processFragment = processFragment(this.fragments.get(planFragmentId));
            Verify.verify(this.fragmentSubGraphs.put(planFragmentId, processFragment) == null, "fragment %s was already processed", planFragmentId);
            PhasedExecutionSchedule.this.sortedFragments.add(planFragmentId);
            return processFragment;
        }

        private FragmentSubGraph processFragment(PlanFragment planFragment) {
            ImmutableSet lazyUpstreamFragments;
            FragmentSubGraph fragmentSubGraph = (FragmentSubGraph) planFragment.getRoot().accept(this, planFragment.getId());
            ImmutableSet build = ImmutableSet.builder().addAll(fragmentSubGraph.getUpstreamFragments()).add(planFragment.getId()).build();
            if (fragmentSubGraph.isCurrentFragmentLazy()) {
                lazyUpstreamFragments = ImmutableSet.builder().addAll(fragmentSubGraph.getLazyUpstreamFragments()).add(planFragment.getId()).build();
            } else {
                lazyUpstreamFragments = fragmentSubGraph.getLazyUpstreamFragments();
                this.nonLazyFragments.add(planFragment.getId());
            }
            return new FragmentSubGraph(build, lazyUpstreamFragments, false);
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public FragmentSubGraph visitJoin(JoinNode joinNode, PlanFragmentId planFragmentId) {
            return processJoin(joinNode.getDistributionType().orElseThrow() == JoinNode.DistributionType.REPLICATED, joinNode.getLeft(), joinNode.getRight(), planFragmentId);
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public FragmentSubGraph visitSpatialJoin(SpatialJoinNode spatialJoinNode, PlanFragmentId planFragmentId) {
            return processJoin(spatialJoinNode.getDistributionType() == SpatialJoinNode.DistributionType.REPLICATED, spatialJoinNode.getLeft(), spatialJoinNode.getRight(), planFragmentId);
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public FragmentSubGraph visitSemiJoin(SemiJoinNode semiJoinNode, PlanFragmentId planFragmentId) {
            return processJoin(semiJoinNode.getDistributionType().orElseThrow() == SemiJoinNode.DistributionType.REPLICATED, semiJoinNode.getSource(), semiJoinNode.getFilteringSource(), planFragmentId);
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public FragmentSubGraph visitIndexJoin(IndexJoinNode indexJoinNode, PlanFragmentId planFragmentId) {
            return processJoin(true, indexJoinNode.getProbeSource(), indexJoinNode.getIndexSource(), planFragmentId);
        }

        private FragmentSubGraph processJoin(boolean z, PlanNode planNode, PlanNode planNode2, PlanFragmentId planFragmentId) {
            FragmentSubGraph fragmentSubGraph = (FragmentSubGraph) planNode2.accept(this, planFragmentId);
            FragmentSubGraph fragmentSubGraph2 = (FragmentSubGraph) planNode.accept(this, planFragmentId);
            addDependencyEdges(fragmentSubGraph.getUpstreamFragments(), fragmentSubGraph2.getLazyUpstreamFragments());
            boolean z2 = fragmentSubGraph2.isCurrentFragmentLazy() && fragmentSubGraph.isCurrentFragmentLazy();
            if (z && z2 && !PhasedExecutionSchedule.this.dynamicFilterService.isStageSchedulingNeededToCollectDynamicFilters(this.queryId, this.fragments.get(planFragmentId))) {
                addDependencyEdges(fragmentSubGraph.getUpstreamFragments(), ImmutableSet.of(planFragmentId));
            } else {
                z2 = false;
            }
            return new FragmentSubGraph(ImmutableSet.builder().addAll(fragmentSubGraph2.getUpstreamFragments()).addAll(fragmentSubGraph.getUpstreamFragments()).build(), fragmentSubGraph2.getLazyUpstreamFragments(), z2);
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public FragmentSubGraph visitAggregation(AggregationNode aggregationNode, PlanFragmentId planFragmentId) {
            FragmentSubGraph fragmentSubGraph = (FragmentSubGraph) aggregationNode.getSource().accept(this, planFragmentId);
            return (aggregationNode.getStep() == AggregationNode.Step.FINAL || aggregationNode.getStep() == AggregationNode.Step.SINGLE) ? new FragmentSubGraph(fragmentSubGraph.getUpstreamFragments(), ImmutableSet.of(), false) : fragmentSubGraph;
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public FragmentSubGraph visitRemoteSource(RemoteSourceNode remoteSourceNode, PlanFragmentId planFragmentId) {
            List list = (List) remoteSourceNode.getSourceFragmentIds().stream().map(this::processFragment).collect(ImmutableList.toImmutableList());
            remoteSourceNode.getSourceFragmentIds().forEach(planFragmentId2 -> {
                PhasedExecutionSchedule.this.fragmentTopology.addEdge(planFragmentId2, planFragmentId);
            });
            return new FragmentSubGraph((Set) list.stream().flatMap(fragmentSubGraph -> {
                return fragmentSubGraph.getUpstreamFragments().stream();
            }).collect(ImmutableSet.toImmutableSet()), (Set) list.stream().flatMap(fragmentSubGraph2 -> {
                return fragmentSubGraph2.getLazyUpstreamFragments().stream();
            }).collect(ImmutableSet.toImmutableSet()), true);
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public FragmentSubGraph visitExchange(ExchangeNode exchangeNode, PlanFragmentId planFragmentId) {
            Preconditions.checkArgument(exchangeNode.getScope() == ExchangeNode.Scope.LOCAL, "Only local exchanges are supported in the phased execution scheduler");
            return visitPlan((PlanNode) exchangeNode, planFragmentId);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.trino.sql.planner.plan.PlanVisitor
        public FragmentSubGraph visitPlan(PlanNode planNode, PlanFragmentId planFragmentId) {
            List list = (List) planNode.getSources().stream().map(planNode2 -> {
                return (FragmentSubGraph) planNode2.accept(this, planFragmentId);
            }).collect(ImmutableList.toImmutableList());
            return new FragmentSubGraph((Set) list.stream().flatMap(fragmentSubGraph -> {
                return fragmentSubGraph.getUpstreamFragments().stream();
            }).collect(ImmutableSet.toImmutableSet()), (Set) list.stream().flatMap(fragmentSubGraph2 -> {
                return fragmentSubGraph2.getLazyUpstreamFragments().stream();
            }).collect(ImmutableSet.toImmutableSet()), list.stream().allMatch((v0) -> {
                return v0.isCurrentFragmentLazy();
            }));
        }

        private void addDependencyEdges(Set<PlanFragmentId> set, Set<PlanFragmentId> set2) {
            for (PlanFragmentId planFragmentId : set2) {
                Iterator<PlanFragmentId> it = set.iterator();
                while (it.hasNext()) {
                    PhasedExecutionSchedule.this.fragmentDependency.addEdge(it.next(), planFragmentId);
                }
            }
        }
    }

    public static PhasedExecutionSchedule forStages(Collection<StageExecution> collection, DynamicFilterService dynamicFilterService) {
        PhasedExecutionSchedule phasedExecutionSchedule = new PhasedExecutionSchedule(collection, dynamicFilterService);
        phasedExecutionSchedule.init(collection);
        return phasedExecutionSchedule;
    }

    private PhasedExecutionSchedule(Collection<StageExecution> collection, DynamicFilterService dynamicFilterService) {
        this.stagesByFragmentId = (Map) collection.stream().collect(ImmutableMap.toImmutableMap(stageExecution -> {
            return stageExecution.getFragment().getId();
        }, Function.identity()));
        this.dynamicFilterService = (DynamicFilterService) Objects.requireNonNull(dynamicFilterService, "dynamicFilterService is null");
    }

    private void init(Collection<StageExecution> collection) {
        ImmutableSet.Builder builder = ImmutableSet.builder();
        builder.addAll(extractDependenciesAndReturnNonLazyFragments(collection));
        Stream filter = this.fragmentDependency.vertexSet().stream().filter(planFragmentId -> {
            return this.fragmentDependency.inDegreeOf(planFragmentId) == 0;
        });
        Objects.requireNonNull(builder);
        filter.forEach((v1) -> {
            r1.add(v1);
        });
        this.fragmentOrdering = Ordering.explicit(this.sortedFragments);
        selectForExecution((Set<PlanFragmentId>) builder.build());
    }

    @Override // io.trino.execution.scheduler.policy.ExecutionSchedule
    public StagesScheduleResult getStagesToSchedule() {
        Optional<ListenableFuture<Void>> rescheduleFuture = getRescheduleFuture();
        schedule();
        return new StagesScheduleResult(this.activeStages, rescheduleFuture);
    }

    @Override // io.trino.execution.scheduler.policy.ExecutionSchedule
    public boolean isFinished() {
        return this.fragmentDependency.vertexSet().isEmpty();
    }

    @VisibleForTesting
    synchronized Optional<ListenableFuture<Void>> getRescheduleFuture() {
        return Optional.of(this.rescheduleFuture);
    }

    @VisibleForTesting
    void schedule() {
        ImmutableSet.Builder builder = ImmutableSet.builder();
        builder.addAll(removeCompletedStages());
        builder.addAll(unblockStagesWithFullOutputBuffer());
        selectForExecution((Set<PlanFragmentId>) builder.build());
    }

    @VisibleForTesting
    List<PlanFragmentId> getSortedFragments() {
        return this.sortedFragments;
    }

    @VisibleForTesting
    DirectedGraph<PlanFragmentId, FragmentsEdge> getFragmentDependency() {
        return this.fragmentDependency;
    }

    @VisibleForTesting
    Set<StageExecution> getActiveStages() {
        return this.activeStages;
    }

    private Set<PlanFragmentId> removeCompletedStages() {
        return (Set) ((Set) this.stagesByFragmentId.values().stream().filter(this::isStageCompleted).collect(ImmutableSet.toImmutableSet())).stream().flatMap(stageExecution -> {
            return removeCompletedStage(stageExecution).stream();
        }).collect(ImmutableSet.toImmutableSet());
    }

    private Set<PlanFragmentId> removeCompletedStage(StageExecution stageExecution) {
        PlanFragmentId id = stageExecution.getFragment().getId();
        if (!this.fragmentDependency.containsVertex(id)) {
            return ImmutableSet.of();
        }
        Set<PlanFragmentId> set = (Set) this.fragmentDependency.outgoingEdgesOf(id).stream().map((v0) -> {
            return v0.getTarget();
        }).filter(planFragmentId -> {
            return this.fragmentDependency.inDegreeOf(planFragmentId) == 1;
        }).collect(ImmutableSet.toImmutableSet());
        this.fragmentDependency.removeVertex(id);
        this.fragmentTopology.removeVertex(id);
        this.activeStages.remove(stageExecution);
        return set;
    }

    private Set<PlanFragmentId> unblockStagesWithFullOutputBuffer() {
        return (Set) ((Set) this.activeStages.stream().filter((v0) -> {
            return v0.isAnyTaskBlocked();
        }).map(stageExecution -> {
            return stageExecution.getFragment().getId();
        }).collect(ImmutableSet.toImmutableSet())).stream().flatMap(planFragmentId -> {
            return this.fragmentTopology.outgoingEdgesOf(planFragmentId).stream();
        }).map((v0) -> {
            return v0.getTarget();
        }).collect(ImmutableSet.toImmutableSet());
    }

    private void selectForExecution(Set<PlanFragmentId> set) {
        Objects.requireNonNull(this.fragmentOrdering, "fragmentOrdering is null");
        Stream<PlanFragmentId> sorted = set.stream().sorted(this.fragmentOrdering);
        Map<PlanFragmentId, StageExecution> map = this.stagesByFragmentId;
        Objects.requireNonNull(map);
        sorted.map((v1) -> {
            return r1.get(v1);
        }).forEach(this::selectForExecution);
    }

    private void selectForExecution(StageExecution stageExecution) {
        if (isStageCompleted(stageExecution)) {
            return;
        }
        if (this.fragmentDependency.outDegreeOf(stageExecution.getFragment().getId()) > 0) {
            stageExecution.addStateChangeListener(state -> {
                if (isStageCompleted(stageExecution)) {
                    notifyReschedule();
                }
            });
        }
        this.activeStages.add(stageExecution);
    }

    private void notifyReschedule() {
        SettableFuture<Void> settableFuture;
        synchronized (this) {
            settableFuture = this.rescheduleFuture;
            this.rescheduleFuture = SettableFuture.create();
        }
        settableFuture.set((Object) null);
    }

    private boolean isStageCompleted(StageExecution stageExecution) {
        StageExecution.State state = stageExecution.getState();
        return state == StageExecution.State.SCHEDULED || state == StageExecution.State.RUNNING || state == StageExecution.State.FLUSHING || state.isDone();
    }

    private Set<PlanFragmentId> extractDependenciesAndReturnNonLazyFragments(Collection<StageExecution> collection) {
        if (collection.isEmpty()) {
            return ImmutableSet.of();
        }
        Visitor visitor = new Visitor((QueryId) collection.stream().map(stageExecution -> {
            return stageExecution.getStageId().getQueryId();
        }).findAny().orElseThrow(), (List) collection.stream().map((v0) -> {
            return v0.getFragment();
        }).collect(ImmutableList.toImmutableList()));
        visitor.processAllFragments();
        Verify.verify(new StrongConnectivityInspector(this.fragmentDependency).stronglyConnectedSets().size() == this.fragmentDependency.vertexSet().size(), "circular dependency between stages", new Object[0]);
        return visitor.getNonLazyFragments();
    }
}
