Skip to content

Commit 38b9e69

Browse files
kiszkcloud-fan
authored andcommitted
[SPARK-18284][SQL] Make ExpressionEncoder.serializer.nullable precise
## What changes were proposed in this pull request? This PR makes `ExpressionEncoder.serializer.nullable` for flat encoder for a primitive type `false`. Since it is `true` for now, it is too conservative. While `ExpressionEncoder.schema` has correct information (e.g. `<IntegerType, false>`), `serializer.head.nullable` of `ExpressionEncoder`, which got from `encoderFor[T]`, is always false. It is too conservative. This is accomplished by checking whether a type is one of primitive types. If it is `true`, `nullable` should be `false`. ## How was this patch tested? Added new tests for encoder and dataframe Author: Kazuaki Ishizaki <[email protected]> Closes #15780 from kiszk/SPARK-18284.
1 parent 70c5549 commit 38b9e69

File tree

8 files changed

+96
-21
lines changed

8 files changed

+96
-21
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,12 +396,14 @@ object JavaTypeInference {
396396

397397
case _ if mapType.isAssignableFrom(typeToken) =>
398398
val (keyType, valueType) = mapKeyValueType(typeToken)
399+
399400
ExternalMapToCatalyst(
400401
inputObject,
401402
ObjectType(keyType.getRawType),
402403
serializerFor(_, keyType),
403404
ObjectType(valueType.getRawType),
404-
serializerFor(_, valueType)
405+
serializerFor(_, valueType),
406+
valueNullable = true
405407
)
406408

407409
case other =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,8 @@ object ScalaReflection extends ScalaReflection {
498498
dataTypeFor(keyType),
499499
serializerFor(_, keyType, keyPath),
500500
dataTypeFor(valueType),
501-
serializerFor(_, valueType, valuePath))
501+
serializerFor(_, valueType, valuePath),
502+
valueNullable = !valueType.typeSymbol.asClass.isPrimitive)
502503

503504
case t if t <:< localTypeOf[String] =>
504505
StaticInvoke(
@@ -590,7 +591,9 @@ object ScalaReflection extends ScalaReflection {
590591
"cannot be used as field name\n" + walkedTypePath.mkString("\n"))
591592
}
592593

593-
val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
594+
val fieldValue = Invoke(
595+
AssertNotNull(inputObject, walkedTypePath), fieldName, dataTypeFor(fieldType),
596+
returnNullable = !fieldType.typeSymbol.asClass.isPrimitive)
594597
val clsName = getClassNameFromType(fieldType)
595598
val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
596599
expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ object ExpressionEncoder {
6060
val cls = mirror.runtimeClass(tpe)
6161
val flat = !ScalaReflection.definedByConstructorParams(tpe)
6262

63-
val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true)
63+
val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = !cls.isPrimitive)
6464
val nullSafeInput = if (flat) {
6565
inputObject
6666
} else {
@@ -71,10 +71,7 @@ object ExpressionEncoder {
7171
val serializer = ScalaReflection.serializerFor[T](nullSafeInput)
7272
val deserializer = ScalaReflection.deserializerFor[T]
7373

74-
val schema = ScalaReflection.schemaFor[T] match {
75-
case ScalaReflection.Schema(s: StructType, _) => s
76-
case ScalaReflection.Schema(dt, nullable) => new StructType().add("value", dt, nullable)
77-
}
74+
val schema = serializer.dataType
7875

7976
new ExpressionEncoder[T](
8077
schema,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ case class ReferenceToExpressions(result: Expression, children: Seq[Expression])
7474
ctx.addMutableState("boolean", classChildVarIsNull, "")
7575

7676
val classChildVar =
77-
LambdaVariable(classChildVarName, classChildVarIsNull, child.dataType)
77+
LambdaVariable(classChildVarName, classChildVarIsNull, child.dataType, child.nullable)
7878

7979
val initCode = s"${classChildVar.value} = ${childGen.value};\n" +
8080
s"${classChildVar.isNull} = ${childGen.isNull};"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -171,15 +171,18 @@ case class StaticInvoke(
171171
* @param arguments An optional list of expressions, whos evaluation will be passed to the function.
172172
* @param propagateNull When true, and any of the arguments is null, null will be returned instead
173173
* of calling the function.
174+
* @param returnNullable When false, indicating the invoked method will always return
175+
* non-null value.
174176
*/
175177
case class Invoke(
176178
targetObject: Expression,
177179
functionName: String,
178180
dataType: DataType,
179181
arguments: Seq[Expression] = Nil,
180-
propagateNull: Boolean = true) extends InvokeLike {
182+
propagateNull: Boolean = true,
183+
returnNullable : Boolean = true) extends InvokeLike {
181184

182-
override def nullable: Boolean = true
185+
override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable
183186
override def children: Seq[Expression] = targetObject +: arguments
184187

185188
override def eval(input: InternalRow): Any =
@@ -405,13 +408,15 @@ case class WrapOption(child: Expression, optType: DataType)
405408
* A place holder for the loop variable used in [[MapObjects]]. This should never be constructed
406409
* manually, but will instead be passed into the provided lambda function.
407410
*/
408-
case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends LeafExpression
411+
case class LambdaVariable(
412+
value: String,
413+
isNull: String,
414+
dataType: DataType,
415+
nullable: Boolean = true) extends LeafExpression
409416
with Unevaluable with NonSQLExpression {
410417

411-
override def nullable: Boolean = true
412-
413418
override def genCode(ctx: CodegenContext): ExprCode = {
414-
ExprCode(code = "", value = value, isNull = isNull)
419+
ExprCode(code = "", value = value, isNull = if (nullable) isNull else "false")
415420
}
416421
}
417422

@@ -592,7 +597,8 @@ object ExternalMapToCatalyst {
592597
keyType: DataType,
593598
keyConverter: Expression => Expression,
594599
valueType: DataType,
595-
valueConverter: Expression => Expression): ExternalMapToCatalyst = {
600+
valueConverter: Expression => Expression,
601+
valueNullable: Boolean): ExternalMapToCatalyst = {
596602
val id = curId.getAndIncrement()
597603
val keyName = "ExternalMapToCatalyst_key" + id
598604
val valueName = "ExternalMapToCatalyst_value" + id
@@ -601,11 +607,11 @@ object ExternalMapToCatalyst {
601607
ExternalMapToCatalyst(
602608
keyName,
603609
keyType,
604-
keyConverter(LambdaVariable(keyName, "false", keyType)),
610+
keyConverter(LambdaVariable(keyName, "false", keyType, false)),
605611
valueName,
606612
valueIsNull,
607613
valueType,
608-
valueConverter(LambdaVariable(valueName, valueIsNull, valueType)),
614+
valueConverter(LambdaVariable(valueName, valueIsNull, valueType, valueNullable)),
609615
inputMap
610616
)
611617
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import java.util.Arrays
2424
import scala.collection.mutable.ArrayBuffer
2525
import scala.reflect.runtime.universe.TypeTag
2626

27-
import org.apache.spark.sql.Encoders
27+
import org.apache.spark.sql.{Encoder, Encoders}
2828
import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
2929
import org.apache.spark.sql.catalyst.analysis.AnalysisTest
3030
import org.apache.spark.sql.catalyst.dsl.plans._
@@ -300,6 +300,11 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
300300
encodeDecodeTest(
301301
ReferenceValueClass(ReferenceValueClass.Container(1)), "reference value class")
302302

303+
encodeDecodeTest(Option(31), "option of int")
304+
encodeDecodeTest(Option.empty[Int], "empty option of int")
305+
encodeDecodeTest(Option("abc"), "option of string")
306+
encodeDecodeTest(Option.empty[String], "empty option of string")
307+
303308
productTest(("UDT", new ExamplePoint(0.1, 0.2)))
304309

305310
test("nullable of encoder schema") {
@@ -338,6 +343,18 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
338343
}
339344
}
340345

346+
test("nullable of encoder serializer") {
347+
def checkNullable[T: Encoder](nullable: Boolean): Unit = {
348+
assert(encoderFor[T].serializer.forall(_.nullable === nullable))
349+
}
350+
351+
// test for flat encoders
352+
checkNullable[Int](false)
353+
checkNullable[Option[Int]](true)
354+
checkNullable[java.lang.Integer](true)
355+
checkNullable[String](true)
356+
}
357+
341358
test("null check for map key") {
342359
val encoder = ExpressionEncoder[Map[String, Int]]()
343360
val e = intercept[RuntimeException](encoder.toRow(Map(("a", 1), (null, 2))))

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@ import org.apache.spark.sql.execution.streaming.MemoryStream
2828
import org.apache.spark.sql.functions._
2929
import org.apache.spark.sql.internal.SQLConf
3030
import org.apache.spark.sql.test.SharedSQLContext
31-
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
31+
import org.apache.spark.sql.types._
32+
33+
case class TestDataPoint(x: Int, y: Double, s: String, t: TestDataPoint2)
34+
case class TestDataPoint2(x: Int, s: String)
3235

3336
class DatasetSuite extends QueryTest with SharedSQLContext {
3437
import testImplicits._
@@ -969,6 +972,53 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
969972
assert(dataset.collect() sameElements Array(resultValue, resultValue))
970973
}
971974

975+
test("SPARK-18284: Serializer should have correct nullable value") {
976+
val df1 = Seq(1, 2, 3, 4).toDF
977+
assert(df1.schema(0).nullable == false)
978+
val df2 = Seq(Integer.valueOf(1), Integer.valueOf(2)).toDF
979+
assert(df2.schema(0).nullable == true)
980+
981+
val df3 = Seq(Seq(1, 2), Seq(3, 4)).toDF
982+
assert(df3.schema(0).nullable == true)
983+
assert(df3.schema(0).dataType.asInstanceOf[ArrayType].containsNull == false)
984+
val df4 = Seq(Seq("a", "b"), Seq("c", "d")).toDF
985+
assert(df4.schema(0).nullable == true)
986+
assert(df4.schema(0).dataType.asInstanceOf[ArrayType].containsNull == true)
987+
988+
val df5 = Seq((0, 1.0), (2, 2.0)).toDF("id", "v")
989+
assert(df5.schema(0).nullable == false)
990+
assert(df5.schema(1).nullable == false)
991+
val df6 = Seq((0, 1.0, "a"), (2, 2.0, "b")).toDF("id", "v1", "v2")
992+
assert(df6.schema(0).nullable == false)
993+
assert(df6.schema(1).nullable == false)
994+
assert(df6.schema(2).nullable == true)
995+
996+
val df7 = (Tuple1(Array(1, 2, 3)) :: Nil).toDF("a")
997+
assert(df7.schema(0).nullable == true)
998+
assert(df7.schema(0).dataType.asInstanceOf[ArrayType].containsNull == false)
999+
1000+
val df8 = (Tuple1(Array((null: Integer), (null: Integer))) :: Nil).toDF("a")
1001+
assert(df8.schema(0).nullable == true)
1002+
assert(df8.schema(0).dataType.asInstanceOf[ArrayType].containsNull == true)
1003+
1004+
val df9 = (Tuple1(Map(2 -> 3)) :: Nil).toDF("m")
1005+
assert(df9.schema(0).nullable == true)
1006+
assert(df9.schema(0).dataType.asInstanceOf[MapType].valueContainsNull == false)
1007+
1008+
val df10 = (Tuple1(Map(1 -> (null: Integer))) :: Nil).toDF("m")
1009+
assert(df10.schema(0).nullable == true)
1010+
assert(df10.schema(0).dataType.asInstanceOf[MapType].valueContainsNull == true)
1011+
1012+
val df11 = Seq(TestDataPoint(1, 2.2, "a", null),
1013+
TestDataPoint(3, 4.4, "null", (TestDataPoint2(33, "b")))).toDF
1014+
assert(df11.schema(0).nullable == false)
1015+
assert(df11.schema(1).nullable == false)
1016+
assert(df11.schema(2).nullable == true)
1017+
assert(df11.schema(3).nullable == true)
1018+
assert(df11.schema(3).dataType.asInstanceOf[StructType].fields(0).nullable == false)
1019+
assert(df11.schema(3).dataType.asInstanceOf[StructType].fields(1).nullable == true)
1020+
}
1021+
9721022
Seq(true, false).foreach { eager =>
9731023
def testCheckpointing(testName: String)(f: => Unit): Unit = {
9741024
test(s"Dataset.checkpoint() - $testName (eager = $eager)") {

sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class FileStreamSinkSuite extends StreamTest {
8686

8787
val outputDf = spark.read.parquet(outputDir)
8888
val expectedSchema = new StructType()
89-
.add(StructField("value", IntegerType))
89+
.add(StructField("value", IntegerType, nullable = false))
9090
.add(StructField("id", IntegerType))
9191
assert(outputDf.schema === expectedSchema)
9292

0 commit comments

Comments
 (0)