/*
 * Decompiled with CFR 0.152.
 */
package io.ray.streaming.jobgraph;

import io.ray.streaming.api.Language;
import io.ray.streaming.api.partition.Partition;
import io.ray.streaming.api.partition.impl.ForwardPartition;
import io.ray.streaming.api.partition.impl.RoundRobinPartition;
import io.ray.streaming.jobgraph.JobEdge;
import io.ray.streaming.jobgraph.JobGraph;
import io.ray.streaming.jobgraph.JobVertex;
import io.ray.streaming.operator.ChainStrategy;
import io.ray.streaming.operator.StreamOperator;
import io.ray.streaming.operator.chain.ChainedOperator;
import io.ray.streaming.python.PythonOperator;
import io.ray.streaming.python.PythonPartition;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;

public class JobGraphOptimizer {
    private final JobGraph jobGraph;
    private Set<JobVertex> visited = new HashSet<JobVertex>();
    private Map<Integer, JobVertex> vertexMap;
    private Map<JobVertex, Set<JobEdge>> outputEdgesMap;
    private Map<Integer, Pair<JobVertex, List<JobVertex>>> mergedVertexMap;

    public JobGraphOptimizer(JobGraph jobGraph) {
        this.jobGraph = jobGraph;
        this.vertexMap = jobGraph.getJobVertices().stream().collect(Collectors.toMap(JobVertex::getVertexId, Function.identity()));
        this.outputEdgesMap = this.vertexMap.keySet().stream().collect(Collectors.toMap(id -> this.vertexMap.get(id), id -> new HashSet<JobEdge>(jobGraph.getVertexOutputEdges((int)id))));
        this.mergedVertexMap = new HashMap<Integer, Pair<JobVertex, List<JobVertex>>>();
    }

    public JobGraph optimize() {
        this.jobGraph.getSourceVertices().forEach(vertex -> {
            ArrayList<JobVertex> verticesToMerge = new ArrayList<JobVertex>();
            verticesToMerge.add((JobVertex)vertex);
            this.mergeVerticesRecursively((JobVertex)vertex, (List<JobVertex>)verticesToMerge);
        });
        List<JobVertex> vertices = this.mergedVertexMap.values().stream().map(Pair::getLeft).collect(Collectors.toList());
        return new JobGraph(this.jobGraph.getJobName(), this.jobGraph.getJobConfig(), vertices, this.createEdges());
    }

    private void mergeVerticesRecursively(JobVertex vertex, List<JobVertex> verticesToMerge) {
        if (!this.visited.contains(vertex)) {
            this.visited.add(vertex);
            Set<JobEdge> outputEdges = this.outputEdgesMap.get(vertex);
            if (outputEdges.isEmpty()) {
                this.mergeAndAddVertex(verticesToMerge);
            } else {
                outputEdges.forEach(edge -> {
                    JobVertex succeedingVertex = this.vertexMap.get(edge.getTargetVertexId());
                    if (this.canBeChained(vertex, succeedingVertex, (JobEdge)edge)) {
                        verticesToMerge.add(succeedingVertex);
                        this.mergeVerticesRecursively(succeedingVertex, verticesToMerge);
                    } else {
                        this.mergeAndAddVertex(verticesToMerge);
                        ArrayList<JobVertex> newMergedVertices = new ArrayList<JobVertex>();
                        newMergedVertices.add(succeedingVertex);
                        this.mergeVerticesRecursively(succeedingVertex, newMergedVertices);
                    }
                });
            }
        }
    }

    private void mergeAndAddVertex(List<JobVertex> verticesToMerge) {
        JobVertex mergedVertex;
        JobVertex headVertex = verticesToMerge.get(0);
        Language language = headVertex.getLanguage();
        if (verticesToMerge.size() == 1) {
            mergedVertex = headVertex;
        } else {
            StreamOperator operator;
            List<StreamOperator> operators = verticesToMerge.stream().map(v -> this.vertexMap.get(v.getVertexId()).getStreamOperator()).collect(Collectors.toList());
            List<Map<String, String>> configs = verticesToMerge.stream().map(v -> this.vertexMap.get(v.getVertexId()).getConfig()).collect(Collectors.toList());
            if (language == Language.JAVA) {
                operator = ChainedOperator.newChainedOperator(operators, configs);
            } else {
                List<PythonOperator> pythonOperators = operators.stream().map(o -> (PythonOperator)o).collect(Collectors.toList());
                operator = new PythonOperator.ChainedPythonOperator(pythonOperators, configs);
            }
            mergedVertex = new JobVertex(headVertex.getVertexId(), headVertex.getParallelism(), headVertex.getVertexType(), operator, new HashMap<String, String>());
        }
        this.mergedVertexMap.put(mergedVertex.getVertexId(), (Pair<JobVertex, List<JobVertex>>)Pair.of((Object)mergedVertex, verticesToMerge));
    }

    private List<JobEdge> createEdges() {
        ArrayList<JobEdge> edges = new ArrayList<JobEdge>();
        this.mergedVertexMap.forEach((id, pair) -> {
            JobVertex mergedVertex = (JobVertex)pair.getLeft();
            List mergedVertices = (List)pair.getRight();
            JobVertex tailVertex = (JobVertex)mergedVertices.get(mergedVertices.size() - 1);
            if (this.outputEdgesMap.containsKey(tailVertex)) {
                this.outputEdgesMap.get(tailVertex).forEach(edge -> {
                    Pair<JobVertex, List<JobVertex>> downstreamPair = this.mergedVertexMap.get(edge.getTargetVertexId());
                    Partition partition = this.changePartition(edge.getPartition());
                    JobEdge newEdge = new JobEdge(mergedVertex.getVertexId(), ((JobVertex)downstreamPair.getLeft()).getVertexId(), partition);
                    edges.add(newEdge);
                });
            }
        });
        return edges;
    }

    private Partition changePartition(Partition partition) {
        if (partition instanceof PythonPartition) {
            PythonPartition pythonPartition = (PythonPartition)partition;
            if (!pythonPartition.isConstructedFromBinary() && pythonPartition.getFunctionName().equals("ForwardPartition")) {
                return PythonPartition.RoundRobinPartition;
            }
            return partition;
        }
        if (partition instanceof ForwardPartition) {
            return new RoundRobinPartition();
        }
        return partition;
    }

    private boolean canBeChained(JobVertex precedingVertex, JobVertex succeedingVertex, JobEdge edge) {
        if (this.jobGraph.getVertexOutputEdges(precedingVertex.getVertexId()).size() > 1 || this.jobGraph.getVertexInputEdges(succeedingVertex.getVertexId()).size() > 1) {
            return false;
        }
        if (precedingVertex.getParallelism() != succeedingVertex.getParallelism()) {
            return false;
        }
        if (precedingVertex.getStreamOperator().getChainStrategy() == ChainStrategy.NEVER || succeedingVertex.getStreamOperator().getChainStrategy() == ChainStrategy.NEVER || succeedingVertex.getStreamOperator().getChainStrategy() == ChainStrategy.HEAD) {
            return false;
        }
        if (precedingVertex.getLanguage() != succeedingVertex.getLanguage()) {
            return false;
        }
        Partition partition = edge.getPartition();
        if (!(partition instanceof PythonPartition)) {
            return partition instanceof ForwardPartition;
        }
        PythonPartition pythonPartition = (PythonPartition)partition;
        return !pythonPartition.isConstructedFromBinary() && pythonPartition.getFunctionName().equals("ForwardPartition");
    }
}

