Skip to content

Commit 08cef63

Browse files
committed
Support aliases in CUBE/ROLLUP/GROUPING SETS
1 parent 0ef16bd commit 08cef63

File tree

4 files changed

+74
-10
lines changed

4 files changed

+74
-10
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,18 +1003,30 @@ class Analyzer(
10031003
*/
10041004
object ResolveAggAliasInGroupBy extends Rule[LogicalPlan] {
10051005

1006+
// This is a strict check though, we put this to apply the rule only in alias expressions
1007+
private def notResolvableByChild(attrName: String, child: LogicalPlan): Boolean =
1008+
!child.output.exists(a => resolver(a.name, attrName))
1009+
10061010
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
10071011
case agg @ Aggregate(groups, aggs, child)
10081012
if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) &&
1009-
groups.exists(_.isInstanceOf[UnresolvedAttribute]) =>
1010-
// This is a strict check though, we put this to apply the rule only in alias expressions
1011-
def notResolvableByChild(attrName: String): Boolean =
1012-
!child.output.exists(a => resolver(a.name, attrName))
1013-
agg.copy(groupingExpressions = groups.map {
1014-
case u: UnresolvedAttribute if notResolvableByChild(u.name) =>
1013+
groups.exists(!_.resolved) =>
1014+
agg.copy(groupingExpressions = groups.map { _.transform {
1015+
case u: UnresolvedAttribute if notResolvableByChild(u.name, child) =>
1016+
aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u)
1017+
}
1018+
})
1019+
1020+
case gs @ GroupingSets(selectedGroups, groups, child, aggs)
1021+
if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) &&
1022+
(selectedGroups :+ groups).exists(_.exists(_.isInstanceOf[UnresolvedAttribute])) =>
1023+
def mayResolveAttrByAggregateExprs(exprs: Seq[Expression]): Seq[Expression] = exprs.map {
1024+
case u: UnresolvedAttribute if notResolvableByChild(u.name, child) =>
10151025
aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u)
10161026
case e => e
1017-
})
1027+
}
1028+
gs.copy(selectedGroupByExprs = selectedGroups.map(mayResolveAttrByAggregateExprs),
1029+
groupByExprs = mayResolveAttrByAggregateExprs(groups))
10181030
}
10191031
}
10201032

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ case class Expand(
704704
* We will transform GROUPING SETS into logical plan Aggregate(.., Expand) in Analyzer
705705
*
706706
* @param selectedGroupByExprs A sequence of selected GroupBy expressions, all exprs should
707-
* exists in groupByExprs.
707+
* exist in groupByExprs.
708708
* @param groupByExprs The Group By expressions candidates.
709709
* @param child Child operator
710710
* @param aggregations The Aggregation expressions, those non selected group by expressions

sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,9 @@ SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY CUBE(co
5454
ORDER BY GROUPING(course), GROUPING(year), course, year;
5555
SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING(course);
5656
SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING_ID(course);
57-
SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id;
57+
SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id;
58+
59+
-- Aliases in SELECT could be used in ROLLUP/CUBE/GROUPING SETS
60+
SELECT a + b AS k1, b AS k2, SUM(a - b) FROM testData GROUP BY CUBE(k1, k2);
61+
SELECT a + b AS k, b, SUM(a - b) FROM testData GROUP BY ROLLUP(k, b);
62+
SELECT a + b, b AS k, SUM(a - b) FROM testData GROUP BY a + b, k GROUPING SETS(k)

sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 26
2+
-- Number of queries: 29
33

44

55
-- !query 0
@@ -328,3 +328,50 @@ struct<>
328328
-- !query 25 output
329329
org.apache.spark.sql.AnalysisException
330330
grouping__id is deprecated; use grouping_id() instead;
331+
332+
333+
-- !query 26
334+
SELECT a + b AS k1, b AS k2, SUM(a - b) FROM testData GROUP BY CUBE(k1, k2)
335+
-- !query 26 schema
336+
struct<k1:int,k2:int,sum((a - b)):bigint>
337+
-- !query 26 output
338+
2 1 0
339+
2 NULL 0
340+
3 1 1
341+
3 2 -1
342+
3 NULL 0
343+
4 1 2
344+
4 2 0
345+
4 NULL 2
346+
5 2 1
347+
5 NULL 1
348+
NULL 1 3
349+
NULL 2 0
350+
NULL NULL 3
351+
352+
353+
-- !query 27
354+
SELECT a + b AS k, b, SUM(a - b) FROM testData GROUP BY ROLLUP(k, b)
355+
-- !query 27 schema
356+
struct<k:int,b:int,sum((a - b)):bigint>
357+
-- !query 27 output
358+
2 1 0
359+
2 NULL 0
360+
3 1 1
361+
3 2 -1
362+
3 NULL 0
363+
4 1 2
364+
4 2 0
365+
4 NULL 2
366+
5 2 1
367+
5 NULL 1
368+
NULL NULL 3
369+
370+
371+
-- !query 28
372+
SELECT a + b, b AS k, SUM(a - b) FROM testData GROUP BY a + b, k GROUPING SETS(k)
373+
-- !query 28 schema
374+
struct<(a + b):int,k:int,sum((a - b)):bigint>
375+
-- !query 28 output
376+
NULL 1 3
377+
NULL 2 0

0 commit comments

Comments
 (0)