Skip to content

Commit adc9ded

Browse files
committed
[SPARK-31937][SQL] Support processing array and map type using spark noserde mode
1 parent b33fa53 commit adc9ded

File tree

4 files changed

+464
-60
lines changed

4 files changed

+464
-60
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,9 @@ object CatalystTypeConverters {
174174
convertedIterable += elementConverter.toCatalyst(item)
175175
}
176176
new GenericArrayData(convertedIterable.toArray)
177+
case g: GenericArrayData => new GenericArrayData(g.array.map(elementConverter.toCatalyst))
177178
case other => throw new IllegalArgumentException(
178-
s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) "
179+
s"AAAThe value (${other.toString}) of the type (${other.getClass.getCanonicalName}) "
179180
+ s"cannot be converted to an array of ${elementType.catalogString}")
180181
}
181182
}
@@ -213,6 +214,9 @@ object CatalystTypeConverters {
213214
scalaValue match {
214215
case map: Map[_, _] => ArrayBasedMapData(map, keyFunction, valueFunction)
215216
case javaMap: JavaMap[_, _] => ArrayBasedMapData(javaMap, keyFunction, valueFunction)
217+
case map: ArrayBasedMapData =>
218+
ArrayBasedMapData(map.keyArray.array.zip(map.valueArray.array).toMap,
219+
keyFunction, valueFunction)
216220
case other => throw new IllegalArgumentException(
217221
s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) "
218222
+ "cannot be converted to a map type with "
@@ -263,6 +267,15 @@ object CatalystTypeConverters {
263267
idx += 1
264268
}
265269
new GenericInternalRow(ar)
270+
case g: GenericInternalRow =>
271+
val ar = new Array[Any](structType.size)
272+
val values = g.values
273+
var idx = 0
274+
while (idx < structType.size) {
275+
ar(idx) = converters(idx).toCatalyst(values(idx))
276+
idx += 1
277+
}
278+
new GenericInternalRow(ar)
266279
case other => throw new IllegalArgumentException(
267280
s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) "
268281
+ s"cannot be converted to ${structType.catalogString}")

sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala

Lines changed: 93 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
1919

2020
import java.io.{BufferedReader, InputStream, InputStreamReader, OutputStream}
2121
import java.nio.charset.StandardCharsets
22+
import java.util.Map.Entry
2223
import java.util.concurrent.TimeUnit
2324

2425
import scala.collection.JavaConverters._
@@ -33,10 +34,8 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
3334
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Cast, Expression, GenericInternalRow, UnsafeProjection}
3435
import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema
3536
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
36-
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils}
3737
import org.apache.spark.sql.internal.SQLConf
3838
import org.apache.spark.sql.types._
39-
import org.apache.spark.unsafe.types.UTF8String
4039
import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils}
4140

4241
trait BaseScriptTransformationExec extends UnaryExecNode {
@@ -47,7 +46,12 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
4746
def ioschema: ScriptTransformationIOSchema
4847

4948
protected lazy val inputExpressionsWithoutSerde: Seq[Expression] = {
50-
input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone))
49+
input.map { in: Expression =>
50+
in.dataType match {
51+
case ArrayType(_, _) | MapType(_, _, _) | StructType(_) => in
52+
case _ => Cast(in, StringType).withTimeZone(conf.sessionLocalTimeZone)
53+
}
54+
}
5155
}
5256

5357
override def producedAttributes: AttributeSet = outputSet -- inputSet
@@ -182,58 +186,8 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
182186
}
183187

184188
private lazy val outputFieldWriters: Seq[String => Any] = output.map { attr =>
185-
val converter = CatalystTypeConverters.createToCatalystConverter(attr.dataType)
186-
attr.dataType match {
187-
case StringType => wrapperConvertException(data => data, converter)
188-
case BooleanType => wrapperConvertException(data => data.toBoolean, converter)
189-
case ByteType => wrapperConvertException(data => data.toByte, converter)
190-
case BinaryType =>
191-
wrapperConvertException(data => UTF8String.fromString(data).getBytes, converter)
192-
case IntegerType => wrapperConvertException(data => data.toInt, converter)
193-
case ShortType => wrapperConvertException(data => data.toShort, converter)
194-
case LongType => wrapperConvertException(data => data.toLong, converter)
195-
case FloatType => wrapperConvertException(data => data.toFloat, converter)
196-
case DoubleType => wrapperConvertException(data => data.toDouble, converter)
197-
case _: DecimalType => wrapperConvertException(data => BigDecimal(data), converter)
198-
case DateType if conf.datetimeJava8ApiEnabled =>
199-
wrapperConvertException(data => DateTimeUtils.stringToDate(
200-
UTF8String.fromString(data),
201-
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
202-
.map(DateTimeUtils.daysToLocalDate).orNull, converter)
203-
case DateType => wrapperConvertException(data => DateTimeUtils.stringToDate(
204-
UTF8String.fromString(data),
205-
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
206-
.map(DateTimeUtils.toJavaDate).orNull, converter)
207-
case TimestampType if conf.datetimeJava8ApiEnabled =>
208-
wrapperConvertException(data => DateTimeUtils.stringToTimestamp(
209-
UTF8String.fromString(data),
210-
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
211-
.map(DateTimeUtils.microsToInstant).orNull, converter)
212-
case TimestampType => wrapperConvertException(data => DateTimeUtils.stringToTimestamp(
213-
UTF8String.fromString(data),
214-
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
215-
.map(DateTimeUtils.toJavaTimestamp).orNull, converter)
216-
case CalendarIntervalType => wrapperConvertException(
217-
data => IntervalUtils.stringToInterval(UTF8String.fromString(data)),
218-
converter)
219-
case udt: UserDefinedType[_] =>
220-
wrapperConvertException(data => udt.deserialize(data), converter)
221-
case dt =>
222-
throw new SparkException(s"${nodeName} without serde does not support " +
223-
s"${dt.getClass.getSimpleName} as output data type")
224-
}
189+
SparkInspectors.unwrapper(attr.dataType, conf, ioschema, 1)
225190
}
226-
227-
// Keep consistent with Hive `LazySimpleSerde`, when there is a type case error, return null
228-
private val wrapperConvertException: (String => Any, Any => Any) => String => Any =
229-
(f: String => Any, converter: Any => Any) =>
230-
(data: String) => converter {
231-
try {
232-
f(data)
233-
} catch {
234-
case NonFatal(_) => null
235-
}
236-
}
237191
}
238192

239193
abstract class BaseScriptTransformationWriterThread extends Thread with Logging {
@@ -256,18 +210,23 @@ abstract class BaseScriptTransformationWriterThread extends Thread with Logging
256210

257211
protected def processRows(): Unit
258212

213+
val wrappers = inputSchema.map(dt => SparkInspectors.wrapper(dt))
214+
259215
protected def processRowsWithoutSerde(): Unit = {
260216
val len = inputSchema.length
261217
iter.foreach { row =>
218+
val values = row.asInstanceOf[GenericInternalRow].values.zip(wrappers).map {
219+
case (value, wrapper) => wrapper(value)
220+
}
262221
val data = if (len == 0) {
263222
ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")
264223
} else {
265224
val sb = new StringBuilder
266-
sb.append(row.get(0, inputSchema(0)))
225+
buildString(sb, values(0), inputSchema(0), 1)
267226
var i = 1
268227
while (i < len) {
269228
sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"))
270-
sb.append(row.get(i, inputSchema(i)))
229+
buildString(sb, values(i), inputSchema(i), 1)
271230
i += 1
272231
}
273232
sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATLINES"))
@@ -277,6 +236,50 @@ abstract class BaseScriptTransformationWriterThread extends Thread with Logging
277236
}
278237
}
279238

239+
/**
240+
* Convert data to string according to the data type.
241+
*
242+
* @param sb The StringBuilder to store the serialized data.
243+
* @param obj The object for the current field.
244+
* @param dataType The DataType for the current Object.
245+
* @param level The current level of separator.
246+
*/
247+
private def buildString(sb: StringBuilder, obj: Any, dataType: DataType, level: Int): Unit = {
248+
(obj, dataType) match {
249+
case (list: java.util.List[_], ArrayType(typ, _)) =>
250+
val separator = ioSchema.getSeparator(level)
251+
(0 until list.size).foreach { i =>
252+
if (i > 0) {
253+
sb.append(separator)
254+
}
255+
buildString(sb, list.get(i), typ, level + 1)
256+
}
257+
case (map: java.util.Map[_, _], MapType(keyType, valueType, _)) =>
258+
val separator = ioSchema.getSeparator(level)
259+
val keyValueSeparator = ioSchema.getSeparator(level + 1)
260+
val entries = map.entrySet().toArray()
261+
(0 until entries.size).foreach { i =>
262+
if (i > 0) {
263+
sb.append(separator)
264+
}
265+
val entry = entries(i).asInstanceOf[Entry[_, _]]
266+
buildString(sb, entry.getKey, keyType, level + 2)
267+
sb.append(keyValueSeparator)
268+
buildString(sb, entry.getValue, valueType, level + 2)
269+
}
270+
case (arrayList: java.util.ArrayList[_], StructType(fields)) =>
271+
val separator = ioSchema.getSeparator(level)
272+
(0 until arrayList.size).foreach { i =>
273+
if (i > 0) {
274+
sb.append(separator)
275+
}
276+
buildString(sb, arrayList.get(i), fields(i).dataType, level + 1)
277+
}
278+
case (other, _) =>
279+
sb.append(other)
280+
}
281+
}
282+
280283
override def run(): Unit = Utils.logUncaughtExceptions {
281284
TaskContext.setTaskContext(taskContext)
282285

@@ -329,14 +332,45 @@ case class ScriptTransformationIOSchema(
329332
schemaLess: Boolean) extends Serializable {
330333
import ScriptTransformationIOSchema._
331334

332-
val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k))
333-
val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k))
335+
val inputRowFormatMap = inputRowFormat.toMap.withDefault(k => defaultFormat(k))
336+
val outputRowFormatMap = outputRowFormat.toMap.withDefault(k => defaultFormat(k))
337+
338+
val separators = (getByte(inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 0.toByte) ::
339+
getByte(inputRowFormatMap("TOK_TABLEROWFORMATCOLLITEMS"), 1.toByte) ::
340+
getByte(inputRowFormatMap("TOK_TABLEROWFORMATMAPKEYS"), 2.toByte) :: Nil) ++
341+
(4 to 8).map(_.toByte)
342+
343+
def getByte(altValue: String, defaultVal: Byte): Byte = {
344+
if (altValue != null && altValue.length > 0) {
345+
try {
346+
java.lang.Byte.parseByte(altValue)
347+
} catch {
348+
case _: NumberFormatException =>
349+
altValue.charAt(0).toByte
350+
}
351+
} else {
352+
defaultVal
353+
}
354+
}
355+
356+
def getSeparator(level: Int): Char = {
357+
try {
358+
separators(level).toChar
359+
} catch {
360+
case _: IndexOutOfBoundsException =>
361+
val msg = "Number of levels of nesting supported for Spark SQL script transform" +
362+
" is " + (separators.length - 1) + " Unable to work with level " + level
363+
throw new RuntimeException(msg)
364+
}
365+
}
334366
}
335367

336368
object ScriptTransformationIOSchema {
337369
val defaultFormat = Map(
338370
("TOK_TABLEROWFORMATFIELD", "\t"),
339-
("TOK_TABLEROWFORMATLINES", "\n")
371+
("TOK_TABLEROWFORMATLINES", "\n"),
372+
("TOK_TABLEROWFORMATCOLLITEMS", "\u0002"),
373+
("TOK_TABLEROWFORMATMAPKEYS", "\u0003")
340374
)
341375

342376
val defaultIOSchema = ScriptTransformationIOSchema(

0 commit comments

Comments
 (0)