Skip to content

Commit d331e56

Browse files
authored
fix(redshift)!: Normalize time units in their full singular form (#3652)
* fix(redshift): Normalize time units in their full singular form * fix make style * PR Feedback 1 * Move DATE_PART_MAPPING to Dialect * Add EPOCH test case
1 parent e8cab58 commit d331e56

File tree

4 files changed

+144
-94
lines changed

4 files changed

+144
-94
lines changed

sqlglot/dialects/dialect.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,11 @@ class Dialect(metaclass=_Dialect):
316316
) SELECT c FROM y;
317317
"""
318318

319+
COPY_PARAMS_ARE_CSV = True
320+
"""
321+
Whether COPY statement parameters are separated by comma or whitespace
322+
"""
323+
319324
# --- Autofilled ---
320325

321326
tokenizer_class = Tokenizer
@@ -347,6 +352,100 @@ class Dialect(metaclass=_Dialect):
347352
UNICODE_START: t.Optional[str] = None
348353
UNICODE_END: t.Optional[str] = None
349354

355+
DATE_PART_MAPPING = {
356+
"Y": "YEAR",
357+
"YY": "YEAR",
358+
"YYY": "YEAR",
359+
"YYYY": "YEAR",
360+
"YR": "YEAR",
361+
"YEARS": "YEAR",
362+
"YRS": "YEAR",
363+
"MM": "MONTH",
364+
"MON": "MONTH",
365+
"MONS": "MONTH",
366+
"MONTHS": "MONTH",
367+
"D": "DAY",
368+
"DD": "DAY",
369+
"DAYS": "DAY",
370+
"DAYOFMONTH": "DAY",
371+
"DAY OF WEEK": "DAYOFWEEK",
372+
"WEEKDAY": "DAYOFWEEK",
373+
"DOW": "DAYOFWEEK",
374+
"DW": "DAYOFWEEK",
375+
"WEEKDAY_ISO": "DAYOFWEEKISO",
376+
"DOW_ISO": "DAYOFWEEKISO",
377+
"DW_ISO": "DAYOFWEEKISO",
378+
"DAY OF YEAR": "DAYOFYEAR",
379+
"DOY": "DAYOFYEAR",
380+
"DY": "DAYOFYEAR",
381+
"W": "WEEK",
382+
"WK": "WEEK",
383+
"WEEKOFYEAR": "WEEK",
384+
"WOY": "WEEK",
385+
"WY": "WEEK",
386+
"WEEK_ISO": "WEEKISO",
387+
"WEEKOFYEARISO": "WEEKISO",
388+
"WEEKOFYEAR_ISO": "WEEKISO",
389+
"Q": "QUARTER",
390+
"QTR": "QUARTER",
391+
"QTRS": "QUARTER",
392+
"QUARTERS": "QUARTER",
393+
"H": "HOUR",
394+
"HH": "HOUR",
395+
"HR": "HOUR",
396+
"HOURS": "HOUR",
397+
"HRS": "HOUR",
398+
"M": "MINUTE",
399+
"MI": "MINUTE",
400+
"MIN": "MINUTE",
401+
"MINUTES": "MINUTE",
402+
"MINS": "MINUTE",
403+
"S": "SECOND",
404+
"SEC": "SECOND",
405+
"SECONDS": "SECOND",
406+
"SECS": "SECOND",
407+
"MS": "MILLISECOND",
408+
"MSEC": "MILLISECOND",
409+
"MSECS": "MILLISECOND",
410+
"MSECOND": "MILLISECOND",
411+
"MSECONDS": "MILLISECOND",
412+
"MILLISEC": "MILLISECOND",
413+
"MILLISECS": "MILLISECOND",
414+
"MILLISECON": "MILLISECOND",
415+
"MILLISECONDS": "MILLISECOND",
416+
"US": "MICROSECOND",
417+
"USEC": "MICROSECOND",
418+
"USECS": "MICROSECOND",
419+
"MICROSEC": "MICROSECOND",
420+
"MICROSECS": "MICROSECOND",
421+
"USECOND": "MICROSECOND",
422+
"USECONDS": "MICROSECOND",
423+
"MICROSECONDS": "MICROSECOND",
424+
"NS": "NANOSECOND",
425+
"NSEC": "NANOSECOND",
426+
"NANOSEC": "NANOSECOND",
427+
"NSECOND": "NANOSECOND",
428+
"NSECONDS": "NANOSECOND",
429+
"NANOSECS": "NANOSECOND",
430+
"EPOCH_SECOND": "EPOCH",
431+
"EPOCH_SECONDS": "EPOCH",
432+
"EPOCH_MILLISECONDS": "EPOCH_MILLISECOND",
433+
"EPOCH_MICROSECONDS": "EPOCH_MICROSECOND",
434+
"EPOCH_NANOSECONDS": "EPOCH_NANOSECOND",
435+
"TZH": "TIMEZONE_HOUR",
436+
"TZM": "TIMEZONE_MINUTE",
437+
"DEC": "DECADE",
438+
"DECS": "DECADE",
439+
"DECADES": "DECADE",
440+
"MIL": "MILLENIUM",
441+
"MILS": "MILLENIUM",
442+
"MILLENIA": "MILLENIUM",
443+
"C": "CENTURY",
444+
"CENT": "CENTURY",
445+
"CENTS": "CENTURY",
446+
"CENTURIES": "CENTURY",
447+
}
448+
350449
@classmethod
351450
def get_or_raise(cls, dialect: DialectType) -> Dialect:
352451
"""
@@ -1062,6 +1161,25 @@ def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[
10621161
return exp.Var(this=default) if default else None
10631162

10641163

1164+
@t.overload
1165+
def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var:
1166+
pass
1167+
1168+
1169+
@t.overload
1170+
def map_date_part(
1171+
part: t.Optional[exp.Expression], dialect: DialectType = Dialect
1172+
) -> t.Optional[exp.Expression]:
1173+
pass
1174+
1175+
1176+
def map_date_part(part, dialect: DialectType = Dialect):
1177+
mapped = (
1178+
Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None
1179+
)
1180+
return exp.var(mapped) if mapped else part
1181+
1182+
10651183
def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
10661184
trunc_curr_date = exp.func("date_trunc", "month", expression.this)
10671185
plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")

sqlglot/dialects/redshift.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
json_extract_segments,
1313
no_tablesample_sql,
1414
rename_func,
15+
map_date_part,
1516
)
1617
from sqlglot.dialects.postgres import Postgres
1718
from sqlglot.helper import seq_get
@@ -23,7 +24,11 @@
2324

2425
def _build_date_delta(expr_type: t.Type[E]) -> t.Callable[[t.List], E]:
2526
def _builder(args: t.List) -> E:
26-
expr = expr_type(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
27+
expr = expr_type(
28+
this=seq_get(args, 2),
29+
expression=seq_get(args, 1),
30+
unit=map_date_part(seq_get(args, 0)),
31+
)
2732
if expr_type is exp.TsOrDsAdd:
2833
expr.set("return_type", exp.DataType.build("TIMESTAMP"))
2934

sqlglot/dialects/snowflake.py

Lines changed: 6 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
timestamptrunc_sql,
2222
timestrtotime_sql,
2323
var_map_sql,
24+
map_date_part,
2425
)
2526
from sqlglot.helper import flatten, is_float, is_int, seq_get
2627
from sqlglot.tokens import TokenType
@@ -75,7 +76,7 @@ def _build_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
7576

7677
def _build_datediff(args: t.List) -> exp.DateDiff:
7778
return exp.DateDiff(
78-
this=seq_get(args, 2), expression=seq_get(args, 1), unit=_map_date_part(seq_get(args, 0))
79+
this=seq_get(args, 2), expression=seq_get(args, 1), unit=map_date_part(seq_get(args, 0))
7980
)
8081

8182

@@ -84,7 +85,7 @@ def _builder(args: t.List) -> E:
8485
return expr_type(
8586
this=seq_get(args, 2),
8687
expression=seq_get(args, 1),
87-
unit=_map_date_part(seq_get(args, 0)),
88+
unit=map_date_part(seq_get(args, 0)),
8889
)
8990

9091
return _builder
@@ -143,97 +144,9 @@ def _parse(self: Snowflake.Parser) -> exp.Show:
143144
return _parse
144145

145146

146-
DATE_PART_MAPPING = {
147-
"Y": "YEAR",
148-
"YY": "YEAR",
149-
"YYY": "YEAR",
150-
"YYYY": "YEAR",
151-
"YR": "YEAR",
152-
"YEARS": "YEAR",
153-
"YRS": "YEAR",
154-
"MM": "MONTH",
155-
"MON": "MONTH",
156-
"MONS": "MONTH",
157-
"MONTHS": "MONTH",
158-
"D": "DAY",
159-
"DD": "DAY",
160-
"DAYS": "DAY",
161-
"DAYOFMONTH": "DAY",
162-
"WEEKDAY": "DAYOFWEEK",
163-
"DOW": "DAYOFWEEK",
164-
"DW": "DAYOFWEEK",
165-
"WEEKDAY_ISO": "DAYOFWEEKISO",
166-
"DOW_ISO": "DAYOFWEEKISO",
167-
"DW_ISO": "DAYOFWEEKISO",
168-
"YEARDAY": "DAYOFYEAR",
169-
"DOY": "DAYOFYEAR",
170-
"DY": "DAYOFYEAR",
171-
"W": "WEEK",
172-
"WK": "WEEK",
173-
"WEEKOFYEAR": "WEEK",
174-
"WOY": "WEEK",
175-
"WY": "WEEK",
176-
"WEEK_ISO": "WEEKISO",
177-
"WEEKOFYEARISO": "WEEKISO",
178-
"WEEKOFYEAR_ISO": "WEEKISO",
179-
"Q": "QUARTER",
180-
"QTR": "QUARTER",
181-
"QTRS": "QUARTER",
182-
"QUARTERS": "QUARTER",
183-
"H": "HOUR",
184-
"HH": "HOUR",
185-
"HR": "HOUR",
186-
"HOURS": "HOUR",
187-
"HRS": "HOUR",
188-
"M": "MINUTE",
189-
"MI": "MINUTE",
190-
"MIN": "MINUTE",
191-
"MINUTES": "MINUTE",
192-
"MINS": "MINUTE",
193-
"S": "SECOND",
194-
"SEC": "SECOND",
195-
"SECONDS": "SECOND",
196-
"SECS": "SECOND",
197-
"MS": "MILLISECOND",
198-
"MSEC": "MILLISECOND",
199-
"MILLISECONDS": "MILLISECOND",
200-
"US": "MICROSECOND",
201-
"USEC": "MICROSECOND",
202-
"MICROSECONDS": "MICROSECOND",
203-
"NS": "NANOSECOND",
204-
"NSEC": "NANOSECOND",
205-
"NANOSEC": "NANOSECOND",
206-
"NSECOND": "NANOSECOND",
207-
"NSECONDS": "NANOSECOND",
208-
"NANOSECS": "NANOSECOND",
209-
"EPOCH": "EPOCH_SECOND",
210-
"EPOCH_SECONDS": "EPOCH_SECOND",
211-
"EPOCH_MILLISECONDS": "EPOCH_MILLISECOND",
212-
"EPOCH_MICROSECONDS": "EPOCH_MICROSECOND",
213-
"EPOCH_NANOSECONDS": "EPOCH_NANOSECOND",
214-
"TZH": "TIMEZONE_HOUR",
215-
"TZM": "TIMEZONE_MINUTE",
216-
}
217-
218-
219-
@t.overload
220-
def _map_date_part(part: exp.Expression) -> exp.Var:
221-
pass
222-
223-
224-
@t.overload
225-
def _map_date_part(part: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
226-
pass
227-
228-
229-
def _map_date_part(part):
230-
mapped = DATE_PART_MAPPING.get(part.name.upper()) if part else None
231-
return exp.var(mapped) if mapped else part
232-
233-
234147
def _date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
235148
trunc = date_trunc_to_time(args)
236-
trunc.set("unit", _map_date_part(trunc.args["unit"]))
149+
trunc.set("unit", map_date_part(trunc.args["unit"]))
237150
return trunc
238151

239152

@@ -367,7 +280,7 @@ class Parser(parser.Parser):
367280
),
368281
"IFF": exp.If.from_arg_list,
369282
"LAST_DAY": lambda args: exp.LastDay(
370-
this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1))
283+
this=seq_get(args, 0), unit=map_date_part(seq_get(args, 1))
371284
),
372285
"LISTAGG": exp.GroupConcat.from_arg_list,
373286
"MEDIAN": lambda args: exp.PercentileCont(
@@ -541,7 +454,7 @@ def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]:
541454

542455
self._match(TokenType.COMMA)
543456
expression = self._parse_bitwise()
544-
this = _map_date_part(this)
457+
this = map_date_part(this)
545458
name = this.name.upper()
546459

547460
if name.startswith("EPOCH"):

tests/dialects/test_redshift.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,12 @@ def test_redshift(self):
259259
"postgres": "COALESCE(a, b, c, d)",
260260
},
261261
)
262+
263+
self.validate_identity(
264+
"DATEDIFF(days, a, b)",
265+
"DATEDIFF(DAY, a, b)",
266+
)
267+
262268
self.validate_all(
263269
"DATEDIFF('day', a, b)",
264270
write={
@@ -300,6 +306,14 @@ def test_redshift(self):
300306
},
301307
)
302308

309+
self.validate_all(
310+
"SELECT EXTRACT(EPOCH FROM CURRENT_DATE)",
311+
write={
312+
"snowflake": "SELECT DATE_PART(EPOCH, CURRENT_DATE)",
313+
"redshift": "SELECT EXTRACT(EPOCH FROM CURRENT_DATE)",
314+
},
315+
)
316+
303317
def test_identity(self):
304318
self.validate_identity("LISTAGG(DISTINCT foo, ', ')")
305319
self.validate_identity("CREATE MATERIALIZED VIEW orders AUTO REFRESH YES AS SELECT 1")

0 commit comments

Comments
 (0)