Skip to content

Commit 43fe7f4

Browse files
committed
Support ANSI SQL intervals by the aggregate function avg
1 parent 978cd0b commit 43fe7f4

File tree

3 files changed

+54
-5
lines changed

3 files changed

+54
-5
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, FunctionRegistr
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.trees.UnaryLike
24-
import org.apache.spark.sql.catalyst.util.TypeUtils
2524
import org.apache.spark.sql.types._
2625

2726
@ExpressionDescription(
@@ -40,10 +39,15 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
4039

4140
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg")
4241

43-
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
42+
override def inputTypes: Seq[AbstractDataType] =
43+
Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType))
4444

45-
override def checkInputDataTypes(): TypeCheckResult =
46-
TypeUtils.checkForNumericExpr(child.dataType, "function average")
45+
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
46+
case YearMonthIntervalType | DayTimeIntervalType | NullType => TypeCheckResult.TypeCheckSuccess
47+
case dt if dt.isInstanceOf[NumericType] => TypeCheckResult.TypeCheckSuccess
48+
case other => TypeCheckResult.TypeCheckFailure(
49+
s"function average requires numeric or interval types, not ${other.catalogString}")
50+
}
4751

4852
override def nullable: Boolean = true
4953

@@ -53,11 +57,15 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
5357
private lazy val resultType = child.dataType match {
5458
case DecimalType.Fixed(p, s) =>
5559
DecimalType.bounded(p + 4, s + 4)
60+
case _: YearMonthIntervalType => YearMonthIntervalType
61+
case _: DayTimeIntervalType => DayTimeIntervalType
5662
case _ => DoubleType
5763
}
5864

5965
private lazy val sumDataType = child.dataType match {
6066
case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
67+
case _: YearMonthIntervalType => YearMonthIntervalType
68+
case _: DayTimeIntervalType => DayTimeIntervalType
6169
case _ => DoubleType
6270
}
6371

@@ -82,6 +90,8 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
8290
case _: DecimalType =>
8391
DecimalPrecision.decimalAndDecimal(
8492
Divide(sum, count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType)
93+
case _: YearMonthIntervalType => DivideYMInterval(sum, count)
94+
case _: DayTimeIntervalType => DivideDTInterval(sum, count)
8595
case _ =>
8696
Divide(sum.cast(resultType), count.cast(resultType), failOnError = false)
8797
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
159159
assertError(Min(Symbol("mapField")), "min does not support ordering on type")
160160
assertError(Max(Symbol("mapField")), "max does not support ordering on type")
161161
assertError(Sum(Symbol("booleanField")), "function sum requires numeric or interval types")
162-
assertError(Average(Symbol("booleanField")), "function average requires numeric type")
162+
assertError(Average(Symbol("booleanField")),
163+
"function average requires numeric or interval types")
163164
}
164165

165166
test("check types for others") {

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,6 +1151,44 @@ class DataFrameAggregateSuite extends QueryTest
11511151
}
11521152
assert(error2.toString contains "java.lang.ArithmeticException: long overflow")
11531153
}
1154+
1155+
test("SPARK-34837: Support ANSI SQL intervals by the aggregate function `avg`") {
1156+
val df = Seq((1, Period.ofMonths(10), Duration.ofDays(10)),
1157+
(2, Period.ofMonths(1), Duration.ofDays(1)),
1158+
(2, null, null),
1159+
(3, Period.ofMonths(-3), Duration.ofDays(-6)),
1160+
(3, Period.ofMonths(21), Duration.ofDays(-5)))
1161+
.toDF("class", "year-month", "day-time")
1162+
1163+
val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)),
1164+
(Period.ofMonths(10), Duration.ofDays(10)))
1165+
.toDF("year-month", "day-time")
1166+
1167+
val avgDF = df.select(avg($"year-month"), avg($"day-time"))
1168+
checkAnswer(avgDF, Row(Period.ofMonths(7), Duration.ofDays(0)))
1169+
assert(find(avgDF.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
1170+
assert(avgDF.schema == StructType(Seq(StructField("avg(year-month)", YearMonthIntervalType),
1171+
StructField("avg(day-time)", DayTimeIntervalType))))
1172+
1173+
val avgDF2 = df.groupBy($"class").agg(avg($"year-month"), avg($"day-time"))
1174+
checkAnswer(avgDF2, Row(1, Period.ofMonths(10), Duration.ofDays(10)) ::
1175+
Row(2, Period.ofMonths(1), Duration.ofDays(1)) ::
1176+
Row(3, Period.ofMonths(9), Duration.ofDays(-5).plusHours(-12)) ::Nil)
1177+
assert(find(avgDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
1178+
assert(avgDF2.schema == StructType(Seq(StructField("class", IntegerType, false),
1179+
StructField("avg(year-month)", YearMonthIntervalType),
1180+
StructField("avg(day-time)", DayTimeIntervalType))))
1181+
1182+
val error = intercept[SparkException] {
1183+
checkAnswer(df2.select(avg($"year-month")), Nil)
1184+
}
1185+
assert(error.toString contains "java.lang.ArithmeticException: integer overflow")
1186+
1187+
val error2 = intercept[SparkException] {
1188+
checkAnswer(df2.select(avg($"day-time")), Nil)
1189+
}
1190+
assert(error2.toString contains "java.lang.ArithmeticException: long overflow")
1191+
}
11541192
}
11551193

11561194
case class B(c: Option[Double])

0 commit comments

Comments
 (0)