Skip to content

Commit 4aac7dd

Browse files
committed
fix some test failed and add some comments
1 parent 446ff43 commit 4aac7dd

File tree

2 files changed

+31
-5
lines changed

2 files changed

+31
-5
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ object ScalaReflection extends ScalaReflection {
164164

165165
/** Returns the current path or `GetColumnByOrdinal`. */
166166
def getPath: Expression = {
167-
val dataType = schemaFor(tpe).dataType
167+
val dataType = schemaForDefaultBinaryType(tpe).dataType
168168
if (path.isDefined) {
169169
path.get
170170
} else {
@@ -409,7 +409,8 @@ object ScalaReflection extends ScalaReflection {
409409
val cls = getClassFromType(tpe)
410410

411411
val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) =>
412-
val Schema(dataType, nullable) = schemaFor(fieldType)
412+
val Schema(dataType, nullablity) = schemaForDefaultBinaryType(fieldType)
413+
413414
val clsName = getClassNameFromType(fieldType)
414415
val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
415416
// For tuples, we based grab the inner fields by ordinal instead of name.
@@ -424,7 +425,7 @@ object ScalaReflection extends ScalaReflection {
424425
Some(addToPath(fieldName, dataType, newTypePath)),
425426
newTypePath)
426427

427-
if (!nullable) {
428+
if (!nullablity) {
428429
AssertNotNull(constructor, newTypePath)
429430
} else {
430431
constructor
@@ -445,6 +446,7 @@ object ScalaReflection extends ScalaReflection {
445446
}
446447

447448
case _ =>
449+
// default kryo deserializer
448450
DecodeUsingSerializer(getPath, ClassTag(getClassFromType(tpe)), true)
449451
}
450452
}
@@ -644,7 +646,8 @@ object ScalaReflection extends ScalaReflection {
644646
val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
645647
expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
646648

647-
case other =>
649+
case _ =>
650+
// default kryo serializer
648651
EncodeUsingSerializer(inputObject, true)
649652
}
650653

@@ -712,6 +715,13 @@ object ScalaReflection extends ScalaReflection {
712715
s.toAttributes
713716
}
714717

718+
/**
719+
* Returns a catalyst DataType and its nullability for the given Scala Type using reflection.
720+
* If the tpe mismatched in schemaFor function, the default BinaryType returned
721+
*/
722+
def schemaForDefaultBinaryType(tpe: `Type`): Schema = scala.util.Try(schemaFor(tpe)).toOption
723+
.getOrElse(Schema(BinaryType, nullable = true))
724+
715725
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
716726
def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T])
717727

@@ -775,7 +785,8 @@ object ScalaReflection extends ScalaReflection {
775785
StructField(fieldName, dataType, nullable)
776786
}), nullable = true)
777787
case other =>
778-
Schema(BinaryType, nullable = false)
788+
throw new UnsupportedOperationException(s"Schema for type $other is not supported")
789+
// Schema(BinaryType, nullable = false)
779790
}
780791
}
781792

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,8 +1136,23 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
11361136
assert(spark.range(1).map { x => new java.sql.Timestamp(100000) }.head ==
11371137
new java.sql.Timestamp(100000))
11381138
}
1139+
1140+
test("fallback to kryo for unknow classes in ExpressionEncoder") {
1141+
val ds = Seq(DefaultKryoEncoderForSubFiled("a", Seq(1), Some(Set(2))),
1142+
DefaultKryoEncoderForSubFiled("b", Seq(3), None)).toDS()
1143+
checkDataset(ds, DefaultKryoEncoderForSubFiled("a", Seq(1), Some(Set(2))),
1144+
DefaultKryoEncoderForSubFiled("b", Seq(3), None))
1145+
1146+
val df = ds.toDF()
1147+
val x = df.schema
1148+
assert(df.schema(0).dataType == StringType)
1149+
assert(df.schema(1).dataType == ArrayType(IntegerType, null = false))
1150+
assert(df.schema(2).dataType == BinaryType)
1151+
}
11391152
}
11401153

1154+
case class DefaultKryoEncoderForSubFiled(a: String, b: Seq[Int], c: Option[Set[Int]])
1155+
11411156
case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String])
11421157
case class WithMap(id: String, map_test: scala.collection.Map[Long, String])
11431158

0 commit comments

Comments
 (0)