|
4 | 4 | any children, and the output of the lexer. |
5 | 5 | """ |
6 | 6 |
|
7 | | -from typing import Any, FrozenSet, List, Optional, Tuple |
| 7 | +from typing import Any, Callable, Dict, FrozenSet, List, Optional, Tuple, Union, cast |
8 | 8 | from uuid import uuid4 |
9 | 9 |
|
| 10 | +import regex as re |
| 11 | + |
10 | 12 | from sqlfluff.core.parser.markers import PositionMarker |
11 | 13 | from sqlfluff.core.parser.segments.base import BaseSegment, SourceFix |
12 | 14 |
|
@@ -35,6 +37,9 @@ def __init__( |
35 | 37 | trim_chars: Optional[Tuple[str, ...]] = None, |
36 | 38 | source_fixes: Optional[List[SourceFix]] = None, |
37 | 39 | uuid: Optional[int] = None, |
| 40 | + quoted_value: Optional[Tuple[str, Union[int, str]]] = None, |
| 41 | + escape_replacements: Optional[List[Tuple[str, str]]] = None, |
| 42 | + casefold: Optional[Callable[[str], str]] = None, |
38 | 43 | ): |
39 | 44 | """Initialise raw segment. |
40 | 45 |
|
@@ -69,6 +74,10 @@ def __init__( |
69 | 74 | self.representation = "<{}: ({}) {!r}>".format( |
70 | 75 | self.__class__.__name__, self.pos_marker, self.raw |
71 | 76 | ) |
| 77 | + self.quoted_value = quoted_value |
| 78 | + self.escape_replacements = escape_replacements |
| 79 | + self.casefold = casefold |
| 80 | + self._raw_value: str = self._raw_normalized() |
72 | 81 |
|
73 | 82 | def __repr__(self) -> str: |
74 | 83 | # This is calculated at __init__, because all elements are immutable |
@@ -171,6 +180,40 @@ def raw_trimmed(self) -> str: |
171 | 180 | return raw_buff |
172 | 181 | return raw_buff |
173 | 182 |
|
| 183 | + def _raw_normalized(self) -> str: |
| 184 | + """Returns the string of the raw content's value. |
| 185 | +
|
| 186 | + E.g. This removes leading and trailing quote characters, removes escapes |
| 187 | +
|
| 188 | + Return: |
| 189 | + str: The raw content's value |
| 190 | + """ |
| 191 | + raw_buff = self.raw |
| 192 | + if self.quoted_value: |
| 193 | + _match = re.match(self.quoted_value[0], raw_buff) |
| 194 | + if _match: |
| 195 | + _group_match = _match.group(self.quoted_value[1]) |
| 196 | + if isinstance(_group_match, str): |
| 197 | + raw_buff = _group_match |
| 198 | + if self.escape_replacements: |
| 199 | + for old, new in self.escape_replacements: |
| 200 | + raw_buff = re.sub(old, new, raw_buff) |
| 201 | + return raw_buff |
| 202 | + |
| 203 | + def raw_normalized(self, casefold: bool = True) -> str: |
| 204 | + """Returns a normalized string of the raw content. |
| 205 | +
|
| 206 | + E.g. This removes leading and trailing quote characters, removes escapes, |
| 207 | + optionally casefolds to the dialect's casing |
| 208 | +
|
| 209 | + Return: |
| 210 | + str: The normalized version of the raw content |
| 211 | + """ |
| 212 | + raw_buff = self._raw_value |
| 213 | + if self.casefold and casefold: |
| 214 | + raw_buff = self.casefold(raw_buff) |
| 215 | + return raw_buff |
| 216 | + |
174 | 217 | def stringify( |
175 | 218 | self, ident: int = 0, tabsize: int = 4, code_only: bool = False |
176 | 219 | ) -> str: |
@@ -223,9 +266,38 @@ def edit( |
223 | 266 | instance_types=self.instance_types, |
224 | 267 | trim_start=self.trim_start, |
225 | 268 | trim_chars=self.trim_chars, |
| 269 | + quoted_value=self.quoted_value, |
| 270 | + escape_replacements=self.escape_replacements, |
| 271 | + casefold=self.casefold, |
226 | 272 | source_fixes=source_fixes or self.source_fixes, |
227 | 273 | ) |
228 | 274 |
|
| 275 | + def _get_raw_segment_kwargs(self) -> Dict[str, Any]: |
| 276 | + return { |
| 277 | + "quoted_value": self.quoted_value, |
| 278 | + "escape_replacements": self.escape_replacements, |
| 279 | + "casefold": self.casefold, |
| 280 | + } |
| 281 | + |
| 282 | + # ################ CLASS METHODS |
| 283 | + |
| 284 | + @classmethod |
| 285 | + def from_result_segments( |
| 286 | + cls, |
| 287 | + result_segments: Tuple[BaseSegment, ...], |
| 288 | + segment_kwargs: Dict[str, Any], |
| 289 | + ) -> "RawSegment": |
| 290 | + """Create a RawSegment from result segments.""" |
| 291 | + assert len(result_segments) == 1 |
| 292 | + raw_seg = cast("RawSegment", result_segments[0]) |
| 293 | + new_segment_kwargs = raw_seg._get_raw_segment_kwargs() |
| 294 | + new_segment_kwargs.update(segment_kwargs) |
| 295 | + return cls( |
| 296 | + raw=raw_seg.raw, |
| 297 | + pos_marker=raw_seg.pos_marker, |
| 298 | + **new_segment_kwargs, |
| 299 | + ) |
| 300 | + |
229 | 301 |
|
230 | 302 | __all__ = [ |
231 | 303 | "PositionMarker", |
|
0 commit comments