Skip to content

Commit 6341b06

Browse files
peter-tothulysses-you
authored andcommitted
[SPARK-40086][SPARK-42049][SQL] Improve AliasAwareOutputPartitioning and AliasAwareQueryOutputOrdering to take all aliases into account
### What changes were proposed in this pull request? Currently `AliasAwareOutputPartitioning` and `AliasAwareQueryOutputOrdering` takes only the last alias by aliased expressions into account. We could avoid some extra shuffles and sorts with better alias handling. ### Why are the changes needed? Performance improvement and this also fix the issue in #39475. ### Does this PR introduce _any_ user-facing change? Yes, this PR fixes the issue in #39475. ### How was this patch tested? Added new UT. Closes #37525 from peter-toth/SPARK-40086-fix-aliasawareoutputexpression. Lead-authored-by: Peter Toth <[email protected]> Co-authored-by: ulysses-you <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 720fe2f commit 6341b06

File tree

19 files changed

+464
-119
lines changed

19 files changed

+464
-119
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3024,3 +3024,27 @@ case class SplitPart (
30243024
partNum = newChildren.apply(2))
30253025
}
30263026
}
3027+
3028+
/**
3029+
* A internal function that converts the empty string to null for partition values.
3030+
* This function should be only used in V1Writes.
3031+
*/
3032+
case class Empty2Null(child: Expression) extends UnaryExpression with String2StringExpression {
3033+
override def convert(v: UTF8String): UTF8String = if (v.numBytes() == 0) null else v
3034+
3035+
override def nullable: Boolean = true
3036+
3037+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
3038+
nullSafeCodeGen(ctx, ev, c => {
3039+
s"""if ($c.numBytes() == 0) {
3040+
| ${ev.isNull} = true;
3041+
| ${ev.value} = null;
3042+
|} else {
3043+
| ${ev.value} = $c;
3044+
|}""".stripMargin
3045+
})
3046+
}
3047+
3048+
override protected def withNewChildInternal(newChild: Expression): Empty2Null =
3049+
copy(child = newChild)
3050+
}
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.plans
19+
20+
import scala.collection.mutable
21+
22+
import org.apache.spark.sql.catalyst.SQLConfHelper
23+
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Empty2Null, Expression, NamedExpression, SortOrder}
24+
import org.apache.spark.sql.internal.SQLConf
25+
26+
/**
27+
* A trait that provides functionality to handle aliases in the `outputExpressions`.
28+
*/
29+
trait AliasAwareOutputExpression extends SQLConfHelper {
30+
protected val aliasCandidateLimit = conf.getConf(SQLConf.EXPRESSION_PROJECTION_CANDIDATE_LIMIT)
31+
protected def outputExpressions: Seq[NamedExpression]
32+
/**
33+
* This method can be used to strip expression which does not affect the result, for example:
34+
* strip the expression which is ordering agnostic for output ordering.
35+
*/
36+
protected def strip(expr: Expression): Expression = expr
37+
38+
// Build an `Expression` -> `Attribute` alias map.
39+
// There can be multiple alias defined for the same expressions but it doesn't make sense to store
40+
// more than `aliasCandidateLimit` attributes for an expression. In those cases the old logic
41+
// handled only the last alias so we need to make sure that we give precedence to that.
42+
// If the `outputExpressions` contain simple attributes we need to add those too to the map.
43+
private lazy val aliasMap = {
44+
val aliases = mutable.Map[Expression, mutable.ArrayBuffer[Attribute]]()
45+
outputExpressions.reverse.foreach {
46+
case a @ Alias(child, _) =>
47+
val buffer = aliases.getOrElseUpdate(strip(child).canonicalized, mutable.ArrayBuffer.empty)
48+
if (buffer.size < aliasCandidateLimit) {
49+
buffer += a.toAttribute
50+
}
51+
case _ =>
52+
}
53+
outputExpressions.foreach {
54+
case a: Attribute if aliases.contains(a.canonicalized) =>
55+
val buffer = aliases(a.canonicalized)
56+
if (buffer.size < aliasCandidateLimit) {
57+
buffer += a
58+
}
59+
case _ =>
60+
}
61+
aliases
62+
}
63+
64+
protected def hasAlias: Boolean = aliasMap.nonEmpty
65+
66+
/**
67+
* Return a stream of expressions in which the original expression is projected with `aliasMap`.
68+
*/
69+
protected def projectExpression(expr: Expression): Stream[Expression] = {
70+
val outputSet = AttributeSet(outputExpressions.map(_.toAttribute))
71+
expr.multiTransformDown {
72+
// Mapping with aliases
73+
case e: Expression if aliasMap.contains(e.canonicalized) =>
74+
aliasMap(e.canonicalized).toSeq ++ (if (e.containsChild.nonEmpty) Seq(e) else Seq.empty)
75+
76+
// Prune if we encounter an attribute that we can't map and it is not in output set.
77+
// This prune will go up to the closest `multiTransformDown()` call and returns `Stream.empty`
78+
// there.
79+
case a: Attribute if !outputSet.contains(a) => Seq.empty
80+
}
81+
}
82+
}
83+
84+
/**
85+
* A trait that handles aliases in the `orderingExpressions` to produce `outputOrdering` that
86+
* satisfies ordering requirements.
87+
*/
88+
trait AliasAwareQueryOutputOrdering[T <: QueryPlan[T]]
89+
extends AliasAwareOutputExpression { self: QueryPlan[T] =>
90+
protected def orderingExpressions: Seq[SortOrder]
91+
92+
override protected def strip(expr: Expression): Expression = expr match {
93+
case e: Empty2Null => strip(e.child)
94+
case _ => expr
95+
}
96+
97+
override final def outputOrdering: Seq[SortOrder] = {
98+
if (hasAlias) {
99+
// Take the first `SortOrder`s only until they can be projected.
100+
// E.g. we have child ordering `Seq(SortOrder(a), SortOrder(b))` then
101+
// if only `a AS x` can be projected then we can return Seq(SortOrder(x))`
102+
// but if only `b AS y` can be projected we can't return `Seq(SortOrder(y))`.
103+
orderingExpressions.iterator.map { sortOrder =>
104+
val orderingSet = mutable.Set.empty[Expression]
105+
val sameOrderings = sortOrder.children.toStream
106+
.flatMap(projectExpression)
107+
.filter(e => orderingSet.add(e.canonicalized))
108+
.take(aliasCandidateLimit)
109+
if (sameOrderings.nonEmpty) {
110+
Some(sortOrder.copy(child = sameOrderings.head,
111+
sameOrderExpressions = sameOrderings.tail))
112+
} else {
113+
None
114+
}
115+
}.takeWhile(_.isDefined).flatten.toSeq
116+
} else {
117+
orderingExpressions
118+
}
119+
}
120+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
5353
@transient
5454
lazy val outputSet: AttributeSet = AttributeSet(output)
5555

56+
/**
57+
* Returns the output ordering that this plan generates, although the semantics differ in logical
58+
* and physical plans. In the logical plan it means global ordering of the data while in physical
59+
* it means ordering in each partition.
60+
*/
61+
def outputOrdering: Seq[SortOrder] = Nil
62+
5663
// Override `treePatternBits` to propagate bits for its expressions.
5764
override lazy val treePatternBits: BitSet = {
5865
val bits: BitSet = getDefaultTreePatternBits

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical
2020
import org.apache.spark.internal.Logging
2121
import org.apache.spark.sql.catalyst.analysis._
2222
import org.apache.spark.sql.catalyst.expressions._
23-
import org.apache.spark.sql.catalyst.plans.QueryPlan
23+
import org.apache.spark.sql.catalyst.plans.{AliasAwareQueryOutputOrdering, QueryPlan}
2424
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.LogicalPlanStats
2525
import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, UnaryLike}
2626
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
@@ -141,11 +141,6 @@ abstract class LogicalPlan
141141
*/
142142
def refresh(): Unit = children.foreach(_.refresh())
143143

144-
/**
145-
* Returns the output ordering that this plan generates.
146-
*/
147-
def outputOrdering: Seq[SortOrder] = Nil
148-
149144
/**
150145
* Returns true iff `other`'s output is semantically the same, i.e.:
151146
* - it contains the same number of `Attribute`s;
@@ -205,8 +200,10 @@ trait UnaryNode extends LogicalPlan with UnaryLike[LogicalPlan] {
205200
*/
206201
trait BinaryNode extends LogicalPlan with BinaryLike[LogicalPlan]
207202

208-
abstract class OrderPreservingUnaryNode extends UnaryNode {
209-
override final def outputOrdering: Seq[SortOrder] = child.outputOrdering
203+
trait OrderPreservingUnaryNode extends UnaryNode
204+
with AliasAwareQueryOutputOrdering[LogicalPlan] {
205+
override protected def outputExpressions: Seq[NamedExpression] = child.output
206+
override protected def orderingExpressions: Seq[SortOrder] = child.outputOrdering
210207
}
211208

212209
object LogicalPlanIntegrity {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ object Subquery {
6969
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan)
7070
extends OrderPreservingUnaryNode {
7171
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
72+
override protected def outputExpressions: Seq[NamedExpression] = projectList
7273
override def maxRows: Option[Long] = child.maxRows
7374
override def maxRowsPerPartition: Option[Long] = child.maxRowsPerPartition
7475

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,16 @@ object SQLConf {
443443
.booleanConf
444444
.createWithDefault(true)
445445

446+
val EXPRESSION_PROJECTION_CANDIDATE_LIMIT =
447+
buildConf("spark.sql.optimizer.expressionProjectionCandidateLimit")
448+
.doc("The maximum number of the candidate of output expressions whose alias are replaced." +
449+
" It can preserve the output partitioning and ordering." +
450+
" Negative value means disable this optimization.")
451+
.internal()
452+
.version("3.4.0")
453+
.intConf
454+
.createWithDefault(100)
455+
446456
val COMPRESS_CACHED = buildConf("spark.sql.inMemoryColumnarStorage.compressed")
447457
.doc("When set to true Spark SQL will automatically select a compression codec for each " +
448458
"column based on statistics of the data.")

sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala

Lines changed: 26 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -16,52 +16,42 @@
1616
*/
1717
package org.apache.spark.sql.execution
1818

19-
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression, SortOrder}
20-
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, PartitioningCollection, UnknownPartitioning}
19+
import scala.collection.mutable
2120

22-
/**
23-
* A trait that provides functionality to handle aliases in the `outputExpressions`.
24-
*/
25-
trait AliasAwareOutputExpression extends UnaryExecNode {
26-
protected def outputExpressions: Seq[NamedExpression]
27-
28-
private lazy val aliasMap = outputExpressions.collect {
29-
case a @ Alias(child, _) => child.canonicalized -> a.toAttribute
30-
}.toMap
31-
32-
protected def hasAlias: Boolean = aliasMap.nonEmpty
33-
34-
protected def normalizeExpression(exp: Expression): Expression = {
35-
exp.transformDown {
36-
case e: Expression => aliasMap.getOrElse(e.canonicalized, e)
37-
}
38-
}
39-
}
21+
import org.apache.spark.sql.catalyst.expressions.Expression
22+
import org.apache.spark.sql.catalyst.plans.{AliasAwareOutputExpression, AliasAwareQueryOutputOrdering}
23+
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection, UnknownPartitioning}
4024

4125
/**
4226
* A trait that handles aliases in the `outputExpressions` to produce `outputPartitioning` that
4327
* satisfies distribution requirements.
4428
*/
45-
trait AliasAwareOutputPartitioning extends AliasAwareOutputExpression {
29+
trait PartitioningPreservingUnaryExecNode extends UnaryExecNode
30+
with AliasAwareOutputExpression {
4631
final override def outputPartitioning: Partitioning = {
47-
val normalizedOutputPartitioning = if (hasAlias) {
48-
child.outputPartitioning match {
32+
if (hasAlias) {
33+
flattenPartitioning(child.outputPartitioning).flatMap {
4934
case e: Expression =>
50-
normalizeExpression(e).asInstanceOf[Partitioning]
51-
case other => other
35+
// We need unique partitionings but if the input partitioning is
36+
// `HashPartitioning(Seq(id + id))` and we have `id -> a` and `id -> b` aliases then after
37+
// the projection we have 4 partitionings:
38+
// `HashPartitioning(Seq(a + a))`, `HashPartitioning(Seq(a + b))`,
39+
// `HashPartitioning(Seq(b + a))`, `HashPartitioning(Seq(b + b))`, but
40+
// `HashPartitioning(Seq(a + b))` is the same as `HashPartitioning(Seq(b + a))`.
41+
val partitioningSet = mutable.Set.empty[Expression]
42+
projectExpression(e)
43+
.filter(e => partitioningSet.add(e.canonicalized))
44+
.take(aliasCandidateLimit)
45+
.asInstanceOf[Stream[Partitioning]]
46+
case o => Seq(o)
47+
} match {
48+
case Seq() => UnknownPartitioning(child.outputPartitioning.numPartitions)
49+
case Seq(p) => p
50+
case ps => PartitioningCollection(ps)
5251
}
5352
} else {
5453
child.outputPartitioning
5554
}
56-
57-
flattenPartitioning(normalizedOutputPartitioning).filter {
58-
case hashPartitioning: HashPartitioning => hashPartitioning.references.subsetOf(outputSet)
59-
case _ => true
60-
} match {
61-
case Seq() => UnknownPartitioning(child.outputPartitioning.numPartitions)
62-
case Seq(singlePartitioning) => singlePartitioning
63-
case seqWithMultiplePartitionings => PartitioningCollection(seqWithMultiplePartitionings)
64-
}
6555
}
6656

6757
private def flattenPartitioning(partitioning: Partitioning): Seq[Partitioning] = {
@@ -74,18 +64,5 @@ trait AliasAwareOutputPartitioning extends AliasAwareOutputExpression {
7464
}
7565
}
7666

77-
/**
78-
* A trait that handles aliases in the `orderingExpressions` to produce `outputOrdering` that
79-
* satisfies ordering requirements.
80-
*/
81-
trait AliasAwareOutputOrdering extends AliasAwareOutputExpression {
82-
protected def orderingExpressions: Seq[SortOrder]
83-
84-
final override def outputOrdering: Seq[SortOrder] = {
85-
if (hasAlias) {
86-
orderingExpressions.map(normalizeExpression(_).asInstanceOf[SortOrder])
87-
} else {
88-
orderingExpressions
89-
}
90-
}
91-
}
67+
trait OrderPreservingUnaryExecNode
68+
extends UnaryExecNode with AliasAwareQueryOutputOrdering[SparkPlan]

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
179179
def requiredChildDistribution: Seq[Distribution] =
180180
Seq.fill(children.size)(UnspecifiedDistribution)
181181

182-
/** Specifies how data is ordered in each partition. */
183-
def outputOrdering: Seq[SortOrder] = Nil
184-
185182
/** Specifies sort order for each partition requirements on the input data for this operator. */
186183
def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)
187184

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@ package org.apache.spark.sql.execution.aggregate
2020
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, NamedExpression}
2121
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Final, PartialMerge}
2222
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution}
23-
import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, ExplainUtils, UnaryExecNode}
23+
import org.apache.spark.sql.execution.{ExplainUtils, PartitioningPreservingUnaryExecNode, UnaryExecNode}
2424
import org.apache.spark.sql.execution.streaming.StatefulOperatorPartitioning
2525

2626
/**
2727
* Holds common logic for aggregate operators
2828
*/
29-
trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning {
29+
trait BaseAggregateExec extends UnaryExecNode with PartitioningPreservingUnaryExecNode {
3030
def requiredChildDistributionExpressions: Option[Seq[Expression]]
3131
def isStreaming: Boolean
3232
def numShufflePartitions: Option[Int]

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.expressions.aggregate._
2424
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
2525
import org.apache.spark.sql.catalyst.util.truncatedString
26-
import org.apache.spark.sql.execution.{AliasAwareOutputOrdering, SparkPlan}
26+
import org.apache.spark.sql.execution.{OrderPreservingUnaryExecNode, SparkPlan}
2727
import org.apache.spark.sql.execution.metric.SQLMetrics
2828
import org.apache.spark.sql.internal.SQLConf
2929

@@ -41,7 +41,7 @@ case class SortAggregateExec(
4141
resultExpressions: Seq[NamedExpression],
4242
child: SparkPlan)
4343
extends AggregateCodegenSupport
44-
with AliasAwareOutputOrdering {
44+
with OrderPreservingUnaryExecNode {
4545

4646
override lazy val metrics = Map(
4747
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))

0 commit comments

Comments
 (0)