@@ -41,6 +41,21 @@ def _build_datediff(args: t.List) -> exp.Expression:
4141 )
4242
4343
44+ def _build_dateadd (args : t .List ) -> exp .Expression :
45+ expression = seq_get (args , 1 )
46+
47+ if len (args ) == 2 :
48+ # DATE_ADD(startDate, numDays INTEGER)
49+ # https://docs.databricks.com/en/sql/language-manual/functions/date_add.html
50+ return exp .TsOrDsAdd (
51+ this = seq_get (args , 0 ), expression = expression , unit = exp .Literal .string ("DAY" )
52+ )
53+
54+ # DATE_ADD / DATEADD / TIMESTAMPADD(unit, value integer, expr)
55+ # https://docs.databricks.com/en/sql/language-manual/functions/date_add3.html
56+ return exp .TimestampAdd (this = seq_get (args , 2 ), expression = expression , unit = seq_get (args , 0 ))
57+
58+
4459def _normalize_partition (e : exp .Expression ) -> exp .Expression :
4560 """Normalize the expressions in PARTITION BY (<expression>, <expression>, ...)"""
4661 if isinstance (e , str ):
@@ -50,6 +65,30 @@ def _normalize_partition(e: exp.Expression) -> exp.Expression:
5065 return e
5166
5267
68+ def _dateadd_sql (self : Spark .Generator , expression : exp .TsOrDsAdd | exp .TimestampAdd ) -> str :
69+ if not expression .unit or (
70+ isinstance (expression , exp .TsOrDsAdd ) and expression .text ("unit" ).upper () == "DAY"
71+ ):
72+ # Coming from Hive/Spark2 DATE_ADD or roundtripping the 2-arg version of Spark3/DB
73+ return self .func ("DATE_ADD" , expression .this , expression .expression )
74+
75+ this = self .func (
76+ "DATE_ADD" ,
77+ unit_to_var (expression ),
78+ expression .expression ,
79+ expression .this ,
80+ )
81+
82+ if isinstance (expression , exp .TsOrDsAdd ):
83+ # The 3 arg version of DATE_ADD produces a timestamp in Spark3/DB but possibly not
84+ # in other dialects
85+ return_type = expression .return_type
86+ if not return_type .is_type (exp .DataType .Type .TIMESTAMP , exp .DataType .Type .DATETIME ):
87+ this = f"CAST({ this } AS { return_type } )"
88+
89+ return this
90+
91+
5392class Spark (Spark2 ):
5493 class Tokenizer (Spark2 .Tokenizer ):
5594 RAW_STRINGS = [
@@ -62,6 +101,9 @@ class Parser(Spark2.Parser):
62101 FUNCTIONS = {
63102 ** Spark2 .Parser .FUNCTIONS ,
64103 "ANY_VALUE" : _build_with_ignore_nulls (exp .AnyValue ),
104+ "DATE_ADD" : _build_dateadd ,
105+ "DATEADD" : _build_dateadd ,
106+ "TIMESTAMPADD" : _build_dateadd ,
65107 "DATEDIFF" : _build_datediff ,
66108 "TIMESTAMP_LTZ" : _build_as_cast ("TIMESTAMP_LTZ" ),
67109 "TIMESTAMP_NTZ" : _build_as_cast ("TIMESTAMP_NTZ" ),
@@ -111,9 +153,8 @@ class Generator(Spark2.Generator):
111153 exp .PartitionedByProperty : lambda self ,
112154 e : f"PARTITIONED BY { self .wrap (self .expressions (sqls = [_normalize_partition (e ) for e in e .this .expressions ], skip_first = True ))} " ,
113155 exp .StartsWith : rename_func ("STARTSWITH" ),
114- exp .TimestampAdd : lambda self , e : self .func (
115- "DATEADD" , unit_to_var (e ), e .expression , e .this
116- ),
156+ exp .TsOrDsAdd : _dateadd_sql ,
157+ exp .TimestampAdd : _dateadd_sql ,
117158 exp .TryCast : lambda self , e : (
118159 self .trycast_sql (e ) if e .args .get ("safe" ) else self .cast_sql (e )
119160 ),
0 commit comments