Skip to content

Commit d1c4e92

Browse files
committed
RuntimeReplaceable
1 parent 483affb commit d1c4e92

File tree

3 files changed

+29
-36
lines changed

3 files changed

+29
-36
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,12 @@ package org.apache.spark.sql.catalyst.expressions
2020
import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
2222
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
23-
import org.apache.spark.sql.types.{AbstractDataType, DataType}
24-
25-
private[catalyst] abstract class TryEval extends Expression with NullIntolerant {
26-
protected def internalExpression: Expression
23+
import org.apache.spark.sql.types.DataType
2724

25+
private[catalyst] case class TryEval(child: Expression)
26+
extends UnaryExpression with NullIntolerant {
2827
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
29-
val childGen = internalExpression.genCode(ctx)
28+
val childGen = child.genCode(ctx)
3029
ev.copy(code = code"""
3130
boolean ${ev.isNull} = true;
3231
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -41,17 +40,18 @@ private[catalyst] abstract class TryEval extends Expression with NullIntolerant
4140

4241
override def eval(input: InternalRow): Any =
4342
try {
44-
internalExpression.eval(input)
43+
child.eval(input)
4544
} catch {
4645
case _: Exception =>
4746
null
4847
}
4948

50-
override def dataType: DataType = internalExpression.dataType
49+
override def dataType: DataType = child.dataType
5150

5251
override def nullable: Boolean = true
5352

54-
override def children: Seq[Expression] = internalExpression.children
53+
override protected def withNewChildInternal(newChild: Expression): Expression =
54+
copy(child = newChild)
5555
}
5656

5757
@ExpressionDescription(
@@ -63,21 +63,21 @@ private[catalyst] abstract class TryEval extends Expression with NullIntolerant
6363
""",
6464
since = "3.2.0",
6565
group = "math_funcs")
66-
case class TryAdd(left: Expression, right: Expression) extends TryEval with ImplicitCastInputTypes {
66+
case class TryAdd(left: Expression, right: Expression, child: Expression)
67+
extends RuntimeReplaceable {
68+
def this(left: Expression, right: Expression) =
69+
this(left, right, TryEval(Add(left, right, failOnError = true)))
6770

68-
protected override def internalExpression: Expression =
69-
Add(left: Expression, right: Expression, failOnError = true)
71+
override def flatArguments: Iterator[Any] = Iterator(left, right)
7072

71-
override def prettyName: String = "try_add"
73+
override def exprsReplaced: Seq[Expression] = Seq(left, right)
7274

73-
override def inputTypes: Seq[AbstractDataType] =
74-
internalExpression.asInstanceOf[ExpectsInputTypes].inputTypes
75+
override def prettyName: String = "try_add"
7576

76-
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
77-
copy(left = newChildren(0), right = newChildren(1))
77+
override protected def withNewChildInternal(newChild: Expression): Expression =
78+
this.copy(child = newChild)
7879
}
7980

80-
8181
// scalastyle:off line.size.limit
8282
@ExpressionDescription(
8383
usage = "_FUNC_(expr1, expr2) - Returns `expr1`/`expr2`. It always performs floating point division. Its result is always null if `expr2` is 0.",
@@ -91,17 +91,17 @@ case class TryAdd(left: Expression, right: Expression) extends TryEval with Impl
9191
since = "3.2.0",
9292
group = "math_funcs")
9393
// scalastyle:on line.size.limit
94-
case class TryDivide(left: Expression, right: Expression)
95-
extends TryEval with ImplicitCastInputTypes {
94+
case class TryDivide(left: Expression, right: Expression, child: Expression)
95+
extends RuntimeReplaceable {
96+
def this(left: Expression, right: Expression) =
97+
this(left, right, TryEval(Divide(left, right, failOnError = true)))
9698

97-
protected override def internalExpression: Expression =
98-
Divide(left, right, failOnError = true)
99+
override def flatArguments: Iterator[Any] = Iterator(left, right)
99100

100-
override def prettyName: String = "try_divide"
101+
override def exprsReplaced: Seq[Expression] = Seq(left, right)
101102

102-
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
103-
copy(left = newChildren(0), right = newChildren(1))
103+
override def prettyName: String = "try_divide"
104104

105-
override def inputTypes: Seq[AbstractDataType] =
106-
internalExpression.asInstanceOf[ExpectsInputTypes].inputTypes
105+
override protected def withNewChildInternal(newChild: Expression): Expression =
106+
this.copy(child = newChild)
107107
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryEvalSuite.scala

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.SparkFunSuite
21-
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FUNC_ALIAS
2221

2322
class TryEvalSuite extends SparkFunSuite with ExpressionEvalHelper {
2423
test("try_add") {
@@ -29,11 +28,8 @@ class TryEvalSuite extends SparkFunSuite with ExpressionEvalHelper {
2928
).foreach { case (a, b, expected) =>
3029
val left = Literal(a)
3130
val right = Literal(b)
32-
val input = TryAdd(left, right)
31+
val input = TryEval(Add(left, right, failOnError = true))
3332
checkEvaluation(input, expected)
34-
input.setTagValue(FUNC_ALIAS, "try_add")
35-
assert(input.toString == s"try_add(${left.toString}, ${right.toString})")
36-
assert(input.sql == s"try_add(${left.sql}, ${right.sql})")
3733
}
3834
}
3935

@@ -45,11 +41,8 @@ class TryEvalSuite extends SparkFunSuite with ExpressionEvalHelper {
4541
).foreach { case (a, b, expected) =>
4642
val left = Literal(a)
4743
val right = Literal(b)
48-
val input = TryDivide(left, right)
44+
val input = TryEval(Divide(left, right, failOnError = true))
4945
checkEvaluation(input, expected)
50-
input.setTagValue(FUNC_ALIAS, "try_divide")
51-
assert(input.toString == s"try_divide(${left.toString}, ${right.toString})")
52-
assert(input.sql == s"try_divide(${left.sql}, ${right.sql})")
5346
}
5447
}
5548
}

sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ NULL
4545
-- !query
4646
SELECT try_divide(1, 0.5)
4747
-- !query schema
48-
struct<try_divide((1 / 0.5)):decimal(8,6)>
48+
struct<try_divide(1, 0.5):decimal(8,6)>
4949
-- !query output
5050
2.000000
5151

0 commit comments

Comments
 (0)