@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
1919
2020import java .io .{BufferedReader , InputStream , InputStreamReader , OutputStream }
2121import java .nio .charset .StandardCharsets
22+ import java .util .Map .Entry
2223import java .util .concurrent .TimeUnit
2324
2425import scala .collection .JavaConverters ._
@@ -33,10 +34,8 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
3334import org .apache .spark .sql .catalyst .expressions .{Attribute , AttributeSet , Cast , Expression , GenericInternalRow , UnsafeProjection }
3435import org .apache .spark .sql .catalyst .plans .logical .ScriptInputOutputSchema
3536import org .apache .spark .sql .catalyst .plans .physical .Partitioning
36- import org .apache .spark .sql .catalyst .util .{DateTimeUtils , IntervalUtils }
3737import org .apache .spark .sql .internal .SQLConf
3838import org .apache .spark .sql .types ._
39- import org .apache .spark .unsafe .types .UTF8String
4039import org .apache .spark .util .{CircularBuffer , RedirectThread , SerializableConfiguration , Utils }
4140
4241trait BaseScriptTransformationExec extends UnaryExecNode {
@@ -47,7 +46,12 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
4746 def ioschema : ScriptTransformationIOSchema
4847
4948 protected lazy val inputExpressionsWithoutSerde : Seq [Expression ] = {
50- input.map(Cast (_, StringType ).withTimeZone(conf.sessionLocalTimeZone))
49+ input.map { in : Expression =>
50+ in.dataType match {
51+ case ArrayType (_, _) | MapType (_, _, _) | StructType (_) => in
52+ case _ => Cast (in, StringType ).withTimeZone(conf.sessionLocalTimeZone)
53+ }
54+ }
5155 }
5256
5357 override def producedAttributes : AttributeSet = outputSet -- inputSet
@@ -182,58 +186,8 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
182186 }
183187
184188 private lazy val outputFieldWriters : Seq [String => Any ] = output.map { attr =>
185- val converter = CatalystTypeConverters .createToCatalystConverter(attr.dataType)
186- attr.dataType match {
187- case StringType => wrapperConvertException(data => data, converter)
188- case BooleanType => wrapperConvertException(data => data.toBoolean, converter)
189- case ByteType => wrapperConvertException(data => data.toByte, converter)
190- case BinaryType =>
191- wrapperConvertException(data => UTF8String .fromString(data).getBytes, converter)
192- case IntegerType => wrapperConvertException(data => data.toInt, converter)
193- case ShortType => wrapperConvertException(data => data.toShort, converter)
194- case LongType => wrapperConvertException(data => data.toLong, converter)
195- case FloatType => wrapperConvertException(data => data.toFloat, converter)
196- case DoubleType => wrapperConvertException(data => data.toDouble, converter)
197- case _ : DecimalType => wrapperConvertException(data => BigDecimal (data), converter)
198- case DateType if conf.datetimeJava8ApiEnabled =>
199- wrapperConvertException(data => DateTimeUtils .stringToDate(
200- UTF8String .fromString(data),
201- DateTimeUtils .getZoneId(conf.sessionLocalTimeZone))
202- .map(DateTimeUtils .daysToLocalDate).orNull, converter)
203- case DateType => wrapperConvertException(data => DateTimeUtils .stringToDate(
204- UTF8String .fromString(data),
205- DateTimeUtils .getZoneId(conf.sessionLocalTimeZone))
206- .map(DateTimeUtils .toJavaDate).orNull, converter)
207- case TimestampType if conf.datetimeJava8ApiEnabled =>
208- wrapperConvertException(data => DateTimeUtils .stringToTimestamp(
209- UTF8String .fromString(data),
210- DateTimeUtils .getZoneId(conf.sessionLocalTimeZone))
211- .map(DateTimeUtils .microsToInstant).orNull, converter)
212- case TimestampType => wrapperConvertException(data => DateTimeUtils .stringToTimestamp(
213- UTF8String .fromString(data),
214- DateTimeUtils .getZoneId(conf.sessionLocalTimeZone))
215- .map(DateTimeUtils .toJavaTimestamp).orNull, converter)
216- case CalendarIntervalType => wrapperConvertException(
217- data => IntervalUtils .stringToInterval(UTF8String .fromString(data)),
218- converter)
219- case udt : UserDefinedType [_] =>
220- wrapperConvertException(data => udt.deserialize(data), converter)
221- case dt =>
222- throw new SparkException (s " ${nodeName} without serde does not support " +
223- s " ${dt.getClass.getSimpleName} as output data type " )
224- }
189+ SparkInspectors .unwrapper(attr.dataType, conf, ioschema, 1 )
225190 }
226-
227- // Keep consistent with Hive `LazySimpleSerde`, when there is a type case error, return null
228- private val wrapperConvertException : (String => Any , Any => Any ) => String => Any =
229- (f : String => Any , converter : Any => Any ) =>
230- (data : String ) => converter {
231- try {
232- f(data)
233- } catch {
234- case NonFatal (_) => null
235- }
236- }
237191}
238192
239193abstract class BaseScriptTransformationWriterThread extends Thread with Logging {
@@ -256,18 +210,23 @@ abstract class BaseScriptTransformationWriterThread extends Thread with Logging
256210
257211 protected def processRows (): Unit
258212
213+ val wrappers = inputSchema.map(dt => SparkInspectors .wrapper(dt))
214+
259215 protected def processRowsWithoutSerde (): Unit = {
260216 val len = inputSchema.length
261217 iter.foreach { row =>
218+ val values = row.asInstanceOf [GenericInternalRow ].values.zip(wrappers).map {
219+ case (value, wrapper) => wrapper(value)
220+ }
262221 val data = if (len == 0 ) {
263222 ioSchema.inputRowFormatMap(" TOK_TABLEROWFORMATLINES" )
264223 } else {
265224 val sb = new StringBuilder
266- sb.append(row.get( 0 , inputSchema(0 )) )
225+ buildString(sb, values( 0 ) , inputSchema(0 ), 1 )
267226 var i = 1
268227 while (i < len) {
269228 sb.append(ioSchema.inputRowFormatMap(" TOK_TABLEROWFORMATFIELD" ))
270- sb.append(row.get(i , inputSchema(i)) )
229+ buildString(sb, values(i) , inputSchema(i), 1 )
271230 i += 1
272231 }
273232 sb.append(ioSchema.inputRowFormatMap(" TOK_TABLEROWFORMATLINES" ))
@@ -277,6 +236,50 @@ abstract class BaseScriptTransformationWriterThread extends Thread with Logging
277236 }
278237 }
279238
239+ /**
240+ * Convert data to string according to the data type.
241+ *
242+ * @param sb The StringBuilder to store the serialized data.
243+ * @param obj The object for the current field.
244+ * @param dataType The DataType for the current Object.
245+ * @param level The current level of separator.
246+ */
247+ private def buildString (sb : StringBuilder , obj : Any , dataType : DataType , level : Int ): Unit = {
248+ (obj, dataType) match {
249+ case (list : java.util.List [_], ArrayType (typ, _)) =>
250+ val separator = ioSchema.getSeparator(level)
251+ (0 until list.size).foreach { i =>
252+ if (i > 0 ) {
253+ sb.append(separator)
254+ }
255+ buildString(sb, list.get(i), typ, level + 1 )
256+ }
257+ case (map : java.util.Map [_, _], MapType (keyType, valueType, _)) =>
258+ val separator = ioSchema.getSeparator(level)
259+ val keyValueSeparator = ioSchema.getSeparator(level + 1 )
260+ val entries = map.entrySet().toArray()
261+ (0 until entries.size).foreach { i =>
262+ if (i > 0 ) {
263+ sb.append(separator)
264+ }
265+ val entry = entries(i).asInstanceOf [Entry [_, _]]
266+ buildString(sb, entry.getKey, keyType, level + 2 )
267+ sb.append(keyValueSeparator)
268+ buildString(sb, entry.getValue, valueType, level + 2 )
269+ }
270+ case (arrayList : java.util.ArrayList [_], StructType (fields)) =>
271+ val separator = ioSchema.getSeparator(level)
272+ (0 until arrayList.size).foreach { i =>
273+ if (i > 0 ) {
274+ sb.append(separator)
275+ }
276+ buildString(sb, arrayList.get(i), fields(i).dataType, level + 1 )
277+ }
278+ case (other, _) =>
279+ sb.append(other)
280+ }
281+ }
282+
280283 override def run (): Unit = Utils .logUncaughtExceptions {
281284 TaskContext .setTaskContext(taskContext)
282285
@@ -329,14 +332,45 @@ case class ScriptTransformationIOSchema(
329332 schemaLess : Boolean ) extends Serializable {
330333 import ScriptTransformationIOSchema ._
331334
332- val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k))
333- val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k))
335+ val inputRowFormatMap = inputRowFormat.toMap.withDefault(k => defaultFormat(k))
336+ val outputRowFormatMap = outputRowFormat.toMap.withDefault(k => defaultFormat(k))
337+
338+ val separators = (getByte(inputRowFormatMap(" TOK_TABLEROWFORMATFIELD" ), 0 .toByte) ::
339+ getByte(inputRowFormatMap(" TOK_TABLEROWFORMATCOLLITEMS" ), 1 .toByte) ::
340+ getByte(inputRowFormatMap(" TOK_TABLEROWFORMATMAPKEYS" ), 2 .toByte) :: Nil ) ++
341+ (4 to 8 ).map(_.toByte)
342+
343+ def getByte (altValue : String , defaultVal : Byte ): Byte = {
344+ if (altValue != null && altValue.length > 0 ) {
345+ try {
346+ java.lang.Byte .parseByte(altValue)
347+ } catch {
348+ case _ : NumberFormatException =>
349+ altValue.charAt(0 ).toByte
350+ }
351+ } else {
352+ defaultVal
353+ }
354+ }
355+
356+ def getSeparator (level : Int ): Char = {
357+ try {
358+ separators(level).toChar
359+ } catch {
360+ case _ : IndexOutOfBoundsException =>
361+ val msg = " Number of levels of nesting supported for Spark SQL script transform" +
362+ " is " + (separators.length - 1 ) + " Unable to work with level " + level
363+ throw new RuntimeException (msg)
364+ }
365+ }
334366}
335367
336368object ScriptTransformationIOSchema {
337369 val defaultFormat = Map (
338370 (" TOK_TABLEROWFORMATFIELD" , " \t " ),
339- (" TOK_TABLEROWFORMATLINES" , " \n " )
371+ (" TOK_TABLEROWFORMATLINES" , " \n " ),
372+ (" TOK_TABLEROWFORMATCOLLITEMS" , " \u0002 " ),
373+ (" TOK_TABLEROWFORMATMAPKEYS" , " \u0003 " )
340374 )
341375
342376 val defaultIOSchema = ScriptTransformationIOSchema (
0 commit comments