@@ -207,20 +207,12 @@ case class Multiply(left: Expression, right: Expression)
207207 protected override def nullSafeEval (input1 : Any , input2 : Any ): Any = numeric.times(input1, input2)
208208}
209209
210- @ ExpressionDescription (
211- usage = " a _FUNC_ b - Divides a by b." ,
212- extended = " > SELECT 3 _FUNC_ 2;\n 1.5" )
213- case class Divide (left : Expression , right : Expression )
214- extends BinaryArithmetic with NullIntolerant {
215-
216- override def inputType : AbstractDataType = TypeCollection (DoubleType , DecimalType )
217-
218- override def symbol : String = " /"
219- override def decimalMethod : String = " $div"
210+ abstract class DivisionArithmetic extends BinaryArithmetic with NullIntolerant {
220211 override def nullable : Boolean = true
221212
222213 private lazy val div : (Any , Any ) => Any = dataType match {
223214 case ft : FractionalType => ft.fractional.asInstanceOf [Fractional [Any ]].div
215+ case i : IntegralType => i.integral.asInstanceOf [Integral [Any ]].quot
224216 }
225217
226218 override def eval (input : InternalRow ): Any = {
@@ -237,119 +229,96 @@ case class Divide(left: Expression, right: Expression)
237229 }
238230 }
239231
232+ // Used by doGenCode
233+ def divide (eval1 : ExprCode , eval2 : ExprCode , javaType : String ): String
234+ def isZero (eval2 : ExprCode ): String
235+
240236 /**
241237 * Special case handling due to division by 0 => null.
242238 */
243239 override def doGenCode (ctx : CodegenContext , ev : ExprCode ): ExprCode = {
244240 val eval1 = left.genCode(ctx)
245241 val eval2 = right.genCode(ctx)
246- val isZero = if (dataType.isInstanceOf [DecimalType ]) {
247- s " ${eval2.value}.isZero() "
248- } else {
249- s " ${eval2.value} == 0 "
250- }
242+ val isZeroCheck = isZero(eval2)
251243 val javaType = ctx.javaType(dataType)
252- val divide = if (dataType.isInstanceOf [DecimalType ]) {
253- s " ${eval1.value}. $decimalMethod( ${eval2.value}) "
254- } else {
255- s " ( $javaType)( ${eval1.value} $symbol ${eval2.value}) "
256- }
244+ val division = divide(eval1, eval2, javaType)
257245 if (! left.nullable && ! right.nullable) {
258246 ev.copy(code = s """
259247 ${eval2.code}
260248 boolean ${ev.isNull} = false;
261249 $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
262- if ( $isZero ) {
250+ if ( $isZeroCheck ) {
263251 ${ev.isNull} = true;
264252 } else {
265253 ${eval1.code}
266- ${ev.value} = $divide ;
254+ ${ev.value} = $division ;
267255 } """ )
268256 } else {
269257 ev.copy(code = s """
270258 ${eval2.code}
271259 boolean ${ev.isNull} = false;
272260 $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
273- if ( ${eval2.isNull} || $isZero ) {
261+ if ( ${eval2.isNull} || $isZeroCheck ) {
274262 ${ev.isNull} = true;
275263 } else {
276264 ${eval1.code}
277265 if ( ${eval1.isNull}) {
278266 ${ev.isNull} = true;
279267 } else {
280- ${ev.value} = $divide ;
268+ ${ev.value} = $division ;
281269 }
282270 } """ )
283271 }
284272 }
285273}
286274
275+ @ ExpressionDescription (
276+ usage = " a _FUNC_ b - Divides a by b." ,
277+ extended = " > SELECT 3 _FUNC_ 2;\n 1.5" )
278+ case class Divide (left : Expression , right : Expression )
279+ extends DivisionArithmetic {
280+
281+ override def inputType : AbstractDataType = TypeCollection (DoubleType , DecimalType )
282+
283+ override def symbol : String = " /"
284+ override def decimalMethod : String = " $div"
285+
286+ // Used by doGenCode
287+ override def divide (eval1 : ExprCode , eval2 : ExprCode , javaType : String ): String = {
288+ if (dataType.isInstanceOf [DecimalType ]) {
289+ s " ${eval1.value}. $decimalMethod( ${eval2.value}) "
290+ } else {
291+ s " ( $javaType)( ${eval1.value} $symbol ${eval2.value}) "
292+ }
293+ }
294+
295+ override def isZero (eval2 : ExprCode ): String = {
296+ if (dataType.isInstanceOf [DecimalType ]) {
297+ s " ${eval2.value}.isZero() "
298+ } else {
299+ s " ${eval2.value} == 0 "
300+ }
301+ }
302+ }
303+
287304@ ExpressionDescription (
288305 usage = " a _FUNC_ b - Divides a by b." ,
289306 extended = " > SELECT 3 _FUNC_ 2;\n 1" )
290307case class IntegerDivide (left : Expression , right : Expression )
291- extends BinaryArithmetic with NullIntolerant {
308+ extends DivisionArithmetic {
292309
293310 override def inputType : AbstractDataType = IntegralType
294311
295312 override def symbol : String = " div"
296313 override def decimalMethod : String = " /"
297- override def nullable : Boolean = true
298314
299- private lazy val div : (Any , Any ) => Any = dataType match {
300- case i : IntegralType => i.integral.asInstanceOf [Integral [Any ]].quot
301- }
302-
303- override def eval (input : InternalRow ): Any = {
304- val input2 = right.eval(input)
305- if (input2 == null || input2 == 0 ) {
306- null
307- } else {
308- val input1 = left.eval(input)
309- if (input1 == null ) {
310- null
311- } else {
312- div(input1, input2)
313- }
314- }
315+ // Used by doGenCode
316+ override def divide (eval1 : ExprCode , eval2 : ExprCode , javaType : String ): String = {
317+ s " ( $javaType)( ${eval1.value} $decimalMethod ( ${eval2.value})) "
315318 }
316319
317- /**
318- * Special case handling due to division by 0 => null.
319- */
320- override def doGenCode (ctx : CodegenContext , ev : ExprCode ): ExprCode = {
321- val eval1 = left.genCode(ctx)
322- val eval2 = right.genCode(ctx)
323- val isZero = s " ${eval2.value} == 0 "
324- val javaType = ctx.javaType(dataType)
325- val divide = s " ( $javaType)( ${eval1.value} $decimalMethod ( ${eval2.value})) "
326- if (! left.nullable && ! right.nullable) {
327- ev.copy(code = s """
328- ${eval2.code}
329- boolean ${ev.isNull} = false;
330- $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
331- if ( $isZero) {
332- ${ev.isNull} = true;
333- } else {
334- ${eval1.code}
335- ${ev.value} = $divide;
336- } """ )
337- } else {
338- ev.copy(code = s """
339- ${eval2.code}
340- boolean ${ev.isNull} = false;
341- $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
342- if ( ${eval2.isNull} || $isZero) {
343- ${ev.isNull} = true;
344- } else {
345- ${eval1.code}
346- if ( ${eval1.isNull}) {
347- ${ev.isNull} = true;
348- } else {
349- ${ev.value} = $divide;
350- }
351- } """ )
352- }
320+ override def isZero (eval2 : ExprCode ): String = {
321+ s " ${eval2.value} == 0 "
353322 }
354323}
355324
0 commit comments