Skip to content

Conversation

@eellison
Copy link
Contributor

@eellison eellison commented Oct 31, 2019

Stack from ghstack:

Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.

EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in __constants__ ("did you forget to add it to Constants"). This PR scripts it even if it is not in __constants__. Is that what we want?

Differential Revision: D18402821

@eellison eellison requested a review from apaszke as a code owner October 31, 2019 20:28
@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Oct 31, 2019
@eellison eellison requested a review from suo October 31, 2019 20:30
eellison pushed a commit that referenced this pull request Oct 31, 2019
eellison pushed a commit that referenced this pull request Oct 31, 2019
…tom forwards"


Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.

[ghstack-poisoned]
eellison pushed a commit that referenced this pull request Nov 1, 2019
@eellison eellison requested a review from driazati November 4, 2019 23:29
v = sub(v)
self.assertEqual(o, v)

with self.assertRaisesRegex(Exception, "object is not iterable"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this need to run under optimized_execution(False)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know, i just put this test where the rest of the functionality is being exposed. I suspect it's not needed.

if hasattr(nn_module, 'forward'):
if getattr(nn_module.forward, "__func__", None) == torch.nn.Module.forward:
forward_func = getattr(nn_module.forward, "__func__", None)
if forward_func == torch.nn.Module.forward or forward_func == Sequential.forward:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about ModuleList and ModuleDict here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ModuleList and ModuleDict don't have a forward, it is already covered by forward_func == torch.nn.Module.forward case

"""
concrete_type = torch._C.ConcreteModuleType()
concrete_type.add_pyclass(type(nn_module))
if isinstance(nn_module, (torch.nn.ModuleDict, torch.jit._ConstModuleDict)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will break in a case like this, due to #29092

class S(Sequential):
   ....

class M(Module):
    def __init__(self):
        super().__init__()
        self.x = torch.jit.script(S())

torch.jit.script(M())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking into this example, however #29092 should be fixed on its own right as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works:

        class Inner(torch.nn.Module):
            def forward(self, x):
                return x + 10

        class CustomSequential(nn.Sequential):
            def __init__(self):
                super(CustomSequential, self).__init__(
                    nn.ReLU(), Inner())

            def forward(self, x):
                x = x + 3
                for mod in self:
                    x = mod(x)
                return x - 5

        class M(nn.Module):
            def __init__(self):
                super().__init__()
                self.x = torch.jit.script(CustomSequential())

            def forward(self, input):
                return self.x.forward(input)

        self.checkModule(M(), (torch.tensor(.5),))

# define magic methods here as a shim to the correct attribute.
def forward_magic_method(self, method_name, *args, **kwargs):
self_method = getattr(self, method_name)
if self_method == getattr(RecursiveScriptModule, method_name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this break __dir__ calls for normal ScriptModules?

Copy link
Contributor Author

@eellison eellison Nov 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are you worried about breaking ? Shouldn't the behavior inherit from ModuleList, which does the following:

        keys = super(ModuleList, self).__dir__()
        keys = [key for key in keys if not key.isdigit()]
        return keys

which this currently calls into, and which the previous implementation of ModuleList copy-pasta'd

Copy link
Member

@suo suo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great change, this simplifies a lot of things! I had a few suggestions to help make it more generic.

# In order to continue to expose module container functions to python,
# we add on the methods that we had previously exposed in the previous
# version of this api.
if isinstance(nn_module, (ModuleList, Sequential)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be awesome of we could eliminate any special handling for ModuleContainers here at all. The generic mechanism we want here is "make sure this Python method on nn_module shows up on the script_module". One way to do this it to have an internal decorator called like @_copy_to_script (we need a better name), which specifies that a given method should be copied to the new Python ScriptModule. These decorators can go on all the container modules, and ScriptModule creation can just process the decorators and do the copy without encoding special-case policies.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are special-cased methods that exist only so that we continue to expose the same api for these built-in containers as we did previously. I don't know if it makes sense to generalize it. Another possibility we can at least consider is to break the api for the constant modules. We'd only be breaking it for their usage as scripted modules within python, not within torchscript.

"For purely script modules use my_script_module.save(<filename>) instead.")

# Python magic methods do lookups on an object's class type, instead of looking up
# the desugared attribute. In order to continue to expose the magic methods
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I was confused by the word "desguared" here. I think it would be clearer to talk about methods defined on classes vs. methods defined on instances.

def __len__(self):
return self.forward_magic_method("__len__")

def __dir__(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nn.Module implements __dir__, so I think it breaks for regular modules.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So in the case of __dir__ we should always @_copy_to_script so to speak ? or what should be done here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since RecursiveScriptModule inherits from nn.Module, __dir__ is always defined on RecursiveScriptModules. This makes forward_magic_method problematic, because it takes over the default behavior. I think the "generate a new type" strategy would work to fix this, since it would just further override __dir__.

…tom forwards"


Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.

EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?


[ghstack-poisoned]
eellison pushed a commit that referenced this pull request Nov 5, 2019
@eellison eellison requested a review from suo November 6, 2019 00:22
…tom forwards"


Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.

EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?


[ghstack-poisoned]
eellison pushed a commit that referenced this pull request Nov 6, 2019
@eellison eellison requested review from gchanan and zou3519 November 6, 2019 17:42
…tom forwards"


Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.

EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?


[ghstack-poisoned]
eellison pushed a commit that referenced this pull request Nov 6, 2019
…tom forwards"


Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.

EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?


[ghstack-poisoned]
eellison pushed a commit that referenced this pull request Nov 7, 2019
@eellison
Copy link
Contributor Author

eellison commented Nov 7, 2019

@suo as far as I can tell this only has python 2 errors... which maybe doesn't matter anymore ? Can i get another review.

EDIT: python 2 errors should be fixed

…tom forwards"


Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.

EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?


[ghstack-poisoned]
…tom forwards"


Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.

EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?


[ghstack-poisoned]
…tom forwards"


Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.

EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?


[ghstack-poisoned]
…tom forwards"


Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.

EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?


[ghstack-poisoned]
…tom forwards"


Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.

EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?


[ghstack-poisoned]
eellison pushed a commit that referenced this pull request Nov 8, 2019
Copy link
Member

@suo suo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, thanks for this!

…tom forwards"


Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.

EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?


[ghstack-poisoned]
…tom forwards"


Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.

EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?

Differential Revision: [D18402821](https://our.internmc.facebook.com/intern/diff/D18402821)

[ghstack-poisoned]
eellison pushed a commit that referenced this pull request Nov 8, 2019
…tom forwards"


Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.

EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?

Differential Revision: [D18402821](https://our.internmc.facebook.com/intern/diff/D18402821)

[ghstack-poisoned]
eellison added 2 commits November 12, 2019 09:05
…tom forwards"


Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.

EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?

Differential Revision: [D18402821](https://our.internmc.facebook.com/intern/diff/D18402821)

[ghstack-poisoned]
…tom forwards"


Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.

EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?

Differential Revision: [D18402821](https://our.internmc.facebook.com/intern/diff/D18402821)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

@eellison merged this pull request in fbe90b6.

@facebook-github-bot facebook-github-bot deleted the gh/eellison/27/head branch November 16, 2019 15:16
xxtEchjovs44 pushed a commit to xxtEchjovs44/pytorch that referenced this pull request Jan 29, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants