Skip to content

Commit 7ce4b82

Browse files
committed
cast for struct can split code even with whole stage codegen
1 parent a8af4da commit 7ce4b82

File tree

1 file changed

+24
-28
lines changed
  • sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions

1 file changed

+24
-28
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -548,8 +548,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
548548
castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast))
549549
}
550550

551-
// three function arguments are: child.primitive, result.primitive and result.isNull
552-
// it returns the code snippets to be put in null safe evaluation region
551+
// The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull`
552+
// in parameter list, because the returned code will be put in null safe evaluation region.
553553
private[this] type CastFunction = (String, String, String) => String
554554

555555
private[this] def nullSafeCastFunction(
@@ -584,15 +584,15 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
584584
throw new SparkException(s"Cannot cast $from to $to.")
585585
}
586586

587-
// Since we need to cast child expressions recursively inside ComplexTypes, such as Map's
587+
// Since we need to cast input expressions recursively inside ComplexTypes, such as Map's
588588
// Key and Value, Struct's field, we need to name out all the variable names involved in a cast.
589-
private[this] def castCode(ctx: CodegenContext, childPrim: String, childNull: String,
590-
resultPrim: String, resultNull: String, resultType: DataType, cast: CastFunction): String = {
589+
private[this] def castCode(ctx: CodegenContext, input: String, inputIsNull: String,
590+
result: String, resultIsNull: String, resultType: DataType, cast: CastFunction): String = {
591591
s"""
592-
boolean $resultNull = $childNull;
593-
${ctx.javaType(resultType)} $resultPrim = ${ctx.defaultValue(resultType)};
594-
if (!$childNull) {
595-
${cast(childPrim, resultPrim, resultNull)}
592+
boolean $resultIsNull = $inputIsNull;
593+
${ctx.javaType(resultType)} $result = ${ctx.defaultValue(resultType)};
594+
if (!$inputIsNull) {
595+
${cast(input, result, resultIsNull)}
596596
}
597597
"""
598598
}
@@ -1014,8 +1014,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
10141014
case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx)
10151015
}
10161016
val rowClass = classOf[GenericInternalRow].getName
1017-
val result = ctx.freshName("result")
1018-
val tmpRow = ctx.freshName("tmpRow")
1017+
val tmpResult = ctx.freshName("tmpResult")
1018+
val tmpInput = ctx.freshName("tmpInput")
10191019

10201020
val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) =>
10211021
val fromFieldPrim = ctx.freshName("ffp")
@@ -1024,37 +1024,33 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
10241024
val toFieldNull = ctx.freshName("tfn")
10251025
val fromType = ctx.javaType(from.fields(i).dataType)
10261026
s"""
1027-
boolean $fromFieldNull = $tmpRow.isNullAt($i);
1027+
boolean $fromFieldNull = $tmpInput.isNullAt($i);
10281028
if ($fromFieldNull) {
1029-
$result.setNullAt($i);
1029+
$tmpResult.setNullAt($i);
10301030
} else {
10311031
$fromType $fromFieldPrim =
1032-
${ctx.getValue(tmpRow, from.fields(i).dataType, i.toString)};
1032+
${ctx.getValue(tmpInput, from.fields(i).dataType, i.toString)};
10331033
${castCode(ctx, fromFieldPrim,
10341034
fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)}
10351035
if ($toFieldNull) {
1036-
$result.setNullAt($i);
1036+
$tmpResult.setNullAt($i);
10371037
} else {
1038-
${ctx.setColumn(result, to.fields(i).dataType, i, toFieldPrim)};
1038+
${ctx.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)};
10391039
}
10401040
}
10411041
"""
10421042
}
1043-
val fieldsEvalCodes = if (ctx.currentVars == null) {
1044-
ctx.splitExpressions(
1045-
expressions = fieldsEvalCode,
1046-
funcName = "castStruct",
1047-
arguments = ("InternalRow", tmpRow) :: (rowClass, result) :: Nil)
1048-
} else {
1049-
fieldsEvalCode.mkString("\n")
1050-
}
1043+
val fieldsEvalCodes = ctx.splitExpressions(
1044+
expressions = fieldsEvalCode,
1045+
funcName = "castStruct",
1046+
arguments = ("InternalRow", tmpInput) :: (rowClass, tmpResult) :: Nil)
10511047

1052-
(c, evPrim, evNull) =>
1048+
(input, result, resultIsNull) =>
10531049
s"""
1054-
final $rowClass $result = new $rowClass(${fieldsCasts.length});
1055-
final InternalRow $tmpRow = $c;
1050+
final $rowClass $tmpResult = new $rowClass(${fieldsCasts.length});
1051+
final InternalRow $tmpInput = $input;
10561052
$fieldsEvalCodes
1057-
$evPrim = $result;
1053+
$result = $tmpResult;
10581054
"""
10591055
}
10601056

0 commit comments

Comments
 (0)