Skip to content

Commit 4ea79bd

Browse files
committed
[SPARK-29684][SQL] Support divide/multiply for interval types
1 parent 4d302cb commit 4ea79bd

File tree

5 files changed

+170
-2
lines changed

5 files changed

+170
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -858,6 +858,12 @@ object TypeCoercion {
858858
SubtractTimestamps(l, Cast(r, TimestampType))
859859
case Subtract(l @ DateType(), r @ TimestampType()) =>
860860
SubtractTimestamps(Cast(l, TimestampType), r)
861+
case Divide(l @ CalendarIntervalType(), r) => IntervalDivide(l, Cast(r, DecimalType(28, 9)))
862+
case Multiply(l @ CalendarIntervalType(), r) =>
863+
IntervalMultiply(l, Cast(r, DecimalType(28, 9)))
864+
case Multiply(l, r @ CalendarIntervalType()) =>
865+
IntervalMultiply(r, Cast(l, DecimalType(28, 9)))
866+
861867
}
862868
}
863869

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.spark.sql.AnalysisException
3030
import org.apache.spark.sql.catalyst.InternalRow
3131
import org.apache.spark.sql.catalyst.expressions.codegen._
3232
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
33-
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, TimestampFormatter}
33+
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils, TimestampFormatter}
3434
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
3535
import org.apache.spark.sql.internal.SQLConf
3636
import org.apache.spark.sql.types._
@@ -2164,3 +2164,66 @@ case class SubtractDates(left: Expression, right: Expression)
21642164
}
21652165
}
21662166

2167+
// scalastyle:off line.size.limit
2168+
@ExpressionDescription(
2169+
usage = "expr1 _FUNC_ expr2 - Divide interval value `expr1` by `expr2`. It returns NULL if `expr2` is 0 or NULL.",
2170+
examples = """
2171+
Examples:
2172+
> SELECT interval '1 year 2 month' / 3.0;
2173+
interval 4 months 2 weeks 6 days
2174+
""",
2175+
since = "3.0.0")
2176+
// scalastyle:on line.size.limit
2177+
case class IntervalDivide(left: Expression, right: Expression)
2178+
extends BinaryExpression with ImplicitCastInputTypes {
2179+
2180+
override def dataType: DataType = CalendarIntervalType
2181+
2182+
override def inputTypes: Seq[AbstractDataType] = Seq(CalendarIntervalType, DecimalType)
2183+
2184+
override def nullSafeEval(interval: Any, divisor: Any): Any = {
2185+
IntervalUtils.divide(interval.asInstanceOf[CalendarInterval],
2186+
divisor.asInstanceOf[Decimal])
2187+
}
2188+
2189+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
2190+
defineCodeGen(ctx, ev, (interval, divisor) => {
2191+
val iu = IntervalUtils.getClass.getName.stripSuffix("$")
2192+
s"$iu.divide($interval, $divisor)"
2193+
})
2194+
}
2195+
}
2196+
2197+
// scalastyle:off line.size.limit
2198+
@ExpressionDescription(
2199+
usage = "expr1 _FUNC_ expr2 - Multiply interval value `expr1` by `expr2`. It returns NULL if `expr2` is 0 or NULL.",
2200+
examples = """
2201+
Examples:
2202+
> SELECT interval '4 months 2 weeks 6 days' * 3.0;
2203+
interval 1 years 8 weeks 4 days
2204+
> SELECT 3.0 * interval '4 months 2 weeks 6 days';
2205+
interval 1 years 8 weeks 4 days
2206+
""",
2207+
since = "3.0.0")
2208+
// scalastyle:on line.size.limit
2209+
case class IntervalMultiply(left: Expression, right: Expression)
2210+
extends BinaryExpression with ImplicitCastInputTypes {
2211+
2212+
override def dataType: DataType = CalendarIntervalType
2213+
2214+
override def inputTypes: Seq[AbstractDataType] = Seq(CalendarIntervalType, DecimalType)
2215+
2216+
override def nullSafeEval(interval: Any, divisor: Any): Any = {
2217+
IntervalUtils.multiply(interval.asInstanceOf[CalendarInterval],
2218+
divisor.asInstanceOf[Decimal])
2219+
}
2220+
2221+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
2222+
defineCodeGen(ctx, ev, (interval, divisor) => {
2223+
val iu = IntervalUtils.getClass.getName.stripSuffix("$")
2224+
s"$iu.multiply($interval, $divisor)"
2225+
})
2226+
}
2227+
}
2228+
2229+

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,22 @@ object IntervalUtils {
9696
Decimal(result, 18, 6)
9797
}
9898

99+
def divide(interval: CalendarInterval, divisor: Decimal): CalendarInterval = {
100+
if (divisor == Decimal.ZERO || divisor == null) return null
101+
val months = Decimal(interval.months) / divisor
102+
val milliseconds = (Decimal(interval.microseconds) / divisor +
103+
months.remainder(Decimal.ONE) * Decimal(MICROS_PER_MONTH)).toLong
104+
new CalendarInterval(months.toInt, milliseconds.toLong)
105+
}
106+
107+
def multiply(interval: CalendarInterval, multiplier: Decimal): CalendarInterval = {
108+
if (multiplier == null) return null
109+
val months = Decimal(interval.months) * multiplier
110+
val milliseconds = (Decimal(interval.microseconds) * multiplier +
111+
months.remainder(Decimal.ONE) * Decimal(MICROS_PER_MONTH)).toLong
112+
new CalendarInterval(months.toInt, milliseconds.toLong)
113+
}
114+
99115
/**
100116
* Converts a string to [[CalendarInterval]] case-insensitively.
101117
*

sql/core/src/test/resources/sql-tests/inputs/datetime.sql

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,14 @@ select date '2001-10-01' - 7;
3636
select date '2001-10-01' - date '2001-09-28';
3737
select date'2020-01-01' - timestamp'2019-10-06 10:11:12.345678';
3838
select timestamp'2019-10-06 10:11:12.345678' - date'2020-01-01';
39+
40+
select interval '1 year 2 month' / null;
41+
select interval '1 year 2 month' / 0;
42+
select interval '1 year 2 month' / 3;
43+
select interval '1 year 2 month' / 3.0;
44+
45+
SELECT interval '4 months 2 weeks 6 days' * null;
46+
SELECT interval '4 months 2 weeks 6 days' * 0;
47+
SELECT interval '4 months 2 weeks 6 days' * 3;
48+
SELECT interval '4 months 2 weeks 6 days' * 3.0;
49+
SELECT 3.0 * interval '4 months 2 weeks 6 days';

sql/core/src/test/resources/sql-tests/results/datetime.sql.out

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 17
2+
-- Number of queries: 26
33

44

55
-- !query 0
@@ -145,3 +145,75 @@ select timestamp'2019-10-06 10:11:12.345678' - date'2020-01-01'
145145
struct<subtracttimestamps(TIMESTAMP('2019-10-06 10:11:12.345678'), CAST(DATE '2020-01-01' AS TIMESTAMP)):interval>
146146
-- !query 16 output
147147
interval -12 weeks -2 days -14 hours -48 minutes -47 seconds -654 milliseconds -322 microseconds
148+
149+
150+
-- !query 17
151+
select interval '1 year 2 month' / null
152+
-- !query 17 schema
153+
struct<intervaldivide(interval 1 years 2 months, CAST(NULL AS DECIMAL(28,9))):interval>
154+
-- !query 17 output
155+
NULL
156+
157+
158+
-- !query 18
159+
select interval '1 year 2 month' / 0
160+
-- !query 18 schema
161+
struct<intervaldivide(interval 1 years 2 months, CAST(0 AS DECIMAL(28,9))):interval>
162+
-- !query 18 output
163+
NULL
164+
165+
166+
-- !query 19
167+
select interval '1 year 2 month' / 3
168+
-- !query 19 schema
169+
struct<intervaldivide(interval 1 years 2 months, CAST(3 AS DECIMAL(28,9))):interval>
170+
-- !query 19 output
171+
interval 4 months 2 weeks 6 days
172+
173+
174+
-- !query 20
175+
select interval '1 year 2 month' / 3.0
176+
-- !query 20 schema
177+
struct<intervaldivide(interval 1 years 2 months, CAST(3.0 AS DECIMAL(28,9))):interval>
178+
-- !query 20 output
179+
interval 4 months 2 weeks 6 days
180+
181+
182+
-- !query 21
183+
SELECT interval '4 months 2 weeks 6 days' * null
184+
-- !query 21 schema
185+
struct<intervalmultiply(interval 4 months 2 weeks 6 days, CAST(NULL AS DECIMAL(28,9))):interval>
186+
-- !query 21 output
187+
NULL
188+
189+
190+
-- !query 22
191+
SELECT interval '4 months 2 weeks 6 days' * 0
192+
-- !query 22 schema
193+
struct<intervalmultiply(interval 4 months 2 weeks 6 days, CAST(0 AS DECIMAL(28,9))):interval>
194+
-- !query 22 output
195+
interval 0 microseconds
196+
197+
198+
-- !query 23
199+
SELECT interval '4 months 2 weeks 6 days' * 3
200+
-- !query 23 schema
201+
struct<intervalmultiply(interval 4 months 2 weeks 6 days, CAST(3 AS DECIMAL(28,9))):interval>
202+
-- !query 23 output
203+
interval 1 years 8 weeks 4 days
204+
205+
206+
-- !query 24
207+
SELECT interval '4 months 2 weeks 6 days' * 3.0
208+
-- !query 24 schema
209+
struct<intervalmultiply(interval 4 months 2 weeks 6 days, CAST(3.0 AS DECIMAL(28,9))):interval>
210+
-- !query 24 output
211+
interval 1 years 8 weeks 4 days
212+
213+
214+
-- !query 25
215+
SELECT 3.0 * interval '4 months 2 weeks 6 days'
216+
-- !query 25 schema
217+
struct<intervalmultiply(interval 4 months 2 weeks 6 days, CAST(3.0 AS DECIMAL(28,9))):interval>
218+
-- !query 25 output
219+
interval 1 years 8 weeks 4 days

0 commit comments

Comments
 (0)