Skip to content

Commit 5b19b91

Browse files
committed
Fix:Improve some unit tests for NullExpressionsSuite and TypeCoercionSuite
1 parent 2a23cdd commit 5b19b91

File tree

2 files changed

+93
-23
lines changed

2 files changed

+93
-23
lines changed

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala

Lines changed: 76 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -484,24 +484,50 @@ class TypeCoercionSuite extends PlanTest {
484484
}
485485

486486
test("coalesce casts") {
487-
ruleTest(TypeCoercion.FunctionArgumentConversion,
488-
Coalesce(Literal(1.0)
489-
:: Literal(1)
490-
:: Literal.create(1.0, FloatType)
491-
:: Nil),
492-
Coalesce(Cast(Literal(1.0), DoubleType)
493-
:: Cast(Literal(1), DoubleType)
494-
:: Cast(Literal.create(1.0, FloatType), DoubleType)
495-
:: Nil))
496-
ruleTest(TypeCoercion.FunctionArgumentConversion,
497-
Coalesce(Literal(1L)
498-
:: Literal(1)
499-
:: Literal(new java.math.BigDecimal("1000000000000000000000"))
500-
:: Nil),
501-
Coalesce(Cast(Literal(1L), DecimalType(22, 0))
502-
:: Cast(Literal(1), DecimalType(22, 0))
503-
:: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0))
504-
:: Nil))
487+
val rule = TypeCoercion.FunctionArgumentConversion
488+
489+
val intLit = Literal(1)
490+
val longLit = Literal.create(1L)
491+
val doubleLit = Literal(1.0)
492+
val stringLit = Literal.create("c", StringType)
493+
val nullLit = Literal.create(null, NullType)
494+
val floatNullLit = Literal.create(null, FloatType)
495+
val floatLit = Literal.create(1.0f, FloatType)
496+
val timestampLit = Literal.create("2017-04-12", TimestampType)
497+
val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000"))
498+
499+
ruleTest(rule,
500+
Coalesce(Seq(doubleLit, intLit, floatLit)),
501+
Coalesce(Seq(Cast(doubleLit, DoubleType),
502+
Cast(intLit, DoubleType), Cast(floatLit, DoubleType))))
503+
504+
ruleTest(rule,
505+
Coalesce(Seq(longLit, intLit, decimalLit)),
506+
Coalesce(Seq(Cast(longLit, DecimalType(22, 0)),
507+
Cast(intLit, DecimalType(22, 0)), Cast(decimalLit, DecimalType(22, 0)))))
508+
509+
ruleTest(rule,
510+
Coalesce(Seq(nullLit, intLit)),
511+
Coalesce(Seq(Cast(nullLit, IntegerType), Cast(intLit, IntegerType))))
512+
513+
ruleTest(rule,
514+
Coalesce(Seq(timestampLit, stringLit)),
515+
Coalesce(Seq(Cast(timestampLit, StringType), Cast(stringLit, StringType))))
516+
517+
ruleTest(rule,
518+
Coalesce(Seq(nullLit, floatNullLit, intLit)),
519+
Coalesce(Seq(Cast(nullLit, FloatType), Cast(floatNullLit, FloatType),
520+
Cast(intLit, FloatType))))
521+
522+
ruleTest(rule,
523+
Coalesce(Seq(nullLit, intLit, decimalLit, doubleLit)),
524+
Coalesce(Seq(Cast(nullLit, DoubleType), Cast(intLit, DoubleType),
525+
Cast(decimalLit, DoubleType), Cast(doubleLit, DoubleType))))
526+
527+
ruleTest(rule,
528+
Coalesce(Seq(nullLit, floatNullLit, doubleLit, stringLit)),
529+
Coalesce(Seq(Cast(nullLit, StringType), Cast(floatNullLit, StringType),
530+
Cast(doubleLit, StringType), Cast(stringLit, StringType))))
505531
}
506532

507533
test("CreateArray casts") {
@@ -675,6 +701,14 @@ class TypeCoercionSuite extends PlanTest {
675701

676702
test("type coercion for If") {
677703
val rule = TypeCoercion.IfCoercion
704+
val intLit = Literal(1)
705+
val doubleLit = Literal(1.0)
706+
val trueLit = Literal.create(true, BooleanType)
707+
val falseLit = Literal.create(false, BooleanType)
708+
val stringLit = Literal.create("c", StringType)
709+
val floatLit = Literal.create(1.0f, FloatType)
710+
val timestampLit = Literal.create("2017-04-12", TimestampType)
711+
val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000"))
678712

679713
ruleTest(rule,
680714
If(Literal(true), Literal(1), Literal(1L)),
@@ -685,12 +719,32 @@ class TypeCoercionSuite extends PlanTest {
685719
If(Literal.create(null, BooleanType), Literal(1), Literal(1)))
686720

687721
ruleTest(rule,
688-
If(AssertTrue(Literal.create(true, BooleanType)), Literal(1), Literal(2)),
689-
If(Cast(AssertTrue(Literal.create(true, BooleanType)), BooleanType), Literal(1), Literal(2)))
722+
If(AssertTrue(trueLit), Literal(1), Literal(2)),
723+
If(Cast(AssertTrue(trueLit), BooleanType), Literal(1), Literal(2)))
724+
725+
ruleTest(rule,
726+
If(AssertTrue(falseLit), Literal(1), Literal(2)),
727+
If(Cast(AssertTrue(falseLit), BooleanType), Literal(1), Literal(2)))
728+
729+
ruleTest(rule,
730+
If(trueLit, intLit, doubleLit),
731+
If(trueLit, Cast(intLit, DoubleType), doubleLit))
732+
733+
ruleTest(rule,
734+
If(trueLit, floatLit, doubleLit),
735+
If(trueLit, Cast(floatLit, DoubleType), doubleLit))
736+
737+
ruleTest(rule,
738+
If(trueLit, floatLit, decimalLit),
739+
If(trueLit, Cast(floatLit, DoubleType), Cast(decimalLit, DoubleType)))
740+
741+
ruleTest(rule,
742+
If(falseLit, stringLit, doubleLit),
743+
If(falseLit, stringLit, Cast(doubleLit, StringType)))
690744

691745
ruleTest(rule,
692-
If(AssertTrue(Literal.create(false, BooleanType)), Literal(1), Literal(2)),
693-
If(Cast(AssertTrue(Literal.create(false, BooleanType)), BooleanType), Literal(1), Literal(2)))
746+
If(trueLit, timestampLit, stringLit),
747+
If(trueLit, Cast(timestampLit, StringType), stringLit))
694748
}
695749

696750
test("type coercion for CaseKeyWhen") {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,30 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
9797
val doubleLit = Literal.create(2.2, DoubleType)
9898
val stringLit = Literal.create("c", StringType)
9999
val nullLit = Literal.create(null, NullType)
100-
100+
val floatNullLit = Literal.create(null, FloatType)
101+
val floatLit = Literal.create(1.01f, FloatType)
102+
val timestampLit = Literal.create("2017-04-12", TimestampType)
103+
val decimalLit = Literal.create(10.2, DecimalType(20, 2))
104+
105+
assert(analyze(new Nvl(decimalLit, stringLit)).dataType == StringType)
106+
assert(analyze(new Nvl(doubleLit, decimalLit)).dataType == DoubleType)
107+
assert(analyze(new Nvl(decimalLit, doubleLit)).dataType == DoubleType)
108+
assert(analyze(new Nvl(decimalLit, floatLit)).dataType == DoubleType)
109+
assert(analyze(new Nvl(floatLit, decimalLit)).dataType == DoubleType)
110+
111+
assert(analyze(new Nvl(timestampLit, stringLit)).dataType == StringType)
101112
assert(analyze(new Nvl(intLit, doubleLit)).dataType == DoubleType)
102113
assert(analyze(new Nvl(intLit, stringLit)).dataType == StringType)
103114
assert(analyze(new Nvl(stringLit, doubleLit)).dataType == StringType)
115+
assert(analyze(new Nvl(doubleLit, stringLit)).dataType == StringType)
104116

105117
assert(analyze(new Nvl(nullLit, intLit)).dataType == IntegerType)
106118
assert(analyze(new Nvl(doubleLit, nullLit)).dataType == DoubleType)
107119
assert(analyze(new Nvl(nullLit, stringLit)).dataType == StringType)
120+
121+
assert(analyze(new Nvl(floatLit, stringLit)).dataType == StringType)
122+
assert(analyze(new Nvl(floatLit, doubleLit)).dataType == DoubleType)
123+
assert(analyze(new Nvl(floatNullLit, intLit)).dataType == FloatType)
108124
}
109125

110126
test("AtLeastNNonNulls") {

0 commit comments

Comments
 (0)