2020package org .apache .spark .sql .catalyst .analysis
2121
2222import org .apache .spark .sql .AnalysisException
23+ import org .apache .spark .sql .catalyst .ProjectingInternalRow
2324import org .apache .spark .sql .catalyst .expressions .Alias
2425import org .apache .spark .sql .catalyst .expressions .Attribute
2526import org .apache .spark .sql .catalyst .expressions .AttributeReference
@@ -54,6 +55,7 @@ import org.apache.spark.sql.catalyst.plans.logical.ReplaceData
5455import org .apache .spark .sql .catalyst .plans .logical .UpdateAction
5556import org .apache .spark .sql .catalyst .plans .logical .WriteDelta
5657import org .apache .spark .sql .catalyst .util .RowDeltaUtils ._
58+ import org .apache .spark .sql .catalyst .util .WriteDeltaProjections
5759import org .apache .spark .sql .connector .expressions .FieldReference
5860import org .apache .spark .sql .connector .expressions .NamedReference
5961import org .apache .spark .sql .connector .iceberg .catalog .SupportsRowLevelOperations
@@ -62,6 +64,8 @@ import org.apache.spark.sql.connector.iceberg.write.SupportsDelta
6264import org .apache .spark .sql .connector .write .RowLevelOperationTable
6365import org .apache .spark .sql .execution .datasources .v2 .DataSourceV2Relation
6466import org .apache .spark .sql .types .IntegerType
67+ import org .apache .spark .sql .types .StructField
68+ import org .apache .spark .sql .types .StructType
6569
6670/**
6771 * Assigns a rewrite plan for v2 tables that support rewriting data to handle MERGE statements.
@@ -297,7 +301,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand {
297301
298302 // build a plan to write the row delta to the table
299303 val writeRelation = relation.copy(table = operationTable)
300- val projections = buildWriteDeltaProjections (mergeRows, rowAttrs, rowIdAttrs, metadataAttrs)
304+ val projections = buildMergeDeltaProjections (mergeRows, rowAttrs, rowIdAttrs, metadataAttrs)
301305 WriteDelta (writeRelation, mergeRows, relation, projections)
302306 }
303307
@@ -384,4 +388,55 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand {
384388 private def resolveAttrRef (ref : NamedReference , plan : LogicalPlan ): AttributeReference = {
385389 ExtendedV2ExpressionUtils .resolveRef[AttributeReference ](ref, plan)
386390 }
391+
392+ private def buildMergeDeltaProjections (
393+ mergeRows : MergeRows ,
394+ rowAttrs : Seq [Attribute ],
395+ rowIdAttrs : Seq [Attribute ],
396+ metadataAttrs : Seq [Attribute ]): WriteDeltaProjections = {
397+
398+ val outputAttrs = mergeRows.output
399+
400+ val outputs = mergeRows.matchedOutputs ++ mergeRows.notMatchedOutputs
401+ val insertAndUpdateOutputs = outputs.filterNot(_.head == Literal (DELETE_OPERATION ))
402+ val updateAndDeleteOutputs = outputs.filterNot(_.head == Literal (INSERT_OPERATION ))
403+
404+ val rowProjection = if (rowAttrs.nonEmpty) {
405+ Some (newLazyProjection(insertAndUpdateOutputs, outputAttrs, rowAttrs))
406+ } else {
407+ None
408+ }
409+
410+ val rowIdProjection = newLazyProjection(updateAndDeleteOutputs, outputAttrs, rowIdAttrs)
411+
412+ val metadataProjection = if (metadataAttrs.nonEmpty) {
413+ Some (newLazyProjection(updateAndDeleteOutputs, outputAttrs, metadataAttrs))
414+ } else {
415+ None
416+ }
417+
418+ WriteDeltaProjections (rowProjection, rowIdProjection, metadataProjection)
419+ }
420+
421+ // the projection is done by name, ignoring expr IDs
422+ private def newLazyProjection (
423+ outputs : Seq [Seq [Expression ]],
424+ outputAttrs : Seq [Attribute ],
425+ projectedAttrs : Seq [Attribute ]): ProjectingInternalRow = {
426+
427+ val projectedOrdinals = projectedAttrs.map(attr => outputAttrs.indexWhere(_.name == attr.name))
428+
429+ val structFields = projectedAttrs.zip(projectedOrdinals).map { case (attr, ordinal) =>
430+ // output attr is nullable if at least one action may produce null for that attr
431+ // but row ID and metadata attrs are projected only in update/delete actions and
432+ // row attrs are projected only in insert/update actions
433+ // that's why the projection schema must rely only on relevant action outputs
434+ // instead of blindly inheriting the output attr nullability
435+ val nullable = outputs.exists(output => output(ordinal).nullable)
436+ StructField (attr.name, attr.dataType, nullable, attr.metadata)
437+ }
438+ val schema = StructType (structFields)
439+
440+ ProjectingInternalRow (schema, projectedOrdinals)
441+ }
387442}
0 commit comments