|
25 | 25 | from jinja2 import DebugUndefined |
26 | 26 | from jinja2.sandbox import SandboxedEnvironment |
27 | 27 | from sqlalchemy.engine.interfaces import Dialect |
| 28 | +from sqlalchemy.sql.expression import bindparam |
28 | 29 | from sqlalchemy.types import String |
29 | 30 |
|
30 | 31 | from superset.constants import LRU_CACHE_MAX_SIZE |
@@ -396,23 +397,39 @@ def validate_template_context( |
396 | 397 | return validate_context_types(context) |
397 | 398 |
|
398 | 399 |
|
399 | | -def where_in(values: list[Any], mark: str = "'") -> str: |
400 | | - """ |
401 | | - Given a list of values, build a parenthesis list suitable for an IN expression. |
| 400 | +class WhereInMacro: # pylint: disable=too-few-public-methods |
| 401 | + def __init__(self, dialect: Dialect): |
| 402 | + self.dialect = dialect |
402 | 403 |
|
403 | | - >>> where_in([1, "b", 3]) |
404 | | - (1, 'b', 3) |
| 404 | + def __call__(self, values: list[Any], mark: Optional[str] = None) -> str: |
| 405 | + """ |
| 406 | + Given a list of values, build a parenthesis list suitable for an IN expression. |
405 | 407 |
|
406 | | - """ |
| 408 | + >>> from sqlalchemy.dialects import mysql |
| 409 | + >>> where_in = WhereInMacro(dialect=mysql.dialect()) |
| 410 | + >>> where_in([1, "Joe's", 3]) |
| 411 | + (1, 'Joe''s', 3) |
407 | 412 |
|
408 | | - def quote(value: Any) -> str: |
409 | | - if isinstance(value, str): |
410 | | - value = value.replace(mark, mark * 2) |
411 | | - return f"{mark}{value}{mark}" |
412 | | - return str(value) |
| 413 | + """ |
| 414 | + binds = [bindparam(f"value_{i}", value) for i, value in enumerate(values)] |
| 415 | + string_representations = [ |
| 416 | + str( |
| 417 | + bind.compile( |
| 418 | + dialect=self.dialect, compile_kwargs={"literal_binds": True} |
| 419 | + ) |
| 420 | + ) |
| 421 | + for bind in binds |
| 422 | + ] |
| 423 | + joined_values = ", ".join(string_representations) |
| 424 | + result = f"({joined_values})" |
| 425 | + |
| 426 | + if mark: |
| 427 | + result += ( |
| 428 | + "\n-- WARNING: the `mark` parameter was removed from the `where_in` " |
| 429 | + "macro for security reasons\n" |
| 430 | + ) |
413 | 431 |
|
414 | | - joined_values = ", ".join(quote(value) for value in values) |
415 | | - return f"({joined_values})" |
| 432 | + return result |
416 | 433 |
|
417 | 434 |
|
418 | 435 | class BaseTemplateProcessor: |
@@ -448,7 +465,7 @@ def __init__( |
448 | 465 | self.set_context(**kwargs) |
449 | 466 |
|
450 | 467 | # custom filters |
451 | | - self._env.filters["where_in"] = where_in |
| 468 | + self._env.filters["where_in"] = WhereInMacro(database.get_dialect()) |
452 | 469 |
|
453 | 470 | def set_context(self, **kwargs: Any) -> None: |
454 | 471 | self._context.update(kwargs) |
|
0 commit comments