Skip to content

Commit 4e25fcb

Browse files
Adding resolution of complex ArrayTypes
1 parent f8f8911 commit 4e25fcb

File tree

5 files changed

+167
-22
lines changed

5 files changed

+167
-22
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
9090
| failure("illegal character")
9191
)
9292

93-
override def identChar = letter | elem('.') | elem('_')
93+
override def identChar = letter | elem('.') | elem('_') | elem('[') | elem(']')
9494

9595
override def whitespace: Parser[Any] = rep(
9696
whitespaceChar
@@ -390,6 +390,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
390390
FALSE ^^^ Literal(false, BooleanType) |
391391
cast |
392392
"(" ~> expression <~ ")" |
393+
"[" ~> literal <~ "]" |
393394
function |
394395
"-" ~> literal ^^ UnaryMinus |
395396
ident ^^ UnresolvedAttribute |

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical
2020
import org.apache.spark.sql.catalyst.errors.TreeNodeException
2121
import org.apache.spark.sql.catalyst.expressions._
2222
import org.apache.spark.sql.catalyst.plans.QueryPlan
23-
import org.apache.spark.sql.catalyst.types.StructType
23+
import org.apache.spark.sql.catalyst.types.{DataType, ArrayType, StructType}
2424
import org.apache.spark.sql.catalyst.trees
2525

2626
abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
@@ -54,9 +54,41 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
5454
/**
5555
* Optionally resolves the given string to a
5656
* [[catalyst.expressions.NamedExpression NamedExpression]]. The attribute is expressed as
57-
* as string in the following form: `[scope].AttributeName.[nested].[fields]...`.
57+
* as string in the following form: `[scope].AttributeName.[nested].[fields]...`. Fields
58+
* can contain ordinal expressions, such as `field[i][j][k]...`.
5859
*/
5960
def resolve(name: String): Option[NamedExpression] = {
61+
def expandFunc(expType: (Expression, DataType), field: String): (Expression, DataType) = {
62+
val (exp, t) = expType
63+
val ordinalRegExp = """(\[(\d+)\])""".r
64+
val fieldName = if (field.matches("\\w*(\\[\\d\\])+")) {
65+
field.substring(0, field.indexOf("["))
66+
} else {
67+
field
68+
}
69+
t match {
70+
case ArrayType(elementType) =>
71+
val ordinals = ordinalRegExp.findAllIn(field).matchData.map(_.group(2))
72+
(ordinals.foldLeft(exp)((v1: Expression, v2: String) => GetItem(v1, Literal(v2.toInt))), elementType)
73+
case StructType(fields) =>
74+
// Note: this only works if we are not on the top-level!
75+
val structField = fields.find(_.name == fieldName)
76+
if (!structField.isDefined) {
77+
throw new TreeNodeException(
78+
this, s"Trying to resolve Attribute but field ${fieldName} is not defined")
79+
}
80+
structField.get.dataType match {
81+
case ArrayType(elementType) =>
82+
val ordinals = ordinalRegExp.findAllIn(field).matchData.map(_.group(2))
83+
(ordinals.foldLeft(GetField(exp, fieldName).asInstanceOf[Expression])((v1: Expression, v2: String) => GetItem(v1, Literal(v2.toInt))), elementType)
84+
case _ =>
85+
(GetField(exp, fieldName), structField.get.dataType)
86+
}
87+
case _ =>
88+
expType
89+
}
90+
}
91+
6092
val parts = name.split("\\.")
6193
// Collect all attributes that are output by this nodes children where either the first part
6294
// matches the name or where the first part matches the scope and the second part matches the
@@ -67,16 +99,40 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
6799
val remainingParts =
68100
if (option.qualifiers.contains(parts.head) && parts.size > 1) parts.drop(1) else parts
69101
if (option.name == remainingParts.head) (option, remainingParts.tail.toList) :: Nil else Nil
102+
// TODO from rebase!
103+
/*val remainingParts = if (option.qualifiers contains parts.head) parts.drop(1) else parts
104+
val relevantRemaining =
105+
if (remainingParts.head.matches("\\w*\\[(\\d+)\\]")) { // array field name
106+
remainingParts.head.substring(0, remainingParts.head.indexOf("["))
107+
} else {
108+
remainingParts.head
109+
}
110+
if (option.name == relevantRemaining) (option, remainingParts.tail.toList) :: Nil else Nil*/
70111
}
71112

72113
options.distinct match {
73-
case (a, Nil) :: Nil => Some(a) // One match, no nested fields, use it.
114+
case (a, Nil) :: Nil => {
115+
a.dataType match {
116+
case ArrayType(elementType) =>
117+
val expression = expandFunc((a: Expression, a.dataType), name)._1
118+
Some(Alias(expression, name)())
119+
case _ => Some(a)
120+
}
121+
} // One match, no nested fields, use it.
74122
// One match, but we also need to extract the requested nested field.
75123
case (a, nestedFields) :: Nil =>
76124
a.dataType match {
77125
case StructType(fields) =>
78-
Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)())
79-
case _ => None // Don't know how to resolve these field references
126+
// this is compatibility reasons with earlier code! TODO: why only nestedFields and not parts?
127+
if ((parts(0) :: nestedFields).forall(!_.matches("\\w*\\[\\d+\\]+"))) { // not nested arrays, only fields
128+
Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)())
129+
} else {
130+
val expression = parts.foldLeft((a: Expression, a.dataType))(expandFunc)._1
131+
Some(Alias(expression, nestedFields.last)())
132+
}
133+
case _ =>
134+
val expression = parts.foldLeft((a: Expression, a.dataType))(expandFunc)._1
135+
Some(Alias(expression, nestedFields.last)())
80136
}
81137
case Nil => None // No matches.
82138
case ambiguousReferences =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,15 @@ abstract class DataType {
2929
case e: Expression if e.dataType == this => true
3030
case _ => false
3131
}
32+
33+
def isPrimitive(): Boolean = false
3234
}
3335

3436
case object NullType extends DataType
3537

36-
trait PrimitiveType
38+
trait PrimitiveType extends DataType {
39+
override def isPrimitive() = true
40+
}
3741

3842
abstract class NativeType extends DataType {
3943
type JvmType

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ private[sql] object ParquetTestData {
139139
|optional group longs {
140140
|repeated int64 values;
141141
|}
142-
|required group booleanNumberPairs {
142+
|repeated group entries {
143143
|required double value;
144144
|optional boolean truth;
145145
|}
@@ -153,8 +153,23 @@ private[sql] object ParquetTestData {
153153
|}
154154
""".stripMargin
155155

156+
val testNestedSchema3 =
157+
"""
158+
|message TestNested3 {
159+
|required int32 x;
160+
|repeated group booleanNumberPairs {
161+
|required int32 key;
162+
|repeated group value {
163+
|required double nestedValue;
164+
|optional boolean truth;
165+
|}
166+
|}
167+
|}
168+
""".stripMargin
169+
156170
val testNestedDir1 = Utils.createTempDir()
157171
val testNestedDir2 = Utils.createTempDir()
172+
val testNestedDir3 = Utils.createTempDir()
158173

159174
lazy val testNestedData1 = new ParquetRelation(testNestedDir1.toURI.toString)
160175
lazy val testNestedData2 = new ParquetRelation(testNestedDir2.toURI.toString)
@@ -285,6 +300,32 @@ private[sql] object ParquetTestData {
285300
writer.close()
286301
}
287302

303+
def writeNestedFile3() {
304+
testNestedDir3.delete()
305+
val path: Path = testNestedDir3
306+
val schema: MessageType = MessageTypeParser.parseMessageType(testNestedSchema3)
307+
308+
val r1 = new SimpleGroup(schema)
309+
r1.add(0, 1)
310+
val g1 = r1.addGroup(1)
311+
g1.add(0, 1)
312+
val ng1 = g1.addGroup(1)
313+
ng1.add(0, 1.5)
314+
ng1.add(1, false)
315+
val ng2 = g1.addGroup(1)
316+
ng2.add(0, 2.5)
317+
ng2.add(1, true)
318+
val g2 = r1.addGroup(1)
319+
g2.add(0, 2)
320+
val ng3 = g2.addGroup(1)
321+
ng3.add(0, 3.5)
322+
ng3.add(1, false)
323+
324+
val writeSupport = new TestGroupWriteSupport(schema)
325+
val writer = new ParquetWriter[Group](path, writeSupport)
326+
writer.write(r1)
327+
writer.close()
328+
}
288329

289330
def readNestedFile(path: File, schemaString: String): Unit = {
290331
val configuration = new Configuration()

sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ import org.apache.spark.sql.SchemaRDD
3434
import org.apache.spark.sql.catalyst.expressions._
3535
import org.apache.spark.sql.catalyst.types.IntegerType
3636
import org.apache.spark.util.Utils
37+
import org.apache.spark.sql.catalyst.types.{StringType, IntegerType, DataType}
38+
import org.apache.spark.sql.{parquet, SchemaRDD}
39+
import org.apache.spark.sql.catalyst.expressions.AttributeReference
40+
import scala.Tuple2
41+
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
3742

3843
// Implicits
3944
import org.apache.spark.sql.test.TestSQLContext._
@@ -432,9 +437,9 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
432437
assert(result(0)(2)(0) === (1.toLong << 32))
433438
assert(result(0)(2)(1) === (1.toLong << 33))
434439
assert(result(0)(2)(2) === (1.toLong << 34))
435-
assert(result(0)(3).size === 2)
436-
assert(result(0)(3)(0) === 2.5)
437-
assert(result(0)(3)(1) === false)
440+
assert(result(0)(3)(0).size === 2)
441+
assert(result(0)(3)(0)(0) === 2.5)
442+
assert(result(0)(3)(0)(1) === false)
438443
assert(result(0)(4).size === 2)
439444
assert(result(0)(4)(0).size === 2)
440445
assert(result(0)(4)(1).size === 1)
@@ -452,23 +457,61 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
452457
assert(tmp(0)(0) === "Julien Le Dem")
453458
}
454459

460+
test("Projection in addressbook") {
461+
implicit def anyToRow(value: Any): Row = value.asInstanceOf[Row]
462+
val data = TestSQLContext
463+
.parquetFile(ParquetTestData.testNestedDir1.toString)
464+
.toSchemaRDD
465+
data.registerAsTable("data")
466+
val tmp = sql("SELECT owner, contacts[1].name FROM data").collect()
467+
assert(tmp.size === 2)
468+
assert(tmp(0).size === 2)
469+
assert(tmp(0)(0) === "Julien Le Dem")
470+
assert(tmp(0)(1) === "Chris Aniszczyk")
471+
assert(tmp(1)(0) === "A. Nonymous")
472+
assert(tmp(1)(1) === null)
473+
}
474+
455475
test("Simple query on nested int data") {
456476
implicit def anyToRow(value: Any): Row = value.asInstanceOf[Row]
457477
val data = TestSQLContext
458478
.parquetFile(ParquetTestData.testNestedDir2.toString)
459479
.toSchemaRDD
460480
data.registerAsTable("data")
461-
val tmp = sql("SELECT booleanNumberPairs.value, booleanNumberPairs.truth FROM data").collect()
462-
assert(tmp(0)(0) === 2.5)
463-
assert(tmp(0)(1) === false)
464-
val result = sql("SELECT outerouter FROM data").collect()
465-
// TODO: why does this not work?
466-
//val result = sql("SELECT outerouter.values FROM data").collect()
467-
// TODO: .. or this:
468-
// val result = sql("SELECT outerouter[0] FROM data").collect()
469-
assert(result(0)(0)(0)(0)(0) === 7)
470-
assert(result(0)(0)(0)(1)(0) === 8)
471-
assert(result(0)(0)(1)(0)(0) === 9)
481+
val result1 = sql("SELECT entries[0].value FROM data").collect()
482+
assert(result1.size === 1)
483+
assert(result1(0).size === 1)
484+
assert(result1(0)(0) === 2.5)
485+
val result2 = sql("SELECT entries[0] FROM data").collect()
486+
assert(result2.size === 1)
487+
assert(result2(0)(0).size === 2)
488+
assert(result2(0)(0)(0) === 2.5)
489+
assert(result2(0)(0)(1) === false)
490+
val result3 = sql("SELECT outerouter FROM data").collect()
491+
assert(result3(0)(0)(0)(0)(0) === 7)
492+
assert(result3(0)(0)(0)(1)(0) === 8)
493+
assert(result3(0)(0)(1)(0)(0) === 9)
494+
}
495+
496+
test("nested structs") {
497+
implicit def anyToRow(value: Any): Row = value.asInstanceOf[Row]
498+
ParquetTestData.writeNestedFile3()
499+
val data = TestSQLContext
500+
.parquetFile(ParquetTestData.testNestedDir3.toString)
501+
.toSchemaRDD
502+
data.registerAsTable("data")
503+
val result1 = sql("SELECT booleanNumberPairs[0].value[0].truth FROM data").collect()
504+
assert(result1.size === 1)
505+
assert(result1(0).size === 1)
506+
assert(result1(0)(0) === false)
507+
val result2 = sql("SELECT booleanNumberPairs[0].value[1].truth FROM data").collect()
508+
assert(result2.size === 1)
509+
assert(result2(0).size === 1)
510+
assert(result2(0)(0) === true)
511+
val result3 = sql("SELECT booleanNumberPairs[1].value[0].truth FROM data").collect()
512+
assert(result3.size === 1)
513+
assert(result3(0).size === 1)
514+
assert(result3(0)(0) === false)
472515
}
473516

474517
/**

0 commit comments

Comments
 (0)