Skip to content

Commit 32229c7

Browse files
Removing Row nested values and placing by generic types
1 parent 0ae9376 commit 32229c7

File tree

5 files changed

+136
-143
lines changed

5 files changed

+136
-143
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala

Lines changed: 0 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -206,68 +206,6 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow {
206206
override def copy() = new GenericRow(values.clone())
207207
}
208208

209-
// TODO: this is an awful lot of code duplication. If values would be covariant we could reuse
210-
// much of GenericRow
211-
class NativeRow[T](protected[catalyst] val values: Array[T]) extends Row {
212-
213-
/** No-arg constructor for serialization. */
214-
def this() = this(null)
215-
216-
def this(elementType: NativeType, size: Int) =
217-
this(elementType.classTag.newArray(size).asInstanceOf[Array[T]])
218-
219-
def iterator = values.iterator
220-
221-
def length = values.length
222-
223-
def apply(i: Int) = values(i)
224-
225-
def isNullAt(i: Int) = values(i) == null
226-
227-
def getInt(i: Int): Int = {
228-
if (values(i) == null) sys.error("Failed to check null bit for primitive int value.")
229-
values(i).asInstanceOf[Int]
230-
}
231-
232-
def getLong(i: Int): Long = {
233-
if (values(i) == null) sys.error("Failed to check null bit for primitive long value.")
234-
values(i).asInstanceOf[Long]
235-
}
236-
237-
def getDouble(i: Int): Double = {
238-
if (values(i) == null) sys.error("Failed to check null bit for primitive double value.")
239-
values(i).asInstanceOf[Double]
240-
}
241-
242-
def getFloat(i: Int): Float = {
243-
if (values(i) == null) sys.error("Failed to check null bit for primitive float value.")
244-
values(i).asInstanceOf[Float]
245-
}
246-
247-
def getBoolean(i: Int): Boolean = {
248-
if (values(i) == null) sys.error("Failed to check null bit for primitive boolean value.")
249-
values(i).asInstanceOf[Boolean]
250-
}
251-
252-
def getShort(i: Int): Short = {
253-
if (values(i) == null) sys.error("Failed to check null bit for primitive short value.")
254-
values(i).asInstanceOf[Short]
255-
}
256-
257-
def getByte(i: Int): Byte = {
258-
if (values(i) == null) sys.error("Failed to check null bit for primitive byte value.")
259-
values(i).asInstanceOf[Byte]
260-
}
261-
262-
def getString(i: Int): String = {
263-
if (values(i) == null) sys.error("Failed to check null bit for primitive String value.")
264-
values(i).asInstanceOf[String]
265-
}
266-
267-
def copy() = this
268-
}
269-
270-
271209
class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] {
272210
def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
273211
this(ordering.map(BindReferences.bindReference(_, inputSchema)))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
5050
null
5151
} else {
5252
if (child.dataType.isInstanceOf[ArrayType]) {
53-
val baseValue = value.asInstanceOf[Seq[_]]
53+
val baseValue = value.asInstanceOf[Array[_]]
5454
val o = key.asInstanceOf[Int]
5555
if (o >= baseValue.size || o < 0) {
5656
null
@@ -92,7 +92,7 @@ case class GetField(child: Expression, fieldName: String) extends UnaryExpressio
9292
override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[StructType]
9393

9494
override def eval(input: Row): Any = {
95-
val baseValue = child.eval(input).asInstanceOf[Row]
95+
val baseValue = child.eval(input).asInstanceOf[Seq[_]]
9696
if (baseValue == null) null else baseValue(ordinal)
9797
}
9898

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter}
2323
import parquet.schema.MessageType
2424

2525
import org.apache.spark.sql.catalyst.types._
26-
import org.apache.spark.sql.catalyst.expressions.{NativeRow, GenericRow, Row, Attribute}
26+
import org.apache.spark.sql.catalyst.expressions.{GenericRow, Row, Attribute}
2727
import org.apache.spark.sql.parquet.CatalystConverter.FieldType
2828

2929
/**
30-
* Collection of converters of Parquet types (Group and primitive types) that
31-
* model arrays and maps. The convertions are partly based on the AvroParquet
30+
* Collection of converters of Parquet types (group and primitive types) that
31+
* model arrays and maps. The conversions are partly based on the AvroParquet
3232
* converters that are part of Parquet in order to be able to process these
3333
* types.
3434
*
@@ -51,7 +51,7 @@ import org.apache.spark.sql.parquet.CatalystConverter.FieldType
5151
* </ul>
5252
*/
5353

54-
private[parquet] object CatalystConverter {
54+
private[sql] object CatalystConverter {
5555
// The type internally used for fields
5656
type FieldType = StructField
5757

@@ -63,6 +63,10 @@ private[parquet] object CatalystConverter {
6363
val MAP_VALUE_SCHEMA_NAME = "value"
6464
val MAP_SCHEMA_NAME = "map"
6565

66+
type ArrayScalaType[T] = Array[T]
67+
type StructScalaType[T] = Seq[T]
68+
type MapScalaType[K, V] = Map[K, V]
69+
6670
protected[parquet] def createConverter(
6771
field: FieldType,
6872
fieldIndex: Int,
@@ -325,7 +329,6 @@ private[parquet] class CatalystPrimitiveRowConverter(
325329
private[parquet] class CatalystPrimitiveConverter(
326330
parent: CatalystConverter,
327331
fieldIndex: Int) extends PrimitiveConverter {
328-
// TODO: consider refactoring these together with ParquetTypesConverter
329332
override def addBinary(value: Binary): Unit =
330333
parent.updateBinary(fieldIndex, value)
331334

@@ -404,6 +407,9 @@ private[parquet] class CatalystArrayConverter(
404407

405408
override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = {
406409
// fieldIndex is ignored (assumed to be zero but not checked)
410+
if(value == null) {
411+
throw new IllegalArgumentException("Null values inside Parquet arrays are not supported!")
412+
}
407413
buffer += value
408414
}
409415

@@ -419,7 +425,8 @@ private[parquet] class CatalystArrayConverter(
419425

420426
override def end(): Unit = {
421427
assert(parent != null)
422-
parent.updateField(index, new GenericRow(buffer.toArray))
428+
// here we need to make sure to use ArrayScalaType
429+
parent.updateField(index, buffer.toArray)
423430
clearBuffer()
424431
}
425432
}
@@ -444,7 +451,8 @@ private[parquet] class CatalystNativeArrayConverter(
444451

445452
type nativeType = elementType.JvmType
446453

447-
private var buffer: Array[nativeType] = elementType.classTag.newArray(capacity)
454+
private var buffer: CatalystConverter.ArrayScalaType[nativeType] =
455+
elementType.classTag.newArray(capacity)
448456

449457
private var elements: Int = 0
450458

@@ -515,16 +523,18 @@ private[parquet] class CatalystNativeArrayConverter(
515523

516524
override def end(): Unit = {
517525
assert(parent != null)
526+
// here we need to make sure to use ArrayScalaType
518527
parent.updateField(
519528
index,
520-
new NativeRow[nativeType](buffer.slice(0, elements)))
529+
buffer.slice(0, elements))
521530
clearBuffer()
522531
}
523532

524533
private def checkGrowBuffer(): Unit = {
525534
if (elements >= capacity) {
526535
val newCapacity = 2 * capacity
527-
val tmp: Array[nativeType] = elementType.classTag.newArray(newCapacity)
536+
val tmp: CatalystConverter.ArrayScalaType[nativeType] =
537+
elementType.classTag.newArray(newCapacity)
528538
Array.copy(buffer, 0, tmp, 0, capacity)
529539
buffer = tmp
530540
capacity = newCapacity
@@ -552,8 +562,10 @@ private[parquet] class CatalystStructConverter(
552562
// TODO: think about reusing the buffer
553563
override def end(): Unit = {
554564
assert(!isRootConverter)
555-
// TODO: use iterators if possible, avoid Row wrapping!
556-
parent.updateField(index, new GenericRow(current.toArray))
565+
// here we need to make sure to use StructScalaType
566+
// Note: we need to actually make a copy of the array since we
567+
// may be in a nested field
568+
parent.updateField(index, current.toArray.toSeq)
557569
}
558570
}
559571

@@ -619,6 +631,7 @@ private[parquet] class CatalystMapConverter(
619631
}
620632

621633
override def end(): Unit = {
634+
// here we need to make sure to use MapScalaType
622635
parent.updateField(index, map.toMap)
623636
}
624637

@@ -627,6 +640,3 @@ private[parquet] class CatalystMapConverter(
627640
override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit =
628641
throw new UnsupportedOperationException
629642
}
630-
631-
632-

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,15 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
140140
private[parquet] def writeValue(schema: DataType, value: Any): Unit = {
141141
if (value != null && value != Nil) {
142142
schema match {
143-
case t @ ArrayType(_) => writeArray(t, value.asInstanceOf[Row])
144-
case t @ MapType(_, _) => writeMap(t, value.asInstanceOf[Map[Any, Any]])
145-
case t @ StructType(_) => writeStruct(t, value.asInstanceOf[Row])
143+
case t @ ArrayType(_) => writeArray(
144+
t,
145+
value.asInstanceOf[CatalystConverter.ArrayScalaType[_]])
146+
case t @ MapType(_, _) => writeMap(
147+
t,
148+
value.asInstanceOf[CatalystConverter.MapScalaType[_, _]])
149+
case t @ StructType(_) => writeStruct(
150+
t,
151+
value.asInstanceOf[CatalystConverter.StructScalaType[_]])
146152
case _ => writePrimitive(schema.asInstanceOf[PrimitiveType], value)
147153
}
148154
}
@@ -166,7 +172,9 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
166172
}
167173
}
168174

169-
private[parquet] def writeStruct(schema: StructType, struct: Row): Unit = {
175+
private[parquet] def writeStruct(
176+
schema: StructType,
177+
struct: CatalystConverter.StructScalaType[_]): Unit = {
170178
if (struct != null && struct != Nil) {
171179
val fields = schema.fields.toArray
172180
writer.startGroup()
@@ -183,7 +191,11 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
183191
}
184192
}
185193

186-
private[parquet] def writeArray(schema: ArrayType, array: Row): Unit = {
194+
// TODO: support null values, see
195+
// https://issues.apache.org/jira/browse/SPARK-1649
196+
private[parquet] def writeArray(
197+
schema: ArrayType,
198+
array: CatalystConverter.ArrayScalaType[_]): Unit = {
187199
val elementType = schema.elementType
188200
writer.startGroup()
189201
if (array.size > 0) {
@@ -198,8 +210,11 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
198210
writer.endGroup()
199211
}
200212

201-
// TODO: this does not allow null values! Should these be supported?
202-
private[parquet] def writeMap(schema: MapType, map: Map[_, _]): Unit = {
213+
// TODO: support null values, see
214+
// https://issues.apache.org/jira/browse/SPARK-1649
215+
private[parquet] def writeMap(
216+
schema: MapType,
217+
map: CatalystConverter.MapScalaType[_, _]): Unit = {
203218
writer.startGroup()
204219
if (map.size > 0) {
205220
writer.startField(CatalystConverter.MAP_SCHEMA_NAME, 0)

0 commit comments

Comments
 (0)