Skip to content

Commit 22f594a

Browse files
committed
Maximum repeatedly substituted alias size
1 parent 69dab94 commit 22f594a

File tree

5 files changed

+83
-3
lines changed

5 files changed

+83
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,8 @@ object CollapseProject extends Rule[LogicalPlan] {
658658

659659
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
660660
case p1 @ Project(_, p2: Project) =>
661-
if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) {
661+
if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList) ||
662+
hasOversizedRepeatedAliases(p1.projectList, p2.projectList)) {
662663
p1
663664
} else {
664665
p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList))
@@ -691,6 +692,28 @@ object CollapseProject extends Rule[LogicalPlan] {
691692
}.exists(!_.deterministic))
692693
}
693694

695+
private def hasOversizedRepeatedAliases(
696+
upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = {
697+
val aliases = collectAliases(lower)
698+
699+
// Count how many times each alias is used in the upper Project.
700+
// If an alias is only used once, we can safely substitute it without increasing the overall
701+
// tree size
702+
val referenceCounts = AttributeMap(
703+
upper
704+
.flatMap(_.collect { case a: Attribute => a })
705+
.groupBy(identity)
706+
.mapValues(_.size).toSeq
707+
)
708+
709+
// Check for any aliases that are used more than once, and are larger than the configured
710+
// maximum size
711+
aliases.exists({ case (attribute, expression) =>
712+
referenceCounts.getOrElse(attribute, 0) > 1 &&
713+
expression.treeSize > SQLConf.get.maxRepeatedAliasSize
714+
})
715+
}
716+
694717
private def buildCleanedProjectList(
695718
upper: Seq[NamedExpression],
696719
lower: Seq[NamedExpression]): Seq[NamedExpression] = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._
2525
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2626
import org.apache.spark.sql.catalyst.plans._
2727
import org.apache.spark.sql.catalyst.plans.logical._
28+
import org.apache.spark.sql.internal.SQLConf
2829

2930
/**
3031
* A pattern that matches any number of project or filter operations on top of another relational
@@ -60,8 +61,13 @@ object PhysicalOperation extends PredicateHelper {
6061
plan match {
6162
case Project(fields, child) if fields.forall(_.deterministic) =>
6263
val (_, filters, other, aliases) = collectProjectsAndFilters(child)
63-
val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]]
64-
(Some(substitutedFields), filters, other, collectAliases(substitutedFields))
64+
if (hasOversizedRepeatedAliases(fields, aliases)) {
65+
// Skip substitution if it could overly increase the overall tree size and risk OOMs
66+
(None, Nil, plan, Map.empty)
67+
} else {
68+
val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]]
69+
(Some(substitutedFields), filters, other, collectAliases(substitutedFields))
70+
}
6571

6672
case Filter(condition, child) if condition.deterministic =>
6773
val (fields, filters, other, aliases) = collectProjectsAndFilters(child)
@@ -79,6 +85,26 @@ object PhysicalOperation extends PredicateHelper {
7985
case a @ Alias(child, _) => a.toAttribute -> child
8086
}.toMap
8187

88+
private def hasOversizedRepeatedAliases(fields: Seq[Expression],
89+
aliases: Map[Attribute, Expression]): Boolean = {
90+
// Count how many times each alias is used in the fields.
91+
// If an alias is only used once, we can safely substitute it without increasing the overall
92+
// tree size
93+
val referenceCounts = AttributeMap(
94+
fields
95+
.flatMap(_.collect { case a: Attribute => a })
96+
.groupBy(identity)
97+
.mapValues(_.size).toSeq
98+
)
99+
100+
// Check for any aliases that are used more than once, and are larger than the configured
101+
// maximum size
102+
aliases.exists({ case (attribute, expression) =>
103+
referenceCounts.getOrElse(attribute, 0) > 1 &&
104+
expression.treeSize > SQLConf.get.maxRepeatedAliasSize
105+
})
106+
}
107+
82108
private def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = {
83109
expr.transform {
84110
case a @ Alias(ref: AttributeReference, name) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
8989

9090
lazy val containsChild: Set[TreeNode[_]] = children.toSet
9191

92+
lazy val treeSize: Long = children.map(_.treeSize).sum + 1
93+
9294
private lazy val _hashCode: Int = scala.util.hashing.MurmurHash3.productHash(this)
9395
override def hashCode(): Int = _hashCode
9496

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1642,6 +1642,15 @@ object SQLConf {
16421642
"a SparkConf entry.")
16431643
.booleanConf
16441644
.createWithDefault(true)
1645+
1646+
val MAX_REPEATED_ALIAS_SIZE =
1647+
buildConf("spark.sql.maxRepeatedAliasSize")
1648+
.internal()
1649+
.doc("The maximum size of alias expression that will be substituted multiple times " +
1650+
"(size defined by the number of nodes in the expression tree). " +
1651+
"Used by the CollapseProject optimizer, and PhysicalOperation.")
1652+
.intConf
1653+
.createWithDefault(100)
16451654
}
16461655

16471656
/**
@@ -2071,6 +2080,8 @@ class SQLConf extends Serializable with Logging {
20712080
def setCommandRejectsSparkCoreConfs: Boolean =
20722081
getConf(SQLConf.SET_COMMAND_REJECTS_SPARK_CORE_CONFS)
20732082

2083+
def maxRepeatedAliasSize: Int = getConf(SQLConf.MAX_REPEATED_ALIAS_SIZE)
2084+
20742085
/** ********************** SQLConf functionality methods ************ */
20752086

20762087
/** Set Spark SQL configuration properties. */

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,4 +138,22 @@ class CollapseProjectSuite extends PlanTest {
138138
assert(projects.size === 1)
139139
assert(hasMetadata(optimized))
140140
}
141+
142+
test("ensure oversize aliases are not repeatedly substituted") {
143+
var query: LogicalPlan = testRelation
144+
for( a <- 1 to 100) {
145+
query = query.select(('a + 'b).as('a), ('a - 'b).as('b))
146+
}
147+
val projects = Optimize.execute(query.analyze).collect { case p: Project => p }
148+
assert(projects.size >= 12)
149+
}
150+
151+
test("ensure oversize aliases are still substituted once") {
152+
var query: LogicalPlan = testRelation
153+
for( a <- 1 to 20) {
154+
query = query.select(('a + 'b).as('a), 'b)
155+
}
156+
val projects = Optimize.execute(query.analyze).collect { case p: Project => p }
157+
assert(projects.size === 1)
158+
}
141159
}

0 commit comments

Comments
 (0)