Skip to content

Commit 4f1691a

Browse files
authored
feat: allow qualify to also annotate on the fly for unnest support (#3316)
1 parent ef84f17 commit 4f1691a

File tree

4 files changed

+71
-55
lines changed

4 files changed

+71
-55
lines changed

sqlglot/optimizer/annotate_types.py

Lines changed: 53 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -351,54 +351,56 @@ def _set_type(
351351

352352
def annotate(self, expression: E) -> E:
353353
for scope in traverse_scope(expression):
354-
selects = {}
355-
for name, source in scope.sources.items():
356-
if not isinstance(source, Scope):
357-
continue
358-
if isinstance(source.expression, exp.UDTF):
359-
values = []
360-
361-
if isinstance(source.expression, exp.Lateral):
362-
if isinstance(source.expression.this, exp.Explode):
363-
values = [source.expression.this.this]
364-
elif isinstance(source.expression, exp.Unnest):
365-
values = [source.expression]
366-
else:
367-
values = source.expression.expressions[0].expressions
368-
369-
if not values:
370-
continue
371-
372-
selects[name] = {
373-
alias: column
374-
for alias, column in zip(
375-
source.expression.alias_column_names,
376-
values,
377-
)
378-
}
354+
self.annotate_scope(scope)
355+
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
356+
357+
def annotate_scope(self, scope: Scope) -> None:
358+
selects = {}
359+
for name, source in scope.sources.items():
360+
if not isinstance(source, Scope):
361+
continue
362+
if isinstance(source.expression, exp.UDTF):
363+
values = []
364+
365+
if isinstance(source.expression, exp.Lateral):
366+
if isinstance(source.expression.this, exp.Explode):
367+
values = [source.expression.this.this]
368+
elif isinstance(source.expression, exp.Unnest):
369+
values = [source.expression]
379370
else:
380-
selects[name] = {
381-
select.alias_or_name: select for select in source.expression.selects
382-
}
371+
values = source.expression.expressions[0].expressions
383372

384-
# First annotate the current scope's column references
385-
for col in scope.columns:
386-
if not col.table:
373+
if not values:
387374
continue
388375

389-
source = scope.sources.get(col.table)
390-
if isinstance(source, exp.Table):
391-
self._set_type(col, self.schema.get_column_type(source, col))
392-
elif source:
393-
if col.table in selects and col.name in selects[col.table]:
394-
self._set_type(col, selects[col.table][col.name].type)
395-
elif isinstance(source.expression, exp.Unnest):
396-
self._set_type(col, source.expression.type)
397-
398-
# Then (possibly) annotate the remaining expressions in the scope
399-
self._maybe_annotate(scope.expression)
400-
401-
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
376+
selects[name] = {
377+
alias: column
378+
for alias, column in zip(
379+
source.expression.alias_column_names,
380+
values,
381+
)
382+
}
383+
else:
384+
selects[name] = {
385+
select.alias_or_name: select for select in source.expression.selects
386+
}
387+
388+
# First annotate the current scope's column references
389+
for col in scope.columns:
390+
if not col.table:
391+
continue
392+
393+
source = scope.sources.get(col.table)
394+
if isinstance(source, exp.Table):
395+
self._set_type(col, self.schema.get_column_type(source, col))
396+
elif source:
397+
if col.table in selects and col.name in selects[col.table]:
398+
self._set_type(col, selects[col.table][col.name].type)
399+
elif isinstance(source.expression, exp.Unnest):
400+
self._set_type(col, source.expression.type)
401+
402+
# Then (possibly) annotate the remaining expressions in the scope
403+
self._maybe_annotate(scope.expression)
402404

403405
def _maybe_annotate(self, expression: E) -> E:
404406
if id(expression) in self._visited:
@@ -601,7 +603,13 @@ def _annotate_explode(self, expression: exp.Explode) -> exp.Explode:
601603
def _annotate_unnest(self, expression: exp.Unnest) -> exp.Unnest:
602604
self._annotate_args(expression)
603605
child = seq_get(expression.expressions, 0)
604-
self._set_type(expression, child and seq_get(child.type.expressions, 0))
606+
607+
if child and child.is_type(exp.DataType.Type.ARRAY):
608+
expr_type = seq_get(child.type.expressions, 0)
609+
else:
610+
expr_type = None
611+
612+
self._set_type(expression, expr_type)
605613
return expression
606614

607615
def _annotate_struct_value(

sqlglot/optimizer/optimizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import inspect
34
import typing as t
45

56
import sqlglot
@@ -85,7 +86,7 @@ def optimize(
8586
optimized = exp.maybe_parse(expression, dialect=dialect, copy=True)
8687
for rule in rules:
8788
# Find any additional rule parameters, beyond `expression`
88-
rule_params = rule.__code__.co_varnames
89+
rule_params = inspect.getfullargspec(rule).args
8990
rule_kwargs = {
9091
param: possible_kwargs[param] for param in rule_params if param in possible_kwargs
9192
}

sqlglot/optimizer/qualify_columns.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from sqlglot.dialects.dialect import Dialect, DialectType
88
from sqlglot.errors import OptimizeError
99
from sqlglot.helper import seq_get, SingleValuedMapping
10-
from sqlglot.optimizer.annotate_types import annotate_types
10+
from sqlglot.optimizer.annotate_types import TypeAnnotator
1111
from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
1212
from sqlglot.optimizer.simplify import simplify_parens
1313
from sqlglot.schema import Schema, ensure_schema
@@ -49,8 +49,10 @@ def qualify_columns(
4949
- Currently only handles a single PIVOT or UNPIVOT operator
5050
"""
5151
schema = ensure_schema(schema)
52+
annotator = TypeAnnotator(schema)
5253
infer_schema = schema.empty if infer_schema is None else infer_schema
53-
pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
54+
dialect = Dialect.get_or_raise(schema.dialect)
55+
pseudocolumns = dialect.PSEUDOCOLUMNS
5456

5557
for scope in traverse_scope(expression):
5658
resolver = Resolver(scope, schema, infer_schema=infer_schema)
@@ -74,6 +76,9 @@ def qualify_columns(
7476
_expand_group_by(scope)
7577
_expand_order_by(scope, resolver)
7678

79+
if dialect == "bigquery":
80+
annotator.annotate_scope(scope)
81+
7782
return expression
7883

7984

@@ -660,11 +665,8 @@ def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequenc
660665
# directly select a struct field in a query.
661666
# this handles the case where the unnest is statically defined.
662667
if self.schema.dialect == "bigquery":
663-
expression = source.expression
664-
annotate_types(expression)
665-
666-
if expression.is_type(exp.DataType.Type.STRUCT):
667-
for k in expression.type.expressions: # type: ignore
668+
if source.expression.is_type(exp.DataType.Type.STRUCT):
669+
for k in source.expression.type.expressions: # type: ignore
668670
columns.append(k.name)
669671
else:
670672
columns = source.expression.named_selects

tests/fixtures/optimizer/qualify_columns.sql

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -494,8 +494,13 @@ SELECT x AS x, y AS y FROM UNNEST([1, 2]) AS x WITH OFFSET AS y;
494494

495495
# dialect: bigquery
496496
# execute: false
497-
select x, a, x.a from unnest([STRUCT(1 AS a)]) as x;
498-
SELECT x AS x, a AS a, x.a AS a FROM UNNEST([STRUCT(1 AS a)]) AS x;
497+
select x, a, x.a from unnest([STRUCT(1 AS a)]) as x CROSS JOIN m;
498+
SELECT x AS x, a AS a, x.a AS a FROM UNNEST([STRUCT(1 AS a)]) AS x CROSS JOIN m AS m;
499+
500+
# dialect: bigquery
501+
# execute: false
502+
WITH cte AS (SELECT [STRUCT(1 AS a)] AS x) select a, x, m.a from cte, UNNEST(x) AS m CROSS JOIN n;
503+
WITH cte AS (SELECT [STRUCT(1 AS a)] AS x) SELECT a AS a, cte.x AS x, m.a AS a FROM cte AS cte, UNNEST(cte.x) AS m CROSS JOIN n AS n;
499504

500505
# dialect: presto
501506
SELECT x.a, i.b FROM x CROSS JOIN UNNEST(SPLIT(CAST(b AS VARCHAR), ',')) AS i(b);

0 commit comments

Comments
 (0)