-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[jit] Support nn.GRU in script #23266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
|
||
| def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices): | ||
| # type: (Tensor, Optional[Tensor], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tensor] # noqa | ||
| if hx is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not so nice that all of this extra code needs to be added in, but I'm sure there are reasons
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah that's kind of unfortunate, I originally want to directly script the RNNBase instead of GRU separately, but the dict to dispatch different module is a dict(str, function), JIT dictionary does not support function/module as a value now so we could only do it separately..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code duplication isn't great but I don't think we have a way around it. Looks good, I'd say give the overriding the implementation (comment below) a shot but if that doesn't work then this is fine as is
test/test_jit.py
Outdated
| self.x = torch.nn.GRU(5, 5) | ||
|
|
||
| @torch.jit.script_method | ||
| def forward(self, input: PackedSequence) -> Tuple[PackedSequence, torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to use Python 2 syntax
torch/nn/modules/rnn.py
Outdated
|
|
||
| self.check_forward_args(input, hx, batch_sizes) | ||
| if batch_sizes is None: | ||
| result = _VF.gru(input, hx, self._get_flat_weights(), self.bias, self.num_layers, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could do it where there's something like RNNBase.run_impl where in RNNBase it's something like
def run_impl(self, input_batch, sizes...):
impl = self.impls(self.mode)
return impl(...)and override that in GRU to directly call _VF.gru, the JIT should only see the overridden version
def run_impl(self, ...):
return _VF.gru(...)Then we could still re-use most of RNNBase and just have the weirdness be concentrated around the actual problem
test/test_jit.py
Outdated
| def forward(self, input: PackedSequence) -> Tuple[PackedSequence, torch.Tensor]: | ||
| return self.x(input) | ||
|
|
||
| eager_out = self.runAndSaveRNG(lambda x: torch.nn.GRU(5, 5)(x), (input,))[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also test the non-PackedSequence version?
apaszke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eh, can't we add union types to script? Honestly this is a significant regression in the code quality. I understand that it's an important use case for the JIT, but this is just unreasonable.
@apaszke nn.LSTM (subclass of RNNBase) already use this forward dispatch mechanism on master now.. I am just moving this to |
|
I'm not saying it doesn't require careful consideration, but this is basically destroying the codebase. The fact that we've already regressed is not an argument supporting further regressions. If you see something can't be expressed but is important then make it expressible instead of hacking your way through the path of least resistance. I can't imagine explaining what the hell is going on there to potential contributors... |
|
I agree this isn't very readable. Can you add comments to the places you think someone wouldn't be able to understand without any context on jit ? After you add the comments I think we should land. It is important that we are able to script all of our nn Modules & it is confusing to have scriptable LSTMs but not GRUs (especially after advertising on our blog). union types would be a significant drain on developer productivity & time, potentially confusing to users (ex. invariant lists), and would likely require other significant changes (like removing list specializations). We haven't had any issues asking for them by users either. I think in follow up PRs we should look into what we can do on the JIT side to make this easier to express & more readable. |
I want to point out that what we advertised on the blog was optimizing custom RNNs, not the nn Module versions of RNNs. |
|
Another possibility is to only change GRU and not RNNBase (or maybe add shared class that LSTM & GRU inherit from) |
@apaszke I agree that this code something not quite readable compared to the code before, I could think of how to support Union type in script, not sure how significance the change would be. I did this PR first on the GRU class directly without modifying the RNNBase, see here, then I realized there's too much code duplication which not good either, so I changed the RNNBase directly. If you think that was a good approach or a shared class that LSTM & GRU inherit from as @eellison suggested, I can make the change. |
|
Having the change be limited to a single class would be an improvement. It should also be accompanied by a big comment describing why this looks like that, and saying it's a transition state only (e.g. we wanted to make it on time for the release). Finally, the statement in the comment saying that it's only temporary should be true, meaning that a fix for this should be on a roadmap. |
|
cc @zdevito |
… "[jit] Support nn.GRU in script" Differential Revision: [D16466586](https://our.internmc.facebook.com/intern/diff/D16466586)
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in #23266. This PR only compiles one if branch if the condition is an isinstance check. This is consistent with what mypy does; it does not report errors if a branch can be determined statically to be unreachable. ``` from typing import overload def foo(x): # type: (int) -> int if isinstance(x, str): return x["1"] return x + 1 reveal_type(foo) # no error, shows int -> int ```
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in #23266. Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading). The usage is: ``` @torch.jit.overload def add(x: int, y: int) -> int: ... @torch.jit.overload def add(x: int, y: int) -> int: ... def add: return x + y ``` Follow up PRs: - Add same API for methods - A couple of cleanups for functions: - don't require default params specified on the overload as well - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in #23266. Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading). The usage is: ``` @torch.jit.overload def add(x: int, y: int) -> int: ... @torch.jit.overload def add(x: int, y: int) -> int: ... def add: return x + y ``` Follow up PRs: - Add same API for methods - A couple of cleanups for functions: - don't require default params specified on the overload as well - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in #23266. Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading). The usage is: ``` @torch.jit.overload def add(x: int, y: int) -> int: ... @torch.jit.overload def add(x: int, y: int) -> int: ... def add: return x + y ``` Follow up PRs: - Add same API for methods - A couple of cleanups for functions: - don't require default params specified on the overload as well - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in #23266. Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading). The usage is: ``` @torch.jit.overload def add(x: int, y: int) -> int: ... @torch.jit.overload def add(x: float, y: float) -> float: ... def add: return x + y ``` Follow up PRs: - Add same API for methods - A couple of cleanups for functions: - don't require default params specified on the overload as well - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in #23266. Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading). The usage is: ``` @torch.jit.overload def add(x: int, y: int) -> int: ... @torch.jit.overload def add(x: float, y: float) -> float: ... def add: return x + y ``` Follow up PRs: - Add same API for methods - A couple of cleanups for functions: - don't require default params specified on the overload as well - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently Differential Revision: [D16694863](https://our.internmc.facebook.com/intern/diff/D16694863)
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in #23266. Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading). The usage is: ``` @torch.jit.overload def add(x: int, y: int) -> int: ... @torch.jit.overload def add(x: float, y: float) -> float: ... def add: return x + y ``` Follow up PRs: - Add same API for methods - A couple of cleanups for functions: - don't require default params specified on the overload as well - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently Differential Revision: [D16694863](https://our.internmc.facebook.com/intern/diff/D16694863)
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in #23266. This PR only compiles one if branch if the condition is an isinstance check. This is consistent with what mypy does; it does not report errors if a branch can be determined statically to be unreachable. ``` def foo(x): # type: (int) -> int if isinstance(x, str): return x["1"] return x + 1 reveal_type(foo) # no error, shows int -> int ``` Differential Revision: [D16697092](https://our.internmc.facebook.com/intern/diff/D16697092)
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in #23266. Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading). The usage is: ``` @torch.jit.overload def add(x: int, y: int) -> int: ... @torch.jit.overload def add(x: float, y: float) -> float: ... def add: return x + y ``` Follow up PRs: - Add same API for methods - A couple of cleanups for functions: - don't require default params specified on the overload as well - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently Differential Revision: [D16694863](https://our.internmc.facebook.com/intern/diff/D16694863)
Summary: Pull Request resolved: #23885 This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in #23266. This PR only compiles one if branch if the condition is an isinstance check. This is consistent with what mypy does; it does not report errors if a branch can be determined statically to be unreachable. ``` def foo(x): # type: (int) -> int if isinstance(x, str): return x["1"] return x + 1 reveal_type(foo) # no error, shows int -> int ``` Test Plan: Imported from OSS Differential Revision: D16697092 Pulled By: eellison fbshipit-source-id: d3eb4925cd16d551515ac6ff620a69897dbec130
Summary: Pull Request resolved: #23886 This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in #23266. Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading). The usage is: ``` torch.jit.overload def add(x: int, y: int) -> int: ... torch.jit.overload def add(x: float, y: float) -> float: ... def add: return x + y ``` Follow up PRs: - Add same API for methods - A couple of cleanups for functions: - don't require default params specified on the overload as well - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently Test Plan: Imported from OSS Differential Revision: D16694863 Pulled By: eellison fbshipit-source-id: f94f2933bc1c97fa58f31846acfe962b0630068c
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` @torch.jit.overload def add(self, y: int) -> int: ... @torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ```
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` @torch.jit.overload def add(self, y: int) -> int: ... @torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ```
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` @torch.jit.overload def add(self, y: int) -> int: ... @torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ```
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` @torch.jit.overload def add(self, y: int) -> int: ... @torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ```
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` @torch.jit.overload def add(self, y: int) -> int: ... @torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ```
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` @torch.jit.overload def add(self, y: int) -> int: ... @torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ```
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` @torch.jit.overload def add(self, y: int) -> int: ... @torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ``` Differential Revision: [D16921304](https://our.internmc.facebook.com/intern/diff/D16921304)
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` @torch.jit.overload def add(self, y: int) -> int: ... @torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ``` Differential Revision: [D16921304](https://our.internmc.facebook.com/intern/diff/D16921304)
Summary: Pull Request resolved: #24259 Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` torch.jit.overload def add(self, y: int) -> int: ... torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ``` Test Plan: Imported from OSS Differential Revision: D16921304 Pulled By: eellison fbshipit-source-id: 784e2f26f7ca9a330a434a603c86b53725c3dc71
Stack from ghstack:
Differential Revision: D16466586