-
Notifications
You must be signed in to change notification settings - Fork 27.4k
torch.fx.replace_pattern doesn't work with untraceable wrapped functions #66197
Copy link
Copy link
Open
Labels
module: fxtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Bug
torch.fx.replace_pattern doesn't seem to work when we have untraceable functions that are wrapped with torch.fx.wrap.
It looks like while doing torch.fx.replace_pattern, the functions that have been wrapped are not passed-through and are instead re-traced fully.
I'm running on PyTorch 1.11.0.dev20210930 (from last week)
To Reproduce
import torch
import torch.fx
# untraceable function
@torch.fx.wrap
def my_func(x):
a = torch.empty(x.shape) # not FX-traceable
a.bernoulli_(0.5)
return x
def my_func_2(x):
return x * 2
# my model
def m(x):
return x + my_func(x)
# we can properly trace it
mm = torch.fx.symbolic_trace(m)
# and we can also run it properly
mm(torch.rand(2))
print(mm.code)
# but this fails
torch.fx.replace_pattern(mm, my_func, my_func_2)
print("After replacement")
print(mm.code)This is the output of the snippet, which fails at replace_pattern
torch.fx._symbolic_trace.wrap("__main___my_func")
def forward(self, x):
my_func = __main___my_func(x)
add = x + my_func; x = my_func = None
return add
Traceback (most recent call last):
File "testing.py", line 25, in <module>
torch.fx.replace_pattern(mm, my_func, identity)
File "/Users/fmassa/anaconda3/lib/python3.8/site-packages/torch/fx/subgraph_rewriter.py", line 251, in replace_pattern
pattern_graph = symbolic_trace(pattern).graph
File "/Users/fmassa/anaconda3/lib/python3.8/site-packages/torch/fx/_symbolic_trace.py", line 907, in symbolic_trace
graph = tracer.trace(root, concrete_args)
File "/Users/fmassa/anaconda3/lib/python3.8/site-packages/torch/fx/_symbolic_trace.py", line 615, in trace
self.create_node('output', 'output', (self.create_arg(fn(*args)),), {},
File "testing.py", line 7, in my_func
a = torch.empty(x.shape)
TypeError: empty(): argument 'size' (position 1) must be tuple of ints, not Attributewhich is the same type of error we would get if we tried to symbolically trace my_func without wrapping it first.
Originally reported by @datumbox
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
module: fxtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module