@@ -628,53 +628,16 @@ def test_scope_warning(self, logger):
628628 level = "warning" ,
629629 )
630630
631- def test_struct_type_annotation (self ):
632- tests = {
633- ("SELECT STRUCT(1 AS col)" , "spark" ): "STRUCT<col INT>" ,
634- ("SELECT STRUCT(1 AS col, 2.5 AS row)" , "spark" ): "STRUCT<col INT, row DOUBLE>" ,
635- ("SELECT STRUCT(1)" , "bigquery" ): "STRUCT<INT>" ,
636- (
637- "SELECT STRUCT(1 AS col, 2.5 AS row, struct(3.5 AS inner_col, 4 AS inner_row) AS nested_struct)" ,
638- "spark" ,
639- ): "STRUCT<col INT, row DOUBLE, nested_struct STRUCT<inner_col DOUBLE, inner_row INT>>" ,
640- (
641- "SELECT STRUCT(1 AS col, 2.5, ARRAY[1, 2, 3] AS nested_array, 'foo')" ,
642- "bigquery" ,
643- ): "STRUCT<col INT, DOUBLE, nested_array ARRAY<INT>, VARCHAR>" ,
644- ("SELECT STRUCT(1, 2.5, 'bar')" , "spark" ): "STRUCT<INT, DOUBLE, VARCHAR>" ,
645- ('SELECT STRUCT(1 AS "CaseSensitive")' , "spark" ): 'STRUCT<"CaseSensitive" INT>' ,
646- ("SELECT STRUCT_PACK(a := 1, b := 2.5)" , "duckdb" ): "STRUCT<a INT, b DOUBLE>" ,
647- ("SELECT ROW(1, 2.5, 'foo')" , "presto" ): "STRUCT<INT, DOUBLE, VARCHAR>" ,
648- }
649-
650- for (sql , dialect ), target_type in tests .items ():
651- with self .subTest (sql ):
652- expression = annotate_types (parse_one (sql , read = dialect ))
653- assert expression .expressions [0 ].is_type (target_type )
654-
655- def test_literal_type_annotation (self ):
656- tests = {
657- "SELECT 5" : exp .DataType .Type .INT ,
658- "SELECT 5.3" : exp .DataType .Type .DOUBLE ,
659- "SELECT 'bla'" : exp .DataType .Type .VARCHAR ,
660- "5" : exp .DataType .Type .INT ,
661- "5.3" : exp .DataType .Type .DOUBLE ,
662- "'bla'" : exp .DataType .Type .VARCHAR ,
663- }
664-
665- for sql , target_type in tests .items ():
666- expression = annotate_types (parse_one (sql ))
667- self .assertEqual (expression .find (exp .Literal ).type .this , target_type )
668-
669- def test_boolean_type_annotation (self ):
670- tests = {
671- "SELECT TRUE" : exp .DataType .Type .BOOLEAN ,
672- "FALSE" : exp .DataType .Type .BOOLEAN ,
673- }
631+ def test_annotate_types (self ):
632+ for i , (meta , sql , expected ) in enumerate (
633+ load_sql_fixture_pairs ("optimizer/annotate_types.sql" ), start = 1
634+ ):
635+ title = meta .get ("title" ) or f"{ i } , { sql } "
636+ dialect = meta .get ("dialect" )
637+ result = parse_and_optimize (annotate_types , sql , dialect )
674638
675- for sql , target_type in tests .items ():
676- expression = annotate_types (parse_one (sql ))
677- self .assertEqual (expression .find (exp .Boolean ).type .this , target_type )
639+ with self .subTest (title ):
640+ self .assertEqual (result .type .sql (), exp .DataType .build (expected ).sql ())
678641
679642 def test_cast_type_annotation (self ):
680643 expression = annotate_types (parse_one ("CAST('2020-01-01' AS TIMESTAMPTZ(9))" ))
0 commit comments