Skip to content

Commit 3e1456c

Browse files
marmbrusAndreSchumacher
authored andcommitted
WIP: Directly serialize catalyst attributes.
1 parent f7aeba3 commit 3e1456c

File tree

5 files changed

+110
-23
lines changed

5 files changed

+110
-23
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,86 @@ package org.apache.spark.sql.catalyst.types
1919

2020
import java.sql.Timestamp
2121

22+
import scala.util.parsing.combinator.RegexParsers
23+
2224
import scala.reflect.ClassTag
2325
import scala.reflect.runtime.universe.{typeTag, TypeTag, runtimeMirror}
2426

25-
import org.apache.spark.sql.catalyst.expressions.Expression
27+
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
2628
import org.apache.spark.util.Utils
2729

30+
/**
31+
*
32+
*/
33+
object DataType extends RegexParsers {
34+
protected lazy val primitiveType: Parser[DataType] =
35+
"StringType" ^^^ StringType |
36+
"FloatType" ^^^ FloatType |
37+
"IntegerType" ^^^ IntegerType |
38+
"ByteType" ^^^ ByteType |
39+
"ShortType" ^^^ ShortType |
40+
"DoubleType" ^^^ DoubleType |
41+
"LongType" ^^^ LongType |
42+
"BinaryType" ^^^ BinaryType |
43+
"BooleanType" ^^^ BooleanType |
44+
"DecimalType" ^^^ DecimalType |
45+
"TimestampType" ^^^ TimestampType
46+
47+
protected lazy val arrayType: Parser[DataType] =
48+
"ArrayType" ~> "(" ~> dataType <~ ")" ^^ ArrayType
49+
50+
protected lazy val mapType: Parser[DataType] =
51+
"MapType" ~> "(" ~> dataType ~ "," ~ dataType <~ ")" ^^ {
52+
case t1 ~ _ ~ t2 => MapType(t1, t2)
53+
}
54+
55+
protected lazy val structField: Parser[StructField] =
56+
("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ {
57+
case name ~ tpe ~ nullable =>
58+
StructField(name, tpe, nullable = nullable)
59+
}
60+
61+
protected lazy val boolVal: Parser[Boolean] =
62+
"true" ^^^ true |
63+
"false" ^^^ false
64+
65+
66+
protected lazy val structType: Parser[DataType] =
67+
"StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ {
68+
case fields => new StructType(fields)
69+
}
70+
71+
protected lazy val dataType: Parser[DataType] =
72+
arrayType |
73+
mapType |
74+
structType |
75+
primitiveType
76+
77+
/**
78+
* Parses a string representation of a DataType.
79+
*
80+
* TODO: Generate parser as pickler...
81+
*/
82+
def apply(asString: String): DataType = parseAll(dataType, asString) match {
83+
case Success(result, _) => result
84+
case failure: NoSuccess => sys.error(s"Unsupported dataType: $asString, $failure")
85+
}
86+
}
87+
2888
abstract class DataType {
2989
/** Matches any expression that evaluates to this DataType */
3090
def unapply(a: Expression): Boolean = a match {
3191
case e: Expression if e.dataType == this => true
3292
case _ => false
3393
}
3494

35-
def isPrimitive(): Boolean = false
95+
def isPrimitive: Boolean = false
3696
}
3797

3898
case object NullType extends DataType
3999

40100
trait PrimitiveType extends DataType {
41-
override def isPrimitive() = true
101+
override def isPrimitive = true
42102
}
43103

44104
abstract class NativeType extends DataType {
@@ -167,6 +227,17 @@ case object FloatType extends FractionalType {
167227
case class ArrayType(elementType: DataType) extends DataType
168228

169229
case class StructField(name: String, dataType: DataType, nullable: Boolean)
170-
case class StructType(fields: Seq[StructField]) extends DataType
230+
231+
object StructType {
232+
def fromAttributes(attributes: Seq[Attribute]): StructType = {
233+
StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable)))
234+
}
235+
236+
//def apply(fields: Seq[StructField]) = new StructType(fields.toIndexedSeq)
237+
}
238+
239+
case class StructType(fields: Seq[StructField]) extends DataType {
240+
def toAttributes = fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)())
241+
}
171242

172243
case class MapType(keyType: DataType, valueType: DataType) extends DataType

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import parquet.schema.MessageType
3636
import org.apache.spark.{Logging, SerializableWritable, SparkContext, TaskContext}
3737
import org.apache.spark.rdd.RDD
3838
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row}
39+
import org.apache.spark.sql.catalyst.types.StructType
3940
import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode}
4041

4142
/**
@@ -167,7 +168,7 @@ case class InsertIntoParquetTable(
167168
val job = new Job(sc.hadoopConfiguration)
168169

169170
val writeSupport =
170-
if (child.output.map(_.dataType).forall(_.isPrimitive())) {
171+
if (child.output.map(_.dataType).forall(_.isPrimitive)) {
171172
logger.debug("Initializing MutableRowWriteSupport")
172173
classOf[org.apache.spark.sql.parquet.MutableRowWriteSupport]
173174
} else {
@@ -178,7 +179,7 @@ case class InsertIntoParquetTable(
178179

179180
// TODO: move that to function in object
180181
val conf = ContextUtil.getConfiguration(job)
181-
conf.set(RowWriteSupport.PARQUET_ROW_SCHEMA, relation.parquetSchema.toString)
182+
conf.set(RowWriteSupport.PARQUET_ROW_SCHEMA, StructType.fromAttributes(relation.output).toString)
182183

183184
val fspath = new Path(relation.path)
184185
val fs = fspath.getFileSystem(conf)

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -82,39 +82,35 @@ private[parquet] object RowReadSupport {
8282
* A `parquet.hadoop.api.WriteSupport` for Row ojects.
8383
*/
8484
private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
85-
def setSchema(schema: MessageType, configuration: Configuration) {
86-
// for testing
87-
this.schema = schema
88-
// TODO: could use Attributes themselves instead of Parquet schema?
85+
86+
87+
def setSchema(schema: Seq[Attribute], configuration: Configuration) {
8988
configuration.set(
9089
RowWriteSupport.PARQUET_ROW_SCHEMA,
91-
schema.toString)
90+
StructType.fromAttributes(schema).toString)
9291
configuration.set(
9392
ParquetOutputFormat.WRITER_VERSION,
9493
ParquetProperties.WriterVersion.PARQUET_1_0.toString)
9594
}
9695

97-
def getSchema(configuration: Configuration): MessageType = {
98-
MessageTypeParser.parseMessageType(configuration.get(RowWriteSupport.PARQUET_ROW_SCHEMA))
99-
}
100-
101-
private[parquet] var schema: MessageType = null
10296
private[parquet] var writer: RecordConsumer = null
10397
private[parquet] var attributes: Seq[Attribute] = null
10498

10599
override def init(configuration: Configuration): WriteSupport.WriteContext = {
106-
schema = if (schema == null) getSchema(configuration) else schema
107-
attributes = ParquetTypesConverter.convertToAttributes(schema)
108-
log.debug(s"write support initialized for requested schema $schema")
100+
attributes = DataType(configuration.get(RowWriteSupport.PARQUET_ROW_SCHEMA)) match {
101+
case s: StructType => s.toAttributes
102+
case other => sys.error(s"Can convert $attributes to row")
103+
}
104+
log.debug(s"write support initialized for requested schema $attributes")
109105
ParquetRelation.enableLogForwarding()
110106
new WriteSupport.WriteContext(
111-
schema,
107+
ParquetTypesConverter.convertFromAttributes(attributes),
112108
new java.util.HashMap[java.lang.String, java.lang.String]())
113109
}
114110

115111
override def prepareForWrite(recordConsumer: RecordConsumer): Unit = {
116112
writer = recordConsumer
117-
log.debug(s"preparing for write with schema $schema")
113+
log.debug(s"preparing for write with schema $attributes")
118114
}
119115

120116
override def write(record: Row): Unit = {

sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,17 @@ case class Nested(i: Int, s: String)
6767

6868
case class Data(array: Seq[Int], nested: Nested)
6969

70-
class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll {
70+
case class AllDataTypes(
71+
stringField: String,
72+
intField: Int,
73+
longField: Long,
74+
floatField: Float,
75+
doubleField: Double,
76+
shortField: Short,
77+
byteField: Byte,
78+
booleanField: Boolean)
79+
80+
class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll {
7181
import TestData._
7282
TestData // Load test data tables.
7383

@@ -100,6 +110,13 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
100110
// here we should also unregister the table??
101111
}
102112

113+
test("Read/Write All Types") {
114+
val data = AllDataTypes("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true)
115+
val tempDir = getTempFilePath("parquetTest").getCanonicalPath
116+
sparkContext.parallelize(data :: Nil).saveAsParquetFile(tempDir)
117+
assert(parquetFile(tempDir).collect().head === data)
118+
}
119+
103120
test("self-join parquet files") {
104121
val x = ParquetTestData.testData.as('x)
105122
val y = ParquetTestData.testData.as('y)

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,9 @@ object HiveMetastoreTypes extends RegexParsers {
208208
}
209209

210210
protected lazy val structType: Parser[DataType] =
211-
"struct" ~> "<" ~> repsep(structField,",") <~ ">" ^^ StructType
211+
"struct" ~> "<" ~> repsep(structField,",") <~ ">" ^^ {
212+
case fields => new StructType(fields)
213+
}
212214

213215
protected lazy val dataType: Parser[DataType] =
214216
arrayType |

0 commit comments

Comments
 (0)