@@ -548,8 +548,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
548548 castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast))
549549 }
550550
551- // three function arguments are: child.primitive, result.primitive and result.isNull
552- // it returns the code snippets to be put in null safe evaluation region
551+ // The function arguments are: `input`, ` result` and `resultIsNull`. We don't need `inputIsNull`
552+ // in parameter list, because the returned code will be put in null safe evaluation region.
553553 private [this ] type CastFunction = (String , String , String ) => String
554554
555555 private [this ] def nullSafeCastFunction (
@@ -584,15 +584,15 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
584584 throw new SparkException (s " Cannot cast $from to $to. " )
585585 }
586586
587- // Since we need to cast child expressions recursively inside ComplexTypes, such as Map's
587+ // Since we need to cast input expressions recursively inside ComplexTypes, such as Map's
588588 // Key and Value, Struct's field, we need to name out all the variable names involved in a cast.
589- private [this ] def castCode (ctx : CodegenContext , childPrim : String , childNull : String ,
590- resultPrim : String , resultNull : String , resultType : DataType , cast : CastFunction ): String = {
589+ private [this ] def castCode (ctx : CodegenContext , input : String , inputIsNull : String ,
590+ result : String , resultIsNull : String , resultType : DataType , cast : CastFunction ): String = {
591591 s """
592- boolean $resultNull = $childNull ;
593- ${ctx.javaType(resultType)} $resultPrim = ${ctx.defaultValue(resultType)};
594- if (! $childNull ) {
595- ${cast(childPrim, resultPrim, resultNull )}
592+ boolean $resultIsNull = $inputIsNull ;
593+ ${ctx.javaType(resultType)} $result = ${ctx.defaultValue(resultType)};
594+ if (! $inputIsNull ) {
595+ ${cast(input, result, resultIsNull )}
596596 }
597597 """
598598 }
@@ -1014,8 +1014,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
10141014 case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx)
10151015 }
10161016 val rowClass = classOf [GenericInternalRow ].getName
1017- val result = ctx.freshName(" result " )
1018- val tmpRow = ctx.freshName(" tmpRow " )
1017+ val tmpResult = ctx.freshName(" tmpResult " )
1018+ val tmpInput = ctx.freshName(" tmpInput " )
10191019
10201020 val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) =>
10211021 val fromFieldPrim = ctx.freshName(" ffp" )
@@ -1024,37 +1024,33 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
10241024 val toFieldNull = ctx.freshName(" tfn" )
10251025 val fromType = ctx.javaType(from.fields(i).dataType)
10261026 s """
1027- boolean $fromFieldNull = $tmpRow .isNullAt( $i);
1027+ boolean $fromFieldNull = $tmpInput .isNullAt( $i);
10281028 if ( $fromFieldNull) {
1029- $result .setNullAt( $i);
1029+ $tmpResult .setNullAt( $i);
10301030 } else {
10311031 $fromType $fromFieldPrim =
1032- ${ctx.getValue(tmpRow , from.fields(i).dataType, i.toString)};
1032+ ${ctx.getValue(tmpInput , from.fields(i).dataType, i.toString)};
10331033 ${castCode(ctx, fromFieldPrim,
10341034 fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)}
10351035 if ( $toFieldNull) {
1036- $result .setNullAt( $i);
1036+ $tmpResult .setNullAt( $i);
10371037 } else {
1038- ${ctx.setColumn(result , to.fields(i).dataType, i, toFieldPrim)};
1038+ ${ctx.setColumn(tmpResult , to.fields(i).dataType, i, toFieldPrim)};
10391039 }
10401040 }
10411041 """
10421042 }
1043- val fieldsEvalCodes = if (ctx.currentVars == null ) {
1044- ctx.splitExpressions(
1045- expressions = fieldsEvalCode,
1046- funcName = " castStruct" ,
1047- arguments = (" InternalRow" , tmpRow) :: (rowClass, result) :: Nil )
1048- } else {
1049- fieldsEvalCode.mkString(" \n " )
1050- }
1043+ val fieldsEvalCodes = ctx.splitExpressions(
1044+ expressions = fieldsEvalCode,
1045+ funcName = " castStruct" ,
1046+ arguments = (" InternalRow" , tmpInput) :: (rowClass, tmpResult) :: Nil )
10511047
1052- (c, evPrim, evNull ) =>
1048+ (input, result, resultIsNull ) =>
10531049 s """
1054- final $rowClass $result = new $rowClass( ${fieldsCasts.length});
1055- final InternalRow $tmpRow = $c ;
1050+ final $rowClass $tmpResult = new $rowClass( ${fieldsCasts.length});
1051+ final InternalRow $tmpInput = $input ;
10561052 $fieldsEvalCodes
1057- $evPrim = $result ;
1053+ $result = $tmpResult ;
10581054 """
10591055 }
10601056
0 commit comments