-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Describe the bug
This means that the compiled autograd state getter (whether compiled autograd is enabled in this case) must be traceable by dynamo. This becomes a problem when we want to support multithreaded and make the state thread local, where the fallback is to call a C++ binding to retrieve TLS which will cause a graph break: #137821
Repro: Make the getter graph break e.g.
# torch/distributed/_composable/fsdp/_fsdp_common.py
def compiled_autograd_enabled():
if torch.compiler.is_compiling():
import torch._dynamo.compiled_autograd as ca
print("graph break") # add a graph break
return ca.compiled_autograd_enabled or ca.in_compiled_autograd_region
else:
return Falsepython test/distributed/_composable/fsdp/test_fully_shard_compile.py TestFullyShardCompile.test_simple_mlp_fullgraph_backend_aot_eager
There's a few options:
- cache the check outside of the compiled region
- fsdp2 threads to always copy tls
Caching the check is a good idea as looking up the live value can be expensive, and it can lead to weird errors if different parts of fsdp2 use different state
Versions
main
cc @zhaojuanmao @mrshenli @rohan-varma @awgu @fegin @kwen2501 @chauhang @ezyang @penguinwu @yf225