Skip to content

Commit e2ce0ca

Browse files
JoshRosenhvanhovell
authored andcommitted
[SPARK-17618] Fix invalid comparisons between UnsafeRow and other row formats
## What changes were proposed in this pull request? This patch addresses a correctness bug in Spark 1.6.x in where `coalesce()` declares that it can process `UnsafeRows` but mis-declares that it always outputs safe rows. If UnsafeRow and other Row types are compared for equality then we will get spurious `false` comparisons, leading to wrong answers in operators which perform whole-row comparison (such as `distinct()` or `except()`). An example of a query impacted by this bug is given in the [JIRA ticket](https://issues.apache.org/jira/browse/SPARK-17618). The problem is that the validity of our row format conversion rules depends on operators which handle `unsafeRows` (signalled by overriding `canProcessUnsafeRows`) correctly reporting their output row format (which is done by overriding `outputsUnsafeRows`). In apache#9024, we overrode `canProcessUnsafeRows` but forgot to override `outputsUnsafeRows`, leading to the incorrect `equals()` comparison. Our interface design is flawed because correctness depends on operators correctly overriding multiple methods this problem could have been prevented by a design which coupled row format methods / metadata into a single method / class so that all three methods had to be overridden at the same time. This patch addresses this issue by adding missing `outputsUnsafeRows` overrides. In order to ensure that bugs in this logic are uncovered sooner, I have modified `UnsafeRow.equals()` to throw an `IllegalArgumentException` if it is called with an object that is not an `UnsafeRow`. ## How was this patch tested? I believe that the stronger misuse-checking in `UnsafeRow.equals()` is sufficient to detect and prevent this class of bug. Author: Josh Rosen <[email protected]> Closes apache#15185 from JoshRosen/SPARK-17618.
1 parent 7aded55 commit e2ce0ca

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import java.util.HashSet;
3131
import java.util.Set;
3232

33+
import org.apache.spark.sql.catalyst.InternalRow;
3334
import org.apache.spark.sql.types.ArrayType;
3435
import org.apache.spark.sql.types.BinaryType;
3536
import org.apache.spark.sql.types.BooleanType;
@@ -610,8 +611,12 @@ public boolean equals(Object other) {
610611
return (sizeInBytes == o.sizeInBytes) &&
611612
ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset,
612613
sizeInBytes);
614+
} else if (other == null || !(other instanceof InternalRow)) {
615+
return false;
616+
} else {
617+
throw new IllegalArgumentException(
618+
"Cannot compare UnsafeRow to " + other.getClass().getName());
613619
}
614-
return false;
615620
}
616621

617622
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ case class Window(
9696
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
9797

9898
override def canProcessUnsafeRows: Boolean = true
99+
override def outputsUnsafeRows: Boolean = false
99100

100101
/**
101102
* Create a bound ordering object for a given frame type and offset. A bound ordering object is

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ case class Coalesce(numPartitions: Int, child: SparkPlan) extends UnaryNode {
251251
}
252252

253253
override def canProcessUnsafeRows: Boolean = true
254+
override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows
254255
}
255256

256257
/**
@@ -319,17 +320,19 @@ case class AppendColumns[T, U](
319320
// We are using an unsafe combiner.
320321
override def canProcessSafeRows: Boolean = false
321322
override def canProcessUnsafeRows: Boolean = true
323+
override def outputsUnsafeRows: Boolean = true
322324

323325
override def output: Seq[Attribute] = child.output ++ newColumns
324326

325327
override protected def doExecute(): RDD[InternalRow] = {
326328
child.execute().mapPartitionsInternal { iter =>
327329
val tBoundEncoder = tEncoder.bind(child.output)
328330
val combiner = GenerateUnsafeRowJoiner.create(tEncoder.schema, uEncoder.schema)
329-
iter.map { row =>
331+
val unsafeRows: Iterator[UnsafeRow] = iter.map { row =>
330332
val newColumns = uEncoder.toRow(func(tBoundEncoder.fromRow(row)))
331-
combiner.join(row.asInstanceOf[UnsafeRow], newColumns.asInstanceOf[UnsafeRow]): InternalRow
333+
combiner.join(row.asInstanceOf[UnsafeRow], newColumns.asInstanceOf[UnsafeRow])
332334
}
335+
unsafeRows
333336
}
334337
}
335338
}

0 commit comments

Comments
 (0)