Skip to content

[fx] Unable to deepcopy a graph produced by fx's split_module #138207

@kshitij12345

Description

@kshitij12345

🐛 Describe the bug

import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.passes.split_module import split_module
import copy

def foo(x):
    x.add_(1)
    return None

g = make_fx(foo, tracing_mode="fake")(torch.randn(3,))
g.print_readable()
copy.deepcopy(g)  # This works

def cb(node):
    return 1

# sp_gm returns a sub-graph with no output.
sp_gm = split_module(g, None, cb)  
sp_gm.print_readable()
copy.deepcopy(sp_gm)  # This fails
Error
Traceback (most recent call last):
  File "/opt/pytorch/lightning-thunder/test.py", line 20, in <module>
    copy.deepcopy(sp_gm)  # This fails
    ^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 143, in deepcopy
    y = copier(memo)
        ^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 883, in __deepcopy__
    fake_mod = _CodeOnlyModule(copy.deepcopy(self.__dict__, memo))
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 143, in deepcopy
    y = copier(memo)
        ^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 883, in __deepcopy__
    fake_mod = _CodeOnlyModule(copy.deepcopy(self.__dict__, memo))
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 143, in deepcopy
    y = copier(memo)
        ^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph.py", line 976, in __deepcopy__
    assert isinstance(output_vals, tuple)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError

Reason -
The error occurs here

pytorch/torch/fx/graph.py

Lines 974 to 976 in 20af56d

output_vals = g.graph_copy(self, val_map=memo, return_output_node=True)
g._codegen = copy.deepcopy(self._codegen)
assert isinstance(output_vals, tuple)

It happens because, split_module produces a subgraph without an "output" node because of which graph_copy returns None.

Graph Copy Impl

pytorch/torch/fx/graph.py

Lines 955 to 962 in 20af56d

for node in g.nodes:
if node in val_map:
continue
if node.op == 'output':
rv = map_arg(node.args[0], lambda n: val_map[n])
return rv if not return_output_node else (rv, node)
val_map[node] = self.node_copy(node, lambda n : val_map[n])
return None

Solution -
I can think of two fix, either we update the assert here to verify that the output is tuple or NoneType so that graphs without "output" node work.

If the invariant is that all fx graphs have an output node, then we can update split_module to do the correct thing.

Versions

main - f415855

cc @ezyang @SherlockNoMad @EikanWang @jgong5 @wenzhe-nrv

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: fxmodule: fx.passesOptimization passes written in FX (don't forget to select a more specific label)triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions