Skip to content

Commit 1a26bff

Browse files
authored
feat(snowflake): Transpile exp.SafeDivide (#4294)
1 parent ee266ef commit 1a26bff

File tree

4 files changed

+19
-8
lines changed

4 files changed

+19
-8
lines changed

sqlglot/dialects/dialect.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,10 +1017,10 @@ def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
10171017
return self.with_sql(expression)
10181018

10191019

1020-
def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
1020+
def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide, if_sql: str = "IF") -> str:
10211021
n = self.sql(expression, "this")
10221022
d = self.sql(expression, "expression")
1023-
return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
1023+
return f"{if_sql}(({d}) <> 0, ({n}) / ({d}), NULL)"
10241024

10251025

10261026
def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:

sqlglot/dialects/snowflake.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
timestrtotime_sql,
2323
var_map_sql,
2424
map_date_part,
25+
no_safe_divide_sql,
2526
)
2627
from sqlglot.helper import flatten, is_float, is_int, seq_get
2728
from sqlglot.tokens import TokenType
@@ -830,6 +831,7 @@ class Generator(generator.Generator):
830831
_unnest_generate_date_array,
831832
]
832833
),
834+
exp.SafeDivide: lambda self, e: no_safe_divide_sql(self, e, "IFF"),
833835
exp.SHA: rename_func("SHA1"),
834836
exp.StarMap: rename_func("OBJECT_CONSTRUCT"),
835837
exp.StartsWith: rename_func("STARTSWITH"),

tests/dialects/test_bigquery.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1533,6 +1533,21 @@ def test_bigquery(self):
15331533

15341534
self.validate_identity("SELECT * FROM a-b c", "SELECT * FROM a-b AS c")
15351535

1536+
self.validate_all(
1537+
"SAFE_DIVIDE(x, y)",
1538+
write={
1539+
"bigquery": "SAFE_DIVIDE(x, y)",
1540+
"duckdb": "IF((y) <> 0, (x) / (y), NULL)",
1541+
"presto": "IF((y) <> 0, (x) / (y), NULL)",
1542+
"trino": "IF((y) <> 0, (x) / (y), NULL)",
1543+
"hive": "IF((y) <> 0, (x) / (y), NULL)",
1544+
"spark2": "IF((y) <> 0, (x) / (y), NULL)",
1545+
"spark": "IF((y) <> 0, (x) / (y), NULL)",
1546+
"databricks": "IF((y) <> 0, (x) / (y), NULL)",
1547+
"snowflake": "IFF((y) <> 0, (x) / (y), NULL)",
1548+
},
1549+
)
1550+
15361551
def test_errors(self):
15371552
with self.assertRaises(TokenError):
15381553
transpile("'\\'", read="bigquery")

tests/dialects/test_duckdb.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -619,12 +619,6 @@ def test_duckdb(self):
619619
"spark": "ARRAY_SUM(ARRAY(1, 2))",
620620
},
621621
)
622-
self.validate_all(
623-
"IF((y) <> 0, (x) / (y), NULL)",
624-
read={
625-
"bigquery": "SAFE_DIVIDE(x, y)",
626-
},
627-
)
628622
self.validate_all(
629623
"STRUCT_PACK(x := 1, y := '2')",
630624
write={

0 commit comments

Comments
 (0)