Skip to content

Commit 236bb18

Browse files
committed
[SPARK-34003][SQL][FOLLOWUP] Fix Rule conflicts between PaddingAndLengthCheckForCharVarchar and ResolveAggregateFunctions
1 parent ff49317 commit 236bb18

File tree

2 files changed

+41
-21
lines changed

2 files changed

+41
-21
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2381,13 +2381,16 @@ class Analyzer(override val catalogManager: CatalogManager)
23812381
val unresolvedSortOrders = sortOrder.filter { s =>
23822382
!s.resolved || !s.references.subsetOf(aggregate.outputSet) || containsAggregate(s)
23832383
}
2384-
val aliasedOrdering =
2385-
unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")())
2386-
val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering)
2384+
val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")())
2385+
2386+
val aggregateWithExtraOrdering = aggregate.copy(
2387+
aggregateExpressions = aggregate.aggregateExpressions ++ aliasedOrdering)
2388+
23872389
val resolvedAggregate: Aggregate =
2388-
executeSameContext(aggregatedOrdering).asInstanceOf[Aggregate]
2389-
val resolvedAliasedOrdering: Seq[Alias] =
2390-
resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]]
2390+
executeSameContext(aggregateWithExtraOrdering).asInstanceOf[Aggregate]
2391+
2392+
val (reResolvedAggExprs, resolvedAliasedOrdering) =
2393+
resolvedAggregate.aggregateExpressions.splitAt(aggregate.aggregateExpressions.length)
23912394

23922395
// If we pass the analysis check, then the ordering expressions should only reference to
23932396
// aggregate expressions or grouping expressions, and it's safe to push them down to
@@ -2401,24 +2404,25 @@ class Analyzer(override val catalogManager: CatalogManager)
24012404
// expression instead.
24022405
val needsPushDown = ArrayBuffer.empty[NamedExpression]
24032406
val orderToAlias = unresolvedSortOrders.zip(aliasedOrdering)
2404-
val evaluatedOrderings = resolvedAliasedOrdering.zip(orderToAlias).map {
2405-
case (evaluated, (order, aliasOrder)) =>
2406-
val index = originalAggExprs.indexWhere {
2407-
case Alias(child, _) => child semanticEquals evaluated.child
2408-
case other => other semanticEquals evaluated.child
2409-
}
2407+
val evaluatedOrderings =
2408+
resolvedAliasedOrdering.asInstanceOf[Seq[Alias]].zip(orderToAlias).map {
2409+
case (evaluated, (order, aliasOrder)) =>
2410+
val index = reResolvedAggExprs.indexWhere {
2411+
case Alias(child, _) => child semanticEquals evaluated.child
2412+
case other => other semanticEquals evaluated.child
2413+
}
24102414

2411-
if (index == -1) {
2412-
if (CharVarcharUtils.getRawType(evaluated.metadata).nonEmpty) {
2413-
needsPushDown += aliasOrder
2414-
order.copy(child = aliasOrder)
2415+
if (index == -1) {
2416+
if (hasCharVarchar(evaluated)) {
2417+
needsPushDown += aliasOrder
2418+
order.copy(child = aliasOrder)
2419+
} else {
2420+
needsPushDown += evaluated
2421+
order.copy(child = evaluated.toAttribute)
2422+
}
24152423
} else {
2416-
needsPushDown += evaluated
2417-
order.copy(child = evaluated.toAttribute)
2424+
order.copy(child = originalAggExprs(index).toAttribute)
24182425
}
2419-
} else {
2420-
order.copy(child = originalAggExprs(index).toAttribute)
2421-
}
24222426
}
24232427

24242428
val sortOrdersMap = unresolvedSortOrders
@@ -2443,6 +2447,13 @@ class Analyzer(override val catalogManager: CatalogManager)
24432447
}
24442448
}
24452449

2450+
def hasCharVarchar(expr: Alias): Boolean = {
2451+
expr.find {
2452+
case ne: NamedExpression => CharVarcharUtils.getRawType(ne.metadata).nonEmpty
2453+
case _ => false
2454+
}.nonEmpty
2455+
}
2456+
24462457
def containsAggregate(condition: Expression): Boolean = {
24472458
condition.find(_.isInstanceOf[AggregateExpression]).isDefined
24482459
}

sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,15 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils {
474474
checkAnswer(sql("SELECT v, sum(i) FROM t GROUP BY v ORDER BY v"), Row("c", 1))
475475
}
476476
}
477+
478+
test("SPARK-34003: fix char/varchar fails w/ order by functions") {
479+
withTable("t") {
480+
sql(s"CREATE TABLE t(v VARCHAR(3), i INT) USING $format")
481+
sql("INSERT INTO t VALUES ('c', 1)")
482+
checkAnswer(sql("SELECT substr(v, 1, 2), sum(i) FROM t GROUP BY v ORDER BY substr(v, 1, 2)"),
483+
Row("c", 1))
484+
}
485+
}
477486
}
478487

479488
// Some basic char/varchar tests which doesn't rely on table implementation.

0 commit comments

Comments
 (0)