Skip to content

Commit d6cfb41

Browse files
authored
feat(spark, databricks)!: Support for DATE_ADD functions (#3609)
* feat(spark, databricks): Support for DATE_ADD functions * PR Feedback 1 * PR Feedback 2
1 parent 664ae5c commit d6cfb41

File tree

4 files changed

+60
-5
lines changed

4 files changed

+60
-5
lines changed

sqlglot/dialects/spark.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,21 @@ def _build_datediff(args: t.List) -> exp.Expression:
4141
)
4242

4343

44+
def _build_dateadd(args: t.List) -> exp.Expression:
45+
expression = seq_get(args, 1)
46+
47+
if len(args) == 2:
48+
# DATE_ADD(startDate, numDays INTEGER)
49+
# https://docs.databricks.com/en/sql/language-manual/functions/date_add.html
50+
return exp.TsOrDsAdd(
51+
this=seq_get(args, 0), expression=expression, unit=exp.Literal.string("DAY")
52+
)
53+
54+
# DATE_ADD / DATEADD / TIMESTAMPADD(unit, value integer, expr)
55+
# https://docs.databricks.com/en/sql/language-manual/functions/date_add3.html
56+
return exp.TimestampAdd(this=seq_get(args, 2), expression=expression, unit=seq_get(args, 0))
57+
58+
4459
def _normalize_partition(e: exp.Expression) -> exp.Expression:
4560
"""Normalize the expressions in PARTITION BY (<expression>, <expression>, ...)"""
4661
if isinstance(e, str):
@@ -50,6 +65,30 @@ def _normalize_partition(e: exp.Expression) -> exp.Expression:
5065
return e
5166

5267

68+
def _dateadd_sql(self: Spark.Generator, expression: exp.TsOrDsAdd | exp.TimestampAdd) -> str:
69+
if not expression.unit or (
70+
isinstance(expression, exp.TsOrDsAdd) and expression.text("unit").upper() == "DAY"
71+
):
72+
# Coming from Hive/Spark2 DATE_ADD or roundtripping the 2-arg version of Spark3/DB
73+
return self.func("DATE_ADD", expression.this, expression.expression)
74+
75+
this = self.func(
76+
"DATE_ADD",
77+
unit_to_var(expression),
78+
expression.expression,
79+
expression.this,
80+
)
81+
82+
if isinstance(expression, exp.TsOrDsAdd):
83+
# The 3 arg version of DATE_ADD produces a timestamp in Spark3/DB but possibly not
84+
# in other dialects
85+
return_type = expression.return_type
86+
if not return_type.is_type(exp.DataType.Type.TIMESTAMP, exp.DataType.Type.DATETIME):
87+
this = f"CAST({this} AS {return_type})"
88+
89+
return this
90+
91+
5392
class Spark(Spark2):
5493
class Tokenizer(Spark2.Tokenizer):
5594
RAW_STRINGS = [
@@ -62,6 +101,9 @@ class Parser(Spark2.Parser):
62101
FUNCTIONS = {
63102
**Spark2.Parser.FUNCTIONS,
64103
"ANY_VALUE": _build_with_ignore_nulls(exp.AnyValue),
104+
"DATE_ADD": _build_dateadd,
105+
"DATEADD": _build_dateadd,
106+
"TIMESTAMPADD": _build_dateadd,
65107
"DATEDIFF": _build_datediff,
66108
"TIMESTAMP_LTZ": _build_as_cast("TIMESTAMP_LTZ"),
67109
"TIMESTAMP_NTZ": _build_as_cast("TIMESTAMP_NTZ"),
@@ -111,9 +153,8 @@ class Generator(Spark2.Generator):
111153
exp.PartitionedByProperty: lambda self,
112154
e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}",
113155
exp.StartsWith: rename_func("STARTSWITH"),
114-
exp.TimestampAdd: lambda self, e: self.func(
115-
"DATEADD", unit_to_var(e), e.expression, e.this
116-
),
156+
exp.TsOrDsAdd: _dateadd_sql,
157+
exp.TimestampAdd: _dateadd_sql,
117158
exp.TryCast: lambda self, e: (
118159
self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e)
119160
),

tests/dialects/test_bigquery.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -619,9 +619,9 @@ def test_bigquery(self):
619619
'SELECT TIMESTAMP_ADD(TIMESTAMP "2008-12-25 15:30:00+00", INTERVAL 10 MINUTE)',
620620
write={
621621
"bigquery": "SELECT TIMESTAMP_ADD(CAST('2008-12-25 15:30:00+00' AS TIMESTAMP), INTERVAL 10 MINUTE)",
622-
"databricks": "SELECT DATEADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
622+
"databricks": "SELECT DATE_ADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
623623
"mysql": "SELECT DATE_ADD(TIMESTAMP('2008-12-25 15:30:00+00'), INTERVAL 10 MINUTE)",
624-
"spark": "SELECT DATEADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
624+
"spark": "SELECT DATE_ADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
625625
},
626626
)
627627
self.validate_all(

tests/dialects/test_redshift.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,9 @@ def test_redshift(self):
281281
"redshift": "SELECT DATEADD(MONTH, 18, '2008-02-28')",
282282
"snowflake": "SELECT DATEADD(MONTH, 18, CAST('2008-02-28' AS TIMESTAMP))",
283283
"tsql": "SELECT DATEADD(MONTH, 18, CAST('2008-02-28' AS DATETIME2))",
284+
"spark": "SELECT DATE_ADD(MONTH, 18, '2008-02-28')",
285+
"spark2": "SELECT ADD_MONTHS('2008-02-28', 18)",
286+
"databricks": "SELECT DATE_ADD(MONTH, 18, '2008-02-28')",
284287
},
285288
)
286289
self.validate_all(

tests/dialects/test_spark.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,7 @@ def test_spark(self):
563563
"SELECT DATE_ADD(my_date_column, 1)",
564564
write={
565565
"spark": "SELECT DATE_ADD(my_date_column, 1)",
566+
"spark2": "SELECT DATE_ADD(my_date_column, 1)",
566567
"bigquery": "SELECT DATE_ADD(CAST(CAST(my_date_column AS DATETIME) AS DATE), INTERVAL 1 DAY)",
567568
},
568569
)
@@ -675,6 +676,16 @@ def test_spark(self):
675676
"spark": "SELECT ARRAY_SORT(x)",
676677
},
677678
)
679+
self.validate_all(
680+
"SELECT DATE_ADD(MONTH, 20, col)",
681+
read={
682+
"spark": "SELECT TIMESTAMPADD(MONTH, 20, col)",
683+
},
684+
write={
685+
"spark": "SELECT DATE_ADD(MONTH, 20, col)",
686+
"databricks": "SELECT DATE_ADD(MONTH, 20, col)",
687+
},
688+
)
678689

679690
def test_bool_or(self):
680691
self.validate_all(

0 commit comments

Comments
 (0)