Skip to content

Commit 12afc4a

Browse files
author
James Reed
committed
[FX] Rename Node._uses and refactor Node.all_input_nodes
[ghstack-poisoned]
1 parent 7780069 commit 12afc4a

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

torch/fx/node.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,10 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target',
5959
self.target = target # for method/module/function, the name of the method/module/function/attr
6060
# being invoked, e.g add, layer1, or torch.add
6161

62-
self._uses : Dict[Node, None] = {}
62+
# All `Node`-valued inputs. Key is the Node, value is don't-care.
63+
# The public API for this is `all_input_nodes`, this private attribute
64+
# should not be accessed directly.
65+
self._input_nodes : Dict[Node, None] = {}
6366
self._update_args_kwargs(map_arg(args, lambda x: x), map_arg(kwargs, lambda x: x)) # type: ignore
6467

6568
# All of the nodes that use the value produced by this Node
@@ -191,10 +194,7 @@ def all_input_nodes(self) -> List['Node']:
191194
List of ``Nodes`` that appear in the ``args`` and ``kwargs`` of this
192195
``Node``, in that order.
193196
"""
194-
all_nodes : List['Node'] = []
195-
map_arg(self.args, lambda n: all_nodes.append(n))
196-
map_arg(self.kwargs, lambda n: all_nodes.append(n))
197-
return all_nodes
197+
return list(self._input_nodes.keys())
198198

199199
def _update_args_kwargs(self, new_args : Tuple[Argument, ...], new_kwargs : Dict[str, Argument]):
200200
"""
@@ -203,14 +203,14 @@ def _update_args_kwargs(self, new_args : Tuple[Argument, ...], new_kwargs : Dict
203203
self._args = new_args
204204
self._kwargs = new_kwargs
205205

206-
for old_use in self._uses.keys():
206+
for old_use in self._input_nodes.keys():
207207
old_use.users.pop(self)
208208

209-
self._uses = {}
210-
map_arg(self._args, lambda n: self._uses.setdefault(n))
211-
map_arg(self._kwargs, lambda n: self._uses.setdefault(n))
209+
self._input_nodes = {}
210+
map_arg(self._args, lambda n: self._input_nodes.setdefault(n))
211+
map_arg(self._kwargs, lambda n: self._input_nodes.setdefault(n))
212212

213-
for new_use in self._uses.keys():
213+
for new_use in self._input_nodes.keys():
214214
new_use.users.setdefault(self)
215215

216216
def __repr__(self) -> str:

0 commit comments

Comments
 (0)