Skip to content

FSDP2 checks live compiled autograd configs within compiled region #138177

@xmfan

Description

@xmfan

🐛 Describe the bug

This means that the compiled autograd state getter (whether compiled autograd is enabled in this case) must be traceable by dynamo. This becomes a problem when we want to support multithreaded and make the state thread local, where the fallback is to call a C++ binding to retrieve TLS which will cause a graph break: #137821

Repro: Make the getter graph break e.g.

# torch/distributed/_composable/fsdp/_fsdp_common.py
def compiled_autograd_enabled():
    if torch.compiler.is_compiling():
        import torch._dynamo.compiled_autograd as ca
    
        print("graph break")  # add a graph break
        return ca.compiled_autograd_enabled or ca.in_compiled_autograd_region
    else:
        return False

python test/distributed/_composable/fsdp/test_fully_shard_compile.py TestFullyShardCompile.test_simple_mlp_fullgraph_backend_aot_eager

There's a few options:

  • cache the check outside of the compiled region
  • fsdp2 threads to always copy tls

Caching the check is a good idea as looking up the live value can be expensive, and it can lead to weird errors if different parts of fsdp2 use different state

Versions

main

cc @zhaojuanmao @mrshenli @rohan-varma @awgu @fegin @kwen2501 @chauhang @ezyang @penguinwu @yf225

Metadata

Metadata

Assignees

Labels

module: compiled autogradcompiled_autogradmodule: fsdponcall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions