-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
module: fxmodule: fx.passesOptimization passes written in FX (don't forget to select a more specific label)Optimization passes written in FX (don't forget to select a more specific label)triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.passes.split_module import split_module
import copy
def foo(x):
x.add_(1)
return None
g = make_fx(foo, tracing_mode="fake")(torch.randn(3,))
g.print_readable()
copy.deepcopy(g) # This works
def cb(node):
return 1
# sp_gm returns a sub-graph with no output.
sp_gm = split_module(g, None, cb)
sp_gm.print_readable()
copy.deepcopy(sp_gm) # This failsError
Traceback (most recent call last):
File "/opt/pytorch/lightning-thunder/test.py", line 20, in <module>
copy.deepcopy(sp_gm) # This fails
^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 143, in deepcopy
y = copier(memo)
^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 883, in __deepcopy__
fake_mod = _CodeOnlyModule(copy.deepcopy(self.__dict__, memo))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 143, in deepcopy
y = copier(memo)
^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 883, in __deepcopy__
fake_mod = _CodeOnlyModule(copy.deepcopy(self.__dict__, memo))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 143, in deepcopy
y = copier(memo)
^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph.py", line 976, in __deepcopy__
assert isinstance(output_vals, tuple)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionErrorReason -
The error occurs here
Lines 974 to 976 in 20af56d
| output_vals = g.graph_copy(self, val_map=memo, return_output_node=True) | |
| g._codegen = copy.deepcopy(self._codegen) | |
| assert isinstance(output_vals, tuple) |
It happens because, split_module produces a subgraph without an "output" node because of which graph_copy returns None.
Graph Copy Impl
Lines 955 to 962 in 20af56d
| for node in g.nodes: | |
| if node in val_map: | |
| continue | |
| if node.op == 'output': | |
| rv = map_arg(node.args[0], lambda n: val_map[n]) | |
| return rv if not return_output_node else (rv, node) | |
| val_map[node] = self.node_copy(node, lambda n : val_map[n]) | |
| return None |
Solution -
I can think of two fix, either we update the assert here to verify that the output is tuple or NoneType so that graphs without "output" node work.
If the invariant is that all fx graphs have an output node, then we can update split_module to do the correct thing.
Versions
main - f415855
kiya00kiya00
Metadata
Metadata
Assignees
Labels
module: fxmodule: fx.passesOptimization passes written in FX (don't forget to select a more specific label)Optimization passes written in FX (don't forget to select a more specific label)triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module