Skip to content

Optimize nn.Module __call__ fast path for dynamo#95931

Closed
wconstab wants to merge 9 commits intogh/wconstab/119/basefrom
gh/wconstab/119/head
Closed

Optimize nn.Module __call__ fast path for dynamo#95931
wconstab wants to merge 9 commits intogh/wconstab/119/basefrom
gh/wconstab/119/head

Conversation

@wconstab
Copy link
Contributor

@wconstab wconstab commented Mar 3, 2023

This PR optimizes the guards overhead introduced by dynamo tracing module forward hooks.

It can and maybe should be followed by a wider change proposed by @voznesenskym to optimize specialized nnmodules by 'observing' any user mutations and directly invalidating the root guard, obviating the need to install other nnmodule guards. (But this observer change seems more involved...)

Stack from ghstack (oldest at bottom):

Idea: maintain a flag, and keep it up to date whenever adding or
removing hooks. Use the flag rather than dict checks to enter the call fast path.

  • need to extend RemovableHandle to keep a ref to nnModule so it can update the flag on removal.
  • also need to handle the flag in ScriptModule which still uses the python call impl when called from python.

cc @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire

Idea 1: Set a permanent flag once any hook is registered, so dynamo
only has to guard on the value of this flag.
  + pretty easy to do and gets the perf back for folks not using hooks
  - no way to recover perf if you install and then remove a hook

Idea 2: maintain the same type of flag, but keep it up to date whenever
removing hooks.
  - difficult to update on hook removal, since 'RemovableHandle' maintains
    no reference to the nnModule (to twiddle its flag), nor does it know
    whether the hooks dict it owns is local or global
  --> could we extend RemovableHandle to know this missing info, or would
      extra refs to nnModule be a problem for some reason?

Any other ideas?

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 3, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/95931

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures

As of commit 23a8968:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Idea 1: Set a permanent flag once any hook is registered, so dynamo
only has to guard on the value of this flag.
  + pretty easy to do and gets the perf back for folks not using hooks
  - no way to recover perf if you install and then remove a hook

Idea 2: maintain the same type of flag, but keep it up to date whenever
removing hooks.
  - difficult to update on hook removal, since 'RemovableHandle' maintains
    no reference to the nnModule (to twiddle its flag), nor does it know
    whether the hooks dict it owns is local or global
  --> could we extend RemovableHandle to know this missing info, or would
      extra refs to nnModule be a problem for some reason?

Any other ideas?

[ghstack-poisoned]
Idea 1: Set a permanent flag once any hook is registered, so dynamo
only has to guard on the value of this flag.
  + pretty easy to do and gets the perf back for folks not using hooks
  - no way to recover perf if you install and then remove a hook

Idea 2: maintain the same type of flag, but keep it up to date whenever
removing hooks.
  - difficult to update on hook removal, since 'RemovableHandle' maintains
    no reference to the nnModule (to twiddle its flag), nor does it know
    whether the hooks dict it owns is local or global
  --> could we extend RemovableHandle to know this missing info, or would
      extra refs to nnModule be a problem for some reason?

Any other ideas?

[ghstack-poisoned]
@wconstab wconstab added topic: not user facing topic category ciflow/inductor-perf-test-nightly Trigger nightly inductor perf tests labels Mar 3, 2023
Idea 1: Set a permanent flag once any hook is registered, so dynamo
only has to guard on the value of this flag.
  + pretty easy to do and gets the perf back for folks not using hooks
  - no way to recover perf if you install and then remove a hook

Idea 2: maintain the same type of flag, but keep it up to date whenever
removing hooks.
  - difficult to update on hook removal, since 'RemovableHandle' maintains
    no reference to the nnModule (to twiddle its flag), nor does it know
    whether the hooks dict it owns is local or global
  --> could we extend RemovableHandle to know this missing info, or would
      extra refs to nnModule be a problem for some reason?

Any other ideas?

[ghstack-poisoned]
@weiwangmeta
Copy link
Contributor

Hi Will, can you please do a rebase? I just changed the runner type. Going forward linux.gcp.a100.large instances would be used instead of linux.gcp.a100 which is only used by the inductor job.

Idea 1: Set a permanent flag once any hook is registered, so dynamo
only has to guard on the value of this flag.
  + pretty easy to do and gets the perf back for folks not using hooks
  - no way to recover perf if you install and then remove a hook

Idea 2: maintain the same type of flag, but keep it up to date whenever
removing hooks.
  - difficult to update on hook removal, since 'RemovableHandle' maintains
    no reference to the nnModule (to twiddle its flag), nor does it know
    whether the hooks dict it owns is local or global
  --> could we extend RemovableHandle to know this missing info, or would
      extra refs to nnModule be a problem for some reason?

Any other ideas?

[ghstack-poisoned]
@wconstab wconstab requested review from ezyang and voznesenskym March 3, 2023 05:07
@wconstab wconstab changed the title Try (badly) to make a dynamo fast path for hooks Make a dynamo fast path for hooks Mar 3, 2023
@wconstab wconstab requested a review from ngimel March 3, 2023 05:08
This PR attempts just to fix the guards overhead introduced by dynamo tracing hooks.

Is this a crazy thing to do?  I almost gave up and then I realized it might actually work fine.  Trial by CI at this point...

It can and maybe should be followed by a wider change proposed by voznesenskym to optimize specialized nnmodules by 'observing' any user mutations and directly invalidating the root guard, obviating the need to install other nnmodule guards.  (But this observer change seems more involved...)


Idea 1: Set a permanent flag once any hook is registered, so dynamo
only has to guard on the value of this flag.
  (+) pretty easy to do and gets the perf back for folks not using hooks
  (-) no way to recover perf if you install and then remove a hook

Idea 2: maintain the same type of flag, but keep it up to date whenever
removing hooks.
  (-) difficult to update on hook removal, since 'RemovableHandle' maintains
    no reference to the nnModule (to twiddle its flag), nor does it know
    whether the hooks dict it owns is local or global
  --> could we extend RemovableHandle to know this missing info, or would
      extra refs to nnModule be a problem for some reason?

Any other ideas?

[ghstack-poisoned]
This PR attempts just to fix the guards overhead introduced by dynamo tracing hooks.

Is this a crazy thing to do?  I almost gave up and then I realized it might actually work fine.  Trial by CI at this point...

It can and maybe should be followed by a wider change proposed by voznesenskym to optimize specialized nnmodules by 'observing' any user mutations and directly invalidating the root guard, obviating the need to install other nnmodule guards.  (But this observer change seems more involved...)


Idea 1: Set a permanent flag once any hook is registered, so dynamo
only has to guard on the value of this flag.
  (+) pretty easy to do and gets the perf back for folks not using hooks
  (-) no way to recover perf if you install and then remove a hook

Idea 2: maintain the same type of flag, but keep it up to date whenever
removing hooks.
  (-) difficult to update on hook removal, since 'RemovableHandle' maintains
    no reference to the nnModule (to twiddle its flag), nor does it know
    whether the hooks dict it owns is local or global
  --> could we extend RemovableHandle to know this missing info, or would
      extra refs to nnModule be a problem for some reason?

Any other ideas?

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Mar 3, 2023
Idea 1: Set a permanent flag once any hook is registered, so dynamo
only has to guard on the value of this flag.
  + pretty easy to do and gets the perf back for folks not using hooks
  - no way to recover perf if you install and then remove a hook

Idea 2: maintain the same type of flag, but keep it up to date whenever
removing hooks.
  - difficult to update on hook removal, since 'RemovableHandle' maintains
    no reference to the nnModule (to twiddle its flag), nor does it know
    whether the hooks dict it owns is local or global
  --> could we extend RemovableHandle to know this missing info, or would
      extra refs to nnModule be a problem for some reason?

Any other ideas?

ghstack-source-id: 3f7f70b
Pull Request resolved: #95931

"""
handle = hooks.RemovableHandle(self._backward_pre_hooks)
handle = hooks.RemovableHandle(self._backward_pre_hooks, module=self)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the module needed for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's for calling module._update_has_hooks when RemovableHandle removes a hook from a module

else weakref.ref(OrderedDict() if state[2] is None else state[2])
)
# TODO can we actually restore module_ref after unpickling? Do we care?
self.module_ref = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The handles are always dead post pickle, it looks like, so None here looks correct

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming the perf is good, this LGTM

@wconstab
Copy link
Contributor Author

wconstab commented Mar 4, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 4, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot.

@wconstab wconstab changed the title Make a dynamo fast path for hooks Optimize NNModule __call__ fast path for dynamo Mar 4, 2023
@wconstab wconstab changed the title Optimize NNModule __call__ fast path for dynamo Optimize nnModule __call__ fast path for dynamo Mar 4, 2023
@wconstab wconstab changed the title Optimize nnModule __call__ fast path for dynamo Optimize nn.Module __call__ fast path for dynamo Mar 4, 2023
@wconstab
Copy link
Contributor Author

wconstab commented Mar 4, 2023

@pytorchbot merge -f"Flaky ci (macos timeout)"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@vadimkantorov
Copy link
Contributor

My proposal is to also allow the user set/express no-hooks-will-be-added flag on a module. This should enable more aggressive inplace optimization.

cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 5, 2023
This PR optimizes the guards overhead introduced by dynamo tracing module forward hooks.

It can and maybe should be followed by a wider change proposed by @voznesenskym to optimize specialized nnmodules by 'observing' any user mutations and directly invalidating the root guard, obviating the need to install other nnmodule guards.  (But this observer change seems more involved...)

Idea: maintain a flag, and keep it up to date whenever adding or
removing hooks. Use the flag rather than dict checks to enter the call fast path.
  - need to extend RemovableHandle to keep a ref to nnModule so it can update the flag on removal.
  - also need to handle the flag in ScriptModule which still uses the python call impl when called from python.

Pull Request resolved: pytorch/pytorch#95931
Approved by: https://github.com/ezyang, https://github.com/voznesenskym
@wconstab
Copy link
Contributor Author

wconstab commented Mar 6, 2023

My proposal is to also allow the user set/express no-hooks-will-be-added flag on a module. This should enable more aggressive inplace optimization.

can you say more about what optimization this 'no-hooks' flag would enable?

I believe if hooks are not used, the only impact their infra has is that dynamo adds guards that ensure no hooks are present. This should still allow any optimizations downstream to assume there will be no hooks.

My PR basically reduces the number of guards for the no-hook case, but more importantly, it avoids any 'type check' guards which are slower to run, and only uses 'bool' constant guards which are fast.

@vadimkantorov
Copy link
Contributor

@wconstab There's a long discussion in #23756 and in #23756 (comment)

Basically, if we promise that no hooks will be called, PyTorch can safely fuse and do inplace optimizations, especially with invertible modules

if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
or _global_backward_pre_hooks or _global_backward_hooks
or _global_forward_hooks or _global_forward_pre_hooks):
# this function, and just call forward. It's slow for dynamo to guard on the state
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have benchmark for that? I'm really curious how checking two bools is significantly faster than checking if dicts are empty?
Do you really care about 10ns here??

In [1]: a, b, c, d = {}, {}, {}, {}

In [2]: e, f = False, False

In [3]: %timeit a or b or c or d
22.9 ns ± 0.202 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

In [4]: %timeit e or f
11.1 ns ± 0.0778 ns per loop (mean ± std. dev. of 7 runs, 100,000,000 loops each)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, i ran benchmarks. the issue really came to dynamo having to run additional check_type_id() on each dict in addition to confirming each dict is empty.

the dynamo guard for len(dict) is somewhat slow, but the dynamo guard for bool(dict) is fast.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But don't you have to also call check_type_id for boolean fields?

the dynamo guard for len(dict) is somewhat slow, but the dynamo guard for bool(dict) is fast.

So would moving dynamo to properly use the bool(dict) guard solve this?

# this function, and just call forward. It's slow for dynamo to guard on the state
# of all these hook dicts individually, so instead it can guard on 2 bools and we just
# have to promise to keep them up to date when hooks are added or removed via official means.
if not self._has_hooks and not _has_global_hooks:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not ok to assume that _has_hooks is available as some people pickle their Module directly. You need to update the __setstate__ function to properly populate this field if it is not in the serialized state.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, good point. i believe i should add a test case for pickle/unpickle and ensure that setstate correctly enforces that _update_has_hooks gets installed and run.

would that address your concern?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixing setstate will work yes.
Unfortunately testing this is challenging indeed :/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix via: #96224

self.id = RemovableHandle.next_id

# TODO: we don't pickle/unpickle this field, which means the 'update_has_hooks'
# functionality (which is an optimization) decays after pickling. Can we fix this?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is NOT an optimization, you rely on it to be correct to have the right behavior! That sounds bad?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's why i thought of it as an optimization (maybe i missed a case?)

  1. you save a TS module that has hooks and you pickle a RemovableHandle associated with a hook
  2. load the scriptmodule and the handle
    3a) run the module: the module correctly recomputes the value of '_has_hooks' on load
    3b) remove the hook using the unpickled handle and then run the module: unclear if remove even works by itself, after save/load? i doubt it. But regardless, if the module still thinks it '_has_hooks' and the hook is removed, all that happens is the fast path will be skipped and the slow path will check the hooks dicts

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean that it is not an optimization because if this field become stale, you will not run user hooks leading to silent correctness issues.
So we must have this field to be 100% accurate all the time.

if len(state) < 3
else weakref.ref(OrderedDict() if state[2] is None else state[2])
)
# TODO can we actually restore module_ref after unpickling? Do we care?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please do not merge PRs with random TODOs in utils or nn. This code is exercised in many ways compared to dynamo code and this kind of things really matter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD can you confirm, is there a change needed here or just delete the TODO? (my bad leaving the TODO in, but i interpreted @ezyang's comment below as the latter)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes see comments below. There are clear BC issues. I think there is some mis-categorization of this as an "optimization" (implying that it's fine if we're not doing a perfect job) even though it is actually critical for correctness (and so we have to do a perfect job here).

module = self.module_ref()
if module is not None:
module._update_has_hooks()
torch.nn.modules.module._update_has_global_hooks()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD this is where I call _update_has_global_hooks()

@gchanan
Copy link
Contributor

gchanan commented Mar 7, 2023

i see about the same 2-3% geomean speedup that i saw as a regression initially with hooks support

Can you open up your spreadsheet? :).

wconstab added a commit that referenced this pull request Mar 8, 2023
wconstab added a commit that referenced this pull request Mar 8, 2023
This reverts commit 2604533.

ghstack-source-id: a8e14d8
Pull Request resolved: #96242
pytorchmergebot pushed a commit that referenced this pull request Mar 10, 2023
…96242)

Reverting due to concerns over silent unsoundness (skipped hooks) if users have directly added hooks dicts without using official torch APIs.

This reverts commit 2604533.

Pull Request resolved: #96242
Approved by: https://github.com/albanD
ydwu4 added a commit to ydwu4/pytorch that referenced this pull request Mar 13, 2023
This PR optimizes the guards overhead introduced by dynamo tracing module forward hooks.

It can and maybe should be followed by a wider change proposed by @voznesenskym to optimize specialized nnmodules by 'observing' any user mutations and directly invalidating the root guard, obviating the need to install other nnmodule guards.  (But this observer change seems more involved...)

Idea: maintain a flag, and keep it up to date whenever adding or
removing hooks. Use the flag rather than dict checks to enter the call fast path.
  - need to extend RemovableHandle to keep a ref to nnModule so it can update the flag on removal.
  - also need to handle the flag in ScriptModule which still uses the python call impl when called from python.

Pull Request resolved: pytorch#95931
Approved by: https://github.com/ezyang, https://github.com/voznesenskym
@facebook-github-bot facebook-github-bot deleted the gh/wconstab/119/head branch June 8, 2023 19:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor-perf-test-nightly Trigger nightly inductor perf tests ciflow/trunk Trigger trunk jobs on your pull request module: dynamo Reverted topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants