-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[module] add mutation support for forward_pre_hook and forward_hook #22285
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Differential Revision: D16022491 Differential Version: 85687669
raghuramank100
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please check if behavior for backward is correct.
torch/nn/modules/module.py
Outdated
| for hook in self._forward_pre_hooks.values(): | ||
| hook(self, input) | ||
| result = hook(self, input) | ||
| if result: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems wrong? We either need to clearly document that you're supposed to return a tuple from this, or we have to auto-wrap the input in a tuple if a single value has been returned from the hook. Think about the case when someone returns a single tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, I'll wrap the input
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will there be a case when the forward function accepts a tuple as input?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is very much possible. If we want to go down the wrapping path, then I'd create a singleton tuple if not isinstance(result, tuple).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So if the forward function accepts a tuple as input, should we wrap the tuple again in another singleton tuple?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well what we're doing is a bit magical and in general hard to do, but I think it's also pretty consistent and user-friendly. We can document the edge cases, but I don't think they will be very common.
Differential Revision: D16022491 Differential Version: 85728558
Differential Revision: D16022491 Differential Version: 85811636
apaszke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Minor nits.
torch/nn/modules/module.py
Outdated
| input = result | ||
| else: | ||
| if result is not None: | ||
| input = (result,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think this is slightly more readable:
if result is not None:
if not isinstance(result, tuple):
result = (result,)
input = result
torch/nn/modules/module.py
Outdated
| The hook should not modify the input. | ||
| The hook can modify the input. User can either return a tuple or a | ||
| single modified value in the hook. We will wrap the value into a tuple | ||
| if a single value is returned. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... if a single value is returned (unless that value is already a tuple).
Differential Revision: D16022491 Differential Version: 85845553
|
This pull request has been merged in 577c04c. |
…2285) Summary: Pull Request resolved: pytorch#22285 Previously forward hooks are expected to return None, this PR adds the support to overwrite input and output in `forward_pre_hook` and `forward_hook`, this is used to implement inserting quant/dequant function calls around forward functions. Differential Revision: D16022491 fbshipit-source-id: 02340080745f22c8ea8a2f80c2c08e3a88e37253
Stack:
:black_circle: #22285 [module] add mutation support for forward_pre_hook and forward_hook 💛
Previously forward hooks are expected to return None, this PR adds the support to overwrite input and output in
forward_pre_hookandforward_hook, this is used to implement inserting quant/dequant function calls around forward functions.Differential Revision: D16022491