Skip to content
6 changes: 2 additions & 4 deletions test/distributed/_composable/fsdp/test_fully_shard_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,8 @@ def run_iters(
torch.manual_seed(42)
losses = []
for i in range(n_iter):
# eager warmup for 1 iteration, so that all FSDP2 lazy-initialization is done in eager
torch.compiler.set_stance("force_eager" if i < 1 else "default")
inp = input_creation_fn()
loss = fwd_bwd_func(inp)
losses.append(loss.item())
Expand All @@ -433,8 +435,6 @@ def run_iters(
def test_compiled():
model, optim = model_init_fn()
fwd_bwd_fn = functools.partial(fwd_bwd, model)
# FSDP2 does lazy init using 1st run, so run it once to init using eager mode
run_iters(fwd_bwd_fn, optim, n_iter=1)

counters.clear()
with self._remove_fsdp2_unsharded_param_graph_input_usage_with_optional_checks(
Expand Down Expand Up @@ -463,8 +463,6 @@ def test_compiled():
def test_eager():
model, optim = model_init_fn()
fwd_bwd_fn = functools.partial(fwd_bwd, model)
# FSDP2 does lazy init using 1st run, so run it once to init using eager mode
run_iters(fwd_bwd_fn, optim, n_iter=1)

res = run_iters(fwd_bwd_fn, optim)
return res
Expand Down
42 changes: 25 additions & 17 deletions torch/_dynamo/compiled_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,24 +523,32 @@ def set_node_origin(

@contextlib.contextmanager
def enable(compiler_fn):
# we need to import this, because user might not have imported it if they directly use this context manager
# we need to lazily import it, because of circular dependencies
import torch._inductor.cudagraph_trees
from torch._dynamo import eval_frame

prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(
functools.partial(AutogradCompilerInstance, compiler_fn)
)
if snapshot_verbose_logging_enabled():
torch._C._dynamo.compiled_autograd.set_verbose_logger(cpp_verbose_log_fn)
global compiled_autograd_enabled
compiled_autograd_enabled = True
try:
with torch.autograd.set_multithreading_enabled(False):
yield
finally:
if not prior:
compiled_autograd_enabled = False
torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)
if eval_frame._stance.stance == "force_eager":
# If user explicitly sets Dynamo stance to "force_eager", we want Compiled Autograd
# to fall back to eager as well.
yield
return
else:
# we need to import this, because user might not have imported it if they directly use this context manager
# we need to lazily import it, because of circular dependencies
import torch._inductor.cudagraph_trees

prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(
functools.partial(AutogradCompilerInstance, compiler_fn)
)
if snapshot_verbose_logging_enabled():
torch._C._dynamo.compiled_autograd.set_verbose_logger(cpp_verbose_log_fn)
global compiled_autograd_enabled
compiled_autograd_enabled = True
try:
with torch.autograd.set_multithreading_enabled(False):
yield
finally:
if not prior:
compiled_autograd_enabled = False
torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)


@contextlib.contextmanager
Expand Down