Skip to content

Commit 2ed6e7b

Browse files
vicennialviirya
authored andcommitted
[SPARK-36677][SQL] NestedColumnAliasing should not push down aggregate functions into projections
### What changes were proposed in this pull request? This PR filters out `ExtractValues`s that contains any aggregation function in the `NestedColumnAliasing` rule to prevent cases where aggregations are pushed down into projections. ### Why are the changes needed? To handle a corner/missed case in `NestedColumnAliasing` that can cause users to encounter a runtime exception. Consider the following schema: ``` root |-- a: struct (nullable = true) | |-- c: struct (nullable = true) | | |-- e: string (nullable = true) | |-- d: integer (nullable = true) |-- b: string (nullable = true) ``` and the query: `SELECT MAX(a).c.e FROM (SELECT a, b FROM test_aggregates) GROUP BY b` Executing the query before this PR will result in the error: ``` java.lang.UnsupportedOperationException: Cannot generate code for expression: max(input[0, struct<c:struct<e:string>,d:int>, true]) at org.apache.spark.sql.errors.QueryExecutionErrors$.cannotGenerateCodeForExpressionError(QueryExecutionErrors.scala:83) at org.apache.spark.sql.catalyst.expressions.Unevaluable.doGenCode(Expression.scala:312) at org.apache.spark.sql.catalyst.expressions.Unevaluable.doGenCode$(Expression.scala:311) at org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression.doGenCode(interfaces.scala:99) ... ``` The optimised plan before this PR is: ``` 'Aggregate [b#1], [_extract_e#5 AS max(a).c.e#3] +- 'Project [max(a#0).c.e AS _extract_e#5, b#1] +- Relation default.test_aggregates[a#0,b#1] parquet ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? A new unit test in `NestedColumnAliasingSuite`. The test consists of the repro mentioned earlier. The produced optimized plan is checked for equivalency with a plan of the form: ``` Aggregate [b#452], [max(a#451).c.e AS max('a)[c][e]#456] +- LocalRelation <empty>, [a#451, b#452] ``` Closes #33921 from vicennial/spark-36677. Authored-by: Venkata Sai Akhil Gudesa <[email protected]> Signed-off-by: Liang-Chi Hsieh <[email protected]>
1 parent 5a0ae69 commit 2ed6e7b

File tree

2 files changed

+40
-2
lines changed

2 files changed

+40
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import scala.collection
2121
import scala.collection.mutable
2222

2323
import org.apache.spark.sql.catalyst.expressions._
24+
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
2425
import org.apache.spark.sql.catalyst.plans.logical._
2526
import org.apache.spark.sql.internal.SQLConf
2627
import org.apache.spark.sql.types._
@@ -258,6 +259,13 @@ object NestedColumnAliasing {
258259
.filter(!_.references.subsetOf(exclusiveAttrSet))
259260
.groupBy(_.references.head.canonicalized.asInstanceOf[Attribute])
260261
.flatMap { case (attr: Attribute, nestedFields: collection.Seq[ExtractValue]) =>
262+
263+
// Check if `ExtractValue` expressions contain any aggregate functions in their tree. Those
264+
// that do should not have an alias generated as it can lead to pushing the aggregate down
265+
// into a projection.
266+
def containsAggregateFunction(ev: ExtractValue): Boolean =
267+
ev.find(_.isInstanceOf[AggregateFunction]).isDefined
268+
261269
// Remove redundant [[ExtractValue]]s if they share the same parent nest field.
262270
// For example, when `a.b` and `a.b.c` are in project list, we only need to alias `a.b`.
263271
// Because `a.b` requires all of the inner fields of `b`, we cannot prune `a.b.c`.
@@ -268,15 +276,18 @@ object NestedColumnAliasing {
268276
val child = e.children.head
269277
nestedFields.forall(f => child.find(_.semanticEquals(f)).isEmpty)
270278
case _ => true
271-
}.distinct
279+
}
280+
.distinct
281+
// Discard [[ExtractValue]]s that contain aggregate functions.
282+
.filterNot(containsAggregateFunction)
272283

273284
// If all nested fields of `attr` are used, we don't need to introduce new aliases.
274285
// By default, the [[ColumnPruning]] rule uses `attr` already.
275286
// Note that we need to remove cosmetic variations first, so we only count a
276287
// nested field once.
277288
val numUsedNestedFields = dedupNestedFields.map(_.canonicalized).distinct
278289
.map { nestedField => totalFieldNum(nestedField.dataType) }.sum
279-
if (numUsedNestedFields < totalFieldNum(attr.dataType)) {
290+
if (dedupNestedFields.nonEmpty && numUsedNestedFields < totalFieldNum(attr.dataType)) {
280291
Some((attr, dedupNestedFields.toSeq))
281292
} else {
282293
None

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
2020
import scala.collection.mutable.ArrayBuffer
2121

2222
import org.apache.spark.sql.catalyst.SchemaPruningTest
23+
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
2324
import org.apache.spark.sql.catalyst.dsl.expressions._
2425
import org.apache.spark.sql.catalyst.dsl.plans._
2526
import org.apache.spark.sql.catalyst.expressions._
@@ -763,6 +764,32 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
763764
$"_extract_search_params.col2".as("col2")).analyze
764765
comparePlans(optimized, query)
765766
}
767+
768+
test("SPARK-36677: NestedColumnAliasing should not push down aggregate functions into " +
769+
"projections") {
770+
val nestedRelation = LocalRelation(
771+
'a.struct(
772+
'c.struct(
773+
'e.string),
774+
'd.string),
775+
'b.string)
776+
777+
val plan = nestedRelation
778+
.select($"a", $"b")
779+
.groupBy($"b")(max($"a").getField("c").getField("e"))
780+
.analyze
781+
782+
val optimized = Optimize.execute(plan)
783+
784+
// The plan should not contain aggregation functions inside the projection
785+
SimpleAnalyzer.checkAnalysis(optimized)
786+
787+
val expected = nestedRelation
788+
.groupBy($"b")(max($"a").getField("c").getField("e"))
789+
.analyze
790+
791+
comparePlans(optimized, expected)
792+
}
766793
}
767794

768795
object NestedColumnAliasingSuite {

0 commit comments

Comments
 (0)