@@ -351,54 +351,56 @@ def _set_type(
351351
352352 def annotate (self , expression : E ) -> E :
353353 for scope in traverse_scope (expression ):
354- selects = {}
355- for name , source in scope .sources .items ():
356- if not isinstance (source , Scope ):
357- continue
358- if isinstance (source .expression , exp .UDTF ):
359- values = []
360-
361- if isinstance (source .expression , exp .Lateral ):
362- if isinstance (source .expression .this , exp .Explode ):
363- values = [source .expression .this .this ]
364- elif isinstance (source .expression , exp .Unnest ):
365- values = [source .expression ]
366- else :
367- values = source .expression .expressions [0 ].expressions
368-
369- if not values :
370- continue
371-
372- selects [name ] = {
373- alias : column
374- for alias , column in zip (
375- source .expression .alias_column_names ,
376- values ,
377- )
378- }
354+ self .annotate_scope (scope )
355+ return self ._maybe_annotate (expression ) # This takes care of non-traversable expressions
356+
357+ def annotate_scope (self , scope : Scope ) -> None :
358+ selects = {}
359+ for name , source in scope .sources .items ():
360+ if not isinstance (source , Scope ):
361+ continue
362+ if isinstance (source .expression , exp .UDTF ):
363+ values = []
364+
365+ if isinstance (source .expression , exp .Lateral ):
366+ if isinstance (source .expression .this , exp .Explode ):
367+ values = [source .expression .this .this ]
368+ elif isinstance (source .expression , exp .Unnest ):
369+ values = [source .expression ]
379370 else :
380- selects [name ] = {
381- select .alias_or_name : select for select in source .expression .selects
382- }
371+ values = source .expression .expressions [0 ].expressions
383372
384- # First annotate the current scope's column references
385- for col in scope .columns :
386- if not col .table :
373+ if not values :
387374 continue
388375
389- source = scope .sources .get (col .table )
390- if isinstance (source , exp .Table ):
391- self ._set_type (col , self .schema .get_column_type (source , col ))
392- elif source :
393- if col .table in selects and col .name in selects [col .table ]:
394- self ._set_type (col , selects [col .table ][col .name ].type )
395- elif isinstance (source .expression , exp .Unnest ):
396- self ._set_type (col , source .expression .type )
397-
398- # Then (possibly) annotate the remaining expressions in the scope
399- self ._maybe_annotate (scope .expression )
400-
401- return self ._maybe_annotate (expression ) # This takes care of non-traversable expressions
376+ selects [name ] = {
377+ alias : column
378+ for alias , column in zip (
379+ source .expression .alias_column_names ,
380+ values ,
381+ )
382+ }
383+ else :
384+ selects [name ] = {
385+ select .alias_or_name : select for select in source .expression .selects
386+ }
387+
388+ # First annotate the current scope's column references
389+ for col in scope .columns :
390+ if not col .table :
391+ continue
392+
393+ source = scope .sources .get (col .table )
394+ if isinstance (source , exp .Table ):
395+ self ._set_type (col , self .schema .get_column_type (source , col ))
396+ elif source :
397+ if col .table in selects and col .name in selects [col .table ]:
398+ self ._set_type (col , selects [col .table ][col .name ].type )
399+ elif isinstance (source .expression , exp .Unnest ):
400+ self ._set_type (col , source .expression .type )
401+
402+ # Then (possibly) annotate the remaining expressions in the scope
403+ self ._maybe_annotate (scope .expression )
402404
403405 def _maybe_annotate (self , expression : E ) -> E :
404406 if id (expression ) in self ._visited :
@@ -601,7 +603,13 @@ def _annotate_explode(self, expression: exp.Explode) -> exp.Explode:
601603 def _annotate_unnest (self , expression : exp .Unnest ) -> exp .Unnest :
602604 self ._annotate_args (expression )
603605 child = seq_get (expression .expressions , 0 )
604- self ._set_type (expression , child and seq_get (child .type .expressions , 0 ))
606+
607+ if child and child .is_type (exp .DataType .Type .ARRAY ):
608+ expr_type = seq_get (child .type .expressions , 0 )
609+ else :
610+ expr_type = None
611+
612+ self ._set_type (expression , expr_type )
605613 return expression
606614
607615 def _annotate_struct_value (
0 commit comments