Skip to content

Commit 097bfcb

Browse files
Zouxxyyclaude
andcommitted
[spark] Fix exprId mismatch in MERGE INTO by constructing Join plan directly
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent f6cb6c8 commit 097bfcb

File tree

2 files changed

+94
-20
lines changed

2 files changed

+94
-20
lines changed

paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonTable.scala

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@ import org.apache.spark.sql.{Dataset, Row, SparkSession}
3131
import org.apache.spark.sql.PaimonUtils._
3232
import org.apache.spark.sql.catalyst.InternalRow
3333
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
34-
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, BasePredicate, Expression, Literal, UnsafeProjection}
34+
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, BasePredicate, Expression, IsNull, Literal, UnsafeProjection}
3535
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
3636
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
37+
import org.apache.spark.sql.catalyst.plans.FullOuter
3738
import org.apache.spark.sql.catalyst.plans.logical._
3839
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
3940
import org.apache.spark.sql.functions.{col, lit, monotonically_increasing_id, sum}
@@ -222,29 +223,26 @@ case class MergeIntoPaimonTable(
222223
deletionVectorEnabled: Boolean = false,
223224
extraMetadataCols: Seq[PaimonMetadataColumn] = Seq.empty,
224225
writeRowTracking: Boolean = false): Dataset[Row] = {
225-
val targetDS = targetDataset
226-
.withColumn(TARGET_ROW_COL, lit(true))
227-
228-
val sourceDS = createDataset(sparkSession, sourceTable)
229-
.withColumn(SOURCE_ROW_COL, lit(true))
230-
231-
val joinedDS = sourceDS.join(targetDS, toColumn(mergeCondition), "fullOuter")
226+
val targetPlan = targetDataset.queryExecution.analyzed
227+
val targetProject =
228+
Project(targetPlan.output :+ Alias(Literal(true), TARGET_ROW_COL)(), targetPlan)
229+
val sourceProject =
230+
Project(sourceTable.output :+ Alias(Literal(true), SOURCE_ROW_COL)(), sourceTable)
231+
val joinNode =
232+
Join(sourceProject, targetProject, FullOuter, Some(mergeCondition), JoinHint.NONE)
233+
val joinedDS = createDataset(sparkSession, joinNode)
232234
val joinedPlan = joinedDS.queryExecution.analyzed
233235

234-
def resolveOnJoinedPlan(exprs: Seq[Expression]): Seq[Expression] = {
235-
resolveExpressions(sparkSession)(exprs, joinedPlan)
236-
}
236+
val resolver = sparkSession.sessionState.conf.resolver
237+
def attribute(name: String) = joinedPlan.output.find(attr => resolver(name, attr.name))
237238

238-
val targetRowNotMatched = resolveOnJoinedPlan(
239-
Seq(toExpression(sparkSession, col(SOURCE_ROW_COL).isNull))).head
240-
val sourceRowNotMatched = resolveOnJoinedPlan(
241-
Seq(toExpression(sparkSession, col(TARGET_ROW_COL).isNull))).head
239+
val targetRowNotMatched = IsNull(attribute(SOURCE_ROW_COL).get)
240+
val sourceRowNotMatched = IsNull(attribute(TARGET_ROW_COL).get)
242241
val matchedExprs = matchedActions.map(_.condition.getOrElse(TrueLiteral))
243242
val notMatchedExprs = notMatchedActions.map(_.condition.getOrElse(TrueLiteral))
244-
val notMatchedBySourceExprs = notMatchedBySourceActions.map(_.condition.getOrElse(TrueLiteral))
243+
val notMatchedBySourceExprs =
244+
notMatchedBySourceActions.map(_.condition.getOrElse(TrueLiteral))
245245

246-
val resolver = sparkSession.sessionState.conf.resolver
247-
def attribute(name: String) = joinedPlan.output.find(attr => resolver(name, attr.name))
248246
val extraMetadataAttributes =
249247
extraMetadataCols.flatMap(metadataCol => attribute(metadataCol.name))
250248
val (rowIdAttr, sequenceNumberAttr) = if (writeRowTracking) {
@@ -266,7 +264,7 @@ case class MergeIntoPaimonTable(
266264
def processMergeActions(actions: Seq[MergeAction]): Seq[Seq[Expression]] = {
267265
val columnExprs = actions.map {
268266
case UpdateAction(_, assignments) =>
269-
var exprs = assignments.map(_.value)
267+
var exprs: Seq[Expression] = assignments.map(_.value)
270268
if (writeRowTracking) {
271269
exprs ++= Seq(rowIdAttr, Literal(null))
272270
}
@@ -280,7 +278,7 @@ case class MergeIntoPaimonTable(
280278
noopOutput
281279
}
282280
case InsertAction(_, assignments) =>
283-
var exprs = assignments.map(_.value)
281+
var exprs: Seq[Expression] = assignments.map(_.value)
284282
if (writeRowTracking) {
285283
exprs ++= Seq(rowIdAttr, sequenceNumberAttr)
286284
}

paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTestBase.scala

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,82 @@ abstract class MergeIntoTableTestBase extends PaimonSparkTestBase with PaimonTab
690690
}
691691
}
692692

693+
test(s"Paimon MergeInto: update with coalesce referencing both source and target columns") {
694+
withTable("source", "target") {
695+
696+
Seq((1, "guid_src_1"), (3, "guid_src_3"))
697+
.toDF("a", "b")
698+
.createOrReplaceTempView("source")
699+
700+
createTable("target", "a INT, b STRING", Seq("a"))
701+
spark.sql("INSERT INTO target values (1, 'guid_tgt_1'), (2, 'guid_tgt_2')")
702+
703+
spark.sql(s"""
704+
|MERGE INTO target AS dest
705+
|USING source AS src
706+
|ON dest.a = src.a
707+
|WHEN MATCHED AND (nullif(cast(src.b as STRING), '') IS NOT NULL) THEN
708+
|UPDATE SET dest.b = COALESCE(nullif(cast(src.b as STRING), ''), dest.b)
709+
|WHEN NOT MATCHED THEN
710+
|INSERT (a, b) VALUES (src.a, src.b)
711+
|""".stripMargin)
712+
713+
checkAnswer(
714+
spark.sql("SELECT * FROM target ORDER BY a"),
715+
Row(1, "guid_src_1") :: Row(2, "guid_tgt_2") :: Row(3, "guid_src_3") :: Nil)
716+
}
717+
}
718+
719+
test(s"Paimon MergeInto: two paimon tables with coalesce referencing both source and target") {
720+
withTable("source", "target") {
721+
722+
createTable("source", "a INT, b STRING", Seq("a"))
723+
createTable("target", "a INT, b STRING", Seq("a"))
724+
spark.sql("INSERT INTO source values (1, 'guid_src_1'), (3, 'guid_src_3')")
725+
spark.sql("INSERT INTO target values (1, 'guid_tgt_1'), (2, 'guid_tgt_2')")
726+
727+
spark.sql(s"""
728+
|MERGE INTO target AS dest
729+
|USING source AS src
730+
|ON dest.a = src.a
731+
|WHEN MATCHED AND (nullif(cast(src.b as STRING), '') IS NOT NULL) THEN
732+
|UPDATE SET dest.b = COALESCE(nullif(cast(src.b as STRING), ''), dest.b)
733+
|WHEN NOT MATCHED THEN
734+
|INSERT (a, b) VALUES (src.a, src.b)
735+
|""".stripMargin)
736+
737+
checkAnswer(
738+
spark.sql("SELECT * FROM target ORDER BY a"),
739+
Row(1, "guid_src_1") :: Row(2, "guid_tgt_2") :: Row(3, "guid_src_3") :: Nil)
740+
}
741+
}
742+
743+
test(s"Paimon MergeInto: subquery source with coalesce referencing both source and target") {
744+
withTable("source", "target") {
745+
746+
Seq((1, "guid_src_1"), (3, "guid_src_3"))
747+
.toDF("a", "b")
748+
.createOrReplaceTempView("source")
749+
750+
createTable("target", "a INT, b STRING", Seq("a"))
751+
spark.sql("INSERT INTO target values (1, 'guid_tgt_1'), (2, 'guid_tgt_2')")
752+
753+
spark.sql(s"""
754+
|MERGE INTO target AS dest
755+
|USING (SELECT * FROM source) AS src
756+
|ON dest.a = src.a
757+
|WHEN MATCHED AND (nullif(cast(src.b as STRING), '') IS NOT NULL) THEN
758+
|UPDATE SET dest.b = COALESCE(nullif(cast(src.b as STRING), ''), dest.b)
759+
|WHEN NOT MATCHED THEN
760+
|INSERT (a, b) VALUES (src.a, src.b)
761+
|""".stripMargin)
762+
763+
checkAnswer(
764+
spark.sql("SELECT * FROM target ORDER BY a"),
765+
Row(1, "guid_src_1") :: Row(2, "guid_tgt_2") :: Row(3, "guid_src_3") :: Nil)
766+
}
767+
}
768+
693769
test(s"Paimon MergeInto: merge into with varchar") {
694770
withTable("source", "target") {
695771
createTable("source", "a INT, b VARCHAR(32)", Seq("a"))

0 commit comments

Comments
 (0)