Skip to content

Commit c49cefa

Browse files
authored
Feat(mysql): support STRAIGHT_JOIN (#3623)
* Feat(mysql): support STRAIGHT_JOIN * PR feedback
1 parent caa3051 commit c49cefa

File tree

7 files changed

+26
-4
lines changed

7 files changed

+26
-4
lines changed

sqlglot/dialects/dialect.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,14 @@ def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[
177177

178178
klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms
179179

180+
if enum not in ("", "doris", "mysql"):
181+
klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | {
182+
TokenType.STRAIGHT_JOIN,
183+
}
184+
klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
185+
TokenType.STRAIGHT_JOIN,
186+
}
187+
180188
if not klass.SUPPORTS_SEMI_ANTI_JOIN:
181189
klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
182190
TokenType.ANTI,

sqlglot/generator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1970,7 +1970,9 @@ def join_sql(self, expression: exp.Join) -> str:
19701970

19711971
return f", {this_sql}"
19721972

1973-
op_sql = f"{op_sql} JOIN" if op_sql else "JOIN"
1973+
if op_sql != "STRAIGHT_JOIN":
1974+
op_sql = f"{op_sql} JOIN" if op_sql else "JOIN"
1975+
19741976
return f"{self.seg(op_sql)} {this_sql}{match_cond}{on_sql}"
19751977

19761978
def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str:

sqlglot/parser.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -588,11 +588,12 @@ class Parser(metaclass=_Parser):
588588
}
589589

590590
JOIN_KINDS = {
591+
TokenType.ANTI,
592+
TokenType.CROSS,
591593
TokenType.INNER,
592594
TokenType.OUTER,
593-
TokenType.CROSS,
594595
TokenType.SEMI,
595-
TokenType.ANTI,
596+
TokenType.STRAIGHT_JOIN,
596597
}
597598

598599
JOIN_HINTS: t.Set[str] = set()
@@ -3106,7 +3107,7 @@ def _parse_join(
31063107
index = self._index
31073108
method, side, kind = self._parse_join_parts()
31083109
hint = self._prev.text if self._match_texts(self.JOIN_HINTS) else None
3109-
join = self._match(TokenType.JOIN)
3110+
join = self._match(TokenType.JOIN) or (kind and kind.token_type == TokenType.STRAIGHT_JOIN)
31103111

31113112
if not skip_join_token and not join:
31123113
self._retreat(index)

sqlglot/tokens.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ class TokenType(AutoName):
361361
SORT_BY = auto()
362362
START_WITH = auto()
363363
STORAGE_INTEGRATION = auto()
364+
STRAIGHT_JOIN = auto()
364365
STRUCT = auto()
365366
TABLE_SAMPLE = auto()
366367
TAG = auto()
@@ -765,6 +766,7 @@ class Tokenizer(metaclass=_Tokenizer):
765766
"SOME": TokenType.SOME,
766767
"SORT BY": TokenType.SORT_BY,
767768
"START WITH": TokenType.START_WITH,
769+
"STRAIGHT_JOIN": TokenType.STRAIGHT_JOIN,
768770
"TABLE": TokenType.TABLE,
769771
"TABLESAMPLE": TokenType.TABLE_SAMPLE,
770772
"TEMP": TokenType.TEMPORARY,

tests/dialects/test_duckdb.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@ def test_duckdb(self):
1818
"WITH _data AS (SELECT [STRUCT(1 AS a, 2 AS b), STRUCT(2 AS a, 3 AS b)] AS col) SELECT col.b FROM _data, UNNEST(_data.col) AS col WHERE col.a = 1",
1919
)
2020

21+
self.validate_all(
22+
"SELECT straight_join",
23+
write={
24+
"duckdb": "SELECT straight_join",
25+
"mysql": "SELECT `straight_join`",
26+
},
27+
)
2128
self.validate_all(
2229
"SELECT CAST('2020-01-01 12:05:01' AS TIMESTAMP)",
2330
read={

tests/dialects/test_mysql.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def test_ddl(self):
117117
)
118118

119119
def test_identity(self):
120+
self.validate_identity("SELECT e.* FROM e STRAIGHT_JOIN p ON e.x = p.y")
120121
self.validate_identity("ALTER TABLE test_table ALTER COLUMN test_column SET DEFAULT 1")
121122
self.validate_identity("SELECT DATE_FORMAT(NOW(), '%Y-%m-%d %H:%i:00.0000')")
122123
self.validate_identity("SELECT @var1 := 1, @var2")

tests/fixtures/identity.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -872,3 +872,4 @@ SELECT name
872872
SELECT copy
873873
SELECT rollup
874874
SELECT unnest
875+
SELECT * FROM a STRAIGHT_JOIN b

0 commit comments

Comments
 (0)