@@ -59,7 +59,10 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target',
5959 self .target = target # for method/module/function, the name of the method/module/function/attr
6060 # being invoked, e.g add, layer1, or torch.add
6161
62- self ._uses : Dict [Node , None ] = {}
62+ # All `Node`-valued inputs. Key is the Node, value is don't-care.
63+ # The public API for this is `all_input_nodes`, this private attribute
64+ # should not be accessed directly.
65+ self ._input_nodes : Dict [Node , None ] = {}
6366 self ._update_args_kwargs (map_arg (args , lambda x : x ), map_arg (kwargs , lambda x : x )) # type: ignore
6467
6568 # All of the nodes that use the value produced by this Node
@@ -191,10 +194,7 @@ def all_input_nodes(self) -> List['Node']:
191194 List of ``Nodes`` that appear in the ``args`` and ``kwargs`` of this
192195 ``Node``, in that order.
193196 """
194- all_nodes : List ['Node' ] = []
195- map_arg (self .args , lambda n : all_nodes .append (n ))
196- map_arg (self .kwargs , lambda n : all_nodes .append (n ))
197- return all_nodes
197+ return list (self ._input_nodes .keys ())
198198
199199 def _update_args_kwargs (self , new_args : Tuple [Argument , ...], new_kwargs : Dict [str , Argument ]):
200200 """
@@ -203,14 +203,14 @@ def _update_args_kwargs(self, new_args : Tuple[Argument, ...], new_kwargs : Dict
203203 self ._args = new_args
204204 self ._kwargs = new_kwargs
205205
206- for old_use in self ._uses .keys ():
206+ for old_use in self ._input_nodes .keys ():
207207 old_use .users .pop (self )
208208
209- self ._uses = {}
210- map_arg (self ._args , lambda n : self ._uses .setdefault (n ))
211- map_arg (self ._kwargs , lambda n : self ._uses .setdefault (n ))
209+ self ._input_nodes = {}
210+ map_arg (self ._args , lambda n : self ._input_nodes .setdefault (n ))
211+ map_arg (self ._kwargs , lambda n : self ._input_nodes .setdefault (n ))
212212
213- for new_use in self ._uses .keys ():
213+ for new_use in self ._input_nodes .keys ():
214214 new_use .users .setdefault (self )
215215
216216 def __repr__ (self ) -> str :
0 commit comments