Skip to content

Commit f798696

Browse files
James Reedfacebook-github-bot
authored andcommitted
[FX] Delete values after their last use (#48631)
Summary: Pull Request resolved: #48631 Test Plan: Imported from OSS Reviewed By: zdevito Differential Revision: D25235981 Pulled By: jamesr66a fbshipit-source-id: f79d8873d3ad1ad90b5bd6367fc6119925f116e9
1 parent cff1ff7 commit f798696

File tree

1 file changed

+42
-10
lines changed

1 file changed

+42
-10
lines changed

torch/fx/graph.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,35 @@ def type_repr(o : Any):
487487
type_repr(sub_type)
488488
return typename
489489

490-
for node in self.nodes:
490+
491+
# Run through reverse nodes and record the first instance of a use
492+
# of a given node. This represents the *last* use of the node in the
493+
# execution order of the program, which we will use to free unused
494+
# values
495+
node_to_last_use : Dict[Node, Node] = {}
496+
user_to_last_uses : Dict[Node, List[Node]] = {}
497+
498+
def register_last_uses(n : Node, user : Node):
499+
if n not in node_to_last_use:
500+
node_to_last_use[n] = user
501+
user_to_last_uses.setdefault(user, []).append(n)
502+
503+
for node in reversed(self.nodes):
504+
map_arg(node.args, lambda n: register_last_uses(n, node))
505+
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
506+
507+
def delete_unused_values(user : Node):
508+
"""
509+
Delete values after their last use. This ensures that values that are
510+
not used in the remainder of the code are freed and the memory usage
511+
of the code is optimal.
512+
"""
513+
nodes_to_delete = user_to_last_uses.get(user, [])
514+
if len(nodes_to_delete):
515+
to_delete_str = ' = '.join([n.name for n in nodes_to_delete] + ['None'])
516+
body.append(f'{to_delete_str}\n')
517+
518+
def emit_node(node : Node):
491519
if node.op == 'placeholder':
492520
assert isinstance(node.target, str)
493521
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
@@ -496,20 +524,20 @@ def type_repr(o : Any):
496524
raw_name = node.target.replace('*', '')
497525
if raw_name != node.name:
498526
body.append(f'{node.name} = {raw_name}\n')
499-
continue
527+
return
500528
elif node.op == 'call_method':
501529
assert isinstance(node.target, str)
502530
body.append(
503531
f'{node.name} = {_format_target(repr(node.args[0]), node.target)}'
504532
f'({_format_args(node.args[1:], node.kwargs)})\n')
505-
continue
533+
return
506534
elif node.op == 'call_function':
507535
assert callable(node.target)
508536
# pretty print operators
509537
if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
510538
assert isinstance(node.args, tuple)
511539
body.append(f'{node.name} = {magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}\n')
512-
continue
540+
return
513541
qualified_name = get_qualified_name(node.target)
514542
register_modules_used(qualified_name)
515543
if qualified_name == 'getattr' and \
@@ -518,24 +546,28 @@ def type_repr(o : Any):
518546
node.args[1].isidentifier():
519547
# pretty print attribute access
520548
body.append(f'{node.name} = {_format_target(repr(node.args[0]), node.args[1])}\n')
521-
continue
549+
return
522550
body.append(f'{node.name} = {qualified_name}({_format_args(node.args, node.kwargs)})\n')
523-
continue
551+
return
524552
elif node.op == 'call_module':
525553
assert isinstance(node.target, str)
526554
body.append(f'{node.name} = {_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})\n')
527-
continue
555+
return
528556
elif node.op == 'get_attr':
529557
assert isinstance(node.target, str)
530558
body.append(f'{node.name} = {_format_target(root_module, node.target)}\n')
531-
continue
559+
return
532560
elif node.op == 'output':
533561
if node.type is not None:
534562
maybe_return_annotation = f" -> {type_repr(node.type)}"
535-
body.append(f'return {repr(node.args[0])}')
536-
continue
563+
body.append(f'return {repr(node.args[0])}\n')
564+
return
537565
raise NotImplementedError(f'node: {node.op} {node.target}')
538566

567+
for node in self.nodes:
568+
emit_node(node)
569+
delete_unused_values(node)
570+
539571
# repr() for inf and nan floating point values aren't parseable by
540572
# python as literals. Explicitly import the names from the `math` module.
541573
import_strs = [f'import {name}' for name in sorted(modules_used)]

0 commit comments

Comments
 (0)