package io.openlineage.spark3.agent.lifecycle.plan.column;

import datahub.spark2.shaded.org.slf4j.Logger;
import datahub.spark2.shaded.org.slf4j.LoggerFactory;
import io.openlineage.spark.agent.lifecycle.plan.column.ColumnLevelLineageBuilder;
import io.openlineage.spark.agent.lifecycle.plan.column.ColumnLevelLineageContext;
import io.openlineage.spark.agent.util.ExtensionPlanUtils;
import io.openlineage.spark.agent.util.ScalaConversionUtils;
import io.openlineage.spark.extension.scala.v1.ColumnLevelLineageNode;
import io.openlineage.spark.extension.scala.v1.ExpressionDependencyWithDelegate;
import io.openlineage.spark.extension.scala.v1.ExpressionDependencyWithIdentifier;
import io.openlineage.spark.shaded.org.apache.commons.lang3.reflect.MethodUtils;
import io.openlineage.spark3.agent.lifecycle.plan.column.visitors.ExpressionDependencyVisitor;
import io.openlineage.spark3.agent.lifecycle.plan.column.visitors.IcebergMergeIntoDependencyVisitor;
import io.openlineage.spark3.agent.lifecycle.plan.column.visitors.UnionDependencyVisitor;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import org.apache.spark.sql.catalyst.expressions.ExprId;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression;
import org.apache.spark.sql.catalyst.plans.logical.Aggregate;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.catalyst.plans.logical.Project;
import org.apache.spark.sql.execution.datasources.LogicalRelation;
import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation;
import scala.collection.Seq;
import scala.runtime.BoxedUnit;

/* loaded from: input_file:io/openlineage/spark3/agent/lifecycle/plan/column/ExpressionDependencyCollector.class */
public class ExpressionDependencyCollector {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) ExpressionDependencyCollector.class);
    private static final List<ExpressionDependencyVisitor> expressionDependencyVisitors = Arrays.asList(new UnionDependencyVisitor(), new IcebergMergeIntoDependencyVisitor());

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void collect(ColumnLevelLineageContext columnLevelLineageContext, LogicalPlan logicalPlan) {
        logicalPlan.foreach(logicalPlan2 -> {
            collectFromNode(columnLevelLineageContext, logicalPlan2);
            return BoxedUnit.UNIT;
        });
    }

    static void collectFromNode(ColumnLevelLineageContext columnLevelLineageContext, LogicalPlan logicalPlan) {
        expressionDependencyVisitors.stream().filter(expressionDependencyVisitor -> {
            return expressionDependencyVisitor.isDefinedAt(logicalPlan);
        }).forEach(expressionDependencyVisitor2 -> {
            expressionDependencyVisitor2.apply(logicalPlan, columnLevelLineageContext.getBuilder());
        });
        CustomCollectorsUtils.collectExpressionDependencies(columnLevelLineageContext, logicalPlan);
        LinkedList linkedList = new LinkedList();
        if (logicalPlan instanceof ColumnLevelLineageNode) {
            extensionColumnLineage(columnLevelLineageContext, (ColumnLevelLineageNode) logicalPlan);
        } else if (logicalPlan instanceof Project) {
            linkedList.addAll(ScalaConversionUtils.fromSeq(((Project) logicalPlan).projectList()));
        } else if (logicalPlan instanceof Aggregate) {
            linkedList.addAll(ScalaConversionUtils.fromSeq(((Aggregate) logicalPlan).aggregateExpressions()));
        } else if ((logicalPlan instanceof LogicalRelation) && (((LogicalRelation) logicalPlan).relation() instanceof JDBCRelation)) {
            JdbcColumnLineageCollector.extractExpressionsFromJDBC(logicalPlan, columnLevelLineageContext.getBuilder());
        }
        linkedList.stream().forEach(namedExpression -> {
            traverseExpression((Expression) namedExpression, namedExpression.exprId(), columnLevelLineageContext.getBuilder());
        });
    }

    private static void extensionColumnLineage(ColumnLevelLineageContext columnLevelLineageContext, ColumnLevelLineageNode columnLevelLineageNode) {
        List fromSeq = ScalaConversionUtils.fromSeq(columnLevelLineageNode.columnLevelLineageDependencies(ExtensionPlanUtils.context(columnLevelLineageContext.getEvent(), columnLevelLineageContext.getOlContext())).toSeq());
        fromSeq.stream().filter(expressionDependency -> {
            return expressionDependency instanceof ExpressionDependencyWithDelegate;
        }).map(expressionDependency2 -> {
            return (ExpressionDependencyWithDelegate) expressionDependency2;
        }).filter(expressionDependencyWithDelegate -> {
            return expressionDependencyWithDelegate.expression() instanceof Expression;
        }).forEach(expressionDependencyWithDelegate2 -> {
            traverseExpression((Expression) expressionDependencyWithDelegate2.expression(), ExprId.apply(expressionDependencyWithDelegate2.outputExprId().exprId()), columnLevelLineageContext.getBuilder());
        });
        fromSeq.stream().filter(expressionDependency3 -> {
            return expressionDependency3 instanceof ExpressionDependencyWithIdentifier;
        }).map(expressionDependency4 -> {
            return (ExpressionDependencyWithIdentifier) expressionDependency4;
        }).forEach(expressionDependencyWithIdentifier -> {
            ScalaConversionUtils.fromSeq(expressionDependencyWithIdentifier.inputExprIds().toSeq()).stream().forEach(olExprId -> {
                columnLevelLineageContext.getBuilder().addDependency(ExprId.apply(expressionDependencyWithIdentifier.outputExprId().exprId()), ExprId.apply(olExprId.exprId()));
            });
        });
    }

    public static void traverseExpression(Expression expression, ExprId exprId, ColumnLevelLineageBuilder columnLevelLineageBuilder) {
        if ((expression instanceof NamedExpression) && !((NamedExpression) expression).exprId().equals(exprId)) {
            columnLevelLineageBuilder.addDependency(exprId, ((NamedExpression) expression).exprId());
        }
        if (expression.children() != null) {
            ScalaConversionUtils.fromSeq(expression.children()).stream().forEach(expression2 -> {
                traverseExpression(expression2, exprId, columnLevelLineageBuilder);
            });
        }
        if (expression instanceof AggregateExpression) {
            AggregateExpression aggregateExpression = (AggregateExpression) expression;
            if (MethodUtils.getAccessibleMethod(AggregateExpression.class, "resultId", new Class[0]) != null) {
                columnLevelLineageBuilder.addDependency(exprId, aggregateExpression.resultId());
                return;
            }
            try {
                ScalaConversionUtils.fromSeq((Seq) MethodUtils.invokeMethod(aggregateExpression, "resultIds")).stream().forEach(exprId2 -> {
                    columnLevelLineageBuilder.addDependency(exprId, exprId2);
                });
            } catch (Exception e) {
                log.warn("Failed extracting resultIds from AggregateExpression", (Throwable) e);
            }
        }
    }
}
