package io.openlineage.spark34.agent.lifecycle.plan.column;

import com.linkedin.metadata.aspect.models.graph.Edge;
import datahub.spark2.shaded.org.slf4j.Logger;
import datahub.spark2.shaded.org.slf4j.LoggerFactory;
import io.openlineage.spark.agent.lifecycle.plan.column.ColumnLevelLineageContext;
import io.openlineage.spark.agent.lifecycle.plan.column.ColumnLevelLineageVisitor;
import io.openlineage.spark.agent.util.ReflectionUtils;
import io.openlineage.spark.agent.util.ScalaConversionUtils;
import io.openlineage.spark.api.OpenLineageContext;
import io.openlineage.spark.shaded.org.apache.commons.lang3.reflect.MethodUtils;
import io.openlineage.spark3.agent.lifecycle.plan.column.InputFieldsCollector;
import io.openlineage.spark3.agent.lifecycle.plan.column.OutputFieldsCollector;
import java.lang.reflect.InvocationTargetException;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.spark.sql.catalyst.expressions.AttributeReference;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.plans.logical.DeltaMergeAction;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import scala.collection.Seq;
import scala.collection.mutable.ArrayBuffer;

/* loaded from: input_file:io/openlineage/spark34/agent/lifecycle/plan/column/MergeIntoCommandEdgeColumnLineageBuilder.class */
public class MergeIntoCommandEdgeColumnLineageBuilder implements ColumnLevelLineageVisitor {
    protected OpenLineageContext context;
    private static final Logger log = LoggerFactory.getLogger((Class<?>) MergeIntoCommandEdgeColumnLineageBuilder.class);
    private static String CLASS = "sql.transaction.tahoe.commands.MergeIntoCommandEdge";

    public static boolean hasClasses() {
        return ReflectionUtils.hasClasses("com.databricks.sql.transaction.tahoe.commands.MergeIntoCommandEdge", "org.apache.spark.sql.catalyst.plans.logical.DeltaMergeIntoNotMatchedClause");
    }

    public MergeIntoCommandEdgeColumnLineageBuilder(OpenLineageContext openLineageContext) {
        this.context = openLineageContext;
    }

    @Override // io.openlineage.spark.agent.lifecycle.plan.column.ColumnLevelLineageVisitor
    public void collectInputs(ColumnLevelLineageContext columnLevelLineageContext, LogicalPlan logicalPlan) {
        if (logicalPlan.getClass().getCanonicalName().endsWith(CLASS)) {
            getFieldFromNode(logicalPlan, "target").ifPresent(logicalPlan2 -> {
                InputFieldsCollector.collect(columnLevelLineageContext, logicalPlan2);
            });
            List list = (List) getMergeActions(logicalPlan).filter(expression -> {
                return expression instanceof DeltaMergeAction;
            }).map(expression2 -> {
                return (DeltaMergeAction) expression2;
            }).filter(deltaMergeAction -> {
                return deltaMergeAction.child() instanceof AttributeReference;
            }).filter(deltaMergeAction2 -> {
                return columnLevelLineageContext.getBuilder().getOutputExprIdByFieldName(deltaMergeAction2.targetColNameParts().mkString()).isPresent();
            }).map(deltaMergeAction3 -> {
                return deltaMergeAction3.child().exprId();
            }).collect(Collectors.toList());
            ((List) columnLevelLineageContext.getBuilder().getInputs().keySet().stream().filter(exprId -> {
                return !list.contains(exprId);
            }).collect(Collectors.toList())).forEach(exprId2 -> {
                columnLevelLineageContext.getBuilder().getInputs().remove(exprId2);
            });
            getFieldFromNode(logicalPlan, Edge.EDGE_FIELD_SOURCE).ifPresent(logicalPlan3 -> {
                InputFieldsCollector.collect(columnLevelLineageContext, logicalPlan3);
            });
        }
    }

    @Override // io.openlineage.spark.agent.lifecycle.plan.column.ColumnLevelLineageVisitor
    public void collectOutputs(ColumnLevelLineageContext columnLevelLineageContext, LogicalPlan logicalPlan) {
        if (logicalPlan.getClass().getCanonicalName().endsWith(CLASS)) {
            getFieldFromNode(logicalPlan, "target").ifPresent(logicalPlan2 -> {
                OutputFieldsCollector.collect(columnLevelLineageContext, logicalPlan2);
            });
        }
    }

    @Override // io.openlineage.spark.agent.lifecycle.plan.column.ColumnLevelLineageVisitor
    public void collectExpressionDependencies(ColumnLevelLineageContext columnLevelLineageContext, LogicalPlan logicalPlan) {
        if (logicalPlan.getClass().getCanonicalName().endsWith(CLASS)) {
            getMergeActions(logicalPlan).filter(expression -> {
                return expression instanceof DeltaMergeAction;
            }).map(expression2 -> {
                return (DeltaMergeAction) expression2;
            }).filter(deltaMergeAction -> {
                return deltaMergeAction.child() instanceof AttributeReference;
            }).filter(deltaMergeAction2 -> {
                return columnLevelLineageContext.getBuilder().getOutputExprIdByFieldName(deltaMergeAction2.targetColNameParts().mkString()).isPresent();
            }).forEach(deltaMergeAction3 -> {
                columnLevelLineageContext.getBuilder().addDependency(columnLevelLineageContext.getBuilder().getOutputExprIdByFieldName(deltaMergeAction3.targetColNameParts().mkString()).get(), deltaMergeAction3.child().exprId());
            });
        }
    }

    public Stream<Expression> getMergeActions(LogicalPlan logicalPlan) {
        return Stream.concat(ScalaConversionUtils.fromSeq((Seq) getFieldFromNode(logicalPlan, "matchedClauses").orElse(new ArrayBuffer())).stream().flatMap(deltaMergeIntoMatchedClause -> {
            return ScalaConversionUtils.fromSeq(deltaMergeIntoMatchedClause.actions()).stream();
        }), ScalaConversionUtils.fromSeq((Seq) getFieldFromNode(logicalPlan, "notMatchedClauses").orElse(new ArrayBuffer())).stream().flatMap(deltaMergeIntoNotMatchedClause -> {
            return ScalaConversionUtils.fromSeq(deltaMergeIntoNotMatchedClause.actions()).stream();
        }));
    }

    private <T> Optional<T> getFieldFromNode(LogicalPlan logicalPlan, String str) {
        try {
            return Optional.of(MethodUtils.invokeMethod(logicalPlan, str));
        } catch (IllegalAccessException | NoSuchMethodException | InvocationTargetException e) {
            log.warn("Couldn't extract field {} from {}", str, logicalPlan);
            return Optional.empty();
        }
    }
}
