Skip to content

Conversation

@wanchaol
Copy link
Collaborator

@wanchaol wanchaol commented Jul 23, 2019

Stack from ghstack:

Differential Revision: D16466586


def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
# type: (Tensor, Optional[Tensor], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tensor] # noqa
if hx is None:
Copy link
Contributor

@zou3519 zou3519 Jul 24, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not so nice that all of this extra code needs to be added in, but I'm sure there are reasons

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that's kind of unfortunate, I originally want to directly script the RNNBase instead of GRU separately, but the dict to dispatch different module is a dict(str, function), JIT dictionary does not support function/module as a value now so we could only do it separately..

Copy link
Contributor

@driazati driazati left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code duplication isn't great but I don't think we have a way around it. Looks good, I'd say give the overriding the implementation (comment below) a shot but if that doesn't work then this is fine as is

test/test_jit.py Outdated
self.x = torch.nn.GRU(5, 5)

@torch.jit.script_method
def forward(self, input: PackedSequence) -> Tuple[PackedSequence, torch.Tensor]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to use Python 2 syntax


self.check_forward_args(input, hx, batch_sizes)
if batch_sizes is None:
result = _VF.gru(input, hx, self._get_flat_weights(), self.bias, self.num_layers,
Copy link
Contributor

@driazati driazati Jul 26, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do it where there's something like RNNBase.run_impl where in RNNBase it's something like

def run_impl(self, input_batch, sizes...):
	impl = self.impls(self.mode)
	return impl(...)

and override that in GRU to directly call _VF.gru, the JIT should only see the overridden version

def run_impl(self, ...):
	return _VF.gru(...)

Then we could still re-use most of RNNBase and just have the weirdness be concentrated around the actual problem

test/test_jit.py Outdated
def forward(self, input: PackedSequence) -> Tuple[PackedSequence, torch.Tensor]:
return self.x(input)

eager_out = self.runAndSaveRNG(lambda x: torch.nn.GRU(5, 5)(x), (input,))[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also test the non-PackedSequence version?

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eh, can't we add union types to script? Honestly this is a significant regression in the code quality. I understand that it's an important use case for the JIT, but this is just unreasonable.

@wanchaol
Copy link
Collaborator Author

Eh, can't we add union types to script? Honestly this is a significant regression in the code quality. I understand that it's an important use case for the JIT, but this is just unreasonable.

@apaszke nn.LSTM (subclass of RNNBase) already use this forward dispatch mechanism on master now.. I am just moving this to RNNBase directly instead of LSTM (I planned to remove that code in a follow up PR), which is a bad code quality there already. I agree that we need to add Union type, but that's something need careful consideration as it changes bunch of things like type refinements, etc.

@apaszke
Copy link
Contributor

apaszke commented Jul 30, 2019

I'm not saying it doesn't require careful consideration, but this is basically destroying the codebase. The fact that we've already regressed is not an argument supporting further regressions. If you see something can't be expressed but is important then make it expressible instead of hacking your way through the path of least resistance. I can't imagine explaining what the hell is going on there to potential contributors...

@eellison
Copy link
Contributor

eellison commented Jul 30, 2019

I agree this isn't very readable. Can you add comments to the places you think someone wouldn't be able to understand without any context on jit ?

After you add the comments I think we should land. It is important that we are able to script all of our nn Modules & it is confusing to have scriptable LSTMs but not GRUs (especially after advertising on our blog). union types would be a significant drain on developer productivity & time, potentially confusing to users (ex. invariant lists), and would likely require other significant changes (like removing list specializations). We haven't had any issues asking for them by users either.

I think in follow up PRs we should look into what we can do on the JIT side to make this easier to express & more readable.

@zou3519
Copy link
Contributor

zou3519 commented Jul 30, 2019

After you add the comments I think we should land. It is important that we are able to script all of our nn Modules & it is confusing to have scriptable LSTMs but not GRUs (especially after advertising on our blog).

I want to point out that what we advertised on the blog was optimizing custom RNNs, not the nn Module versions of RNNs.

@eellison
Copy link
Contributor

Another possibility is to only change GRU and not RNNBase (or maybe add shared class that LSTM & GRU inherit from)

@wanchaol
Copy link
Collaborator Author

I'm not saying it doesn't require careful consideration, but this is basically destroying the codebase. The fact that we've already regressed is not an argument supporting further regressions. If you see something can't be expressed but is important then make it expressible instead of hacking your way through the path of least resistance. I can't imagine explaining what the hell is going on there to potential contributors...

@apaszke I agree that this code something not quite readable compared to the code before, I could think of how to support Union type in script, not sure how significance the change would be.

I did this PR first on the GRU class directly without modifying the RNNBase, see here, then I realized there's too much code duplication which not good either, so I changed the RNNBase directly. If you think that was a good approach or a shared class that LSTM & GRU inherit from as @eellison suggested, I can make the change.

@apaszke
Copy link
Contributor

apaszke commented Jul 31, 2019

Having the change be limited to a single class would be an improvement. It should also be accompanied by a big comment describing why this looks like that, and saying it's a transition state only (e.g. we wanted to make it on time for the release). Finally, the statement in the comment saying that it's only temporary should be true, meaning that a fix for this should be on a roadmap.

@apaszke
Copy link
Contributor

apaszke commented Jul 31, 2019

cc @zdevito

@zou3519 zou3519 deleted the gh/wanchaol/35/head branch August 2, 2019 00:22
@facebook-github-bot
Copy link
Contributor

@wanchaol merged this pull request in 9d2cc2c.

eellison pushed a commit that referenced this pull request Aug 6, 2019
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in  #23266.

This PR only compiles one if branch if the condition is an isinstance check. This is consistent with what mypy does; it does not report errors if a branch can be determined statically to be unreachable.

```
from typing import overload

def foo(x):
    # type: (int) -> int
    if isinstance(x, str):
        return x["1"]
    return x + 1

reveal_type(foo) # no error, shows int -> int
```
eellison pushed a commit that referenced this pull request Aug 6, 2019
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in  #23266.
eellison pushed a commit that referenced this pull request Aug 6, 2019
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in  #23266.

Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading).

The usage is:
```
@torch.jit.overload
def add(x: int, y: int) -> int: ...
@torch.jit.overload
def add(x: int, y: int) -> int: ...

def add:
    return x + y
```

Follow up PRs:

- Add same API for methods
- A couple of cleanups for functions:
     - don't require default params specified on the overload as well
     - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently
eellison pushed a commit that referenced this pull request Aug 6, 2019
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in  #23266.

Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading).

The usage is:
```
@torch.jit.overload
def add(x: int, y: int) -> int: ...
@torch.jit.overload
def add(x: int, y: int) -> int: ...

def add:
    return x + y
```

Follow up PRs:

- Add same API for methods
- A couple of cleanups for functions:
     - don't require default params specified on the overload as well
     - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently
eellison pushed a commit that referenced this pull request Aug 6, 2019
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in  #23266.

Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading).

The usage is:
```
@torch.jit.overload
def add(x: int, y: int) -> int: ...
@torch.jit.overload
def add(x: int, y: int) -> int: ...

def add:
    return x + y
```

Follow up PRs:

- Add same API for methods
- A couple of cleanups for functions:
     - don't require default params specified on the overload as well
     - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently
eellison pushed a commit that referenced this pull request Aug 7, 2019
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in  #23266.

Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading).

The usage is:
```
@torch.jit.overload
def add(x: int, y: int) -> int: ...
@torch.jit.overload
def add(x: float, y: float) -> float: ...

def add:
    return x + y
```

Follow up PRs:

- Add same API for methods
- A couple of cleanups for functions:
     - don't require default params specified on the overload as well
     - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently
eellison pushed a commit that referenced this pull request Aug 7, 2019
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in  #23266.

Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading).

The usage is:
```
@torch.jit.overload
def add(x: int, y: int) -> int: ...
@torch.jit.overload
def add(x: float, y: float) -> float: ...

def add:
    return x + y
```

Follow up PRs:

- Add same API for methods
- A couple of cleanups for functions:
     - don't require default params specified on the overload as well
     - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently
 


Differential Revision: [D16694863](https://our.internmc.facebook.com/intern/diff/D16694863)
eellison pushed a commit that referenced this pull request Aug 7, 2019
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in  #23266.

Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading).

The usage is:
```
@torch.jit.overload
def add(x: int, y: int) -> int: ...
@torch.jit.overload
def add(x: float, y: float) -> float: ...

def add:
    return x + y
```

Follow up PRs:

- Add same API for methods
- A couple of cleanups for functions:
     - don't require default params specified on the overload as well
     - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently
 


Differential Revision: [D16694863](https://our.internmc.facebook.com/intern/diff/D16694863)
eellison pushed a commit that referenced this pull request Aug 8, 2019
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in  #23266.

This PR only compiles one if branch if the condition is an isinstance check. This is consistent with what mypy does; it does not report errors if a branch can be determined statically to be unreachable.

```
def foo(x):
    # type: (int) -> int
    if isinstance(x, str):
        return x["1"]
    return x + 1

reveal_type(foo) # no error, shows int -> int
``` 



Differential Revision: [D16697092](https://our.internmc.facebook.com/intern/diff/D16697092)
eellison pushed a commit that referenced this pull request Aug 8, 2019
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in  #23266.

Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading).

The usage is:
```
@torch.jit.overload
def add(x: int, y: int) -> int: ...
@torch.jit.overload
def add(x: float, y: float) -> float: ...

def add:
    return x + y
```

Follow up PRs:

- Add same API for methods
- A couple of cleanups for functions:
     - don't require default params specified on the overload as well
     - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently
 


Differential Revision: [D16694863](https://our.internmc.facebook.com/intern/diff/D16694863)
facebook-github-bot pushed a commit that referenced this pull request Aug 8, 2019
Summary:
Pull Request resolved: #23885

This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in  #23266.

This PR only compiles one if branch if the condition is an isinstance check. This is consistent with what mypy does; it does not report errors if a branch can be determined statically to be unreachable.

```
def foo(x):
    # type: (int) -> int
    if isinstance(x, str):
        return x["1"]
    return x + 1

reveal_type(foo) # no error, shows int -> int
```

Test Plan: Imported from OSS

Differential Revision: D16697092

Pulled By: eellison

fbshipit-source-id: d3eb4925cd16d551515ac6ff620a69897dbec130
facebook-github-bot pushed a commit that referenced this pull request Aug 8, 2019
Summary:
Pull Request resolved: #23886

This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in  #23266.

Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading).

The usage is:
```
torch.jit.overload
def add(x: int, y: int) -> int: ...
torch.jit.overload
def add(x: float, y: float) -> float: ...

def add:
    return x + y
```

Follow up PRs:

- Add same API for methods
- A couple of cleanups for functions:
     - don't require default params specified on the overload as well
     - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently

Test Plan: Imported from OSS

Differential Revision: D16694863

Pulled By: eellison

fbshipit-source-id: f94f2933bc1c97fa58f31846acfe962b0630068c
eellison pushed a commit that referenced this pull request Aug 15, 2019
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266.

The usage is:
```
@torch.jit.overload
def add(self, y: int) -> int: ...
@torch.jit.overload
def add(self, y: float) -> float: ...
def add():
   ... 
```
eellison pushed a commit that referenced this pull request Aug 15, 2019
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266.

The usage is:
```
@torch.jit.overload
def add(self, y: int) -> int: ...
@torch.jit.overload
def add(self, y: float) -> float: ...
def add():
   ... 
```
eellison pushed a commit that referenced this pull request Aug 16, 2019
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266.

The usage is:
```
@torch.jit.overload
def add(self, y: int) -> int: ...
@torch.jit.overload
def add(self, y: float) -> float: ...
def add():
   ... 
```
eellison pushed a commit that referenced this pull request Aug 16, 2019
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266.

The usage is:
```
@torch.jit.overload
def add(self, y: int) -> int: ...
@torch.jit.overload
def add(self, y: float) -> float: ...
def add():
   ... 
```
eellison pushed a commit that referenced this pull request Aug 16, 2019
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266.

The usage is:
```
@torch.jit.overload
def add(self, y: int) -> int: ...
@torch.jit.overload
def add(self, y: float) -> float: ...
def add():
   ... 
```
eellison pushed a commit that referenced this pull request Aug 16, 2019
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266.

The usage is:
```
@torch.jit.overload
def add(self, y: int) -> int: ...
@torch.jit.overload
def add(self, y: float) -> float: ...
def add():
   ... 
```
eellison pushed a commit that referenced this pull request Aug 20, 2019
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266.

The usage is:
```
@torch.jit.overload
def add(self, y: int) -> int: ...
@torch.jit.overload
def add(self, y: float) -> float: ...
def add():
   ... 
```

Differential Revision: [D16921304](https://our.internmc.facebook.com/intern/diff/D16921304)
eellison pushed a commit that referenced this pull request Aug 20, 2019
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266.

The usage is:
```
@torch.jit.overload
def add(self, y: int) -> int: ...
@torch.jit.overload
def add(self, y: float) -> float: ...
def add():
   ... 
```

Differential Revision: [D16921304](https://our.internmc.facebook.com/intern/diff/D16921304)
facebook-github-bot pushed a commit that referenced this pull request Aug 20, 2019
Summary:
Pull Request resolved: #24259

Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266.

The usage is:
```
torch.jit.overload
def add(self, y: int) -> int: ...
torch.jit.overload
def add(self, y: float) -> float: ...
def add():
   ...
```

Test Plan: Imported from OSS

Differential Revision: D16921304

Pulled By: eellison

fbshipit-source-id: 784e2f26f7ca9a330a434a603c86b53725c3dc71
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: nn Related to torch.nn oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants