Skip to content

Commit ff97457

Browse files
committed
Abstract out into DivisionArithmetic
1 parent 237a1ad commit ff97457

File tree

1 file changed

+47
-78
lines changed
  • sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions

1 file changed

+47
-78
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 47 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -207,20 +207,12 @@ case class Multiply(left: Expression, right: Expression)
207207
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)
208208
}
209209

210-
@ExpressionDescription(
211-
usage = "a _FUNC_ b - Divides a by b.",
212-
extended = "> SELECT 3 _FUNC_ 2;\n 1.5")
213-
case class Divide(left: Expression, right: Expression)
214-
extends BinaryArithmetic with NullIntolerant {
215-
216-
override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType)
217-
218-
override def symbol: String = "/"
219-
override def decimalMethod: String = "$div"
210+
abstract class DivisionArithmetic extends BinaryArithmetic with NullIntolerant {
220211
override def nullable: Boolean = true
221212

222213
private lazy val div: (Any, Any) => Any = dataType match {
223214
case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div
215+
case i: IntegralType => i.integral.asInstanceOf[Integral[Any]].quot
224216
}
225217

226218
override def eval(input: InternalRow): Any = {
@@ -237,119 +229,96 @@ case class Divide(left: Expression, right: Expression)
237229
}
238230
}
239231

232+
// Used by doGenCode
233+
def divide(eval1: ExprCode, eval2: ExprCode, javaType: String): String
234+
def isZero(eval2: ExprCode): String
235+
240236
/**
241237
* Special case handling due to division by 0 => null.
242238
*/
243239
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
244240
val eval1 = left.genCode(ctx)
245241
val eval2 = right.genCode(ctx)
246-
val isZero = if (dataType.isInstanceOf[DecimalType]) {
247-
s"${eval2.value}.isZero()"
248-
} else {
249-
s"${eval2.value} == 0"
250-
}
242+
val isZeroCheck = isZero(eval2)
251243
val javaType = ctx.javaType(dataType)
252-
val divide = if (dataType.isInstanceOf[DecimalType]) {
253-
s"${eval1.value}.$decimalMethod(${eval2.value})"
254-
} else {
255-
s"($javaType)(${eval1.value} $symbol ${eval2.value})"
256-
}
244+
val division = divide(eval1, eval2, javaType)
257245
if (!left.nullable && !right.nullable) {
258246
ev.copy(code = s"""
259247
${eval2.code}
260248
boolean ${ev.isNull} = false;
261249
$javaType ${ev.value} = ${ctx.defaultValue(javaType)};
262-
if ($isZero) {
250+
if ($isZeroCheck) {
263251
${ev.isNull} = true;
264252
} else {
265253
${eval1.code}
266-
${ev.value} = $divide;
254+
${ev.value} = $division;
267255
}""")
268256
} else {
269257
ev.copy(code = s"""
270258
${eval2.code}
271259
boolean ${ev.isNull} = false;
272260
$javaType ${ev.value} = ${ctx.defaultValue(javaType)};
273-
if (${eval2.isNull} || $isZero) {
261+
if (${eval2.isNull} || $isZeroCheck) {
274262
${ev.isNull} = true;
275263
} else {
276264
${eval1.code}
277265
if (${eval1.isNull}) {
278266
${ev.isNull} = true;
279267
} else {
280-
${ev.value} = $divide;
268+
${ev.value} = $division;
281269
}
282270
}""")
283271
}
284272
}
285273
}
286274

275+
@ExpressionDescription(
276+
usage = "a _FUNC_ b - Divides a by b.",
277+
extended = "> SELECT 3 _FUNC_ 2;\n 1.5")
278+
case class Divide(left: Expression, right: Expression)
279+
extends DivisionArithmetic {
280+
281+
override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType)
282+
283+
override def symbol: String = "/"
284+
override def decimalMethod: String = "$div"
285+
286+
// Used by doGenCode
287+
override def divide(eval1: ExprCode, eval2: ExprCode, javaType: String): String = {
288+
if (dataType.isInstanceOf[DecimalType]) {
289+
s"${eval1.value}.$decimalMethod(${eval2.value})"
290+
} else {
291+
s"($javaType)(${eval1.value} $symbol ${eval2.value})"
292+
}
293+
}
294+
295+
override def isZero(eval2: ExprCode): String = {
296+
if (dataType.isInstanceOf[DecimalType]) {
297+
s"${eval2.value}.isZero()"
298+
} else {
299+
s"${eval2.value} == 0"
300+
}
301+
}
302+
}
303+
287304
@ExpressionDescription(
288305
usage = "a _FUNC_ b - Divides a by b.",
289306
extended = "> SELECT 3 _FUNC_ 2;\n 1")
290307
case class IntegerDivide(left: Expression, right: Expression)
291-
extends BinaryArithmetic with NullIntolerant {
308+
extends DivisionArithmetic {
292309

293310
override def inputType: AbstractDataType = IntegralType
294311

295312
override def symbol: String = "div"
296313
override def decimalMethod: String = "/"
297-
override def nullable: Boolean = true
298314

299-
private lazy val div: (Any, Any) => Any = dataType match {
300-
case i: IntegralType => i.integral.asInstanceOf[Integral[Any]].quot
301-
}
302-
303-
override def eval(input: InternalRow): Any = {
304-
val input2 = right.eval(input)
305-
if (input2 == null || input2 == 0) {
306-
null
307-
} else {
308-
val input1 = left.eval(input)
309-
if (input1 == null) {
310-
null
311-
} else {
312-
div(input1, input2)
313-
}
314-
}
315+
// Used by doGenCode
316+
override def divide(eval1: ExprCode, eval2: ExprCode, javaType: String): String = {
317+
s"($javaType)(${eval1.value} $decimalMethod (${eval2.value}))"
315318
}
316319

317-
/**
318-
* Special case handling due to division by 0 => null.
319-
*/
320-
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
321-
val eval1 = left.genCode(ctx)
322-
val eval2 = right.genCode(ctx)
323-
val isZero = s"${eval2.value} == 0"
324-
val javaType = ctx.javaType(dataType)
325-
val divide = s"($javaType)(${eval1.value} $decimalMethod (${eval2.value}))"
326-
if (!left.nullable && !right.nullable) {
327-
ev.copy(code = s"""
328-
${eval2.code}
329-
boolean ${ev.isNull} = false;
330-
$javaType ${ev.value} = ${ctx.defaultValue(javaType)};
331-
if ($isZero) {
332-
${ev.isNull} = true;
333-
} else {
334-
${eval1.code}
335-
${ev.value} = $divide;
336-
}""")
337-
} else {
338-
ev.copy(code = s"""
339-
${eval2.code}
340-
boolean ${ev.isNull} = false;
341-
$javaType ${ev.value} = ${ctx.defaultValue(javaType)};
342-
if (${eval2.isNull} || $isZero) {
343-
${ev.isNull} = true;
344-
} else {
345-
${eval1.code}
346-
if (${eval1.isNull}) {
347-
${ev.isNull} = true;
348-
} else {
349-
${ev.value} = $divide;
350-
}
351-
}""")
352-
}
320+
override def isZero(eval2: ExprCode): String = {
321+
s"${eval2.value} == 0"
353322
}
354323
}
355324

0 commit comments

Comments
 (0)