-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[JIT] Cleanup special handling of Containers, allowing custom forwards #28988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
…tom forwards" [ghstack-poisoned]
…tom forwards" [ghstack-poisoned]
…tom forwards" Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods. [ghstack-poisoned]
| v = sub(v) | ||
| self.assertEqual(o, v) | ||
|
|
||
| with self.assertRaisesRegex(Exception, "object is not iterable"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this need to run under optimized_execution(False)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know, i just put this test where the rest of the functionality is being exposed. I suspect it's not needed.
torch/jit/_recursive.py
Outdated
| if hasattr(nn_module, 'forward'): | ||
| if getattr(nn_module.forward, "__func__", None) == torch.nn.Module.forward: | ||
| forward_func = getattr(nn_module.forward, "__func__", None) | ||
| if forward_func == torch.nn.Module.forward or forward_func == Sequential.forward: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about ModuleList and ModuleDict here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ModuleList and ModuleDict don't have a forward, it is already covered by forward_func == torch.nn.Module.forward case
| """ | ||
| concrete_type = torch._C.ConcreteModuleType() | ||
| concrete_type.add_pyclass(type(nn_module)) | ||
| if isinstance(nn_module, (torch.nn.ModuleDict, torch.jit._ConstModuleDict)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this will break in a case like this, due to #29092
class S(Sequential):
....
class M(Module):
def __init__(self):
super().__init__()
self.x = torch.jit.script(S())
torch.jit.script(M())There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking into this example, however #29092 should be fixed on its own right as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works:
class Inner(torch.nn.Module):
def forward(self, x):
return x + 10
class CustomSequential(nn.Sequential):
def __init__(self):
super(CustomSequential, self).__init__(
nn.ReLU(), Inner())
def forward(self, x):
x = x + 3
for mod in self:
x = mod(x)
return x - 5
class M(nn.Module):
def __init__(self):
super().__init__()
self.x = torch.jit.script(CustomSequential())
def forward(self, input):
return self.x.forward(input)
self.checkModule(M(), (torch.tensor(.5),))
torch/jit/__init__.py
Outdated
| # define magic methods here as a shim to the correct attribute. | ||
| def forward_magic_method(self, method_name, *args, **kwargs): | ||
| self_method = getattr(self, method_name) | ||
| if self_method == getattr(RecursiveScriptModule, method_name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Won't this break __dir__ calls for normal ScriptModules?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are you worried about breaking ? Shouldn't the behavior inherit from ModuleList, which does the following:
keys = super(ModuleList, self).__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys
which this currently calls into, and which the previous implementation of ModuleList copy-pasta'd
suo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great change, this simplifies a lot of things! I had a few suggestions to help make it more generic.
torch/jit/_recursive.py
Outdated
| # In order to continue to expose module container functions to python, | ||
| # we add on the methods that we had previously exposed in the previous | ||
| # version of this api. | ||
| if isinstance(nn_module, (ModuleList, Sequential)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be awesome of we could eliminate any special handling for ModuleContainers here at all. The generic mechanism we want here is "make sure this Python method on nn_module shows up on the script_module". One way to do this it to have an internal decorator called like @_copy_to_script (we need a better name), which specifies that a given method should be copied to the new Python ScriptModule. These decorators can go on all the container modules, and ScriptModule creation can just process the decorators and do the copy without encoding special-case policies.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are special-cased methods that exist only so that we continue to expose the same api for these built-in containers as we did previously. I don't know if it makes sense to generalize it. Another possibility we can at least consider is to break the api for the constant modules. We'd only be breaking it for their usage as scripted modules within python, not within torchscript.
torch/jit/__init__.py
Outdated
| "For purely script modules use my_script_module.save(<filename>) instead.") | ||
|
|
||
| # Python magic methods do lookups on an object's class type, instead of looking up | ||
| # the desugared attribute. In order to continue to expose the magic methods |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I was confused by the word "desguared" here. I think it would be clearer to talk about methods defined on classes vs. methods defined on instances.
torch/jit/__init__.py
Outdated
| def __len__(self): | ||
| return self.forward_magic_method("__len__") | ||
|
|
||
| def __dir__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nn.Module implements __dir__, so I think it breaks for regular modules.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So in the case of __dir__ we should always @_copy_to_script so to speak ? or what should be done here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since RecursiveScriptModule inherits from nn.Module, __dir__ is always defined on RecursiveScriptModules. This makes forward_magic_method problematic, because it takes over the default behavior. I think the "generate a new type" strategy would work to fix this, since it would just further override __dir__.
…tom forwards"
Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.
EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?
[ghstack-poisoned]
…tom forwards"
Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.
EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?
[ghstack-poisoned]
…tom forwards"
Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.
EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?
[ghstack-poisoned]
…tom forwards"
Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.
EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?
[ghstack-poisoned]
|
@suo as far as I can tell this only has python 2 errors... which maybe doesn't matter anymore ? Can i get another review. EDIT: python 2 errors should be fixed |
…tom forwards"
Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.
EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?
[ghstack-poisoned]
…tom forwards"
Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.
EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?
[ghstack-poisoned]
…tom forwards"
Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.
EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?
[ghstack-poisoned]
…tom forwards"
Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.
EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?
[ghstack-poisoned]
…tom forwards"
Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.
EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?
[ghstack-poisoned]
suo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, thanks for this!
…tom forwards"
Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.
EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?
[ghstack-poisoned]
…tom forwards"
Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.
EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?
Differential Revision: [D18402821](https://our.internmc.facebook.com/intern/diff/D18402821)
[ghstack-poisoned]
…tom forwards"
Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.
EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?
Differential Revision: [D18402821](https://our.internmc.facebook.com/intern/diff/D18402821)
[ghstack-poisoned]
…tom forwards"
Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.
EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?
Differential Revision: [D18402821](https://our.internmc.facebook.com/intern/diff/D18402821)
[ghstack-poisoned]
…tom forwards"
Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.
EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?
Differential Revision: [D18402821](https://our.internmc.facebook.com/intern/diff/D18402821)
[ghstack-poisoned]
ghstack-source-id: 87ff1fa Pull Request resolved: pytorch/pytorch#28988
Stack from ghstack:
Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.
EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in
__constants__("did you forget to add it to Constants"). This PR scripts it even if it is not in__constants__. Is that what we want?Differential Revision: D18402821