Skip to content

Commit 83fef40

Browse files
committed
address review comment
1 parent 39ffaa9 commit 83fef40

File tree

1 file changed

+24
-7
lines changed
  • sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions

1 file changed

+24
-7
lines changed

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -830,14 +830,31 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
830830

831831
test("SPARK-22500: cast for struct should not generate codes beyond 64KB") {
832832
val N = 1000
833-
val from = new StructType(
833+
834+
val from1 = new StructType(
834835
(1 to N).map(i => StructField(s"s$i", StringType)).toArray)
835-
val to = new StructType(
836+
val to1 = new StructType(
836837
(1 to N).map(i => StructField(s"i$i", IntegerType)).toArray)
837-
838-
val input = Row.fromSeq((1 to N).map(i => i.toString))
839-
val output = Row.fromSeq((1 to N))
840-
841-
checkEvaluation(cast(Literal.create(input, from), to), output)
838+
val input1 = Row.fromSeq((1 to N).map(i => i.toString))
839+
val output1 = Row.fromSeq((1 to N))
840+
checkEvaluation(cast(Literal.create(input1, from1), to1), output1)
841+
842+
val from2 = new StructType(
843+
(1 to N).map(i => StructField(s"a$i", ArrayType(StringType, containsNull = false))).toArray)
844+
val to2 = new StructType(
845+
(1 to N).map(i => StructField(s"i$i", ArrayType(IntegerType, containsNull = true))).toArray)
846+
val input2 = Row.fromSeq((1 to N).map(_ => Seq("456", "true", "78.9")))
847+
val output2 = Row.fromSeq((1 to N).map(_ => Seq(456, null, 78)))
848+
checkEvaluation(cast(Literal.create(input2, from2), to2), output2)
849+
850+
val from3 = new StructType(
851+
(1 to N).map(i => StructField(s"s$i",
852+
StructType(Seq(StructField("l$i", IntegerType, nullable = true))))).toArray)
853+
val to3 = new StructType(
854+
(1 to N).map(i => StructField(s"s$i",
855+
StructType(Seq(StructField("l$i", LongType, nullable = true))))).toArray)
856+
val input3 = Row.fromSeq((1 to N).map(i => Row(i)))
857+
val output3 = Row.fromSeq((1 to N).map(i => Row(i.toLong)))
858+
checkEvaluation(cast(Literal.create(input3, from3), to3), output3)
842859
}
843860
}

0 commit comments

Comments
 (0)