|
15 | 15 | from sqlglot.helper import seq_get |
16 | 16 | from sqlglot.parser import OPTIONS_TYPE, build_coalesce |
17 | 17 | from sqlglot.tokens import TokenType |
| 18 | +from sqlglot.errors import ParseError |
18 | 19 |
|
19 | 20 | if t.TYPE_CHECKING: |
20 | 21 | from sqlglot._typing import E |
@@ -205,6 +206,57 @@ def _parse_json_array(self, expr_type: t.Type[E], **kwargs) -> E: |
205 | 206 | ) |
206 | 207 |
|
207 | 208 | def _parse_hint(self) -> t.Optional[exp.Hint]: |
| 209 | + start_index = self._index |
| 210 | + should_fallback_to_string = False |
| 211 | + |
| 212 | + if not self._match(TokenType.HINT): |
| 213 | + return None |
| 214 | + |
| 215 | + hints = [] |
| 216 | + |
| 217 | + try: |
| 218 | + for hint in iter( |
| 219 | + lambda: self._parse_csv( |
| 220 | + lambda: self._parse_hint_function_call() or self._parse_var(upper=True), |
| 221 | + ), |
| 222 | + [], |
| 223 | + ): |
| 224 | + hints.extend(hint) |
| 225 | + except ParseError: |
| 226 | + should_fallback_to_string = True |
| 227 | + |
| 228 | + if not self._match_pair(TokenType.STAR, TokenType.SLASH): |
| 229 | + should_fallback_to_string = True |
| 230 | + |
| 231 | + if should_fallback_to_string: |
| 232 | + self._retreat(start_index) |
| 233 | + return self._parse_hint_fallback_to_string() |
| 234 | + |
| 235 | + return self.expression(exp.Hint, expressions=hints) |
| 236 | + |
| 237 | + def _parse_hint_function_call(self) -> t.Optional[exp.Expression]: |
| 238 | + if not self._curr or not self._next or self._next.token_type != TokenType.L_PAREN: |
| 239 | + return None |
| 240 | + |
| 241 | + this = self._curr.text |
| 242 | + |
| 243 | + self._advance(2) |
| 244 | + args = self._parse_hint_args() |
| 245 | + this = self.expression(exp.Anonymous, this=this, expressions=args) |
| 246 | + self._match_r_paren(this) |
| 247 | + return this |
| 248 | + |
| 249 | + def _parse_hint_args(self): |
| 250 | + args = [] |
| 251 | + result = self._parse_var() |
| 252 | + |
| 253 | + while result: |
| 254 | + args.append(result) |
| 255 | + result = self._parse_var() |
| 256 | + |
| 257 | + return args |
| 258 | + |
| 259 | + def _parse_hint_fallback_to_string(self) -> t.Optional[exp.Hint]: |
208 | 260 | if self._match(TokenType.HINT): |
209 | 261 | start = self._curr |
210 | 262 | while self._curr and not self._match_pair(TokenType.STAR, TokenType.SLASH): |
@@ -271,6 +323,7 @@ class Generator(generator.Generator): |
271 | 323 | LAST_DAY_SUPPORTS_DATE_PART = False |
272 | 324 | SUPPORTS_SELECT_INTO = True |
273 | 325 | TZ_TO_WITH_TIME_ZONE = True |
| 326 | + QUERY_HINT_SEP = " " |
274 | 327 |
|
275 | 328 | TYPE_MAPPING = { |
276 | 329 | **generator.Generator.TYPE_MAPPING, |
@@ -370,3 +423,23 @@ def into_sql(self, expression: exp.Into) -> str: |
370 | 423 | return f"{self.seg(into)} {self.sql(expression, 'this')}" |
371 | 424 |
|
372 | 425 | return f"{self.seg(into)} {self.expressions(expression)}" |
| 426 | + |
| 427 | + def hint_sql(self, expression: exp.Hint) -> str: |
| 428 | + expressions = [] |
| 429 | + |
| 430 | + for expression in expression.expressions: |
| 431 | + if isinstance(expression, exp.Anonymous): |
| 432 | + formatted_args = self._format_hint_function_args(*expression.expressions) |
| 433 | + expressions.append(f"{self.sql(expression, 'this')}({formatted_args})") |
| 434 | + else: |
| 435 | + expressions.append(self.sql(expression)) |
| 436 | + |
| 437 | + return f" /*+ {self.expressions(sqls=expressions, sep=self.QUERY_HINT_SEP).strip()} */" |
| 438 | + |
| 439 | + def _format_hint_function_args(self, *args: t.Optional[str | exp.Expression]) -> str: |
| 440 | + arg_sqls = tuple(self.sql(arg) for arg in args) |
| 441 | + if self.pretty and self.too_wide(arg_sqls): |
| 442 | + return self.indent( |
| 443 | + "\n" + "\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True |
| 444 | + ) |
| 445 | + return " ".join(arg_sqls) |
0 commit comments