|
1 | 1 | """Implementation of Rule ST04.""" |
2 | 2 |
|
3 | | -from sqlfluff.core.parser import NewlineSegment, WhitespaceSegment |
| 3 | +from typing import List |
| 4 | + |
| 5 | +from sqlfluff.core.parser import BaseSegment, Indent, NewlineSegment, WhitespaceSegment |
4 | 6 | from sqlfluff.core.rules import BaseRule, LintFix, LintResult, RuleContext |
5 | 7 | from sqlfluff.core.rules.crawlers import SegmentSeekerCrawler |
6 | | -from sqlfluff.utils.functional import FunctionalContext, sp |
| 8 | +from sqlfluff.utils.functional import FunctionalContext, Segments, sp |
7 | 9 | from sqlfluff.utils.reflow.reindent import construct_single_indent |
8 | 10 |
|
9 | 11 |
|
@@ -53,6 +55,7 @@ def _eval(self, context: RuleContext) -> LintResult: |
53 | 55 | assert segment.select(sp.is_type("case_expression")) |
54 | 56 | case1_children = segment.children() |
55 | 57 | case1_first_case = case1_children.first(sp.is_keyword("CASE")).get() |
| 58 | + assert case1_first_case |
56 | 59 | case1_first_when = case1_children.first( |
57 | 60 | sp.is_type("when_clause", "else_clause") |
58 | 61 | ).get() |
@@ -104,36 +107,138 @@ def _eval(self, context: RuleContext) -> LintResult: |
104 | 107 | case1_to_delete = case1_children.select( |
105 | 108 | start_seg=case1_last_when, stop_seg=case1_else_clause_seg |
106 | 109 | ) |
| 110 | + # Restore any comments that were deleted |
| 111 | + after_last_comment_index = ( |
| 112 | + case1_to_delete.find(case1_to_delete.last(sp.is_comment()).get()) + 1 |
| 113 | + ) |
| 114 | + case1_comments_to_restore = case1_to_delete.select( |
| 115 | + stop_seg=case1_to_delete.get(after_last_comment_index) |
| 116 | + ) |
| 117 | + after_else_comment = case1_else_clause.children().select( |
| 118 | + select_if=sp.is_type("newline", "comment", "whitespace"), |
| 119 | + stop_seg=case1_else_expressions.get(), |
| 120 | + ) |
107 | 121 |
|
108 | 122 | # Delete the nested "CASE" expression. |
109 | | - fixes = case1_to_delete.apply(lambda seg: LintFix.delete(seg)) |
| 123 | + fixes = case1_to_delete.apply(LintFix.delete) |
110 | 124 |
|
111 | 125 | tab_space_size: int = context.config.get("tab_space_size", ["indentation"]) |
112 | 126 | indent_unit: str = context.config.get("indent_unit", ["indentation"]) |
113 | 127 |
|
114 | 128 | # Determine the indentation to use when we move the nested "WHEN" |
115 | 129 | # and "ELSE" clauses, based on the indentation of case1_last_when. |
116 | 130 | # If no whitespace segments found, use default indent. |
117 | | - indent = ( |
118 | | - case1_children.select(stop_seg=case1_last_when) |
119 | | - .reversed() |
120 | | - .first(sp.is_type("whitespace")) |
| 131 | + when_indent_str = self._get_indentation( |
| 132 | + case1_children, case1_last_when, tab_space_size, indent_unit |
121 | 133 | ) |
122 | | - indent_str = ( |
123 | | - "".join(seg.raw for seg in indent) |
124 | | - if indent |
125 | | - else construct_single_indent(indent_unit, tab_space_size) |
| 134 | + # Again determine indentation, but matching the "CASE"/"END" level. |
| 135 | + end_indent_str = self._get_indentation( |
| 136 | + case1_children, case1_first_case, tab_space_size, indent_unit |
126 | 137 | ) |
127 | 138 |
|
128 | 139 | # Move the nested "when" and "else" clauses after the last outer |
129 | 140 | # "when". |
130 | | - nested_clauses = case2.children(sp.is_type("when_clause", "else_clause")) |
131 | | - create_after_last_when = nested_clauses.apply( |
132 | | - lambda seg: [NewlineSegment(), WhitespaceSegment(indent_str), seg] |
| 141 | + nested_clauses = case2.children( |
| 142 | + sp.is_type("when_clause", "else_clause", "newline", "comment", "whitespace") |
133 | 143 | ) |
134 | | - segments = [item for sublist in create_after_last_when for item in sublist] |
| 144 | + |
| 145 | + # Rebuild the nested case statement. |
| 146 | + # Any comments after the last outer "WHEN" that were deleted |
| 147 | + segments = list(case1_comments_to_restore) |
| 148 | + # Any comments between the "ELSE" and nested "CASE" |
| 149 | + segments += self._rebuild_spacing(when_indent_str, after_else_comment) |
| 150 | + # The nested "WHEN", "ELSE" or "comments", with logical spacing |
| 151 | + segments += self._rebuild_spacing(when_indent_str, nested_clauses) |
135 | 152 | fixes.append(LintFix.create_after(case1_last_when, segments, source=segments)) |
136 | 153 |
|
137 | 154 | # Delete the outer "else" clause. |
138 | 155 | fixes.append(LintFix.delete(case1_else_clause_seg)) |
| 156 | + # Add spacing for any comments that may exist after the nested `END` |
| 157 | + # but only on that same line. |
| 158 | + fixes += self._nested_end_trailing_comment( |
| 159 | + case1_children, case1_else_clause_seg, end_indent_str |
| 160 | + ) |
139 | 161 | return LintResult(case2[0], fixes=fixes) |
| 162 | + |
| 163 | + def _get_indentation( |
| 164 | + self, |
| 165 | + parent_segments: Segments, |
| 166 | + segment: BaseSegment, |
| 167 | + tab_space_size: int, |
| 168 | + indent_unit: str, |
| 169 | + ) -> str: |
| 170 | + """Calculate the indentation level for rebuilding nested struct. |
| 171 | +
|
| 172 | + This is only a best attempt as the input may not be equally indented. The layout |
| 173 | + rules, if run, would resolve this. |
| 174 | + """ |
| 175 | + leading_whitespace = ( |
| 176 | + parent_segments.select(stop_seg=segment) |
| 177 | + .reversed() |
| 178 | + .first(sp.is_type("whitespace")) |
| 179 | + ) |
| 180 | + seg_indent = parent_segments.select(stop_seg=segment).last(sp.is_type("indent")) |
| 181 | + indent_level = 1 |
| 182 | + if ( |
| 183 | + seg_indent |
| 184 | + and (segment_indent := seg_indent.get()) |
| 185 | + and isinstance(segment_indent, Indent) |
| 186 | + ): |
| 187 | + indent_level = segment_indent.indent_val + 1 |
| 188 | + indent_str = ( |
| 189 | + "".join(seg.raw for seg in leading_whitespace) |
| 190 | + if leading_whitespace |
| 191 | + and (whitespace_seg := leading_whitespace.get()) |
| 192 | + and len(whitespace_seg.raw) > 1 |
| 193 | + else construct_single_indent(indent_unit, tab_space_size) * indent_level |
| 194 | + ) |
| 195 | + |
| 196 | + return indent_str |
| 197 | + |
| 198 | + def _nested_end_trailing_comment( |
| 199 | + self, |
| 200 | + case1_children: Segments, |
| 201 | + case1_else_clause_seg: BaseSegment, |
| 202 | + end_indent_str: str, |
| 203 | + ) -> List[LintFix]: |
| 204 | + """Prepend newline spacing to comments on the final nested `END` line.""" |
| 205 | + trailing_end = case1_children.select( |
| 206 | + start_seg=case1_else_clause_seg, |
| 207 | + loop_while=sp.not_(sp.is_type("newline")), |
| 208 | + ) |
| 209 | + fixes = trailing_end.select( |
| 210 | + sp.is_whitespace(), loop_while=sp.not_(sp.is_comment()) |
| 211 | + ).apply(LintFix.delete) |
| 212 | + first_comment = trailing_end.first(sp.is_comment()).get() |
| 213 | + if first_comment: |
| 214 | + segments = [NewlineSegment(), WhitespaceSegment(end_indent_str)] |
| 215 | + fixes.append(LintFix.create_before(first_comment, segments, segments)) |
| 216 | + return fixes |
| 217 | + |
| 218 | + def _rebuild_spacing( |
| 219 | + self, indent_str: str, nested_clauses: Segments |
| 220 | + ) -> List[BaseSegment]: |
| 221 | + buff = [] |
| 222 | + # If the first segment is a comment, add a newline |
| 223 | + prior_newline = nested_clauses.first(sp.not_(sp.is_whitespace())).any( |
| 224 | + sp.is_comment() |
| 225 | + ) |
| 226 | + prior_whitespace = "" |
| 227 | + for seg in nested_clauses: |
| 228 | + if seg.is_type("when_clause", "else_clause") or ( |
| 229 | + prior_newline and seg.is_comment |
| 230 | + ): |
| 231 | + buff += [NewlineSegment(), WhitespaceSegment(indent_str), seg] |
| 232 | + prior_newline = False |
| 233 | + prior_whitespace = "" |
| 234 | + elif seg.is_type("newline"): |
| 235 | + prior_newline = True |
| 236 | + prior_whitespace = "" |
| 237 | + elif not prior_newline and seg.is_comment: |
| 238 | + buff += [WhitespaceSegment(prior_whitespace), seg] |
| 239 | + prior_newline = False |
| 240 | + prior_whitespace = "" |
| 241 | + elif seg.is_whitespace: |
| 242 | + # Don't reset newline |
| 243 | + prior_whitespace = seg.raw |
| 244 | + return buff |
0 commit comments