Skip to content

Commit 3ab6dfb

Browse files
authored
fix(clickhouse)!: Generalize COLUMNS(...) APPLY (#4161)
* fix(clickhouse): Generalize COLUMNS(...) APPLY * Merge exp.UnpackColumns into exp.Columns
1 parent f6d3bdd commit 3ab6dfb

File tree

6 files changed

+40
-7
lines changed

6 files changed

+40
-7
lines changed

sqlglot/dialects/clickhouse.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,7 @@ class Parser(parser.Parser):
396396
**parser.Parser.FUNCTION_PARSERS,
397397
"ARRAYJOIN": lambda self: self.expression(exp.Explode, this=self._parse_expression()),
398398
"QUANTILE": lambda self: self._parse_quantile(),
399+
"COLUMNS": lambda self: self._parse_columns(),
399400
}
400401

401402
FUNCTION_PARSERS.pop("MATCH")
@@ -776,6 +777,14 @@ def _parse_expression(self) -> t.Optional[exp.Expression]:
776777

777778
return this
778779

780+
def _parse_columns(self) -> exp.Expression:
781+
this: exp.Expression = self.expression(exp.Columns, this=self._parse_lambda())
782+
783+
while self._next and self._match_text_seq(")", "APPLY", "("):
784+
self._match(TokenType.R_PAREN)
785+
this = exp.Apply(this=this, expression=self._parse_var(any_token=True))
786+
return this
787+
779788
class Generator(generator.Generator):
780789
QUERY_HINTS = False
781790
STRUCT_DELIMITER = ("(", ")")

sqlglot/expressions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4984,6 +4984,10 @@ class ToNumber(Func):
49844984
}
49854985

49864986

4987+
class Columns(Func):
4988+
arg_types = {"this": True, "unpack": False}
4989+
4990+
49874991
# https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16#syntax
49884992
class Convert(Func):
49894993
arg_types = {"this": True, "expression": True, "style": False}
@@ -6244,10 +6248,6 @@ class UnixToTimeStr(Func):
62446248
pass
62456249

62466250

6247-
class UnpackColumns(Func):
6248-
pass
6249-
6250-
62516251
class Uuid(Func):
62526252
_sql_names = ["UUID", "GEN_RANDOM_UUID", "GENERATE_UUID", "UUID_STRING"]
62536253

sqlglot/generator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,6 @@ class Generator(metaclass=_Generator):
192192
exp.TransientProperty: lambda *_: "TRANSIENT",
193193
exp.Union: lambda self, e: self.set_operations(e),
194194
exp.UnloggedProperty: lambda *_: "UNLOGGED",
195-
exp.UnpackColumns: lambda self, e: f"*{self.sql(e.this)}",
196195
exp.Uuid: lambda *_: "UUID()",
197196
exp.UppercaseColumnConstraint: lambda *_: "UPPERCASE",
198197
exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]),
@@ -4372,3 +4371,10 @@ def grantprincipal_sql(self, expression: exp.GrantPrincipal):
43724371
kind = f"{kind} " if kind else ""
43734372

43744373
return f"{kind}{this}"
4374+
4375+
def columns_sql(self, expression: exp.Columns):
4376+
func = self.function_fallback_sql(expression)
4377+
if expression.args.get("unpack"):
4378+
func = f"*{func}"
4379+
4380+
return func

sqlglot/parser.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7400,9 +7400,12 @@ def _parse_normalize(self) -> exp.Normalize:
74007400
form=self._match(TokenType.COMMA) and self._parse_var(),
74017401
)
74027402

7403-
def _parse_star_ops(self) -> exp.Star | exp.UnpackColumns:
7403+
def _parse_star_ops(self) -> t.Optional[exp.Expression]:
74047404
if self._match_text_seq("COLUMNS", "(", advance=False):
7405-
return exp.UnpackColumns(this=self._parse_function())
7405+
this = self._parse_function()
7406+
if isinstance(this, exp.Columns):
7407+
this.set("unpack", True)
7408+
return this
74067409

74077410
return self.expression(
74087411
exp.Star,

tests/dialects/test_clickhouse.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,9 @@ def test_clickhouse(self):
525525
"SELECT COLUMNS('[jk]') APPLY(toString) APPLY(length) APPLY(max) FROM columns_transformers"
526526
)
527527
self.validate_identity("SELECT * APPLY(sum), COLUMNS('col') APPLY(sum) APPLY(avg) FROM t")
528+
self.validate_identity(
529+
"SELECT * FROM ABC WHERE hasAny(COLUMNS('.*field') APPLY(toUInt64) APPLY(to), (SELECT groupUniqArray(toUInt64(field))))"
530+
)
528531
self.validate_identity("SELECT col apply", "SELECT col AS apply")
529532

530533
def test_clickhouse_values(self):

tests/dialects/test_duckdb.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,18 @@ def test_duckdb(self):
846846
"SELECT id, STRUCT_PACK(*COLUMNS('m\\d')) AS measurements FROM many_measurements",
847847
"""SELECT id, {'_0': *COLUMNS('m\\d')} AS measurements FROM many_measurements""",
848848
)
849+
self.validate_identity("SELECT COLUMNS(c -> c LIKE '%num%') FROM numbers")
850+
self.validate_identity(
851+
"SELECT MIN(COLUMNS(* REPLACE (number + id AS number))), COUNT(COLUMNS(* EXCLUDE (number))) FROM numbers"
852+
)
853+
self.validate_identity("SELECT COLUMNS(*) + COLUMNS(*) FROM numbers")
854+
self.validate_identity("SELECT COLUMNS('(id|numbers?)') FROM numbers")
855+
self.validate_identity(
856+
"SELECT COALESCE(COLUMNS(['a', 'b', 'c'])) AS result FROM (SELECT NULL AS a, 42 AS b, TRUE AS c)"
857+
)
858+
self.validate_identity(
859+
"SELECT COALESCE(*COLUMNS(['a', 'b', 'c'])) AS result FROM (SELECT NULL AS a, 42 AS b, TRUE AS c)"
860+
)
849861

850862
def test_array_index(self):
851863
with self.assertLogs(helper_logger) as cm:

0 commit comments

Comments
 (0)