Skip to content

Commit 5659c87

Browse files
fix: DB-specific quoting in Jinja macro (#25779)
1 parent ed14f36 commit 5659c87

File tree

2 files changed

+38
-16
lines changed

2 files changed

+38
-16
lines changed

superset/jinja_context.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from jinja2 import DebugUndefined
2626
from jinja2.sandbox import SandboxedEnvironment
2727
from sqlalchemy.engine.interfaces import Dialect
28+
from sqlalchemy.sql.expression import bindparam
2829
from sqlalchemy.types import String
2930

3031
from superset.constants import LRU_CACHE_MAX_SIZE
@@ -396,23 +397,39 @@ def validate_template_context(
396397
return validate_context_types(context)
397398

398399

399-
def where_in(values: list[Any], mark: str = "'") -> str:
400-
"""
401-
Given a list of values, build a parenthesis list suitable for an IN expression.
400+
class WhereInMacro: # pylint: disable=too-few-public-methods
401+
def __init__(self, dialect: Dialect):
402+
self.dialect = dialect
402403

403-
>>> where_in([1, "b", 3])
404-
(1, 'b', 3)
404+
def __call__(self, values: list[Any], mark: Optional[str] = None) -> str:
405+
"""
406+
Given a list of values, build a parenthesis list suitable for an IN expression.
405407
406-
"""
408+
>>> from sqlalchemy.dialects import mysql
409+
>>> where_in = WhereInMacro(dialect=mysql.dialect())
410+
>>> where_in([1, "Joe's", 3])
411+
(1, 'Joe''s', 3)
407412
408-
def quote(value: Any) -> str:
409-
if isinstance(value, str):
410-
value = value.replace(mark, mark * 2)
411-
return f"{mark}{value}{mark}"
412-
return str(value)
413+
"""
414+
binds = [bindparam(f"value_{i}", value) for i, value in enumerate(values)]
415+
string_representations = [
416+
str(
417+
bind.compile(
418+
dialect=self.dialect, compile_kwargs={"literal_binds": True}
419+
)
420+
)
421+
for bind in binds
422+
]
423+
joined_values = ", ".join(string_representations)
424+
result = f"({joined_values})"
425+
426+
if mark:
427+
result += (
428+
"\n-- WARNING: the `mark` parameter was removed from the `where_in` "
429+
"macro for security reasons\n"
430+
)
413431

414-
joined_values = ", ".join(quote(value) for value in values)
415-
return f"({joined_values})"
432+
return result
416433

417434

418435
class BaseTemplateProcessor:
@@ -448,7 +465,7 @@ def __init__(
448465
self.set_context(**kwargs)
449466

450467
# custom filters
451-
self._env.filters["where_in"] = where_in
468+
self._env.filters["where_in"] = WhereInMacro(database.get_dialect())
452469

453470
def set_context(self, **kwargs: Any) -> None:
454471
self._context.update(kwargs)

tests/unit_tests/jinja_context_test.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,22 @@
2020

2121
import pytest
2222
from pytest_mock import MockFixture
23+
from sqlalchemy.dialects import mysql
2324

2425
from superset.datasets.commands.exceptions import DatasetNotFoundError
25-
from superset.jinja_context import dataset_macro, where_in
26+
from superset.jinja_context import dataset_macro, WhereInMacro
2627

2728

2829
def test_where_in() -> None:
2930
"""
3031
Test the ``where_in`` Jinja2 filter.
3132
"""
33+
where_in = WhereInMacro(mysql.dialect())
3234
assert where_in([1, "b", 3]) == "(1, 'b', 3)"
33-
assert where_in([1, "b", 3], '"') == '(1, "b", 3)'
35+
assert where_in([1, "b", 3], '"') == (
36+
"(1, 'b', 3)\n-- WARNING: the `mark` parameter was removed from the "
37+
"`where_in` macro for security reasons\n"
38+
)
3439
assert where_in(["O'Malley's"]) == "('O''Malley''s')"
3540

3641

0 commit comments

Comments
 (0)