@@ -31,9 +31,10 @@ import org.apache.spark.sql.{Dataset, Row, SparkSession}
3131import org .apache .spark .sql .PaimonUtils ._
3232import org .apache .spark .sql .catalyst .InternalRow
3333import org .apache .spark .sql .catalyst .encoders .ExpressionEncoder
34- import org .apache .spark .sql .catalyst .expressions .{Alias , Attribute , BasePredicate , Expression , Literal , UnsafeProjection }
34+ import org .apache .spark .sql .catalyst .expressions .{Alias , Attribute , BasePredicate , Expression , IsNull , Literal , UnsafeProjection }
3535import org .apache .spark .sql .catalyst .expressions .Literal .TrueLiteral
3636import org .apache .spark .sql .catalyst .expressions .codegen .GeneratePredicate
37+ import org .apache .spark .sql .catalyst .plans .FullOuter
3738import org .apache .spark .sql .catalyst .plans .logical ._
3839import org .apache .spark .sql .execution .datasources .v2 .DataSourceV2Relation
3940import org .apache .spark .sql .functions .{col , lit , monotonically_increasing_id , sum }
@@ -222,29 +223,26 @@ case class MergeIntoPaimonTable(
222223 deletionVectorEnabled : Boolean = false ,
223224 extraMetadataCols : Seq [PaimonMetadataColumn ] = Seq .empty,
224225 writeRowTracking : Boolean = false ): Dataset [Row ] = {
225- val targetDS = targetDataset
226- .withColumn(TARGET_ROW_COL , lit(true ))
227-
228- val sourceDS = createDataset(sparkSession, sourceTable)
229- .withColumn(SOURCE_ROW_COL , lit(true ))
230-
231- val joinedDS = sourceDS.join(targetDS, toColumn(mergeCondition), " fullOuter" )
226+ val targetPlan = targetDataset.queryExecution.analyzed
227+ val targetProject =
228+ Project (targetPlan.output :+ Alias (Literal (true ), TARGET_ROW_COL )(), targetPlan)
229+ val sourceProject =
230+ Project (sourceTable.output :+ Alias (Literal (true ), SOURCE_ROW_COL )(), sourceTable)
231+ val joinNode =
232+ Join (sourceProject, targetProject, FullOuter , Some (mergeCondition), JoinHint .NONE )
233+ val joinedDS = createDataset(sparkSession, joinNode)
232234 val joinedPlan = joinedDS.queryExecution.analyzed
233235
234- def resolveOnJoinedPlan (exprs : Seq [Expression ]): Seq [Expression ] = {
235- resolveExpressions(sparkSession)(exprs, joinedPlan)
236- }
236+ val resolver = sparkSession.sessionState.conf.resolver
237+ def attribute (name : String ) = joinedPlan.output.find(attr => resolver(name, attr.name))
237238
238- val targetRowNotMatched = resolveOnJoinedPlan(
239- Seq (toExpression(sparkSession, col(SOURCE_ROW_COL ).isNull))).head
240- val sourceRowNotMatched = resolveOnJoinedPlan(
241- Seq (toExpression(sparkSession, col(TARGET_ROW_COL ).isNull))).head
239+ val targetRowNotMatched = IsNull (attribute(SOURCE_ROW_COL ).get)
240+ val sourceRowNotMatched = IsNull (attribute(TARGET_ROW_COL ).get)
242241 val matchedExprs = matchedActions.map(_.condition.getOrElse(TrueLiteral ))
243242 val notMatchedExprs = notMatchedActions.map(_.condition.getOrElse(TrueLiteral ))
244- val notMatchedBySourceExprs = notMatchedBySourceActions.map(_.condition.getOrElse(TrueLiteral ))
243+ val notMatchedBySourceExprs =
244+ notMatchedBySourceActions.map(_.condition.getOrElse(TrueLiteral ))
245245
246- val resolver = sparkSession.sessionState.conf.resolver
247- def attribute (name : String ) = joinedPlan.output.find(attr => resolver(name, attr.name))
248246 val extraMetadataAttributes =
249247 extraMetadataCols.flatMap(metadataCol => attribute(metadataCol.name))
250248 val (rowIdAttr, sequenceNumberAttr) = if (writeRowTracking) {
@@ -266,7 +264,7 @@ case class MergeIntoPaimonTable(
266264 def processMergeActions (actions : Seq [MergeAction ]): Seq [Seq [Expression ]] = {
267265 val columnExprs = actions.map {
268266 case UpdateAction (_, assignments) =>
269- var exprs = assignments.map(_.value)
267+ var exprs : Seq [ Expression ] = assignments.map(_.value)
270268 if (writeRowTracking) {
271269 exprs ++= Seq (rowIdAttr, Literal (null ))
272270 }
@@ -280,7 +278,7 @@ case class MergeIntoPaimonTable(
280278 noopOutput
281279 }
282280 case InsertAction (_, assignments) =>
283- var exprs = assignments.map(_.value)
281+ var exprs : Seq [ Expression ] = assignments.map(_.value)
284282 if (writeRowTracking) {
285283 exprs ++= Seq (rowIdAttr, sequenceNumberAttr)
286284 }
0 commit comments