Skip to content

Commit 24af7b2

Browse files
committed
[SPARK-24497][SQL] add tests, fix nested WITH, fix Exchange reuse
1 parent f5feb63 commit 24af7b2

File tree

7 files changed

+1857
-208
lines changed

7 files changed

+1857
-208
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 86 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ class Analyzer(
209209
object ResolveRecursiveReferneces extends Rule[LogicalPlan] {
210210
def apply(plan: LogicalPlan): LogicalPlan = {
211211
val recursiveTables = plan.collect {
212-
case rt @ RecursiveTable(name, _, _) if rt.anchorResolved => name -> rt
212+
case rt @ RecursiveTable(name, _, _, _) if rt.anchorResolved => name -> rt
213213
}.toMap
214214

215215
plan.resolveOperatorsUp {
@@ -231,123 +231,139 @@ class Analyzer(
231231
substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) {
232232
case (resolved, (name, relation)) =>
233233
val recursiveTableName = if (allowRecursion) Some(name) else None
234-
resolved :+
235-
name -> executeSameContext(substituteCTE(relation, resolved, recursiveTableName))
236-
}, None)
234+
val (substitutedPlan, recursiveReferenceFound) =
235+
substituteCTE(relation, resolved, recursiveTableName)
236+
val analyzedPlan = executeSameContext(substitutedPlan)
237+
resolved :+ name -> (
238+
if (recursiveReferenceFound) {
239+
insertRecursiveTable(analyzedPlan, recursiveTableName.get)
240+
} else {
241+
analyzedPlan
242+
})
243+
}, None)._1
237244
case other => other
238245
}
239246

240247
def substituteCTE(
241248
plan: LogicalPlan,
242249
cteRelations: Seq[(String, LogicalPlan)],
243-
recursiveTableName: Option[String]): LogicalPlan = {
244-
def substitute(
245-
plan: LogicalPlan,
246-
inSubQuery: Boolean = false): (LogicalPlan, Boolean) = {
247-
var recursiveReferenceFound = false
248-
249-
val newPlan = plan resolveOperatorsDown {
250-
case u: UnresolvedRelation =>
251-
val table = u.tableIdentifier.table
252-
253-
val recursiveReference = recursiveTableName.find(resolver(_, table)).map { name =>
254-
if (inSubQuery) {
255-
throw new AnalysisException(
256-
s"Recursive reference ${name} can't be used in a subquery")
257-
}
258-
259-
recursiveReferenceFound = true
250+
recursiveTableName: Option[String]): (LogicalPlan, Boolean) = {
251+
var recursiveReferenceFound = false
260252

261-
UnresolvedRecursiveReference(name)
262-
}
253+
val newPlan = plan resolveOperatorsDown {
254+
case u: UnresolvedRelation =>
255+
val table = u.tableIdentifier.table
263256

264-
recursiveReference
265-
.orElse(cteRelations.find(x => resolver(x._1, table)).map(_._2))
266-
.getOrElse(u)
257+
val recursiveReference = recursiveTableName.find(resolver(_, table)).map { name =>
258+
recursiveReferenceFound = true
267259

268-
case other =>
269-
// This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE.
270-
other transformExpressions {
271-
case e: SubqueryExpression => e.withNewPlan(substitute(e.plan, true)._1)
272-
}
273-
}
260+
UnresolvedRecursiveReference(name)
261+
}
274262

275-
(newPlan, recursiveReferenceFound)
263+
recursiveReference
264+
.orElse(cteRelations.find(x => resolver(x._1, table)).map(_._2))
265+
.getOrElse(u)
266+
case w @ With(_, cteRelations, _) =>
267+
w.copy(cteRelations = cteRelations.map {
268+
case (name, sa @ SubqueryAlias(_, plan)) =>
269+
val (substitutedPlan, recursiveReferenceFoundInCTE) =
270+
substituteCTE(plan, Seq.empty, recursiveTableName)
271+
recursiveReferenceFound |= recursiveReferenceFoundInCTE
272+
(name, sa.copy(child = substitutedPlan))
273+
})
274+
case other =>
275+
// This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE.
276+
other transformExpressions {
277+
case e: SubqueryExpression =>
278+
val (substitutedPlan, recursiveReferenceFoundInSubQuery) =
279+
substituteCTE(e.plan, cteRelations, recursiveTableName)
280+
281+
recursiveReferenceFound |= recursiveReferenceFoundInSubQuery
282+
e.withNewPlan(substitutedPlan)
283+
}
276284
}
277285

286+
(newPlan, recursiveReferenceFound)
287+
}
288+
289+
def insertRecursiveTable(plan: LogicalPlan, recursiveTableName: String): LogicalPlan =
278290
plan match {
279-
case SubqueryAlias(name, u: Union) if recursiveTableName.contains(name.identifier) =>
291+
case sa @ SubqueryAlias(name, u: Union) if name.identifier == recursiveTableName =>
280292
def combineUnions(union: Union): Seq[LogicalPlan] = union.children.flatMap {
281293
case u: Union => combineUnions(u)
282294
case o => Seq(o)
283295
}
284296

285-
val substitutedTerms = combineUnions(u).map(substitute(_))
286-
val (anchorTerms, recursiveTerms) = substitutedTerms.partition(!_._2)
297+
val combinedTerms = combineUnions(u)
298+
val (anchorTerms, recursiveTerms) = combinedTerms.partition(!_.collectFirst {
299+
case UnresolvedRecursiveReference(name) if name == recursiveTableName => true
300+
}.isDefined)
287301

288302
if (!recursiveTerms.isEmpty) {
289303
if (anchorTerms.isEmpty) {
290304
throw new AnalysisException("There should be at least 1 anchor term defined in the " +
291-
s"recursive query ${recursiveTableName.get}")
305+
s"recursive query ${recursiveTableName}")
292306
}
293307

294-
val recursiveTermPlans = recursiveTerms.map(_._1)
295-
296308
def traversePlanAndCheck(
297309
plan: LogicalPlan,
298-
isRecursiveReferenceAllowed: Boolean = true): Boolean = plan match {
299-
case UnresolvedRecursiveReference(name) if recursiveTableName.contains(name) =>
310+
isRecursiveReferenceAllowed: Boolean = true): Int = plan match {
311+
case UnresolvedRecursiveReference(name) if name == recursiveTableName =>
300312
if (!isRecursiveReferenceAllowed) {
301-
throw new AnalysisException(s"Recursive reference ${recursiveTableName.get} " +
313+
throw new AnalysisException(s"Recursive reference ${recursiveTableName} " +
302314
"cannot be used here. This can be caused by using it in a different join " +
303315
"than inner or left outer or right outer, using it on inner side of an " +
304-
"outer join or using it in an aggregate or with a distinct statement")
316+
"outer join, using it with aggregate or distinct, using it in a subquery " +
317+
"or using it multiple times in a recursive term.")
305318
}
306-
true
319+
1
307320
case Join(left, right, Inner, _, _) =>
308-
val l = traversePlanAndCheck(left, isRecursiveReferenceAllowed)
309-
val r = traversePlanAndCheck(right, isRecursiveReferenceAllowed)
310-
if (l && r) {
311-
throw new AnalysisException(s"Recursive reference ${recursiveTableName.get} " +
312-
"cannot be used on both sides of an inner join")
313-
}
314-
l || r
321+
traversePlanAndCheck(left, isRecursiveReferenceAllowed) +
322+
traversePlanAndCheck(right, isRecursiveReferenceAllowed)
315323
case Join(left, right, LeftOuter, _, _) =>
316-
traversePlanAndCheck(left, isRecursiveReferenceAllowed) ||
324+
traversePlanAndCheck(left, isRecursiveReferenceAllowed) +
317325
traversePlanAndCheck(right, false)
318326
case Join(left, right, RightOuter, _, _) =>
319-
traversePlanAndCheck(left, false) ||
327+
traversePlanAndCheck(left, false) +
320328
traversePlanAndCheck(right, isRecursiveReferenceAllowed)
321329
case Join(left, right, _, _, _) =>
322-
traversePlanAndCheck(left, false) || traversePlanAndCheck(right, false)
330+
traversePlanAndCheck(left, false) +
331+
traversePlanAndCheck(right, false)
323332
case Aggregate(_, _, child) => traversePlanAndCheck(child, false)
324333
case Distinct(child) => traversePlanAndCheck(child, false)
325334
case o =>
326-
o.children.map(traversePlanAndCheck(_, isRecursiveReferenceAllowed)).contains(true)
335+
o transformExpressions {
336+
case se: SubqueryExpression =>
337+
traversePlanAndCheck(se.plan, false)
338+
se
339+
}
340+
o.children
341+
.map(traversePlanAndCheck(_, isRecursiveReferenceAllowed))
342+
.foldLeft(0)(_ + _)
327343
}
328344

329-
recursiveTermPlans.foreach(traversePlanAndCheck(_))
345+
recursiveTerms.foreach { recursiveTerm =>
346+
if (traversePlanAndCheck(recursiveTerm) > 1) {
347+
throw new AnalysisException(s"Recursive reference ${recursiveTableName} cannot " +
348+
"be used multiple times in a recursive term")
349+
}
350+
}
330351

331352
RecursiveTable(
332-
recursiveTableName.get,
333-
SubqueryAlias(name, Union(anchorTerms.map(_._1))),
334-
Union(recursiveTermPlans))
353+
recursiveTableName,
354+
sa.copy(child = Union(anchorTerms)),
355+
Union(recursiveTerms),
356+
None)
335357
} else {
336-
SubqueryAlias(name, Union(substitutedTerms.map(_._1)))
358+
SubqueryAlias(recursiveTableName, Union(combinedTerms))
337359
}
338360

339361
case _ =>
340-
val (substitutedPlan, recursiveReferenceFound) = substitute(plan)
341-
342-
if (recursiveReferenceFound) {
343-
throw new AnalysisException(s"Recursive query ${recursiveTableName.get} should " +
344-
"contain UNION ALL statements only")
345-
}
346-
347-
substitutedPlan
362+
throw new AnalysisException(s"Recursive query ${recursiveTableName} should contain " +
363+
"UNION ALL statements only. This can also be caused by ORDER BY or LIMIT keywords " +
364+
"used on result of UNION ALL.")
348365
}
349366
}
350-
}
351367

352368
/**
353369
* Substitute child plan with WindowSpecDefinitions.
@@ -1656,7 +1672,7 @@ class Analyzer(
16561672
case RecursiveReference(name, _) =>
16571673
throw new AnalysisException(s"Recursive reference ${name} can't be used in an " +
16581674
"aggregate")
1659-
case RecursiveTable(_, _, recursiveTerm) =>
1675+
case RecursiveTable(_, _, recursiveTerm, _) =>
16601676
case o => o.children.map(traversePlanAndCheck)
16611677
}
16621678

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ case class Subquery(child: LogicalPlan) extends OrderPreservingUnaryNode {
6262
case class RecursiveTable(
6363
name: String,
6464
anchorTerm: LogicalPlan,
65-
recursiveTerm: LogicalPlan) extends LogicalPlan {
65+
recursiveTerm: LogicalPlan,
66+
limit: Option[Long]) extends LogicalPlan {
6667
override def children: Seq[LogicalPlan] = Seq(anchorTerm, recursiveTerm)
6768

6869
override def output: Seq[Attribute] = anchorTerm.output.map(_.withNullability(true))
@@ -553,7 +554,8 @@ case class With(
553554

554555
override def simpleString(maxFields: Int): String = {
555556
val cteAliases = truncatedString(cteRelations.map(_._1), "[", ", ", "]", maxFields)
556-
s"CTE $cteAliases"
557+
val recursive = if (allowRecursion) " recursive" else ""
558+
s"CTE$recursive $cteAliases"
557559
}
558560

559561
override def innerChildren: Seq[LogicalPlan] = cteRelations.map(_._2)

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -608,9 +608,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
608608
execution.ProjectExec(projectList, planLater(child)) :: Nil
609609
case logical.Filter(condition, child) =>
610610
execution.FilterExec(condition, planLater(child)) :: Nil
611-
case logical.RecursiveTable(name, anchorTerm, recursiveTerm) =>
611+
case logical.RecursiveTable(name, anchorTerm, recursiveTerm, limit) =>
612612
execution.RecursiveTableExec(
613-
name, planLater(anchorTerm), planLater(recursiveTerm)) :: Nil
613+
name, planLater(anchorTerm), planLater(recursiveTerm), limit) :: Nil
614614
case logical.RecursiveReference(name, output) =>
615615
execution.RecursiveReferenceExec(name, output) :: Nil
616616
case f: logical.TypedFilter =>

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
2727
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
2828
import org.apache.spark.sql.catalyst.expressions.codegen._
2929
import org.apache.spark.sql.catalyst.plans.physical._
30-
import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ShuffleExchangeExec}
30+
import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ReusedExchangeExec, ShuffleExchangeExec}
3131
import org.apache.spark.sql.execution.metric.SQLMetrics
3232
import org.apache.spark.sql.internal.SQLConf
3333
import org.apache.spark.sql.types.LongType
@@ -232,7 +232,8 @@ case class FilterExec(condition: Expression, child: SparkPlan)
232232
case class RecursiveTableExec(
233233
name: String,
234234
anchorTerm: SparkPlan,
235-
recursiveTerm: SparkPlan) extends SparkPlan {
235+
recursiveTerm: SparkPlan,
236+
limit: Option[Long]) extends SparkPlan { // TODO: how to implement limit?
236237
override def children: Seq[SparkPlan] = Seq(anchorTerm, recursiveTerm)
237238

238239
override def output: Seq[Attribute] = anchorTerm.output
@@ -243,9 +244,10 @@ case class RecursiveTableExec(
243244
var temp = anchorTerm.execute().map(_.copy()).cache()
244245
var tempCount = temp.count()
245246
var result = temp
247+
var sumCount = tempCount
246248
var level = 0
247249
val levelLimit = conf.recursionLevelLimit
248-
do {
250+
while ((level == 0 || tempCount > 0) && limit.map(_ < sumCount).getOrElse(true)) {
249251
if (level > levelLimit) {
250252
throw new SparkException("Recursion level limit reached but query hasn't exhausted, try " +
251253
s"increasing ${SQLConf.RECURSION_LEVEL_LIMIT.key}")
@@ -261,13 +263,17 @@ case class RecursiveTableExec(
261263
if (level > 0) {
262264
newRecursiveTerm.reset()
263265
}
264-
newRecursiveTerm.foreach {
266+
267+
def updateRecursiveTables(plan: SparkPlan): Unit = plan.foreach {
265268
_ match {
266269
case rr: RecursiveReferenceExec if rr.name == name => rr.recursiveTable = temp
270+
case ReusedExchangeExec(_, child) => updateRecursiveTables(child)
267271
case _ =>
268272
}
269273
}
270274

275+
updateRecursiveTables(newRecursiveTerm)
276+
271277
val newTemp = newRecursiveTerm.execute().map(_.copy()).cache()
272278
tempCount = newTemp.count()
273279
temp.unpersist()
@@ -276,7 +282,7 @@ case class RecursiveTableExec(
276282
result = result.union(temp)
277283

278284
level = level + 1
279-
} while (tempCount > 0)
285+
}
280286

281287
result
282288
}

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow
2626
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SortOrder}
2727
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
2828
import org.apache.spark.sql.catalyst.rules.Rule
29-
import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode}
29+
import org.apache.spark.sql.execution._
3030
import org.apache.spark.sql.internal.SQLConf
3131
import org.apache.spark.sql.types.StructType
3232

@@ -56,6 +56,10 @@ case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchan
5656
child.execute()
5757
}
5858

59+
override def doReset(): Unit = {
60+
child.reset()
61+
}
62+
5963
override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
6064
child.executeBroadcast()
6165
}
@@ -90,11 +94,22 @@ case class ReuseExchange(conf: SQLConf) extends Rule[SparkPlan] {
9094
return plan
9195
}
9296
// Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls.
93-
val exchanges = mutable.HashMap[StructType, ArrayBuffer[Exchange]]()
97+
// TODO: document recursion related changes
98+
val allExchanges = mutable.Stack[mutable.HashMap[StructType, ArrayBuffer[Exchange]]]()
99+
allExchanges.push(mutable.HashMap[StructType, ArrayBuffer[Exchange]]())
100+
val recursiveTables = mutable.Set.empty[String]
94101
plan.transformUp {
102+
case rr @ RecursiveReferenceExec(name, _) if !recursiveTables.contains(name) =>
103+
allExchanges.push(mutable.HashMap[StructType, ArrayBuffer[Exchange]]())
104+
recursiveTables += name
105+
rr
106+
case rt @ RecursiveTableExec(name, _, _, _) =>
107+
allExchanges.pop()
108+
recursiveTables -= name
109+
rt
95110
case exchange: Exchange =>
96111
// the exchanges that have same results usually also have same schemas (same column names).
97-
val sameSchema = exchanges.getOrElseUpdate(exchange.schema, ArrayBuffer[Exchange]())
112+
val sameSchema = allExchanges.top.getOrElseUpdate(exchange.schema, ArrayBuffer[Exchange]())
98113
val samePlan = sameSchema.find { e =>
99114
exchange.sameResult(e)
100115
}

0 commit comments

Comments
 (0)