Skip to content

Commit 4722f99

Browse files
Enhancement: Improved Identifiers - casefolding, quoted values, and basic escaping (#5726)
Co-authored-by: Alan Cruickshank <[email protected]> Co-authored-by: Alan Cruickshank <[email protected]>
1 parent cbddd6d commit 4722f99

42 files changed

Lines changed: 1677 additions & 111 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/sqlfluff/core/default_config.cfg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,9 @@ aliasing = explicit
315315
# Aliasing preference for columns
316316
aliasing = explicit
317317

318+
[sqlfluff:rules:aliasing.unused]
319+
alias_case_check = dialect
320+
318321
[sqlfluff:rules:aliasing.length]
319322
min_alias_length = None
320323
max_alias_length = None

src/sqlfluff/core/parser/match_result.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,13 @@
1616
Tuple,
1717
Type,
1818
Union,
19-
cast,
2019
)
2120

2221
from sqlfluff.core.helpers.slice import slice_length
2322
from sqlfluff.core.parser.markers import PositionMarker
2423

2524
if TYPE_CHECKING: # pragma: no cover
26-
from sqlfluff.core.parser.segments import BaseSegment, MetaSegment, RawSegment
25+
from sqlfluff.core.parser.segments import BaseSegment, MetaSegment
2726

2827

2928
def _get_point_pos_at_idx(
@@ -287,19 +286,7 @@ def apply(self, segments: Tuple["BaseSegment", ...]) -> Tuple["BaseSegment", ...
287286
return result_segments
288287

289288
# Otherwise construct the subsegment
290-
new_seg: "BaseSegment"
291-
if self.matched_class.class_is_type("raw"):
292-
_raw_type = cast(Type["RawSegment"], self.matched_class)
293-
assert len(result_segments) == 1
294-
# TODO: Should this be a generic method on BaseSegment and RawSegment?
295-
# It feels a little strange to be this specific here.
296-
new_seg = _raw_type(
297-
raw=result_segments[0].raw,
298-
pos_marker=result_segments[0].pos_marker,
299-
**self.segment_kwargs,
300-
)
301-
else:
302-
new_seg = self.matched_class(
303-
segments=result_segments, **self.segment_kwargs
304-
)
289+
new_seg: "BaseSegment" = self.matched_class.from_result_segments(
290+
result_segments, self.segment_kwargs
291+
)
305292
return (new_seg,)

src/sqlfluff/core/parser/parsers.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
from abc import abstractmethod
7-
from typing import Any, Collection, Dict, Optional, Sequence, Tuple, Type
7+
from typing import Any, Callable, Collection, Dict, Optional, Sequence, Tuple, Type
88
from uuid import uuid4
99

1010
import regex
@@ -31,6 +31,7 @@ def __init__(
3131
optional: bool = False,
3232
# The following kwargs are passed on to the segment:
3333
trim_chars: Optional[Tuple[str, ...]] = None,
34+
casefold: Optional[Callable[[str], str]] = None,
3435
) -> None:
3536
self.raw_class = raw_class
3637
# Store instance_types rather than just type to allow
@@ -39,6 +40,7 @@ def __init__(
3940
self._instance_types: Tuple[str, ...] = (type or raw_class.type,)
4041
self.optional = optional
4142
self._trim_chars = trim_chars
43+
self._casefold = casefold
4244
# Generate a cache key
4345
self._cache_key = uuid4().hex
4446

@@ -63,6 +65,8 @@ def _match_at(self, idx: int) -> MatchResult:
6365
segment_kwargs["instance_types"] = self._instance_types
6466
if self._trim_chars:
6567
segment_kwargs["trim_chars"] = self._trim_chars
68+
if self._casefold:
69+
segment_kwargs["casefold"] = self._casefold
6670
return MatchResult(
6771
matched_slice=slice(idx, idx + 1),
6872
matched_class=self.raw_class,
@@ -80,6 +84,7 @@ def __init__(
8084
type: Optional[str] = None,
8185
optional: bool = False,
8286
trim_chars: Optional[Tuple[str, ...]] = None,
87+
casefold: Optional[Callable[[str], str]] = None,
8388
) -> None:
8489
"""Initialize a new instance of the class.
8590
@@ -89,6 +94,7 @@ def __init__(
8994
type (Optional[str]): The type of the instance.
9095
optional (bool): Whether the instance is optional.
9196
trim_chars (Optional[Tuple[str, ...]]): The characters to trim.
97+
casefold: (Optional[Callable[[str],str]]): The default casing used.
9298
9399
Returns:
94100
None
@@ -102,6 +108,7 @@ def __init__(
102108
raw_class=raw_class,
103109
optional=optional,
104110
trim_chars=trim_chars,
111+
casefold=casefold,
105112
)
106113
# NOTE: We override the instance types after initialising the base
107114
# class. We want to ensure that re-matching is possible by ensuring that
@@ -167,6 +174,7 @@ def __init__(
167174
type: Optional[str] = None,
168175
optional: bool = False,
169176
trim_chars: Optional[Tuple[str, ...]] = None,
177+
casefold: Optional[Callable[[str], str]] = None,
170178
):
171179
self.template = template.upper()
172180
# Create list version upfront to avoid recreating it multiple times.
@@ -176,6 +184,7 @@ def __init__(
176184
type=type,
177185
optional=optional,
178186
trim_chars=trim_chars,
187+
casefold=casefold,
179188
)
180189

181190
def __repr__(self) -> str:
@@ -217,6 +226,7 @@ def __init__(
217226
type: Optional[str] = None,
218227
optional: bool = False,
219228
trim_chars: Optional[Tuple[str, ...]] = None,
229+
casefold: Optional[Callable[[str], str]] = None,
220230
):
221231
self.templates = {template.upper() for template in templates}
222232
# Create list version upfront to avoid recreating it multiple times.
@@ -226,6 +236,7 @@ def __init__(
226236
type=type,
227237
optional=optional,
228238
trim_chars=trim_chars,
239+
casefold=casefold,
229240
)
230241

231242
def __repr__(self) -> str:
@@ -268,6 +279,7 @@ def __init__(
268279
optional: bool = False,
269280
anti_template: Optional[str] = None,
270281
trim_chars: Optional[Tuple[str, ...]] = None,
282+
casefold: Optional[Callable[[str], str]] = None,
271283
):
272284
# Store the optional anti-template
273285
self.template = template
@@ -280,13 +292,14 @@ def __init__(
280292
type=type,
281293
optional=optional,
282294
trim_chars=trim_chars,
295+
casefold=casefold,
283296
)
284297

285298
def __repr__(self) -> str:
286299
return f"<RegexParser: {self.template!r}>"
287300

288301
def simple(
289-
cls, parse_context: ParseContext, crumbs: Optional[Tuple[str, ...]] = None
302+
self, parse_context: ParseContext, crumbs: Optional[Tuple[str, ...]] = None
290303
) -> None:
291304
"""Does this matcher support a uppercase hash matching route?
292305

src/sqlfluff/core/parser/segments/base.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,10 @@ def get_raw_segments(self) -> List["RawSegment"]:
930930
"""Iterate raw segments, mostly for searching."""
931931
return [item for s in self.segments for item in s.raw_segments]
932932

933+
def raw_normalized(self, casefold: bool = True) -> str:
934+
"""Iterate raw segments, return normalized value."""
935+
return "".join(seg.raw_normalized(casefold) for seg in self.get_raw_segments())
936+
933937
def iter_segments(
934938
self, expanding: Optional[Sequence[str]] = None, pass_through: bool = False
935939
) -> Iterator["BaseSegment"]:
@@ -1229,6 +1233,15 @@ def edit(
12291233
"""Stub."""
12301234
raise NotImplementedError()
12311235

1236+
@classmethod
1237+
def from_result_segments(
1238+
cls,
1239+
result_segments: Tuple[BaseSegment, ...],
1240+
segment_kwargs: Dict[str, Any],
1241+
) -> "BaseSegment":
1242+
"""Create an instance of this class from a tuple of matched segments."""
1243+
return cls(segments=result_segments, **segment_kwargs)
1244+
12321245

12331246
class UnparsableSegment(BaseSegment):
12341247
"""This is a segment which can't be parsed. It indicates a error during parsing."""

src/sqlfluff/core/parser/segments/keyword.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""The KeywordSegment class."""
22

3-
from typing import List, Optional, Tuple
3+
from typing import Callable, List, Optional, Tuple, Union
44

55
from sqlfluff.core.parser.markers import PositionMarker
66
from sqlfluff.core.parser.segments.base import SourceFix
@@ -24,13 +24,20 @@ def __init__(
2424
instance_types: Tuple[str, ...] = (),
2525
source_fixes: Optional[List[SourceFix]] = None,
2626
trim_chars: Optional[Tuple[str, ...]] = None,
27+
quoted_value: Optional[Tuple[str, Union[int, str]]] = None,
28+
escape_replacements: Optional[List[Tuple[str, str]]] = None,
29+
casefold: Optional[Callable[[str], str]] = None,
2730
):
2831
"""If no other name is provided we extrapolate it from the raw."""
2932
super().__init__(
3033
raw=raw,
3134
pos_marker=pos_marker,
3235
instance_types=instance_types,
3336
source_fixes=source_fixes,
37+
trim_chars=trim_chars,
38+
quoted_value=quoted_value,
39+
escape_replacements=escape_replacements,
40+
casefold=casefold,
3441
)
3542

3643
def edit(

src/sqlfluff/core/parser/segments/raw.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
any children, and the output of the lexer.
55
"""
66

7-
from typing import Any, FrozenSet, List, Optional, Tuple
7+
from typing import Any, Callable, Dict, FrozenSet, List, Optional, Tuple, Union, cast
88
from uuid import uuid4
99

10+
import regex as re
11+
1012
from sqlfluff.core.parser.markers import PositionMarker
1113
from sqlfluff.core.parser.segments.base import BaseSegment, SourceFix
1214

@@ -35,6 +37,9 @@ def __init__(
3537
trim_chars: Optional[Tuple[str, ...]] = None,
3638
source_fixes: Optional[List[SourceFix]] = None,
3739
uuid: Optional[int] = None,
40+
quoted_value: Optional[Tuple[str, Union[int, str]]] = None,
41+
escape_replacements: Optional[List[Tuple[str, str]]] = None,
42+
casefold: Optional[Callable[[str], str]] = None,
3843
):
3944
"""Initialise raw segment.
4045
@@ -69,6 +74,10 @@ def __init__(
6974
self.representation = "<{}: ({}) {!r}>".format(
7075
self.__class__.__name__, self.pos_marker, self.raw
7176
)
77+
self.quoted_value = quoted_value
78+
self.escape_replacements = escape_replacements
79+
self.casefold = casefold
80+
self._raw_value: str = self._raw_normalized()
7281

7382
def __repr__(self) -> str:
7483
# This is calculated at __init__, because all elements are immutable
@@ -171,6 +180,40 @@ def raw_trimmed(self) -> str:
171180
return raw_buff
172181
return raw_buff
173182

183+
def _raw_normalized(self) -> str:
184+
"""Returns the string of the raw content's value.
185+
186+
E.g. This removes leading and trailing quote characters, removes escapes
187+
188+
Return:
189+
str: The raw content's value
190+
"""
191+
raw_buff = self.raw
192+
if self.quoted_value:
193+
_match = re.match(self.quoted_value[0], raw_buff)
194+
if _match:
195+
_group_match = _match.group(self.quoted_value[1])
196+
if isinstance(_group_match, str):
197+
raw_buff = _group_match
198+
if self.escape_replacements:
199+
for old, new in self.escape_replacements:
200+
raw_buff = re.sub(old, new, raw_buff)
201+
return raw_buff
202+
203+
def raw_normalized(self, casefold: bool = True) -> str:
204+
"""Returns a normalized string of the raw content.
205+
206+
E.g. This removes leading and trailing quote characters, removes escapes,
207+
optionally casefolds to the dialect's casing
208+
209+
Return:
210+
str: The normalized version of the raw content
211+
"""
212+
raw_buff = self._raw_value
213+
if self.casefold and casefold:
214+
raw_buff = self.casefold(raw_buff)
215+
return raw_buff
216+
174217
def stringify(
175218
self, ident: int = 0, tabsize: int = 4, code_only: bool = False
176219
) -> str:
@@ -223,9 +266,38 @@ def edit(
223266
instance_types=self.instance_types,
224267
trim_start=self.trim_start,
225268
trim_chars=self.trim_chars,
269+
quoted_value=self.quoted_value,
270+
escape_replacements=self.escape_replacements,
271+
casefold=self.casefold,
226272
source_fixes=source_fixes or self.source_fixes,
227273
)
228274

275+
def _get_raw_segment_kwargs(self) -> Dict[str, Any]:
276+
return {
277+
"quoted_value": self.quoted_value,
278+
"escape_replacements": self.escape_replacements,
279+
"casefold": self.casefold,
280+
}
281+
282+
# ################ CLASS METHODS
283+
284+
@classmethod
285+
def from_result_segments(
286+
cls,
287+
result_segments: Tuple[BaseSegment, ...],
288+
segment_kwargs: Dict[str, Any],
289+
) -> "RawSegment":
290+
"""Create a RawSegment from result segments."""
291+
assert len(result_segments) == 1
292+
raw_seg = cast("RawSegment", result_segments[0])
293+
new_segment_kwargs = raw_seg._get_raw_segment_kwargs()
294+
new_segment_kwargs.update(segment_kwargs)
295+
return cls(
296+
raw=raw_seg.raw,
297+
pos_marker=raw_seg.pos_marker,
298+
**new_segment_kwargs,
299+
)
300+
229301

230302
__all__ = [
231303
"PositionMarker",

src/sqlfluff/core/rules/config_info.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,16 @@
195195
"in the file."
196196
),
197197
},
198+
"alias_case_check": {
199+
"validation": [
200+
"dialect",
201+
"case_insensitive",
202+
"quoted_cs_naked_upper",
203+
"quoted_cs_naked_lower",
204+
"case_sensitive",
205+
],
206+
"definition": "How to handle comparison casefolding in an alias.",
207+
},
198208
"min_alias_length": {
199209
"validation": range(1000),
200210
"definition": (

0 commit comments

Comments
 (0)