Skip to content

torch.fx.replace_pattern doesn't work with untraceable wrapped functions #66197

@fmassa

Description

@fmassa

🐛 Bug

torch.fx.replace_pattern doesn't seem to work when we have untraceable functions that are wrapped with torch.fx.wrap.

It looks like while doing torch.fx.replace_pattern, the functions that have been wrapped are not passed-through and are instead re-traced fully.

I'm running on PyTorch 1.11.0.dev20210930 (from last week)

To Reproduce

import torch
import torch.fx

# untraceable function
@torch.fx.wrap
def my_func(x):
    a = torch.empty(x.shape)  # not FX-traceable
    a.bernoulli_(0.5)
    return x


def my_func_2(x):
    return x * 2


# my model
def m(x):
     return x + my_func(x)

# we can properly trace it
mm = torch.fx.symbolic_trace(m)
# and we can also run it properly
mm(torch.rand(2))

print(mm.code)

# but this fails
torch.fx.replace_pattern(mm, my_func, my_func_2)

print("After replacement")
print(mm.code)

This is the output of the snippet, which fails at replace_pattern

torch.fx._symbolic_trace.wrap("__main___my_func")

def forward(self, x):
    my_func = __main___my_func(x)
    add = x + my_func;  x = my_func = None
    return add

Traceback (most recent call last):
  File "testing.py", line 25, in <module>
    torch.fx.replace_pattern(mm, my_func, identity)
  File "/Users/fmassa/anaconda3/lib/python3.8/site-packages/torch/fx/subgraph_rewriter.py", line 251, in replace_pattern
    pattern_graph = symbolic_trace(pattern).graph
  File "/Users/fmassa/anaconda3/lib/python3.8/site-packages/torch/fx/_symbolic_trace.py", line 907, in symbolic_trace
    graph = tracer.trace(root, concrete_args)
  File "/Users/fmassa/anaconda3/lib/python3.8/site-packages/torch/fx/_symbolic_trace.py", line 615, in trace
    self.create_node('output', 'output', (self.create_arg(fn(*args)),), {},
  File "testing.py", line 7, in my_func
    a = torch.empty(x.shape)
TypeError: empty(): argument 'size' (position 1) must be tuple of ints, not Attribute

which is the same type of error we would get if we tried to symbolically trace my_func without wrapping it first.

Originally reported by @datumbox

cc @ezyang @SherlockNoMad

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: fxtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions