Skip to content

Conversation

@jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented Jun 27, 2019

Stack:
    :black_circle:  #22285 [module] add mutation support for forward_pre_hook and forward_hook  💛

Previously forward hooks are expected to return None, this PR adds the support to overwrite input and output in forward_pre_hook and forward_hook, this is used to implement inserting quant/dequant function calls around forward functions.

Differential Revision: D16022491

Differential Revision: D16022491
Differential Version: 85687669
@pytorchbot pytorchbot added the module: nn Related to torch.nn label Jun 27, 2019
Copy link
Contributor

@raghuramank100 raghuramank100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check if behavior for backward is correct.

for hook in self._forward_pre_hooks.values():
hook(self, input)
result = hook(self, input)
if result:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems wrong? We either need to clearly document that you're supposed to return a tuple from this, or we have to auto-wrap the input in a tuple if a single value has been returned from the hook. Think about the case when someone returns a single tensor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, I'll wrap the input

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will there be a case when the forward function accepts a tuple as input?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is very much possible. If we want to go down the wrapping path, then I'd create a singleton tuple if not isinstance(result, tuple).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if the forward function accepts a tuple as input, should we wrap the tuple again in another singleton tuple?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well what we're doing is a bit magical and in general hard to do, but I think it's also pretty consistent and user-friendly. We can document the edge cases, but I don't think they will be very common.

Differential Revision: D16022491
Differential Version: 85728558
Differential Revision: D16022491
Differential Version: 85811636
Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Minor nits.

input = result
else:
if result is not None:
input = (result,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think this is slightly more readable:

if result is not None:
    if not isinstance(result, tuple):
        result = (result,)
    input = result

The hook should not modify the input.
The hook can modify the input. User can either return a tuple or a
single modified value in the hook. We will wrap the value into a tuple
if a single value is returned.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... if a single value is returned (unless that value is already a tuple).

Differential Revision: D16022491
Differential Version: 85845553
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 577c04c.

xzhu1900 pushed a commit to xzhu1900/pytorch that referenced this pull request Jul 5, 2019
…2285)

Summary:
Pull Request resolved: pytorch#22285

Previously forward hooks are expected to return None, this PR adds the support to overwrite input and output in `forward_pre_hook` and `forward_hook`, this is used to implement inserting quant/dequant function calls around forward functions.

Differential Revision: D16022491

fbshipit-source-id: 02340080745f22c8ea8a2f80c2c08e3a88e37253
@ezyang ezyang deleted the export-D16022491 branch July 19, 2019 15:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: nn Related to torch.nn

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants