Skip to content

Conversation

@driazati
Copy link
Contributor

@driazati driazati commented Jun 3, 2019

Adds support for recursively compiling nn.Sequential and
nn.ModuleList. When either is used, it is converted to a
jit._ConstModuleList or jit._ConstSequential as necessary. Due to
this, we don't need to add it to __constants__ since it's made
constant on demand.

This PR also moves the recursive script tests out to their own class
TestRecursiveScript (the added test is called test_iterable_modules)

Differential Revision: D15611738

@pytorchbot pytorchbot added oncall: jit Add this issue/PR to JIT oncall triage queue module: pybind Related to our Python bindings / interactions with other Python libraries labels Jun 3, 2019
@eellison
Copy link
Contributor

eellison commented Jun 3, 2019

Can you include 'jit' in the name so that it is grepable and will be invoked when run_tests.py --jit is run

@driazati
Copy link
Contributor Author

driazati commented Jun 3, 2019

discussed in person, the test is still in test_jit.py so run_test.py --jit still runs it

Adds support for recursively compiling `nn.Sequential` and
`nn.ModuleList`. When either is used, it is converted to a
`jit._ConstModuleList` or `jit._ConstSequential` as necessary. Due to
this, we don't need to add it to `__constants__` since it's made
constant on demand.

This PR also moves the recursive script tests out to their own class
`TestRecursiveScript` (the added test is called `test_iterable_modules`)
@driazati driazati requested review from eellison and suo and removed request for suo June 3, 2019 22:03
@TheCodez
Copy link
Contributor

TheCodez commented Jun 4, 2019

@driazati does this also solve #20644 ?

@driazati
Copy link
Contributor Author

driazati commented Jun 4, 2019

@TheCodez it does in conjunction with #20939

elif isinstance(module, ModuleList):
return _ConstModuleList(module)
else:
raise RuntimeError("Only nn.ModuleList and nn.Sequential can be made "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like you would want to do the error checking that it's not a Sequential or ModuleList upfront so that you don't get some weird error when you try to enumerate through them

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be called on modules that aren't enumerable (it's only used in _convert_to_script_module and the type check is there)

`('forward',)`. Methods accessed in forward are scripted on demand if
`_enable_recursive_script()` is used.
"""
if isinstance(mod, (ModuleList, Sequential)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why isn't this change not guarded by a torch._C._jit_recursive_script check

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's only 1 guard with that check in script() before it calls _convert_to_script_module, adding it everywhere would be cumbersome to delete later

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems good, couple more comments.

continue

if not torch._C._jit_recursive_script():
# For recursive script, these are constantified after
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add test for an example of when these are constantified and then attempted to be modified, if that applies

@ezyang ezyang added facebook and removed facebook labels Jun 5, 2019
new_strong_submodule = _convert_to_script_module(module)

# Install the ScriptModule on the python side
parent._modules._python_modules[field] = new_strong_submodule
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did this change ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will have a nn.Sequential until it is used and compiled to a _ConstSequential and installs that as a submodule of the ._c ScriptModule. This keeps the Python ScriptModule in sync by adding it there

@facebook-github-bot
Copy link
Contributor

@driazati merged this pull request in 8a2985e.

@facebook-github-bot facebook-github-bot deleted the driazati/apimdo branch July 13, 2020 17:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: pybind Related to our Python bindings / interactions with other Python libraries oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants