@@ -553,41 +553,68 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
553553 foldables.nonEmpty && others.length < 2
554554 }
555555
556+ // Not all UnaryExpression can be pushed into (if / case) branches, e.g. Alias.
557+ private def supportedUnaryExpression (e : UnaryExpression ): Boolean = e match {
558+ case _ : IsNull | _ : IsNotNull => true
559+ case _ : UnaryMathExpression | _ : Abs | _ : Bin | _ : Factorial | _ : Hex => true
560+ case _ : String2StringExpression | _ : Ascii | _ : Base64 | _ : BitLength | _ : Chr | _ : Length =>
561+ true
562+ case _ : CastBase => true
563+ case _ : GetDateField | _ : LastDay => true
564+ case _ : ExtractIntervalPart => true
565+ case _ : ArraySetLike => true
566+ case _ : ExtractValue => true
567+ case _ => false
568+ }
569+
570+ // Not all BinaryExpression can be pushed into (if / case) branches.
571+ private def supportedBinaryExpression (e : BinaryExpression ): Boolean = e match {
572+ case _ : BinaryComparison | _ : StringPredicate | _ : StringRegexExpression => true
573+ case _ : BinaryArithmetic => true
574+ case _ : BinaryMathExpression => true
575+ case _ : AddMonths | _ : DateAdd | _ : DateAddInterval | _ : DateDiff | _ : DateSub => true
576+ case _ : FindInSet | _ : RoundBase => true
577+ case _ => false
578+ }
579+
556580 def apply (plan : LogicalPlan ): LogicalPlan = plan transform {
557581 case q : LogicalPlan => q transformExpressionsUp {
558- case a : Alias => a // Skip an alias.
559582 case u @ UnaryExpression (i @ If (_, trueValue, falseValue))
560- if atMostOneUnfoldable(Seq (trueValue, falseValue)) =>
583+ if supportedUnaryExpression(u) && atMostOneUnfoldable(Seq (trueValue, falseValue)) =>
561584 i.copy(
562585 trueValue = u.withNewChildren(Array (trueValue)),
563586 falseValue = u.withNewChildren(Array (falseValue)))
564587
565588 case u @ UnaryExpression (c @ CaseWhen (branches, elseValue))
566- if atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
589+ if supportedUnaryExpression(u) && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
567590 c.copy(
568591 branches.map(e => e.copy(_2 = u.withNewChildren(Array (e._2)))),
569592 elseValue.map(e => u.withNewChildren(Array (e))))
570593
571594 case b @ BinaryExpression (i @ If (_, trueValue, falseValue), right)
572- if right.foldable && atMostOneUnfoldable(Seq (trueValue, falseValue)) =>
595+ if supportedBinaryExpression(b) && right.foldable &&
596+ atMostOneUnfoldable(Seq (trueValue, falseValue)) =>
573597 i.copy(
574598 trueValue = b.withNewChildren(Array (trueValue, right)),
575599 falseValue = b.withNewChildren(Array (falseValue, right)))
576600
577601 case b @ BinaryExpression (left, i @ If (_, trueValue, falseValue))
578- if left.foldable && atMostOneUnfoldable(Seq (trueValue, falseValue)) =>
602+ if supportedBinaryExpression(b) && left.foldable &&
603+ atMostOneUnfoldable(Seq (trueValue, falseValue)) =>
579604 i.copy(
580605 trueValue = b.withNewChildren(Array (left, trueValue)),
581606 falseValue = b.withNewChildren(Array (left, falseValue)))
582607
583608 case b @ BinaryExpression (c @ CaseWhen (branches, elseValue), right)
584- if right.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
609+ if supportedBinaryExpression(b) && right.foldable &&
610+ atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
585611 c.copy(
586612 branches.map(e => e.copy(_2 = b.withNewChildren(Array (e._2, right)))),
587613 elseValue.map(e => b.withNewChildren(Array (e, right))))
588614
589615 case b @ BinaryExpression (left, c @ CaseWhen (branches, elseValue))
590- if left.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
616+ if supportedBinaryExpression(b) && left.foldable &&
617+ atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
591618 c.copy(
592619 branches.map(e => e.copy(_2 = b.withNewChildren(Array (left, e._2)))),
593620 elseValue.map(e => b.withNewChildren(Array (left, e))))
0 commit comments