Skip to content

Commit e80fe7f

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo][guards] Skip guards on empty nn module hooks (#138942)
This brings some unsoundness in guards. Earlier we were skipping empty nn module hooks dict guard only on inbuilt nn modules, but as seen in #138386, there could be still be significant guard overhead. With this PR, we reduce the guard eval latency from 420 us to 280 us (1.5x reduction). Pull Request resolved: #138942 Approved by: https://github.com/ezyang, https://github.com/jansel ghstack dependencies: #139040, #138954
1 parent 2aa5348 commit e80fe7f

File tree

3 files changed

+37
-3
lines changed

3 files changed

+37
-3
lines changed

test/dynamo/test_activation_checkpointing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,7 @@ def _factory_fn():
519519
mod_no_hook, backend, x, fullgraph=True, compiled_autograd=True
520520
)
521521

522+
torch._dynamo.reset()
522523
mod_with_hook, x, backend = _factory_fn()
523524
mod_with_hook.submod.register_forward_hook(my_post_forward_hook)
524525
mod_with_hook_fwd_outputs = set()

test/dynamo/test_hooks.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,41 @@ def forward(self, x):
796796

797797
self.assertEqual(cnts.frame_count, 1)
798798

799+
@torch._dynamo.config.patch(skip_nnmodule_hook_guards=False)
800+
def test_nnmodule_hook_guards(self):
801+
# Compile a model and then apply a hook
802+
803+
class Mod(torch.nn.Module):
804+
def __init__(self) -> None:
805+
super().__init__()
806+
self.linear = torch.nn.Linear(16, 16)
807+
808+
def forward(self, x):
809+
return self.linear(x)
810+
811+
cnts = torch._dynamo.testing.CompileCounter()
812+
813+
mod = Mod()
814+
815+
def fn(x):
816+
return mod(x)
817+
818+
opt_fn = torch.compile(fn, backend=cnts)
819+
820+
x = torch.ones(16, 16)
821+
opt_fn(x)
822+
823+
# Register a hook
824+
def forward_hook(self, inputs, out):
825+
return out * 2
826+
827+
mod.register_forward_hook(forward_hook)
828+
829+
ref = fn(x)
830+
res = opt_fn(x)
831+
self.assertEqual(ref, res)
832+
self.assertEqual(cnts.frame_count, 2)
833+
799834

800835
if __name__ == "__main__":
801836
from torch._dynamo.test_case import run_tests

torch/_dynamo/variables/nn_module.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,9 +1080,7 @@ def var_getattr(self, tx: "InstructionTranslator", name):
10801080
):
10811081
# For empty hooks, make an EMPTY_NN_MODULE_HOOKS_DICT. This allows us to control the installation of empty
10821082
# hooks guard via skip_nnmodule_hook_guards
1083-
if not tx.output.side_effects.has_pending_mutation_of_attr(
1084-
self, name
1085-
) and self.value.__module__.startswith(("torch.nn.", "torch.ao.")):
1083+
if not tx.output.side_effects.has_pending_mutation_of_attr(self, name):
10861084
hooks_dict = getattr(self.value, name)
10871085
if isinstance(hooks_dict, dict) and len(hooks_dict) == 0:
10881086
if self.source:

0 commit comments

Comments
 (0)