Skip to content

Commit 80e487e

Browse files
committed
asNullable in UDT.
1 parent 587d88b commit 80e487e

File tree

4 files changed

+31
-21
lines changed

4 files changed

+31
-21
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
182182
case _ => false
183183
}
184184
}
185+
186+
private[spark] override def asNullable: VectorUDT = this
185187
}
186188

187189
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -271,12 +271,13 @@ abstract class DataType {
271271
/** Check if `this` and `other` are the same data type when ignoring nullability
272272
* (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
273273
*/
274-
private[sql] def sameType(other: DataType): Boolean = DataType.equalsIgnoreNullability(this, other)
274+
private[spark] def sameType(other: DataType): Boolean =
275+
DataType.equalsIgnoreNullability(this, other)
275276

276277
/** Returns the same data type but set all nullability fields are true
277278
* (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
278279
*/
279-
private[sql] def asNullable: DataType
280+
private[spark] def asNullable: DataType
280281
}
281282

282283
/**
@@ -293,7 +294,7 @@ class NullType private() extends DataType {
293294
// Defined with a private constructor so the companion object is the only possible instantiation.
294295
override def defaultSize: Int = 1
295296

296-
private[sql] override def asNullable: NullType = this
297+
private[spark] override def asNullable: NullType = this
297298
}
298299

299300
case object NullType extends NullType
@@ -360,7 +361,7 @@ class StringType private() extends NativeType with PrimitiveType {
360361
*/
361362
override def defaultSize: Int = 4096
362363

363-
private[sql] override def asNullable: StringType = this
364+
private[spark] override def asNullable: StringType = this
364365
}
365366

366367
case object StringType extends StringType
@@ -396,7 +397,7 @@ class BinaryType private() extends NativeType with PrimitiveType {
396397
*/
397398
override def defaultSize: Int = 4096
398399

399-
private[sql] override def asNullable: BinaryType = this
400+
private[spark] override def asNullable: BinaryType = this
400401
}
401402

402403
case object BinaryType extends BinaryType
@@ -423,7 +424,7 @@ class BooleanType private() extends NativeType with PrimitiveType {
423424
*/
424425
override def defaultSize: Int = 1
425426

426-
private[sql] override def asNullable: BooleanType = this
427+
private[spark] override def asNullable: BooleanType = this
427428
}
428429

429430
case object BooleanType extends BooleanType
@@ -455,7 +456,7 @@ class TimestampType private() extends NativeType {
455456
*/
456457
override def defaultSize: Int = 12
457458

458-
private[sql] override def asNullable: TimestampType = this
459+
private[spark] override def asNullable: TimestampType = this
459460
}
460461

461462
case object TimestampType extends TimestampType
@@ -485,7 +486,7 @@ class DateType private() extends NativeType {
485486
*/
486487
override def defaultSize: Int = 4
487488

488-
private[sql] override def asNullable: DateType = this
489+
private[spark] override def asNullable: DateType = this
489490
}
490491

491492
case object DateType extends DateType
@@ -545,7 +546,7 @@ class LongType private() extends IntegralType {
545546

546547
override def simpleString = "bigint"
547548

548-
private[sql] override def asNullable: LongType = this
549+
private[spark] override def asNullable: LongType = this
549550
}
550551

551552
case object LongType extends LongType
@@ -576,7 +577,7 @@ class IntegerType private() extends IntegralType {
576577

577578
override def simpleString = "int"
578579

579-
private[sql] override def asNullable: IntegerType = this
580+
private[spark] override def asNullable: IntegerType = this
580581
}
581582

582583
case object IntegerType extends IntegerType
@@ -607,7 +608,7 @@ class ShortType private() extends IntegralType {
607608

608609
override def simpleString = "smallint"
609610

610-
private[sql] override def asNullable: ShortType = this
611+
private[spark] override def asNullable: ShortType = this
611612
}
612613

613614
case object ShortType extends ShortType
@@ -638,7 +639,7 @@ class ByteType private() extends IntegralType {
638639

639640
override def simpleString = "tinyint"
640641

641-
private[sql] override def asNullable: ByteType = this
642+
private[spark] override def asNullable: ByteType = this
642643
}
643644

644645
case object ByteType extends ByteType
@@ -706,7 +707,7 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT
706707
case None => "decimal(10,0)"
707708
}
708709

709-
private[sql] override def asNullable: DecimalType = this
710+
private[spark] override def asNullable: DecimalType = this
710711
}
711712

712713

@@ -766,7 +767,7 @@ class DoubleType private() extends FractionalType {
766767
*/
767768
override def defaultSize: Int = 8
768769

769-
private[sql] override def asNullable: DoubleType = this
770+
private[spark] override def asNullable: DoubleType = this
770771
}
771772

772773
case object DoubleType extends DoubleType
@@ -796,7 +797,7 @@ class FloatType private() extends FractionalType {
796797
*/
797798
override def defaultSize: Int = 4
798799

799-
private[sql] override def asNullable: FloatType = this
800+
private[spark] override def asNullable: FloatType = this
800801
}
801802

802803
case object FloatType extends FloatType
@@ -846,7 +847,8 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
846847

847848
override def simpleString = s"array<${elementType.simpleString}>"
848849

849-
private[sql] override def asNullable: ArrayType = ArrayType(elementType.asNullable, containsNull = true)
850+
private[spark] override def asNullable: ArrayType =
851+
ArrayType(elementType.asNullable, containsNull = true)
850852
}
851853

852854

@@ -1093,7 +1095,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
10931095
private[sql] def merge(that: StructType): StructType =
10941096
StructType.merge(this, that).asInstanceOf[StructType]
10951097

1096-
private[sql] override def asNullable: StructType = {
1098+
private[spark] override def asNullable: StructType = {
10971099
val newFields = fields.map {
10981100
case StructField(name, dataType, nullable, metadata) =>
10991101
StructField(name, dataType.asNullable, nullable = true, metadata)
@@ -1154,7 +1156,7 @@ case class MapType(
11541156

11551157
override def simpleString = s"map<${keyType.simpleString},${valueType.simpleString}>"
11561158

1157-
private[sql] override def asNullable: MapType =
1159+
private[spark] override def asNullable: MapType =
11581160
MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true)
11591161
}
11601162

@@ -1210,7 +1212,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
12101212
*/
12111213
override def defaultSize: Int = 4096
12121214

1213-
private[sql] override def sameType(other: DataType): Boolean = ???
1214-
1215-
private[sql] override def asNullable: DataType = ???
1215+
/**
1216+
* For UDT, asNullable will not change the nullability of its internal sqlType and just returns
1217+
* itself.
1218+
*/
1219+
private[spark] override def asNullable: UserDefinedType[UserType] = this
12161220
}

sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,6 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] {
5959
}
6060

6161
override def userClass: Class[ExamplePoint] = classOf[ExamplePoint]
62+
63+
private[spark] override def asNullable: ExamplePointUDT = this
6264
}

sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
6262
}
6363

6464
override def userClass = classOf[MyDenseVector]
65+
66+
private[spark] override def asNullable: MyDenseVectorUDT = this
6567
}
6668

6769
class UserDefinedTypeSuite extends QueryTest {

0 commit comments

Comments
 (0)