22import collections
33import contextlib
44import functools
5- from typing import Any
65from typing import Dict
76from typing import Generator
87from typing import Iterable
98from typing import List
109from typing import Set
1110from typing import Tuple
12- from typing import Type
13- from typing import Union
1411
1512from tokenize_rt import Offset
1613from tokenize_rt import Token
1714from tokenize_rt import tokens_to_src
1815
1916from pyupgrade ._ast_helpers import ast_to_offset
17+ from pyupgrade ._ast_helpers import targets_same
2018from pyupgrade ._data import register
2119from pyupgrade ._data import State
2220from pyupgrade ._data import TokenFunc
2624from pyupgrade ._token_helpers import find_token
2725
2826FUNC_TYPES = (ast .Lambda , ast .FunctionDef , ast .AsyncFunctionDef )
27+ NON_LAMBDA_FUNC_TYPES = (ast .FunctionDef , ast .AsyncFunctionDef )
2928
3029
3130def _fix_yield (i : int , tokens : List [Token ]) -> None :
@@ -36,44 +35,13 @@ def _fix_yield(i: int, tokens: List[Token]) -> None:
3635 tokens [i :block .end ] = [Token ('CODE' , f'yield from { container } \n ' )]
3736
3837
39- def _all_isinstance (
40- vals : Iterable [Any ],
41- tp : Union [Type [Any ], Tuple [Type [Any ], ...]],
42- ) -> bool :
43- return all (isinstance (v , tp ) for v in vals )
44-
45-
46- def _fields_same (n1 : ast .AST , n2 : ast .AST ) -> bool :
47- for (a1 , v1 ), (a2 , v2 ) in zip (ast .iter_fields (n1 ), ast .iter_fields (n2 )):
48- # ignore ast attributes, they'll be covered by walk
49- if a1 != a2 :
50- return False
51- elif _all_isinstance ((v1 , v2 ), ast .AST ):
52- continue
53- elif _all_isinstance ((v1 , v2 ), (list , tuple )):
54- if len (v1 ) != len (v2 ):
55- return False
56- # ignore sequences which are all-ast, they'll be covered by walk
57- elif _all_isinstance (v1 , ast .AST ) and _all_isinstance (v2 , ast .AST ):
58- continue
59- elif v1 != v2 :
60- return False
61- elif v1 != v2 :
62- return False
63- return True
64-
65-
66- def _targets_same (target : ast .AST , yield_value : ast .AST ) -> bool :
67- for t1 , t2 in zip (ast .walk (target ), ast .walk (yield_value )):
68- # ignore `ast.Load` / `ast.Store`
69- if _all_isinstance ((t1 , t2 ), ast .expr_context ):
70- continue
71- elif type (t1 ) != type (t2 ):
72- return False
73- elif not _fields_same (t1 , t2 ):
74- return False
75- else :
76- return True
38+ def _is_simple_base (base : ast .AST ) -> bool :
39+ return (
40+ isinstance (base , ast .Name ) or (
41+ isinstance (base , ast .Attribute ) and
42+ _is_simple_base (base .value )
43+ )
44+ )
7745
7846
7947class Scope :
@@ -92,6 +60,7 @@ class Visitor(ast.NodeVisitor):
9260 def __init__ (self ) -> None :
9361 self ._scopes : List [Scope ] = []
9462 self .super_offsets : Set [Offset ] = set ()
63+ self .old_super_offsets : Set [Tuple [Offset , str ]] = set ()
9564 self .yield_offsets : Set [Offset ] = set ()
9665
9766 @contextlib .contextmanager
@@ -137,7 +106,6 @@ def visit_Call(self, node: ast.Call) -> None:
137106 len (node .args ) == 2 and
138107 isinstance (node .args [0 ], ast .Name ) and
139108 isinstance (node .args [1 ], ast .Name ) and
140- # there are at least two scopes
141109 len (self ._scopes ) >= 2 and
142110 # the second to last scope is the class in arg1
143111 isinstance (self ._scopes [- 2 ].node , ast .ClassDef ) and
@@ -148,6 +116,29 @@ def visit_Call(self, node: ast.Call) -> None:
148116 node .args [1 ].id == self ._scopes [- 1 ].node .args .args [0 ].arg
149117 ):
150118 self .super_offsets .add (ast_to_offset (node ))
119+ elif (
120+ # base.funcname(funcarg1, ...)
121+ isinstance (node .func , ast .Attribute ) and
122+ len (node .args ) >= 1 and
123+ isinstance (node .args [0 ], ast .Name ) and
124+ len (self ._scopes ) >= 2 and
125+ # last stack is a function whose first argument is the first
126+ # argument of this function
127+ isinstance (self ._scopes [- 1 ].node , NON_LAMBDA_FUNC_TYPES ) and
128+ node .func .attr == self ._scopes [- 1 ].node .name and
129+ node .func .attr != '__new__' and
130+ len (self ._scopes [- 1 ].node .args .args ) >= 1 and
131+ node .args [0 ].id == self ._scopes [- 1 ].node .args .args [0 ].arg and
132+ # the function is an attribute of the contained class name
133+ isinstance (self ._scopes [- 2 ].node , ast .ClassDef ) and
134+ len (self ._scopes [- 2 ].node .bases ) == 1 and
135+ _is_simple_base (self ._scopes [- 2 ].node .bases [0 ]) and
136+ targets_same (
137+ self ._scopes [- 2 ].node .bases [0 ],
138+ node .func .value ,
139+ )
140+ ):
141+ self .old_super_offsets .add ((ast_to_offset (node ), node .func .attr ))
151142
152143 self .generic_visit (node )
153144
@@ -159,7 +150,7 @@ def visit_For(self, node: ast.For) -> None:
159150 isinstance (node .body [0 ], ast .Expr ) and
160151 isinstance (node .body [0 ].value , ast .Yield ) and
161152 node .body [0 ].value .value is not None and
162- _targets_same (node .target , node .body [0 ].value .value ) and
153+ targets_same (node .target , node .body [0 ].value .value ) and
163154 not node .orelse
164155 ):
165156 offset = ast_to_offset (node )
@@ -198,5 +189,10 @@ def visit_Module(
198189 for offset in visitor .super_offsets :
199190 yield offset , super_func
200191
192+ for offset , func_name in visitor .old_super_offsets :
193+ template = f'super().{ func_name } ({{rest}})'
194+ callback = functools .partial (find_and_replace_call , template = template )
195+ yield offset , callback
196+
201197 for offset in visitor .yield_offsets :
202198 yield offset , _fix_yield
0 commit comments