Skip to content

Commit 2224881

Browse files
committed
fix: case when / if should ignore null types
1 parent 6b21bba commit 2224881

File tree

3 files changed

+68
-47
lines changed

3 files changed

+68
-47
lines changed

sqlglot/optimizer/annotate_types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,8 @@ def _annotate_by_args(
505505
last_datatype = expr_type
506506
break
507507

508-
last_datatype = self._maybe_coerce(last_datatype or expr_type, expr_type)
508+
if not expr_type.is_type(exp.DataType.Type.NULL, exp.DataType.Type.UNKNOWN):
509+
last_datatype = self._maybe_coerce(last_datatype or expr_type, expr_type)
509510

510511
self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN)
511512

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
5;
2+
INT;
3+
4+
5.3;
5+
DOUBLE;
6+
7+
'bla';
8+
VARCHAR;
9+
10+
True;
11+
bool;
12+
13+
false;
14+
bool;
15+
16+
null;
17+
null;
18+
CASE WHEN x THEN NULL ELSE 1 END;
19+
INT;
20+
21+
CASE WHEN x THEN 1 ELSE NULL END;
22+
INT;
23+
24+
IF(true, 1, null);
25+
INT;
26+
27+
IF(true, null, 1);
28+
INT;
29+
30+
STRUCT(1 AS col);
31+
STRUCT<col INT>;
32+
33+
STRUCT(1 AS col, 2.5 AS row);
34+
STRUCT<col INT, row DOUBLE>;
35+
36+
STRUCT(1);
37+
STRUCT<INT>;
38+
39+
STRUCT(1 AS col, 2.5 AS row, struct(3.5 AS inner_col, 4 AS inner_row) AS nested_struct);
40+
STRUCT<col INT, row DOUBLE, nested_struct STRUCT<inner_col DOUBLE, inner_row INT>>;
41+
42+
STRUCT(1 AS col, 2.5, ARRAY[1, 2, 3] AS nested_array, 'foo');
43+
STRUCT<col INT, DOUBLE, nested_array ARRAY<INT>, VARCHAR>;
44+
45+
STRUCT(1, 2.5, 'bar');
46+
STRUCT<INT, DOUBLE, VARCHAR>;
47+
48+
STRUCT(1 AS "CaseSensitive");
49+
STRUCT<"CaseSensitive" INT>;
50+
51+
# dialect: duckdb
52+
STRUCT_PACK(a := 1, b := 2.5);
53+
STRUCT<a INT, b DOUBLE>;
54+
55+
# dialect: presto
56+
ROW(1, 2.5, 'foo');
57+
STRUCT<INT, DOUBLE, VARCHAR>;

tests/test_optimizer.py

Lines changed: 9 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -628,53 +628,16 @@ def test_scope_warning(self, logger):
628628
level="warning",
629629
)
630630

631-
def test_struct_type_annotation(self):
632-
tests = {
633-
("SELECT STRUCT(1 AS col)", "spark"): "STRUCT<col INT>",
634-
("SELECT STRUCT(1 AS col, 2.5 AS row)", "spark"): "STRUCT<col INT, row DOUBLE>",
635-
("SELECT STRUCT(1)", "bigquery"): "STRUCT<INT>",
636-
(
637-
"SELECT STRUCT(1 AS col, 2.5 AS row, struct(3.5 AS inner_col, 4 AS inner_row) AS nested_struct)",
638-
"spark",
639-
): "STRUCT<col INT, row DOUBLE, nested_struct STRUCT<inner_col DOUBLE, inner_row INT>>",
640-
(
641-
"SELECT STRUCT(1 AS col, 2.5, ARRAY[1, 2, 3] AS nested_array, 'foo')",
642-
"bigquery",
643-
): "STRUCT<col INT, DOUBLE, nested_array ARRAY<INT>, VARCHAR>",
644-
("SELECT STRUCT(1, 2.5, 'bar')", "spark"): "STRUCT<INT, DOUBLE, VARCHAR>",
645-
('SELECT STRUCT(1 AS "CaseSensitive")', "spark"): 'STRUCT<"CaseSensitive" INT>',
646-
("SELECT STRUCT_PACK(a := 1, b := 2.5)", "duckdb"): "STRUCT<a INT, b DOUBLE>",
647-
("SELECT ROW(1, 2.5, 'foo')", "presto"): "STRUCT<INT, DOUBLE, VARCHAR>",
648-
}
649-
650-
for (sql, dialect), target_type in tests.items():
651-
with self.subTest(sql):
652-
expression = annotate_types(parse_one(sql, read=dialect))
653-
assert expression.expressions[0].is_type(target_type)
654-
655-
def test_literal_type_annotation(self):
656-
tests = {
657-
"SELECT 5": exp.DataType.Type.INT,
658-
"SELECT 5.3": exp.DataType.Type.DOUBLE,
659-
"SELECT 'bla'": exp.DataType.Type.VARCHAR,
660-
"5": exp.DataType.Type.INT,
661-
"5.3": exp.DataType.Type.DOUBLE,
662-
"'bla'": exp.DataType.Type.VARCHAR,
663-
}
664-
665-
for sql, target_type in tests.items():
666-
expression = annotate_types(parse_one(sql))
667-
self.assertEqual(expression.find(exp.Literal).type.this, target_type)
668-
669-
def test_boolean_type_annotation(self):
670-
tests = {
671-
"SELECT TRUE": exp.DataType.Type.BOOLEAN,
672-
"FALSE": exp.DataType.Type.BOOLEAN,
673-
}
631+
def test_annotate_types(self):
632+
for i, (meta, sql, expected) in enumerate(
633+
load_sql_fixture_pairs("optimizer/annotate_types.sql"), start=1
634+
):
635+
title = meta.get("title") or f"{i}, {sql}"
636+
dialect = meta.get("dialect")
637+
result = parse_and_optimize(annotate_types, sql, dialect)
674638

675-
for sql, target_type in tests.items():
676-
expression = annotate_types(parse_one(sql))
677-
self.assertEqual(expression.find(exp.Boolean).type.this, target_type)
639+
with self.subTest(title):
640+
self.assertEqual(result.type.sql(), exp.DataType.build(expected).sql())
678641

679642
def test_cast_type_annotation(self):
680643
expression = annotate_types(parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))"))

0 commit comments

Comments
 (0)