Skip to content

Commit f155856

Browse files
author
Bogdan Raducanu
committed
ported from master
1 parent 82fcc13 commit f155856

File tree

3 files changed

+24
-4
lines changed

3 files changed

+24
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,6 @@ class CodegenContext {
555555
addNewFunction(compareFunc, funcCode)
556556
s"this.$compareFunc($c1, $c2)"
557557
case schema: StructType =>
558-
INPUT_ROW = "i"
559558
val comparisons = GenerateOrdering.genComparisons(this, schema)
560559
val compareFunc = freshName("compareStruct")
561560
val funcCode: String =
@@ -566,7 +565,6 @@ class CodegenContext {
566565
if (a instanceof UnsafeRow && b instanceof UnsafeRow && a.equals(b)) {
567566
return 0;
568567
}
569-
InternalRow i = null;
570568
$comparisons
571569
return 0;
572570
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,12 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
7373
*/
7474
def genComparisons(ctx: CodegenContext, ordering: Seq[SortOrder]): String = {
7575
val comparisons = ordering.map { order =>
76+
val oldCurrentVars = ctx.currentVars
77+
ctx.INPUT_ROW = "i"
78+
// to use INPUT_ROW we must make sure currentVars is null
79+
ctx.currentVars = null
7680
val eval = order.child.genCode(ctx)
81+
ctx.currentVars = oldCurrentVars
7782
val asc = order.isAscending
7883
val isNullA = ctx.freshName("isNullA")
7984
val primitiveA = ctx.freshName("primitiveA")
@@ -119,7 +124,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
119124
"""
120125
}
121126

122-
ctx.splitExpressions(
127+
val code = ctx.splitExpressions(
123128
expressions = comparisons,
124129
funcName = "compare",
125130
arguments = Seq(("InternalRow", "a"), ("InternalRow", "b")),
@@ -142,6 +147,12 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
142147
"""
143148
}.mkString
144149
})
150+
// make sure INPUT_ROW is declared even if splitExpressions
151+
// returns an inlined block
152+
s"""
153+
|InternalRow ${ctx.INPUT_ROW} = null;
154+
|$code
155+
""".stripMargin
145156
}
146157

147158
protected def create(ordering: Seq[SortOrder]): BaseOrdering = {
@@ -165,7 +176,6 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
165176
${ctx.declareAddedFunctions()}
166177

167178
public int compare(InternalRow a, InternalRow b) {
168-
InternalRow ${ctx.INPUT_ROW} = null; // Holds current row being evaluated.
169179
$comparisons
170180
return 0;
171181
}

sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,4 +113,16 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
113113
p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
114114
assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0)))
115115
}
116+
117+
test("SPARK-19512 codegen for comparing structs is incorrect") {
118+
// this would raise CompileException before the fix
119+
spark.range(10)
120+
.selectExpr("named_struct('a', id) as col1", "named_struct('a', id+2) as col2")
121+
.filter("col1 = col2").count()
122+
// this would raise java.lang.IndexOutOfBoundsException before the fix
123+
spark.range(10)
124+
.selectExpr("named_struct('a', id, 'b', id) as col1",
125+
"named_struct('a',id+2, 'b',id+2) as col2")
126+
.filter("col1 = col2").count()
127+
}
116128
}

0 commit comments

Comments
 (0)