diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonTable.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonTable.scala index 254c7afc8b31..d0d28bd60666 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonTable.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonTable.scala @@ -31,9 +31,10 @@ import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.sql.PaimonUtils._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, BasePredicate, Expression, Literal, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, BasePredicate, Expression, IsNull, Literal, UnsafeProjection} import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate +import org.apache.spark.sql.catalyst.plans.FullOuter import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.functions.{col, lit, monotonically_increasing_id, sum} @@ -222,29 +223,26 @@ case class MergeIntoPaimonTable( deletionVectorEnabled: Boolean = false, extraMetadataCols: Seq[PaimonMetadataColumn] = Seq.empty, writeRowTracking: Boolean = false): Dataset[Row] = { - val targetDS = targetDataset - .withColumn(TARGET_ROW_COL, lit(true)) - - val sourceDS = createDataset(sparkSession, sourceTable) - .withColumn(SOURCE_ROW_COL, lit(true)) - - val joinedDS = sourceDS.join(targetDS, toColumn(mergeCondition), "fullOuter") + val targetPlan = targetDataset.queryExecution.analyzed + val targetProject = + Project(targetPlan.output :+ Alias(Literal(true), TARGET_ROW_COL)(), targetPlan) + val sourceProject = + Project(sourceTable.output :+ Alias(Literal(true), SOURCE_ROW_COL)(), sourceTable) + val joinNode = + Join(sourceProject, targetProject, FullOuter, Some(mergeCondition), JoinHint.NONE) + val joinedDS = createDataset(sparkSession, joinNode) val joinedPlan = joinedDS.queryExecution.analyzed - def resolveOnJoinedPlan(exprs: Seq[Expression]): Seq[Expression] = { - resolveExpressions(sparkSession)(exprs, joinedPlan) - } + val resolver = sparkSession.sessionState.conf.resolver + def attribute(name: String) = joinedPlan.output.find(attr => resolver(name, attr.name)) - val targetRowNotMatched = resolveOnJoinedPlan( - Seq(toExpression(sparkSession, col(SOURCE_ROW_COL).isNull))).head - val sourceRowNotMatched = resolveOnJoinedPlan( - Seq(toExpression(sparkSession, col(TARGET_ROW_COL).isNull))).head + val targetRowNotMatched = IsNull(attribute(SOURCE_ROW_COL).get) + val sourceRowNotMatched = IsNull(attribute(TARGET_ROW_COL).get) val matchedExprs = matchedActions.map(_.condition.getOrElse(TrueLiteral)) val notMatchedExprs = notMatchedActions.map(_.condition.getOrElse(TrueLiteral)) - val notMatchedBySourceExprs = notMatchedBySourceActions.map(_.condition.getOrElse(TrueLiteral)) + val notMatchedBySourceExprs = + notMatchedBySourceActions.map(_.condition.getOrElse(TrueLiteral)) - val resolver = sparkSession.sessionState.conf.resolver - def attribute(name: String) = joinedPlan.output.find(attr => resolver(name, attr.name)) val extraMetadataAttributes = extraMetadataCols.flatMap(metadataCol => attribute(metadataCol.name)) val (rowIdAttr, sequenceNumberAttr) = if (writeRowTracking) { @@ -266,7 +264,7 @@ case class MergeIntoPaimonTable( def processMergeActions(actions: Seq[MergeAction]): Seq[Seq[Expression]] = { val columnExprs = actions.map { case UpdateAction(_, assignments) => - var exprs = assignments.map(_.value) + var exprs: Seq[Expression] = assignments.map(_.value) if (writeRowTracking) { exprs ++= Seq(rowIdAttr, Literal(null)) } @@ -280,7 +278,7 @@ case class MergeIntoPaimonTable( noopOutput } case InsertAction(_, assignments) => - var exprs = assignments.map(_.value) + var exprs: Seq[Expression] = assignments.map(_.value) if (writeRowTracking) { exprs ++= Seq(rowIdAttr, sequenceNumberAttr) } diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTestBase.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTestBase.scala index 683108742e3e..6c7ee127616a 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTestBase.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTestBase.scala @@ -690,6 +690,82 @@ abstract class MergeIntoTableTestBase extends PaimonSparkTestBase with PaimonTab } } + test(s"Paimon MergeInto: update with coalesce referencing both source and target columns") { + withTable("source", "target") { + + Seq((1, "guid_src_1"), (3, "guid_src_3")) + .toDF("a", "b") + .createOrReplaceTempView("source") + + createTable("target", "a INT, b STRING", Seq("a")) + spark.sql("INSERT INTO target values (1, 'guid_tgt_1'), (2, 'guid_tgt_2')") + + spark.sql(s""" + |MERGE INTO target AS dest + |USING source AS src + |ON dest.a = src.a + |WHEN MATCHED AND (nullif(cast(src.b as STRING), '') IS NOT NULL) THEN + |UPDATE SET dest.b = COALESCE(nullif(cast(src.b as STRING), ''), dest.b) + |WHEN NOT MATCHED THEN + |INSERT (a, b) VALUES (src.a, src.b) + |""".stripMargin) + + checkAnswer( + spark.sql("SELECT * FROM target ORDER BY a"), + Row(1, "guid_src_1") :: Row(2, "guid_tgt_2") :: Row(3, "guid_src_3") :: Nil) + } + } + + test(s"Paimon MergeInto: two paimon tables with coalesce referencing both source and target") { + withTable("source", "target") { + + createTable("source", "a INT, b STRING", Seq("a")) + createTable("target", "a INT, b STRING", Seq("a")) + spark.sql("INSERT INTO source values (1, 'guid_src_1'), (3, 'guid_src_3')") + spark.sql("INSERT INTO target values (1, 'guid_tgt_1'), (2, 'guid_tgt_2')") + + spark.sql(s""" + |MERGE INTO target AS dest + |USING source AS src + |ON dest.a = src.a + |WHEN MATCHED AND (nullif(cast(src.b as STRING), '') IS NOT NULL) THEN + |UPDATE SET dest.b = COALESCE(nullif(cast(src.b as STRING), ''), dest.b) + |WHEN NOT MATCHED THEN + |INSERT (a, b) VALUES (src.a, src.b) + |""".stripMargin) + + checkAnswer( + spark.sql("SELECT * FROM target ORDER BY a"), + Row(1, "guid_src_1") :: Row(2, "guid_tgt_2") :: Row(3, "guid_src_3") :: Nil) + } + } + + test(s"Paimon MergeInto: subquery source with coalesce referencing both source and target") { + withTable("source", "target") { + + Seq((1, "guid_src_1"), (3, "guid_src_3")) + .toDF("a", "b") + .createOrReplaceTempView("source") + + createTable("target", "a INT, b STRING", Seq("a")) + spark.sql("INSERT INTO target values (1, 'guid_tgt_1'), (2, 'guid_tgt_2')") + + spark.sql(s""" + |MERGE INTO target AS dest + |USING (SELECT * FROM source) AS src + |ON dest.a = src.a + |WHEN MATCHED AND (nullif(cast(src.b as STRING), '') IS NOT NULL) THEN + |UPDATE SET dest.b = COALESCE(nullif(cast(src.b as STRING), ''), dest.b) + |WHEN NOT MATCHED THEN + |INSERT (a, b) VALUES (src.a, src.b) + |""".stripMargin) + + checkAnswer( + spark.sql("SELECT * FROM target ORDER BY a"), + Row(1, "guid_src_1") :: Row(2, "guid_tgt_2") :: Row(3, "guid_src_3") :: Nil) + } + } + test(s"Paimon MergeInto: merge into with varchar") { withTable("source", "target") { createTable("source", "a INT, b VARCHAR(32)", Seq("a"))