@@ -484,24 +484,50 @@ class TypeCoercionSuite extends PlanTest {
484484 }
485485
486486 test(" coalesce casts" ) {
487- ruleTest(TypeCoercion .FunctionArgumentConversion ,
488- Coalesce (Literal (1.0 )
489- :: Literal (1 )
490- :: Literal .create(1.0 , FloatType )
491- :: Nil ),
492- Coalesce (Cast (Literal (1.0 ), DoubleType )
493- :: Cast (Literal (1 ), DoubleType )
494- :: Cast (Literal .create(1.0 , FloatType ), DoubleType )
495- :: Nil ))
496- ruleTest(TypeCoercion .FunctionArgumentConversion ,
497- Coalesce (Literal (1L )
498- :: Literal (1 )
499- :: Literal (new java.math.BigDecimal (" 1000000000000000000000" ))
500- :: Nil ),
501- Coalesce (Cast (Literal (1L ), DecimalType (22 , 0 ))
502- :: Cast (Literal (1 ), DecimalType (22 , 0 ))
503- :: Cast (Literal (new java.math.BigDecimal (" 1000000000000000000000" )), DecimalType (22 , 0 ))
504- :: Nil ))
487+ val rule = TypeCoercion .FunctionArgumentConversion
488+
489+ val intLit = Literal (1 )
490+ val longLit = Literal .create(1L )
491+ val doubleLit = Literal (1.0 )
492+ val stringLit = Literal .create(" c" , StringType )
493+ val nullLit = Literal .create(null , NullType )
494+ val floatNullLit = Literal .create(null , FloatType )
495+ val floatLit = Literal .create(1.0f , FloatType )
496+ val timestampLit = Literal .create(" 2017-04-12" , TimestampType )
497+ val decimalLit = Literal (new java.math.BigDecimal (" 1000000000000000000000" ))
498+
499+ ruleTest(rule,
500+ Coalesce (Seq (doubleLit, intLit, floatLit)),
501+ Coalesce (Seq (Cast (doubleLit, DoubleType ),
502+ Cast (intLit, DoubleType ), Cast (floatLit, DoubleType ))))
503+
504+ ruleTest(rule,
505+ Coalesce (Seq (longLit, intLit, decimalLit)),
506+ Coalesce (Seq (Cast (longLit, DecimalType (22 , 0 )),
507+ Cast (intLit, DecimalType (22 , 0 )), Cast (decimalLit, DecimalType (22 , 0 )))))
508+
509+ ruleTest(rule,
510+ Coalesce (Seq (nullLit, intLit)),
511+ Coalesce (Seq (Cast (nullLit, IntegerType ), Cast (intLit, IntegerType ))))
512+
513+ ruleTest(rule,
514+ Coalesce (Seq (timestampLit, stringLit)),
515+ Coalesce (Seq (Cast (timestampLit, StringType ), Cast (stringLit, StringType ))))
516+
517+ ruleTest(rule,
518+ Coalesce (Seq (nullLit, floatNullLit, intLit)),
519+ Coalesce (Seq (Cast (nullLit, FloatType ), Cast (floatNullLit, FloatType ),
520+ Cast (intLit, FloatType ))))
521+
522+ ruleTest(rule,
523+ Coalesce (Seq (nullLit, intLit, decimalLit, doubleLit)),
524+ Coalesce (Seq (Cast (nullLit, DoubleType ), Cast (intLit, DoubleType ),
525+ Cast (decimalLit, DoubleType ), Cast (doubleLit, DoubleType ))))
526+
527+ ruleTest(rule,
528+ Coalesce (Seq (nullLit, floatNullLit, doubleLit, stringLit)),
529+ Coalesce (Seq (Cast (nullLit, StringType ), Cast (floatNullLit, StringType ),
530+ Cast (doubleLit, StringType ), Cast (stringLit, StringType ))))
505531 }
506532
507533 test(" CreateArray casts" ) {
@@ -675,6 +701,14 @@ class TypeCoercionSuite extends PlanTest {
675701
676702 test(" type coercion for If" ) {
677703 val rule = TypeCoercion .IfCoercion
704+ val intLit = Literal (1 )
705+ val doubleLit = Literal (1.0 )
706+ val trueLit = Literal .create(true , BooleanType )
707+ val falseLit = Literal .create(false , BooleanType )
708+ val stringLit = Literal .create(" c" , StringType )
709+ val floatLit = Literal .create(1.0f , FloatType )
710+ val timestampLit = Literal .create(" 2017-04-12" , TimestampType )
711+ val decimalLit = Literal (new java.math.BigDecimal (" 1000000000000000000000" ))
678712
679713 ruleTest(rule,
680714 If (Literal (true ), Literal (1 ), Literal (1L )),
@@ -685,12 +719,32 @@ class TypeCoercionSuite extends PlanTest {
685719 If (Literal .create(null , BooleanType ), Literal (1 ), Literal (1 )))
686720
687721 ruleTest(rule,
688- If (AssertTrue (Literal .create(true , BooleanType )), Literal (1 ), Literal (2 )),
689- If (Cast (AssertTrue (Literal .create(true , BooleanType )), BooleanType ), Literal (1 ), Literal (2 )))
722+ If (AssertTrue (trueLit), Literal (1 ), Literal (2 )),
723+ If (Cast (AssertTrue (trueLit), BooleanType ), Literal (1 ), Literal (2 )))
724+
725+ ruleTest(rule,
726+ If (AssertTrue (falseLit), Literal (1 ), Literal (2 )),
727+ If (Cast (AssertTrue (falseLit), BooleanType ), Literal (1 ), Literal (2 )))
728+
729+ ruleTest(rule,
730+ If (trueLit, intLit, doubleLit),
731+ If (trueLit, Cast (intLit, DoubleType ), doubleLit))
732+
733+ ruleTest(rule,
734+ If (trueLit, floatLit, doubleLit),
735+ If (trueLit, Cast (floatLit, DoubleType ), doubleLit))
736+
737+ ruleTest(rule,
738+ If (trueLit, floatLit, decimalLit),
739+ If (trueLit, Cast (floatLit, DoubleType ), Cast (decimalLit, DoubleType )))
740+
741+ ruleTest(rule,
742+ If (falseLit, stringLit, doubleLit),
743+ If (falseLit, stringLit, Cast (doubleLit, StringType )))
690744
691745 ruleTest(rule,
692- If (AssertTrue ( Literal .create( false , BooleanType )), Literal ( 1 ), Literal ( 2 ) ),
693- If (Cast (AssertTrue ( Literal .create( false , BooleanType )), BooleanType ), Literal ( 1 ), Literal ( 2 ) ))
746+ If (trueLit, timestampLit, stringLit ),
747+ If (trueLit, Cast (timestampLit, StringType ), stringLit ))
694748 }
695749
696750 test(" type coercion for CaseKeyWhen" ) {
0 commit comments