@@ -209,7 +209,7 @@ class Analyzer(
209209 object ResolveRecursiveReferneces extends Rule [LogicalPlan ] {
210210 def apply (plan : LogicalPlan ): LogicalPlan = {
211211 val recursiveTables = plan.collect {
212- case rt @ RecursiveTable (name, _, _) if rt.anchorResolved => name -> rt
212+ case rt @ RecursiveTable (name, _, _, _ ) if rt.anchorResolved => name -> rt
213213 }.toMap
214214
215215 plan.resolveOperatorsUp {
@@ -231,123 +231,139 @@ class Analyzer(
231231 substituteCTE(child, relations.foldLeft(Seq .empty[(String , LogicalPlan )]) {
232232 case (resolved, (name, relation)) =>
233233 val recursiveTableName = if (allowRecursion) Some (name) else None
234- resolved :+
235- name -> executeSameContext(substituteCTE(relation, resolved, recursiveTableName))
236- }, None )
234+ val (substitutedPlan, recursiveReferenceFound) =
235+ substituteCTE(relation, resolved, recursiveTableName)
236+ val analyzedPlan = executeSameContext(substitutedPlan)
237+ resolved :+ name -> (
238+ if (recursiveReferenceFound) {
239+ insertRecursiveTable(analyzedPlan, recursiveTableName.get)
240+ } else {
241+ analyzedPlan
242+ })
243+ }, None )._1
237244 case other => other
238245 }
239246
240247 def substituteCTE (
241248 plan : LogicalPlan ,
242249 cteRelations : Seq [(String , LogicalPlan )],
243- recursiveTableName : Option [String ]): LogicalPlan = {
244- def substitute (
245- plan : LogicalPlan ,
246- inSubQuery : Boolean = false ): (LogicalPlan , Boolean ) = {
247- var recursiveReferenceFound = false
248-
249- val newPlan = plan resolveOperatorsDown {
250- case u : UnresolvedRelation =>
251- val table = u.tableIdentifier.table
252-
253- val recursiveReference = recursiveTableName.find(resolver(_, table)).map { name =>
254- if (inSubQuery) {
255- throw new AnalysisException (
256- s " Recursive reference ${name} can't be used in a subquery " )
257- }
258-
259- recursiveReferenceFound = true
250+ recursiveTableName : Option [String ]): (LogicalPlan , Boolean ) = {
251+ var recursiveReferenceFound = false
260252
261- UnresolvedRecursiveReference (name)
262- }
253+ val newPlan = plan resolveOperatorsDown {
254+ case u : UnresolvedRelation =>
255+ val table = u.tableIdentifier.table
263256
264- recursiveReference
265- .orElse(cteRelations.find(x => resolver(x._1, table)).map(_._2))
266- .getOrElse(u)
257+ val recursiveReference = recursiveTableName.find(resolver(_, table)).map { name =>
258+ recursiveReferenceFound = true
267259
268- case other =>
269- // This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE.
270- other transformExpressions {
271- case e : SubqueryExpression => e.withNewPlan(substitute(e.plan, true )._1)
272- }
273- }
260+ UnresolvedRecursiveReference (name)
261+ }
274262
275- (newPlan, recursiveReferenceFound)
263+ recursiveReference
264+ .orElse(cteRelations.find(x => resolver(x._1, table)).map(_._2))
265+ .getOrElse(u)
266+ case w @ With (_, cteRelations, _) =>
267+ w.copy(cteRelations = cteRelations.map {
268+ case (name, sa @ SubqueryAlias (_, plan)) =>
269+ val (substitutedPlan, recursiveReferenceFoundInCTE) =
270+ substituteCTE(plan, Seq .empty, recursiveTableName)
271+ recursiveReferenceFound |= recursiveReferenceFoundInCTE
272+ (name, sa.copy(child = substitutedPlan))
273+ })
274+ case other =>
275+ // This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE.
276+ other transformExpressions {
277+ case e : SubqueryExpression =>
278+ val (substitutedPlan, recursiveReferenceFoundInSubQuery) =
279+ substituteCTE(e.plan, cteRelations, recursiveTableName)
280+
281+ recursiveReferenceFound |= recursiveReferenceFoundInSubQuery
282+ e.withNewPlan(substitutedPlan)
283+ }
276284 }
277285
286+ (newPlan, recursiveReferenceFound)
287+ }
288+
289+ def insertRecursiveTable (plan : LogicalPlan , recursiveTableName : String ): LogicalPlan =
278290 plan match {
279- case SubqueryAlias (name, u : Union ) if recursiveTableName.contains( name.identifier) =>
291+ case sa @ SubqueryAlias (name, u : Union ) if name.identifier == recursiveTableName =>
280292 def combineUnions (union : Union ): Seq [LogicalPlan ] = union.children.flatMap {
281293 case u : Union => combineUnions(u)
282294 case o => Seq (o)
283295 }
284296
285- val substitutedTerms = combineUnions(u).map(substitute(_))
286- val (anchorTerms, recursiveTerms) = substitutedTerms.partition(! _._2)
297+ val combinedTerms = combineUnions(u)
298+ val (anchorTerms, recursiveTerms) = combinedTerms.partition(! _.collectFirst {
299+ case UnresolvedRecursiveReference (name) if name == recursiveTableName => true
300+ }.isDefined)
287301
288302 if (! recursiveTerms.isEmpty) {
289303 if (anchorTerms.isEmpty) {
290304 throw new AnalysisException (" There should be at least 1 anchor term defined in the " +
291- s " recursive query ${recursiveTableName.get }" )
305+ s " recursive query ${recursiveTableName}" )
292306 }
293307
294- val recursiveTermPlans = recursiveTerms.map(_._1)
295-
296308 def traversePlanAndCheck (
297309 plan : LogicalPlan ,
298- isRecursiveReferenceAllowed : Boolean = true ): Boolean = plan match {
299- case UnresolvedRecursiveReference (name) if recursiveTableName.contains( name) =>
310+ isRecursiveReferenceAllowed : Boolean = true ): Int = plan match {
311+ case UnresolvedRecursiveReference (name) if name == recursiveTableName =>
300312 if (! isRecursiveReferenceAllowed) {
301- throw new AnalysisException (s " Recursive reference ${recursiveTableName.get } " +
313+ throw new AnalysisException (s " Recursive reference ${recursiveTableName} " +
302314 " cannot be used here. This can be caused by using it in a different join " +
303315 " than inner or left outer or right outer, using it on inner side of an " +
304- " outer join or using it in an aggregate or with a distinct statement" )
316+ " outer join, using it with aggregate or distinct, using it in a subquery " +
317+ " or using it multiple times in a recursive term." )
305318 }
306- true
319+ 1
307320 case Join (left, right, Inner , _, _) =>
308- val l = traversePlanAndCheck(left, isRecursiveReferenceAllowed)
309- val r = traversePlanAndCheck(right, isRecursiveReferenceAllowed)
310- if (l && r) {
311- throw new AnalysisException (s " Recursive reference ${recursiveTableName.get} " +
312- " cannot be used on both sides of an inner join" )
313- }
314- l || r
321+ traversePlanAndCheck(left, isRecursiveReferenceAllowed) +
322+ traversePlanAndCheck(right, isRecursiveReferenceAllowed)
315323 case Join (left, right, LeftOuter , _, _) =>
316- traversePlanAndCheck(left, isRecursiveReferenceAllowed) ||
324+ traversePlanAndCheck(left, isRecursiveReferenceAllowed) +
317325 traversePlanAndCheck(right, false )
318326 case Join (left, right, RightOuter , _, _) =>
319- traversePlanAndCheck(left, false ) ||
327+ traversePlanAndCheck(left, false ) +
320328 traversePlanAndCheck(right, isRecursiveReferenceAllowed)
321329 case Join (left, right, _, _, _) =>
322- traversePlanAndCheck(left, false ) || traversePlanAndCheck(right, false )
330+ traversePlanAndCheck(left, false ) +
331+ traversePlanAndCheck(right, false )
323332 case Aggregate (_, _, child) => traversePlanAndCheck(child, false )
324333 case Distinct (child) => traversePlanAndCheck(child, false )
325334 case o =>
326- o.children.map(traversePlanAndCheck(_, isRecursiveReferenceAllowed)).contains(true )
335+ o transformExpressions {
336+ case se : SubqueryExpression =>
337+ traversePlanAndCheck(se.plan, false )
338+ se
339+ }
340+ o.children
341+ .map(traversePlanAndCheck(_, isRecursiveReferenceAllowed))
342+ .foldLeft(0 )(_ + _)
327343 }
328344
329- recursiveTermPlans.foreach(traversePlanAndCheck(_))
345+ recursiveTerms.foreach { recursiveTerm =>
346+ if (traversePlanAndCheck(recursiveTerm) > 1 ) {
347+ throw new AnalysisException (s " Recursive reference ${recursiveTableName} cannot " +
348+ " be used multiple times in a recursive term" )
349+ }
350+ }
330351
331352 RecursiveTable (
332- recursiveTableName.get,
333- SubqueryAlias (name, Union (anchorTerms.map(_._1))),
334- Union (recursiveTermPlans))
353+ recursiveTableName,
354+ sa.copy(child = Union (anchorTerms)),
355+ Union (recursiveTerms),
356+ None )
335357 } else {
336- SubqueryAlias (name , Union (substitutedTerms.map(_._1) ))
358+ SubqueryAlias (recursiveTableName , Union (combinedTerms ))
337359 }
338360
339361 case _ =>
340- val (substitutedPlan, recursiveReferenceFound) = substitute(plan)
341-
342- if (recursiveReferenceFound) {
343- throw new AnalysisException (s " Recursive query ${recursiveTableName.get} should " +
344- " contain UNION ALL statements only" )
345- }
346-
347- substitutedPlan
362+ throw new AnalysisException (s " Recursive query ${recursiveTableName} should contain " +
363+ " UNION ALL statements only. This can also be caused by ORDER BY or LIMIT keywords " +
364+ " used on result of UNION ALL." )
348365 }
349366 }
350- }
351367
352368 /**
353369 * Substitute child plan with WindowSpecDefinitions.
@@ -1656,7 +1672,7 @@ class Analyzer(
16561672 case RecursiveReference (name, _) =>
16571673 throw new AnalysisException (s " Recursive reference ${name} can't be used in an " +
16581674 " aggregate" )
1659- case RecursiveTable (_, _, recursiveTerm) =>
1675+ case RecursiveTable (_, _, recursiveTerm, _ ) =>
16601676 case o => o.children.map(traversePlanAndCheck)
16611677 }
16621678
0 commit comments