-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[jit] Support recursive ModuleList / Sequential #21306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Can you include 'jit' in the name so that it is grepable and will be invoked when run_tests.py --jit is run |
|
discussed in person, the test is still in |
Adds support for recursively compiling `nn.Sequential` and `nn.ModuleList`. When either is used, it is converted to a `jit._ConstModuleList` or `jit._ConstSequential` as necessary. Due to this, we don't need to add it to `__constants__` since it's made constant on demand. This PR also moves the recursive script tests out to their own class `TestRecursiveScript` (the added test is called `test_iterable_modules`)
| elif isinstance(module, ModuleList): | ||
| return _ConstModuleList(module) | ||
| else: | ||
| raise RuntimeError("Only nn.ModuleList and nn.Sequential can be made " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like you would want to do the error checking that it's not a Sequential or ModuleList upfront so that you don't get some weird error when you try to enumerate through them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be called on modules that aren't enumerable (it's only used in _convert_to_script_module and the type check is there)
| `('forward',)`. Methods accessed in forward are scripted on demand if | ||
| `_enable_recursive_script()` is used. | ||
| """ | ||
| if isinstance(mod, (ModuleList, Sequential)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why isn't this change not guarded by a torch._C._jit_recursive_script check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's only 1 guard with that check in script() before it calls _convert_to_script_module, adding it everywhere would be cumbersome to delete later
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems good, couple more comments.
| continue | ||
|
|
||
| if not torch._C._jit_recursive_script(): | ||
| # For recursive script, these are constantified after |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add test for an example of when these are constantified and then attempted to be modified, if that applies
| new_strong_submodule = _convert_to_script_module(module) | ||
|
|
||
| # Install the ScriptModule on the python side | ||
| parent._modules._python_modules[field] = new_strong_submodule |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did this change ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will have a nn.Sequential until it is used and compiled to a _ConstSequential and installs that as a submodule of the ._c ScriptModule. This keeps the Python ScriptModule in sync by adding it there
Adds support for recursively compiling
nn.Sequentialandnn.ModuleList. When either is used, it is converted to ajit._ConstModuleListorjit._ConstSequentialas necessary. Due tothis, we don't need to add it to
__constants__since it's madeconstant on demand.
This PR also moves the recursive script tests out to their own class
TestRecursiveScript(the added test is calledtest_iterable_modules)Differential Revision: D15611738