Skip to content

Commit b3f8c7b

Browse files
authored
Merge pull request #320 from asottile/new_class_super_v2
Rewrite old-style-class super calls
2 parents 4130821 + 1a76829 commit b3f8c7b

5 files changed

Lines changed: 206 additions & 62 deletions

File tree

pyupgrade/_ast_helpers.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import ast
22
import warnings
3+
from typing import Any
34
from typing import Container
45
from typing import Dict
6+
from typing import Iterable
57
from typing import Set
8+
from typing import Tuple
9+
from typing import Type
610
from typing import Union
711

812
from tokenize_rt import Offset
@@ -57,3 +61,43 @@ def is_async_listcomp(node: ast.ListComp) -> bool:
5761
any(gen.is_async for gen in node.generators) or
5862
contains_await(node)
5963
)
64+
65+
66+
def _all_isinstance(
67+
vals: Iterable[Any],
68+
tp: Union[Type[Any], Tuple[Type[Any], ...]],
69+
) -> bool:
70+
return all(isinstance(v, tp) for v in vals)
71+
72+
73+
def _fields_same(n1: ast.AST, n2: ast.AST) -> bool:
74+
for (a1, v1), (a2, v2) in zip(ast.iter_fields(n1), ast.iter_fields(n2)):
75+
# ignore ast attributes, they'll be covered by walk
76+
if a1 != a2:
77+
return False
78+
elif _all_isinstance((v1, v2), ast.AST):
79+
continue
80+
elif _all_isinstance((v1, v2), (list, tuple)):
81+
if len(v1) != len(v2):
82+
return False
83+
# ignore sequences which are all-ast, they'll be covered by walk
84+
elif _all_isinstance(v1, ast.AST) and _all_isinstance(v2, ast.AST):
85+
continue
86+
elif v1 != v2:
87+
return False
88+
elif v1 != v2:
89+
return False
90+
return True
91+
92+
93+
def targets_same(node1: ast.AST, node2: ast.AST) -> bool:
94+
for t1, t2 in zip(ast.walk(node1), ast.walk(node2)):
95+
# ignore `ast.Load` / `ast.Store`
96+
if _all_isinstance((t1, t2), ast.expr_context):
97+
continue
98+
elif type(t1) != type(t2):
99+
return False
100+
elif not _fields_same(t1, t2):
101+
return False
102+
else:
103+
return True

pyupgrade/_plugins/legacy.py

Lines changed: 39 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,19 @@
22
import collections
33
import contextlib
44
import functools
5-
from typing import Any
65
from typing import Dict
76
from typing import Generator
87
from typing import Iterable
98
from typing import List
109
from typing import Set
1110
from typing import Tuple
12-
from typing import Type
13-
from typing import Union
1411

1512
from tokenize_rt import Offset
1613
from tokenize_rt import Token
1714
from tokenize_rt import tokens_to_src
1815

1916
from pyupgrade._ast_helpers import ast_to_offset
17+
from pyupgrade._ast_helpers import targets_same
2018
from pyupgrade._data import register
2119
from pyupgrade._data import State
2220
from pyupgrade._data import TokenFunc
@@ -26,6 +24,7 @@
2624
from pyupgrade._token_helpers import find_token
2725

2826
FUNC_TYPES = (ast.Lambda, ast.FunctionDef, ast.AsyncFunctionDef)
27+
NON_LAMBDA_FUNC_TYPES = (ast.FunctionDef, ast.AsyncFunctionDef)
2928

3029

3130
def _fix_yield(i: int, tokens: List[Token]) -> None:
@@ -36,44 +35,13 @@ def _fix_yield(i: int, tokens: List[Token]) -> None:
3635
tokens[i:block.end] = [Token('CODE', f'yield from {container}\n')]
3736

3837

39-
def _all_isinstance(
40-
vals: Iterable[Any],
41-
tp: Union[Type[Any], Tuple[Type[Any], ...]],
42-
) -> bool:
43-
return all(isinstance(v, tp) for v in vals)
44-
45-
46-
def _fields_same(n1: ast.AST, n2: ast.AST) -> bool:
47-
for (a1, v1), (a2, v2) in zip(ast.iter_fields(n1), ast.iter_fields(n2)):
48-
# ignore ast attributes, they'll be covered by walk
49-
if a1 != a2:
50-
return False
51-
elif _all_isinstance((v1, v2), ast.AST):
52-
continue
53-
elif _all_isinstance((v1, v2), (list, tuple)):
54-
if len(v1) != len(v2):
55-
return False
56-
# ignore sequences which are all-ast, they'll be covered by walk
57-
elif _all_isinstance(v1, ast.AST) and _all_isinstance(v2, ast.AST):
58-
continue
59-
elif v1 != v2:
60-
return False
61-
elif v1 != v2:
62-
return False
63-
return True
64-
65-
66-
def _targets_same(target: ast.AST, yield_value: ast.AST) -> bool:
67-
for t1, t2 in zip(ast.walk(target), ast.walk(yield_value)):
68-
# ignore `ast.Load` / `ast.Store`
69-
if _all_isinstance((t1, t2), ast.expr_context):
70-
continue
71-
elif type(t1) != type(t2):
72-
return False
73-
elif not _fields_same(t1, t2):
74-
return False
75-
else:
76-
return True
38+
def _is_simple_base(base: ast.AST) -> bool:
39+
return (
40+
isinstance(base, ast.Name) or (
41+
isinstance(base, ast.Attribute) and
42+
_is_simple_base(base.value)
43+
)
44+
)
7745

7846

7947
class Scope:
@@ -92,6 +60,7 @@ class Visitor(ast.NodeVisitor):
9260
def __init__(self) -> None:
9361
self._scopes: List[Scope] = []
9462
self.super_offsets: Set[Offset] = set()
63+
self.old_super_offsets: Set[Tuple[Offset, str]] = set()
9564
self.yield_offsets: Set[Offset] = set()
9665

9766
@contextlib.contextmanager
@@ -137,7 +106,6 @@ def visit_Call(self, node: ast.Call) -> None:
137106
len(node.args) == 2 and
138107
isinstance(node.args[0], ast.Name) and
139108
isinstance(node.args[1], ast.Name) and
140-
# there are at least two scopes
141109
len(self._scopes) >= 2 and
142110
# the second to last scope is the class in arg1
143111
isinstance(self._scopes[-2].node, ast.ClassDef) and
@@ -148,6 +116,29 @@ def visit_Call(self, node: ast.Call) -> None:
148116
node.args[1].id == self._scopes[-1].node.args.args[0].arg
149117
):
150118
self.super_offsets.add(ast_to_offset(node))
119+
elif (
120+
# base.funcname(funcarg1, ...)
121+
isinstance(node.func, ast.Attribute) and
122+
len(node.args) >= 1 and
123+
isinstance(node.args[0], ast.Name) and
124+
len(self._scopes) >= 2 and
125+
# last stack is a function whose first argument is the first
126+
# argument of this function
127+
isinstance(self._scopes[-1].node, NON_LAMBDA_FUNC_TYPES) and
128+
node.func.attr == self._scopes[-1].node.name and
129+
node.func.attr != '__new__' and
130+
len(self._scopes[-1].node.args.args) >= 1 and
131+
node.args[0].id == self._scopes[-1].node.args.args[0].arg and
132+
# the function is an attribute of the contained class name
133+
isinstance(self._scopes[-2].node, ast.ClassDef) and
134+
len(self._scopes[-2].node.bases) == 1 and
135+
_is_simple_base(self._scopes[-2].node.bases[0]) and
136+
targets_same(
137+
self._scopes[-2].node.bases[0],
138+
node.func.value,
139+
)
140+
):
141+
self.old_super_offsets.add((ast_to_offset(node), node.func.attr))
151142

152143
self.generic_visit(node)
153144

@@ -159,7 +150,7 @@ def visit_For(self, node: ast.For) -> None:
159150
isinstance(node.body[0], ast.Expr) and
160151
isinstance(node.body[0].value, ast.Yield) and
161152
node.body[0].value.value is not None and
162-
_targets_same(node.target, node.body[0].value.value) and
153+
targets_same(node.target, node.body[0].value.value) and
163154
not node.orelse
164155
):
165156
offset = ast_to_offset(node)
@@ -198,5 +189,10 @@ def visit_Module(
198189
for offset in visitor.super_offsets:
199190
yield offset, super_func
200191

192+
for offset, func_name in visitor.old_super_offsets:
193+
template = f'super().{func_name}({{rest}})'
194+
callback = functools.partial(find_and_replace_call, template=template)
195+
yield offset, callback
196+
201197
for offset in visitor.yield_offsets:
202198
yield offset, _fix_yield

tests/ast_helpers_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import ast
2+
3+
from pyupgrade._ast_helpers import _fields_same
4+
from pyupgrade._ast_helpers import targets_same
5+
6+
7+
def test_targets_same():
8+
assert targets_same(ast.parse('global a, b'), ast.parse('global a, b'))
9+
assert not targets_same(ast.parse('global a'), ast.parse('global b'))
10+
11+
12+
def _get_body(expr):
13+
body = ast.parse(expr).body[0]
14+
assert isinstance(body, ast.Expr)
15+
return body.value
16+
17+
18+
def test_fields_same():
19+
assert not _fields_same(_get_body('x'), _get_body('1'))

tests/features/super_test.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,107 @@ def test_fix_super_noop(s):
122122
)
123123
def test_fix_super(s, expected):
124124
assert _fix_plugins(s, settings=Settings(min_version=(3,))) == expected
125+
126+
127+
@pytest.mark.parametrize(
128+
's',
129+
(
130+
pytest.param(
131+
'class C(B):\n'
132+
' def f(self):\n'
133+
' B.f(notself)\n',
134+
id='old style super, first argument is not first function arg',
135+
),
136+
pytest.param(
137+
'class C(B1, B2):\n'
138+
' def f(self):\n'
139+
' B1.f(self)\n',
140+
# TODO: is this safe to rewrite? I don't think so
141+
id='old-style super, multiple inheritance first class',
142+
),
143+
pytest.param(
144+
'class C(B1, B2):\n'
145+
' def f(self):\n'
146+
' B2.f(self)\n',
147+
# TODO: is this safe to rewrite? I don't think so
148+
id='old-style super, multiple inheritance not-first class',
149+
),
150+
pytest.param(
151+
'class C(Base):\n'
152+
' def f(self):\n'
153+
' return [Base.f(self) for _ in ()]\n',
154+
id='super in comprehension',
155+
),
156+
pytest.param(
157+
'class C(Base):\n'
158+
' def f(self):\n'
159+
' def g():\n'
160+
' Base.f(self)\n'
161+
' g()\n',
162+
id='super in nested functions',
163+
),
164+
pytest.param(
165+
'class C(not_simple()):\n'
166+
' def f(self):\n'
167+
' not_simple().f(self)\n',
168+
id='not a simple base',
169+
),
170+
pytest.param(
171+
'class C(a().b):\n'
172+
' def f(self):\n'
173+
' a().b.f(self)\n',
174+
id='non simple attribute base',
175+
),
176+
pytest.param(
177+
'class C:\n'
178+
' @classmethod\n'
179+
' def make(cls, instance):\n'
180+
' ...\n'
181+
'class D(C):\n'
182+
' def find(self):\n'
183+
' return C.make(self)\n',
184+
),
185+
pytest.param(
186+
'class C(tuple):\n'
187+
' def __new__(cls, arg):\n'
188+
' return tuple.__new__(cls, (arg,))\n',
189+
id='super() does not work properly for __new__',
190+
),
191+
),
192+
)
193+
def test_old_style_class_super_noop(s):
194+
assert _fix_plugins(s, settings=Settings(min_version=(3,))) == s
195+
196+
197+
@pytest.mark.parametrize(
198+
('s', 'expected'),
199+
(
200+
(
201+
'class C(B):\n'
202+
' def f(self):\n'
203+
' B.f(self)\n'
204+
' B.f(self, arg, arg)\n',
205+
'class C(B):\n'
206+
' def f(self):\n'
207+
' super().f()\n'
208+
' super().f(arg, arg)\n',
209+
),
210+
pytest.param(
211+
'class C(B):\n'
212+
' def f(self, a):\n'
213+
' B.f(\n'
214+
' self,\n'
215+
' a,\n'
216+
' )\n',
217+
218+
'class C(B):\n'
219+
' def f(self, a):\n'
220+
' super().f(\n'
221+
' a,\n'
222+
' )\n',
223+
id='multi-line super call',
224+
),
225+
),
226+
)
227+
def test_old_style_class_super(s, expected):
228+
assert _fix_plugins(s, settings=Settings(min_version=(3,))) == expected

tests/features/yield_from_test.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
1-
import ast
2-
31
import pytest
42

53
from pyupgrade._data import Settings
64
from pyupgrade._main import _fix_plugins
7-
from pyupgrade._plugins.legacy import _fields_same
8-
from pyupgrade._plugins.legacy import _targets_same
95

106

117
@pytest.mark.parametrize(
@@ -215,18 +211,3 @@ def test_fix_yield_from(s, expected):
215211
)
216212
def test_fix_yield_from_noop(s):
217213
assert _fix_plugins(s, settings=Settings(min_version=(3,))) == s
218-
219-
220-
def test_targets_same():
221-
assert _targets_same(ast.parse('global a, b'), ast.parse('global a, b'))
222-
assert not _targets_same(ast.parse('global a'), ast.parse('global b'))
223-
224-
225-
def _get_body(expr):
226-
body = ast.parse(expr).body[0]
227-
assert isinstance(body, ast.Expr)
228-
return body.value
229-
230-
231-
def test_fields_same():
232-
assert not _fields_same(_get_body('x'), _get_body('1'))

0 commit comments

Comments
 (0)