Skip to content

Commit 88134e7

Browse files
dongjoon-hyuncloud-fan
authored andcommitted
[SPARK-16288][SQL] Implement inline table generating function
## What changes were proposed in this pull request? This PR implements `inline` table generating function. ## How was this patch tested? Pass the Jenkins tests with new testcase. Author: Dongjoon Hyun <[email protected]> Closes #13976 from dongjoon-hyun/SPARK-16288.
1 parent 54b27c1 commit 88134e7

File tree

5 files changed

+124
-36
lines changed

5 files changed

+124
-36
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ object FunctionRegistry {
165165
expression[Explode]("explode"),
166166
expression[Greatest]("greatest"),
167167
expression[If]("if"),
168+
expression[Inline]("inline"),
168169
expression[IsNaN]("isnan"),
169170
expression[IfNull]("ifnull"),
170171
expression[IsNull]("isnull"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,3 +195,38 @@ case class Explode(child: Expression) extends ExplodeBase(child, position = fals
195195
extended = "> SELECT _FUNC_(array(10,20));\n 0\t10\n 1\t20")
196196
// scalastyle:on line.size.limit
197197
case class PosExplode(child: Expression) extends ExplodeBase(child, position = true)
198+
199+
/**
200+
* Explodes an array of structs into a table.
201+
*/
202+
@ExpressionDescription(
203+
usage = "_FUNC_(a) - Explodes an array of structs into a table.",
204+
extended = "> SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b')));\n [1,a]\n [2,b]")
205+
case class Inline(child: Expression) extends UnaryExpression with Generator with CodegenFallback {
206+
207+
override def children: Seq[Expression] = child :: Nil
208+
209+
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
210+
case ArrayType(et, _) if et.isInstanceOf[StructType] =>
211+
TypeCheckResult.TypeCheckSuccess
212+
case _ =>
213+
TypeCheckResult.TypeCheckFailure(
214+
s"input to function $prettyName should be array of struct type, not ${child.dataType}")
215+
}
216+
217+
override def elementSchema: StructType = child.dataType match {
218+
case ArrayType(et : StructType, _) => et
219+
}
220+
221+
private lazy val numFields = elementSchema.fields.length
222+
223+
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
224+
val inputArray = child.eval(input).asInstanceOf[ArrayData]
225+
if (inputArray == null) {
226+
Nil
227+
} else {
228+
for (i <- 0 until inputArray.numElements())
229+
yield inputArray.getStruct(i, numFields)
230+
}
231+
}
232+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,53 +19,48 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.sql.catalyst.InternalRow
22-
import org.apache.spark.unsafe.types.UTF8String
22+
import org.apache.spark.sql.types._
2323

2424
class GeneratorExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
25-
private def checkTuple(actual: ExplodeBase, expected: Seq[InternalRow]): Unit = {
26-
assert(actual.eval(null).toSeq === expected)
25+
private def checkTuple(actual: Expression, expected: Seq[InternalRow]): Unit = {
26+
assert(actual.eval(null).asInstanceOf[TraversableOnce[InternalRow]].toSeq === expected)
2727
}
2828

29-
private final val int_array = Seq(1, 2, 3)
30-
private final val str_array = Seq("a", "b", "c")
29+
private final val empty_array = CreateArray(Seq.empty)
30+
private final val int_array = CreateArray(Seq(1, 2, 3).map(Literal(_)))
31+
private final val str_array = CreateArray(Seq("a", "b", "c").map(Literal(_)))
3132

3233
test("explode") {
33-
val int_correct_answer = Seq(Seq(1), Seq(2), Seq(3))
34-
val str_correct_answer = Seq(
35-
Seq(UTF8String.fromString("a")),
36-
Seq(UTF8String.fromString("b")),
37-
Seq(UTF8String.fromString("c")))
34+
val int_correct_answer = Seq(create_row(1), create_row(2), create_row(3))
35+
val str_correct_answer = Seq(create_row("a"), create_row("b"), create_row("c"))
3836

39-
checkTuple(
40-
Explode(CreateArray(Seq.empty)),
41-
Seq.empty)
37+
checkTuple(Explode(empty_array), Seq.empty)
38+
checkTuple(Explode(int_array), int_correct_answer)
39+
checkTuple(Explode(str_array), str_correct_answer)
40+
}
4241

43-
checkTuple(
44-
Explode(CreateArray(int_array.map(Literal(_)))),
45-
int_correct_answer.map(InternalRow.fromSeq(_)))
42+
test("posexplode") {
43+
val int_correct_answer = Seq(create_row(0, 1), create_row(1, 2), create_row(2, 3))
44+
val str_correct_answer = Seq(create_row(0, "a"), create_row(1, "b"), create_row(2, "c"))
4645

47-
checkTuple(
48-
Explode(CreateArray(str_array.map(Literal(_)))),
49-
str_correct_answer.map(InternalRow.fromSeq(_)))
46+
checkTuple(PosExplode(CreateArray(Seq.empty)), Seq.empty)
47+
checkTuple(PosExplode(int_array), int_correct_answer)
48+
checkTuple(PosExplode(str_array), str_correct_answer)
5049
}
5150

52-
test("posexplode") {
53-
val int_correct_answer = Seq(Seq(0, 1), Seq(1, 2), Seq(2, 3))
54-
val str_correct_answer = Seq(
55-
Seq(0, UTF8String.fromString("a")),
56-
Seq(1, UTF8String.fromString("b")),
57-
Seq(2, UTF8String.fromString("c")))
51+
test("inline") {
52+
val correct_answer = Seq(create_row(0, "a"), create_row(1, "b"), create_row(2, "c"))
5853

5954
checkTuple(
60-
PosExplode(CreateArray(Seq.empty)),
55+
Inline(Literal.create(Array(), ArrayType(new StructType().add("id", LongType)))),
6156
Seq.empty)
6257

6358
checkTuple(
64-
PosExplode(CreateArray(int_array.map(Literal(_)))),
65-
int_correct_answer.map(InternalRow.fromSeq(_)))
66-
67-
checkTuple(
68-
PosExplode(CreateArray(str_array.map(Literal(_)))),
69-
str_correct_answer.map(InternalRow.fromSeq(_)))
59+
Inline(CreateArray(Seq(
60+
CreateStruct(Seq(Literal(0), Literal("a"))),
61+
CreateStruct(Seq(Literal(1), Literal("b"))),
62+
CreateStruct(Seq(Literal(2), Literal("c")))
63+
))),
64+
correct_answer)
7065
}
7166
}

sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,4 +89,64 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
8989
exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")),
9090
Row(3) :: Nil)
9191
}
92+
93+
test("inline raises exception on array of null type") {
94+
val m = intercept[AnalysisException] {
95+
spark.range(2).selectExpr("inline(array())")
96+
}.getMessage
97+
assert(m.contains("data type mismatch"))
98+
}
99+
100+
test("inline with empty table") {
101+
checkAnswer(
102+
spark.range(0).selectExpr("inline(array(struct(10, 100)))"),
103+
Nil)
104+
}
105+
106+
test("inline on literal") {
107+
checkAnswer(
108+
spark.range(2).selectExpr("inline(array(struct(10, 100), struct(20, 200), struct(30, 300)))"),
109+
Row(10, 100) :: Row(20, 200) :: Row(30, 300) ::
110+
Row(10, 100) :: Row(20, 200) :: Row(30, 300) :: Nil)
111+
}
112+
113+
test("inline on column") {
114+
val df = Seq((1, 2)).toDF("a", "b")
115+
116+
checkAnswer(
117+
df.selectExpr("inline(array(struct(a), struct(a)))"),
118+
Row(1) :: Row(1) :: Nil)
119+
120+
checkAnswer(
121+
df.selectExpr("inline(array(struct(a, b), struct(a, b)))"),
122+
Row(1, 2) :: Row(1, 2) :: Nil)
123+
124+
// Spark think [struct<a:int>, struct<b:int>] is heterogeneous due to name difference.
125+
val m = intercept[AnalysisException] {
126+
df.selectExpr("inline(array(struct(a), struct(b)))")
127+
}.getMessage
128+
assert(m.contains("data type mismatch"))
129+
130+
checkAnswer(
131+
df.selectExpr("inline(array(struct(a), named_struct('a', b)))"),
132+
Row(1) :: Row(2) :: Nil)
133+
134+
// Spark think [struct<a:int>, struct<col1:int>] is heterogeneous due to name difference.
135+
val m2 = intercept[AnalysisException] {
136+
df.selectExpr("inline(array(struct(a), struct(2)))")
137+
}.getMessage
138+
assert(m2.contains("data type mismatch"))
139+
140+
checkAnswer(
141+
df.selectExpr("inline(array(struct(a), named_struct('a', 2)))"),
142+
Row(1) :: Row(2) :: Nil)
143+
144+
checkAnswer(
145+
df.selectExpr("struct(a)").selectExpr("inline(array(*))"),
146+
Row(1) :: Nil)
147+
148+
checkAnswer(
149+
df.selectExpr("array(struct(a), named_struct('a', b))").selectExpr("inline(*)"),
150+
Row(1) :: Row(2) :: Nil)
151+
}
92152
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,9 +241,6 @@ private[sql] class HiveSessionCatalog(
241241
"hash", "java_method", "histogram_numeric",
242242
"parse_url", "percentile", "percentile_approx", "reflect", "sentences", "stack", "str_to_map",
243243
"xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long",
244-
"xpath_number", "xpath_short", "xpath_string",
245-
246-
// table generating function
247-
"inline"
244+
"xpath_number", "xpath_short", "xpath_string"
248245
)
249246
}

0 commit comments

Comments
 (0)