Skip to content

Commit 0904fc9

Browse files
committed
Support aliases for table value functions
1 parent c0189ab commit 0904fc9

File tree

8 files changed

+122
-18
lines changed

8 files changed

+122
-18
lines changed

sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -472,15 +472,23 @@ identifierComment
472472
;
473473

474474
relationPrimary
475-
: tableIdentifier sample? (AS? strictIdentifier)? #tableName
476-
| '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery
477-
| '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation
478-
| inlineTable #inlineTableDefault2
479-
| identifier '(' (expression (',' expression)*)? ')' #tableValuedFunction
475+
: tableIdentifier sample? (AS? strictIdentifier)? #tableName
476+
| '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery
477+
| '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation
478+
| inlineTable #inlineTableDefault2
479+
| functionTable #tableValuedFunction
480480
;
481481

482482
inlineTable
483-
: VALUES expression (',' expression)* (AS? identifier identifierList?)?
483+
: VALUES expression (',' expression)* tableAlias
484+
;
485+
486+
functionTable
487+
: identifier '(' (expression (',' expression)*)? ')' tableAlias
488+
;
489+
490+
tableAlias
491+
: (AS? strictIdentifier identifierList?)?
484492
;
485493

486494
rowFormat

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.analysis
1919

2020
import java.util.Locale
2121

22-
import org.apache.spark.sql.catalyst.expressions.Expression
23-
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range}
22+
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression}
23+
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Range}
2424
import org.apache.spark.sql.catalyst.rules._
2525
import org.apache.spark.sql.types.{DataType, IntegerType, LongType}
2626

@@ -105,7 +105,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
105105

106106
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
107107
case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) =>
108-
builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match {
108+
val resolvedFunc = builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match {
109109
case Some(tvf) =>
110110
val resolved = tvf.flatMap { case (argList, resolver) =>
111111
argList.implicitCast(u.functionArgs) match {
@@ -125,5 +125,21 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
125125
case _ =>
126126
u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function")
127127
}
128+
129+
// If alias names assigned, add `Project` with the aliases
130+
if (u.outputNames.nonEmpty) {
131+
val outputAttrs = resolvedFunc.output
132+
// Checks if the number of the aliases is equal to expected one
133+
if (u.outputNames.size != outputAttrs.size) {
134+
u.failAnalysis(s"expected ${outputAttrs.size} columns but " +
135+
s"found ${u.outputNames.size} columns")
136+
}
137+
val aliases = outputAttrs.zip(u.outputNames).map {
138+
case (attr, name) => Alias(attr, name)()
139+
}
140+
Project(aliases, resolvedFunc)
141+
} else {
142+
resolvedFunc
143+
}
128144
}
129145
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,16 @@ case class UnresolvedInlineTable(
6666
/**
6767
* A table-valued function, e.g.
6868
* {{{
69-
* select * from range(10);
69+
* select id from range(10);
70+
*
71+
* // Assign alias names
72+
* select t.a from range(10) t(a);
7073
* }}}
7174
*/
72-
case class UnresolvedTableValuedFunction(functionName: String, functionArgs: Seq[Expression])
75+
case class UnresolvedTableValuedFunction(
76+
functionName: String,
77+
functionArgs: Seq[Expression],
78+
outputNames: Seq[String])
7379
extends LeafNode {
7480

7581
override def output: Seq[Attribute] = Nil

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,16 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
687687
*/
688688
override def visitTableValuedFunction(ctx: TableValuedFunctionContext)
689689
: LogicalPlan = withOrigin(ctx) {
690-
UnresolvedTableValuedFunction(ctx.identifier.getText, ctx.expression.asScala.map(expression))
690+
val func = ctx.functionTable
691+
val aliases = if (func.tableAlias.identifierList != null) {
692+
visitIdentifierList(func.tableAlias.identifierList)
693+
} else {
694+
Seq.empty
695+
}
696+
697+
val tvf = UnresolvedTableValuedFunction(
698+
func.identifier.getText, func.expression.asScala.map(expression), aliases)
699+
tvf.optionalMap(func.tableAlias.strictIdentifier)(aliasPlan)
691700
}
692701

693702
/**
@@ -705,14 +714,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
705714
}
706715
}
707716

708-
val aliases = if (ctx.identifierList != null) {
709-
visitIdentifierList(ctx.identifierList)
717+
val aliases = if (ctx.tableAlias.identifierList != null) {
718+
visitIdentifierList(ctx.tableAlias.identifierList)
710719
} else {
711720
Seq.tabulate(rows.head.size)(i => s"col${i + 1}")
712721
}
713722

714723
val table = UnresolvedInlineTable(aliases, rows)
715-
table.optionalMap(ctx.identifier)(aliasPlan)
724+
table.optionalMap(ctx.tableAlias.strictIdentifier)(aliasPlan)
716725
}
717726

718727
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.TableIdentifier
2525
import org.apache.spark.sql.catalyst.dsl.expressions._
2626
import org.apache.spark.sql.catalyst.dsl.plans._
2727
import org.apache.spark.sql.catalyst.expressions._
28-
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
2928
import org.apache.spark.sql.catalyst.plans.Cross
3029
import org.apache.spark.sql.catalyst.plans.logical._
3130
import org.apache.spark.sql.types._
@@ -441,4 +440,17 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers {
441440

442441
checkAnalysis(SubqueryAlias("tbl", testRelation).as("tbl2"), testRelation)
443442
}
443+
444+
test("SPARK-20311 range(N) as alias") {
445+
def rangeWithAliases(args: Seq[Int], outputNames: Seq[String]): LogicalPlan = {
446+
SubqueryAlias("t", UnresolvedTableValuedFunction("range", args.map(Literal(_)), outputNames))
447+
.select(star())
448+
}
449+
assertAnalysisSuccess(rangeWithAliases(3 :: Nil, "a" :: Nil))
450+
assertAnalysisSuccess(rangeWithAliases(1 :: 4 :: Nil, "b" :: Nil))
451+
assertAnalysisSuccess(rangeWithAliases(2 :: 6 :: 2 :: Nil, "c" :: Nil))
452+
assertAnalysisError(
453+
rangeWithAliases(3 :: Nil, "a" :: "b" :: Nil),
454+
Seq("expected 1 columns but found 2 columns"))
455+
}
444456
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,18 @@ class PlanParserSuite extends PlanTest {
468468
test("table valued function") {
469469
assertEqual(
470470
"select * from range(2)",
471-
UnresolvedTableValuedFunction("range", Literal(2) :: Nil).select(star()))
471+
UnresolvedTableValuedFunction("range", Literal(2) :: Nil, Seq.empty).select(star()))
472+
}
473+
474+
test("SPARK-20311 range(N) as alias") {
475+
assertEqual(
476+
"select * from range(10) AS t",
477+
SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(10) :: Nil, Seq.empty))
478+
.select(star()))
479+
assertEqual(
480+
"select * from range(7) AS t(a)",
481+
SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(7) :: Nil, "a" :: Nil))
482+
.select(star()))
472483
}
473484

474485
test("inline table") {

sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,7 @@ select * from RaNgE(2);
2424

2525
-- Explain
2626
EXPLAIN select * from RaNgE(2);
27+
28+
-- cross-join table valued functions
29+
set spark.sql.crossJoin.enabled=true;
30+
EXPLAIN EXTENDED select * from range(3) cross join range(3);

sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 9
2+
-- Number of queries: 11
33

44

55
-- !query 0
@@ -103,3 +103,41 @@ struct<plan:string>
103103
-- !query 8 output
104104
== Physical Plan ==
105105
*Range (0, 2, step=1, splits=2)
106+
107+
108+
-- !query 9
109+
set spark.sql.crossJoin.enabled=true
110+
-- !query 9 schema
111+
struct<key:string,value:string>
112+
-- !query 9 output
113+
spark.sql.crossJoin.enabled true
114+
115+
116+
-- !query 10
117+
EXPLAIN EXTENDED select * from range(3) cross join range(3)
118+
-- !query 10 schema
119+
struct<plan:string>
120+
-- !query 10 output
121+
== Parsed Logical Plan ==
122+
'Project [*]
123+
+- 'Join Cross
124+
:- 'UnresolvedTableValuedFunction range, [3]
125+
+- 'UnresolvedTableValuedFunction range, [3]
126+
127+
== Analyzed Logical Plan ==
128+
id: bigint, id: bigint
129+
Project [id#xL, id#xL]
130+
+- Join Cross
131+
:- Range (0, 3, step=1, splits=None)
132+
+- Range (0, 3, step=1, splits=None)
133+
134+
== Optimized Logical Plan ==
135+
Join Cross
136+
:- Range (0, 3, step=1, splits=None)
137+
+- Range (0, 3, step=1, splits=None)
138+
139+
== Physical Plan ==
140+
BroadcastNestedLoopJoin BuildRight, Cross
141+
:- *Range (0, 3, step=1, splits=2)
142+
+- BroadcastExchange IdentityBroadcastMode
143+
+- *Range (0, 3, step=1, splits=2)

0 commit comments

Comments
 (0)