Skip to content

Commit d5411d2

Browse files
authored
Spark 3.2: Fix nullability in merge-on-read projections (#5917)
1 parent 3170e6e commit d5411d2

File tree

4 files changed

+95
-55
lines changed

4 files changed

+95
-55
lines changed

spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package org.apache.spark.sql.catalyst.analysis
2121

2222
import org.apache.spark.sql.AnalysisException
23+
import org.apache.spark.sql.catalyst.ProjectingInternalRow
2324
import org.apache.spark.sql.catalyst.expressions.Alias
2425
import org.apache.spark.sql.catalyst.expressions.Attribute
2526
import org.apache.spark.sql.catalyst.expressions.AttributeReference
@@ -54,6 +55,7 @@ import org.apache.spark.sql.catalyst.plans.logical.ReplaceData
5455
import org.apache.spark.sql.catalyst.plans.logical.UpdateAction
5556
import org.apache.spark.sql.catalyst.plans.logical.WriteDelta
5657
import org.apache.spark.sql.catalyst.util.RowDeltaUtils._
58+
import org.apache.spark.sql.catalyst.util.WriteDeltaProjections
5759
import org.apache.spark.sql.connector.expressions.FieldReference
5860
import org.apache.spark.sql.connector.expressions.NamedReference
5961
import org.apache.spark.sql.connector.iceberg.catalog.SupportsRowLevelOperations
@@ -62,6 +64,8 @@ import org.apache.spark.sql.connector.iceberg.write.SupportsDelta
6264
import org.apache.spark.sql.connector.write.RowLevelOperationTable
6365
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
6466
import org.apache.spark.sql.types.IntegerType
67+
import org.apache.spark.sql.types.StructField
68+
import org.apache.spark.sql.types.StructType
6569

6670
/**
6771
* Assigns a rewrite plan for v2 tables that support rewriting data to handle MERGE statements.
@@ -297,7 +301,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand {
297301

298302
// build a plan to write the row delta to the table
299303
val writeRelation = relation.copy(table = operationTable)
300-
val projections = buildWriteDeltaProjections(mergeRows, rowAttrs, rowIdAttrs, metadataAttrs)
304+
val projections = buildMergeDeltaProjections(mergeRows, rowAttrs, rowIdAttrs, metadataAttrs)
301305
WriteDelta(writeRelation, mergeRows, relation, projections)
302306
}
303307

@@ -384,4 +388,55 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand {
384388
private def resolveAttrRef(ref: NamedReference, plan: LogicalPlan): AttributeReference = {
385389
ExtendedV2ExpressionUtils.resolveRef[AttributeReference](ref, plan)
386390
}
391+
392+
private def buildMergeDeltaProjections(
393+
mergeRows: MergeRows,
394+
rowAttrs: Seq[Attribute],
395+
rowIdAttrs: Seq[Attribute],
396+
metadataAttrs: Seq[Attribute]): WriteDeltaProjections = {
397+
398+
val outputAttrs = mergeRows.output
399+
400+
val outputs = mergeRows.matchedOutputs ++ mergeRows.notMatchedOutputs
401+
val insertAndUpdateOutputs = outputs.filterNot(_.head == Literal(DELETE_OPERATION))
402+
val updateAndDeleteOutputs = outputs.filterNot(_.head == Literal(INSERT_OPERATION))
403+
404+
val rowProjection = if (rowAttrs.nonEmpty) {
405+
Some(newLazyProjection(insertAndUpdateOutputs, outputAttrs, rowAttrs))
406+
} else {
407+
None
408+
}
409+
410+
val rowIdProjection = newLazyProjection(updateAndDeleteOutputs, outputAttrs, rowIdAttrs)
411+
412+
val metadataProjection = if (metadataAttrs.nonEmpty) {
413+
Some(newLazyProjection(updateAndDeleteOutputs, outputAttrs, metadataAttrs))
414+
} else {
415+
None
416+
}
417+
418+
WriteDeltaProjections(rowProjection, rowIdProjection, metadataProjection)
419+
}
420+
421+
// the projection is done by name, ignoring expr IDs
422+
private def newLazyProjection(
423+
outputs: Seq[Seq[Expression]],
424+
outputAttrs: Seq[Attribute],
425+
projectedAttrs: Seq[Attribute]): ProjectingInternalRow = {
426+
427+
val projectedOrdinals = projectedAttrs.map(attr => outputAttrs.indexWhere(_.name == attr.name))
428+
429+
val structFields = projectedAttrs.zip(projectedOrdinals).map { case (attr, ordinal) =>
430+
// output attr is nullable if at least one action may produce null for that attr
431+
// but row ID and metadata attrs are projected only in update/delete actions and
432+
// row attrs are projected only in insert/update actions
433+
// that's why the projection schema must rely only on relevant action outputs
434+
// instead of blindly inheriting the output attr nullability
435+
val nullable = outputs.exists(output => output(ordinal).nullable)
436+
StructField(attr.name, attr.dataType, nullable, attr.metadata)
437+
}
438+
val schema = StructType(structFields)
439+
440+
ProjectingInternalRow(schema, projectedOrdinals)
441+
}
387442
}

spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ import org.apache.spark.sql.connector.write.RowLevelOperationTable
3737
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
3838
import org.apache.spark.sql.types.StructType
3939
import org.apache.spark.sql.util.CaseInsensitiveStringMap
40-
import scala.collection.compat.immutable.ArraySeq
4140
import scala.collection.mutable
4241

4342
trait RewriteRowLevelCommand extends Rule[LogicalPlan] {
@@ -79,20 +78,15 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] {
7978
metadataAttrs: Seq[Attribute]): WriteDeltaProjections = {
8079

8180
val rowProjection = if (rowAttrs.nonEmpty) {
82-
Some(newLazyProjection(plan, rowAttrs, usePlanTypes = true))
81+
Some(newLazyProjection(plan, rowAttrs))
8382
} else {
8483
None
8584
}
8685

87-
// in MERGE, the plan may contain both delete and insert records that may affect
88-
// the nullability of metadata columns (e.g. metadata columns for new records are always null)
89-
// since metadata columns are never passed with new records to insert,
90-
// use the actual metadata column types instead of the ones present in the plan
91-
92-
val rowIdProjection = newLazyProjection(plan, rowIdAttrs, usePlanTypes = false)
86+
val rowIdProjection = newLazyProjection(plan, rowIdAttrs)
9387

9488
val metadataProjection = if (metadataAttrs.nonEmpty) {
95-
Some(newLazyProjection(plan, metadataAttrs, usePlanTypes = false))
89+
Some(newLazyProjection(plan, metadataAttrs))
9690
} else {
9791
None
9892
}
@@ -103,17 +97,11 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] {
10397
// the projection is done by name, ignoring expr IDs
10498
private def newLazyProjection(
10599
plan: LogicalPlan,
106-
attrs: Seq[Attribute],
107-
usePlanTypes: Boolean): ProjectingInternalRow = {
100+
projectedAttrs: Seq[Attribute]): ProjectingInternalRow = {
108101

109-
val colOrdinals = attrs.map(attr => plan.output.indexWhere(_.name == attr.name))
110-
val schema = if (usePlanTypes) {
111-
val planAttrs = colOrdinals.map(plan.output(_))
112-
StructType.fromAttributes(planAttrs)
113-
} else {
114-
StructType.fromAttributes(attrs)
115-
}
116-
ProjectingInternalRow(schema, colOrdinals)
102+
val projectedOrdinals = projectedAttrs.map(attr => plan.output.indexWhere(_.name == attr.name))
103+
val schema = StructType.fromAttributes(projectedOrdinals.map(plan.output(_)))
104+
ProjectingInternalRow(schema, projectedOrdinals)
117105
}
118106

119107
protected def resolveRequiredMetadataAttrs(

spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteMerge.java

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@
2020

2121
import java.util.Map;
2222
import org.apache.iceberg.TableProperties;
23-
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
2423
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
25-
import org.junit.Test;
2624

2725
public class TestCopyOnWriteMerge extends TestMerge {
2826

@@ -40,37 +38,4 @@ public TestCopyOnWriteMerge(
4038
protected Map<String, String> extraTableProperties() {
4139
return ImmutableMap.of(TableProperties.MERGE_MODE, "copy-on-write");
4240
}
43-
44-
@Test
45-
public void testMergeWithTableWithNonNullableColumn() {
46-
createAndInitTable(
47-
"id INT NOT NULL, dep STRING",
48-
"{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
49-
50-
createOrReplaceView(
51-
"source",
52-
"id INT NOT NULL, dep STRING",
53-
"{ \"id\": 2, \"dep\": \"emp-id-2\" }\n"
54-
+ "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n"
55-
+ "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
56-
57-
sql(
58-
"MERGE INTO %s AS t USING source AS s "
59-
+ "ON t.id == s.id "
60-
+ "WHEN MATCHED AND t.id = 1 THEN "
61-
+ " UPDATE SET * "
62-
+ "WHEN MATCHED AND t.id = 6 THEN "
63-
+ " DELETE "
64-
+ "WHEN NOT MATCHED AND s.id = 2 THEN "
65-
+ " INSERT *",
66-
tableName);
67-
68-
ImmutableList<Object[]> expectedRows =
69-
ImmutableList.of(
70-
row(1, "emp-id-1"), // updated
71-
row(2, "emp-id-2") // new
72-
);
73-
assertEquals(
74-
"Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName));
75-
}
7641
}

spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1687,6 +1687,38 @@ public void testMergeShouldResolveWhenThereAreNoUnresolvedExpressionsOrColumns()
16871687
"Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName));
16881688
}
16891689

1690+
@Test
1691+
public void testMergeWithTableWithNonNullableColumn() {
1692+
createAndInitTable(
1693+
"id INT NOT NULL, dep STRING",
1694+
"{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
1695+
1696+
createOrReplaceView(
1697+
"source",
1698+
"id INT NOT NULL, dep STRING",
1699+
"{ \"id\": 2, \"dep\": \"emp-id-2\" }\n"
1700+
+ "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n"
1701+
+ "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
1702+
1703+
sql(
1704+
"MERGE INTO %s AS t USING source AS s "
1705+
+ "ON t.id == s.id "
1706+
+ "WHEN MATCHED AND t.id = 1 THEN "
1707+
+ " UPDATE SET * "
1708+
+ "WHEN MATCHED AND t.id = 6 THEN "
1709+
+ " DELETE "
1710+
+ "WHEN NOT MATCHED AND s.id = 2 THEN "
1711+
+ " INSERT *",
1712+
tableName);
1713+
1714+
ImmutableList<Object[]> expectedRows =
1715+
ImmutableList.of(
1716+
row(1, "emp-id-1"), // updated
1717+
row(2, "emp-id-2")); // new
1718+
assertEquals(
1719+
"Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName));
1720+
}
1721+
16901722
@Test
16911723
public void testMergeWithNonExistingColumns() {
16921724
createAndInitTable("id INT, c STRUCT<n1:INT,n2:STRUCT<dn1:INT,dn2:INT>>");

0 commit comments

Comments
 (0)