Optimize nn.Module __call__ fast path for dynamo#95931
Optimize nn.Module __call__ fast path for dynamo#95931wconstab wants to merge 9 commits intogh/wconstab/119/basefrom
Conversation
Idea 1: Set a permanent flag once any hook is registered, so dynamo
only has to guard on the value of this flag.
+ pretty easy to do and gets the perf back for folks not using hooks
- no way to recover perf if you install and then remove a hook
Idea 2: maintain the same type of flag, but keep it up to date whenever
removing hooks.
- difficult to update on hook removal, since 'RemovableHandle' maintains
no reference to the nnModule (to twiddle its flag), nor does it know
whether the hooks dict it owns is local or global
--> could we extend RemovableHandle to know this missing info, or would
extra refs to nnModule be a problem for some reason?
Any other ideas?
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/95931
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 FailuresAs of commit 23a8968: NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Idea 1: Set a permanent flag once any hook is registered, so dynamo
only has to guard on the value of this flag.
+ pretty easy to do and gets the perf back for folks not using hooks
- no way to recover perf if you install and then remove a hook
Idea 2: maintain the same type of flag, but keep it up to date whenever
removing hooks.
- difficult to update on hook removal, since 'RemovableHandle' maintains
no reference to the nnModule (to twiddle its flag), nor does it know
whether the hooks dict it owns is local or global
--> could we extend RemovableHandle to know this missing info, or would
extra refs to nnModule be a problem for some reason?
Any other ideas?
[ghstack-poisoned]
Idea 1: Set a permanent flag once any hook is registered, so dynamo
only has to guard on the value of this flag.
+ pretty easy to do and gets the perf back for folks not using hooks
- no way to recover perf if you install and then remove a hook
Idea 2: maintain the same type of flag, but keep it up to date whenever
removing hooks.
- difficult to update on hook removal, since 'RemovableHandle' maintains
no reference to the nnModule (to twiddle its flag), nor does it know
whether the hooks dict it owns is local or global
--> could we extend RemovableHandle to know this missing info, or would
extra refs to nnModule be a problem for some reason?
Any other ideas?
[ghstack-poisoned]
Idea 1: Set a permanent flag once any hook is registered, so dynamo
only has to guard on the value of this flag.
+ pretty easy to do and gets the perf back for folks not using hooks
- no way to recover perf if you install and then remove a hook
Idea 2: maintain the same type of flag, but keep it up to date whenever
removing hooks.
- difficult to update on hook removal, since 'RemovableHandle' maintains
no reference to the nnModule (to twiddle its flag), nor does it know
whether the hooks dict it owns is local or global
--> could we extend RemovableHandle to know this missing info, or would
extra refs to nnModule be a problem for some reason?
Any other ideas?
[ghstack-poisoned]
|
Hi Will, can you please do a rebase? I just changed the runner type. Going forward linux.gcp.a100.large instances would be used instead of linux.gcp.a100 which is only used by the inductor job. |
Idea 1: Set a permanent flag once any hook is registered, so dynamo
only has to guard on the value of this flag.
+ pretty easy to do and gets the perf back for folks not using hooks
- no way to recover perf if you install and then remove a hook
Idea 2: maintain the same type of flag, but keep it up to date whenever
removing hooks.
- difficult to update on hook removal, since 'RemovableHandle' maintains
no reference to the nnModule (to twiddle its flag), nor does it know
whether the hooks dict it owns is local or global
--> could we extend RemovableHandle to know this missing info, or would
extra refs to nnModule be a problem for some reason?
Any other ideas?
[ghstack-poisoned]
This PR attempts just to fix the guards overhead introduced by dynamo tracing hooks.
Is this a crazy thing to do? I almost gave up and then I realized it might actually work fine. Trial by CI at this point...
It can and maybe should be followed by a wider change proposed by voznesenskym to optimize specialized nnmodules by 'observing' any user mutations and directly invalidating the root guard, obviating the need to install other nnmodule guards. (But this observer change seems more involved...)
Idea 1: Set a permanent flag once any hook is registered, so dynamo
only has to guard on the value of this flag.
(+) pretty easy to do and gets the perf back for folks not using hooks
(-) no way to recover perf if you install and then remove a hook
Idea 2: maintain the same type of flag, but keep it up to date whenever
removing hooks.
(-) difficult to update on hook removal, since 'RemovableHandle' maintains
no reference to the nnModule (to twiddle its flag), nor does it know
whether the hooks dict it owns is local or global
--> could we extend RemovableHandle to know this missing info, or would
extra refs to nnModule be a problem for some reason?
Any other ideas?
[ghstack-poisoned]
This PR attempts just to fix the guards overhead introduced by dynamo tracing hooks.
Is this a crazy thing to do? I almost gave up and then I realized it might actually work fine. Trial by CI at this point...
It can and maybe should be followed by a wider change proposed by voznesenskym to optimize specialized nnmodules by 'observing' any user mutations and directly invalidating the root guard, obviating the need to install other nnmodule guards. (But this observer change seems more involved...)
Idea 1: Set a permanent flag once any hook is registered, so dynamo
only has to guard on the value of this flag.
(+) pretty easy to do and gets the perf back for folks not using hooks
(-) no way to recover perf if you install and then remove a hook
Idea 2: maintain the same type of flag, but keep it up to date whenever
removing hooks.
(-) difficult to update on hook removal, since 'RemovableHandle' maintains
no reference to the nnModule (to twiddle its flag), nor does it know
whether the hooks dict it owns is local or global
--> could we extend RemovableHandle to know this missing info, or would
extra refs to nnModule be a problem for some reason?
Any other ideas?
[ghstack-poisoned]
Idea 1: Set a permanent flag once any hook is registered, so dynamo
only has to guard on the value of this flag.
+ pretty easy to do and gets the perf back for folks not using hooks
- no way to recover perf if you install and then remove a hook
Idea 2: maintain the same type of flag, but keep it up to date whenever
removing hooks.
- difficult to update on hook removal, since 'RemovableHandle' maintains
no reference to the nnModule (to twiddle its flag), nor does it know
whether the hooks dict it owns is local or global
--> could we extend RemovableHandle to know this missing info, or would
extra refs to nnModule be a problem for some reason?
Any other ideas?
ghstack-source-id: 3f7f70b
Pull Request resolved: #95931
|
|
||
| """ | ||
| handle = hooks.RemovableHandle(self._backward_pre_hooks) | ||
| handle = hooks.RemovableHandle(self._backward_pre_hooks, module=self) |
There was a problem hiding this comment.
What's the module needed for?
There was a problem hiding this comment.
it's for calling module._update_has_hooks when RemovableHandle removes a hook from a module
| else weakref.ref(OrderedDict() if state[2] is None else state[2]) | ||
| ) | ||
| # TODO can we actually restore module_ref after unpickling? Do we care? | ||
| self.module_ref = None |
There was a problem hiding this comment.
The handles are always dead post pickle, it looks like, so None here looks correct
ezyang
left a comment
There was a problem hiding this comment.
Assuming the perf is good, this LGTM
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot. |
|
@pytorchbot merge -f"Flaky ci (macos timeout)" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
My proposal is to also allow the user set/express no-hooks-will-be-added flag on a module. This should enable more aggressive inplace optimization. |
This PR optimizes the guards overhead introduced by dynamo tracing module forward hooks. It can and maybe should be followed by a wider change proposed by @voznesenskym to optimize specialized nnmodules by 'observing' any user mutations and directly invalidating the root guard, obviating the need to install other nnmodule guards. (But this observer change seems more involved...) Idea: maintain a flag, and keep it up to date whenever adding or removing hooks. Use the flag rather than dict checks to enter the call fast path. - need to extend RemovableHandle to keep a ref to nnModule so it can update the flag on removal. - also need to handle the flag in ScriptModule which still uses the python call impl when called from python. Pull Request resolved: pytorch/pytorch#95931 Approved by: https://github.com/ezyang, https://github.com/voznesenskym
can you say more about what optimization this 'no-hooks' flag would enable? I believe if hooks are not used, the only impact their infra has is that dynamo adds guards that ensure no hooks are present. This should still allow any optimizations downstream to assume there will be no hooks. My PR basically reduces the number of guards for the no-hook case, but more importantly, it avoids any 'type check' guards which are slower to run, and only uses 'bool' constant guards which are fast. |
|
@wconstab There's a long discussion in #23756 and in #23756 (comment) Basically, if we promise that no hooks will be called, PyTorch can safely fuse and do inplace optimizations, especially with invertible modules |
| if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks | ||
| or _global_backward_pre_hooks or _global_backward_hooks | ||
| or _global_forward_hooks or _global_forward_pre_hooks): | ||
| # this function, and just call forward. It's slow for dynamo to guard on the state |
There was a problem hiding this comment.
Do you have benchmark for that? I'm really curious how checking two bools is significantly faster than checking if dicts are empty?
Do you really care about 10ns here??
In [1]: a, b, c, d = {}, {}, {}, {}
In [2]: e, f = False, False
In [3]: %timeit a or b or c or d
22.9 ns ± 0.202 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
In [4]: %timeit e or f
11.1 ns ± 0.0778 ns per loop (mean ± std. dev. of 7 runs, 100,000,000 loops each)
There was a problem hiding this comment.
yes, i ran benchmarks. the issue really came to dynamo having to run additional check_type_id() on each dict in addition to confirming each dict is empty.
the dynamo guard for len(dict) is somewhat slow, but the dynamo guard for bool(dict) is fast.
There was a problem hiding this comment.
But don't you have to also call check_type_id for boolean fields?
the dynamo guard for len(dict) is somewhat slow, but the dynamo guard for bool(dict) is fast.
So would moving dynamo to properly use the bool(dict) guard solve this?
| # this function, and just call forward. It's slow for dynamo to guard on the state | ||
| # of all these hook dicts individually, so instead it can guard on 2 bools and we just | ||
| # have to promise to keep them up to date when hooks are added or removed via official means. | ||
| if not self._has_hooks and not _has_global_hooks: |
There was a problem hiding this comment.
It's not ok to assume that _has_hooks is available as some people pickle their Module directly. You need to update the __setstate__ function to properly populate this field if it is not in the serialized state.
There was a problem hiding this comment.
ok, good point. i believe i should add a test case for pickle/unpickle and ensure that setstate correctly enforces that _update_has_hooks gets installed and run.
would that address your concern?
There was a problem hiding this comment.
fixing setstate will work yes.
Unfortunately testing this is challenging indeed :/
| self.id = RemovableHandle.next_id | ||
|
|
||
| # TODO: we don't pickle/unpickle this field, which means the 'update_has_hooks' | ||
| # functionality (which is an optimization) decays after pickling. Can we fix this? |
There was a problem hiding this comment.
It is NOT an optimization, you rely on it to be correct to have the right behavior! That sounds bad?
There was a problem hiding this comment.
Here's why i thought of it as an optimization (maybe i missed a case?)
- you save a TS module that has hooks and you pickle a RemovableHandle associated with a hook
- load the scriptmodule and the handle
3a) run the module: the module correctly recomputes the value of '_has_hooks' on load
3b) remove the hook using the unpickled handle and then run the module: unclear if remove even works by itself, after save/load? i doubt it. But regardless, if the module still thinks it '_has_hooks' and the hook is removed, all that happens is the fast path will be skipped and the slow path will check the hooks dicts
There was a problem hiding this comment.
I mean that it is not an optimization because if this field become stale, you will not run user hooks leading to silent correctness issues.
So we must have this field to be 100% accurate all the time.
| if len(state) < 3 | ||
| else weakref.ref(OrderedDict() if state[2] is None else state[2]) | ||
| ) | ||
| # TODO can we actually restore module_ref after unpickling? Do we care? |
There was a problem hiding this comment.
Please do not merge PRs with random TODOs in utils or nn. This code is exercised in many ways compared to dynamo code and this kind of things really matter.
There was a problem hiding this comment.
Yes see comments below. There are clear BC issues. I think there is some mis-categorization of this as an "optimization" (implying that it's fine if we're not doing a perfect job) even though it is actually critical for correctness (and so we have to do a perfect job here).
| module = self.module_ref() | ||
| if module is not None: | ||
| module._update_has_hooks() | ||
| torch.nn.modules.module._update_has_global_hooks() |
There was a problem hiding this comment.
@albanD this is where I call _update_has_global_hooks()
Can you open up your spreadsheet? :). |
This reverts commit 2604533. [ghstack-poisoned]
…96242) Reverting due to concerns over silent unsoundness (skipped hooks) if users have directly added hooks dicts without using official torch APIs. This reverts commit 2604533. Pull Request resolved: #96242 Approved by: https://github.com/albanD
This PR optimizes the guards overhead introduced by dynamo tracing module forward hooks. It can and maybe should be followed by a wider change proposed by @voznesenskym to optimize specialized nnmodules by 'observing' any user mutations and directly invalidating the root guard, obviating the need to install other nnmodule guards. (But this observer change seems more involved...) Idea: maintain a flag, and keep it up to date whenever adding or removing hooks. Use the flag rather than dict checks to enter the call fast path. - need to extend RemovableHandle to keep a ref to nnModule so it can update the flag on removal. - also need to handle the flag in ScriptModule which still uses the python call impl when called from python. Pull Request resolved: pytorch#95931 Approved by: https://github.com/ezyang, https://github.com/voznesenskym
This PR optimizes the guards overhead introduced by dynamo tracing module forward hooks.
It can and maybe should be followed by a wider change proposed by @voznesenskym to optimize specialized nnmodules by 'observing' any user mutations and directly invalidating the root guard, obviating the need to install other nnmodule guards. (But this observer change seems more involved...)
Stack from ghstack (oldest at bottom):
Idea: maintain a flag, and keep it up to date whenever adding or
removing hooks. Use the flag rather than dict checks to enter the call fast path.
cc @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire