-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[jit] Add support for recursive compilation on Modules #20708
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
What would it take to get rid of |
|
Similar to #19439 we could just make everything that we see in |
|
This change deserves a design document before we land any code. There are some key questions which I am not clear on:
We should come up with a consistent design (across weak and strong modules) that we look at first. |
dd9f511 to
e260151
Compare
zdevito
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good start. I have a few inline comments, but generally I think we should merge it and then discuss the next steps.
torch/_jit_internal.py
Outdated
| saving. This decorator explicitly marks that a method should be included | ||
| even if it is not called from Python. | ||
| """ | ||
| class_dict = get_class_attribute_dict(fn, frames_up=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like this would be idiomatic if it just set a __torchscript_export property on fn and the class then searched its own __dict__ for things marked export. Searching the frames for what you hope to be the class dict is not great. What happens when we add @property and then we chain @export @property def. It would break.
|
|
||
|
|
||
| def _make_strong_submodule(field, module, parent): | ||
| if field not in parent._modules: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how does this happen? There are nn.Modules in module that are not in the _modules list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was more of a sanity check though it's possible someone does self.__dict__["secret_module"] = nn.Linear()
| class WeakScriptModuleProxy(ScriptModule): | ||
| def __init__(self, original, stubs): | ||
| # Guards behavior of __setattr__ and __getattr__ so ScriptModule | ||
| # __init__ can run correctly |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need a plan to get rid of WeakScriptModuleProxy. Eventually all ScriptModules will just always be this class, and they will not longer be proxies at all. It is difficult to understand all the components we have after this patch, so we need some followup effort merging the concepts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most of the functionality here can operate directly on a ScriptModule instance instead of constructing a WeakScriptModuleProxy, so deleting it should be pretty painless. There is still the question of what something like this would do:
scripted_linear = torch.jit.script(nn.Linear(10, 10))
isinstance(scripted_linear, nn.Linear) # true or false?Intuitively I feel like it should be False since it's constructing a new module, but I can see people doing something like this, which could lead to confusing behavior.
class A(nn.Module):
def __init__(self, mod):
self.module = mod
if isinstance(self.module, nn.Linear):
...
A(torch.jit.script(B())| constants_set = set(getattr(original, "__constants__", [])) | ||
| self.__dict__["_constants_set"] = {} | ||
|
|
||
| if not hasattr(original, '_parameters'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to document better exactly what happens when a we create a ScriptModule from a Module: what gets copied, what is translated in some way, what is ignored.
| # TODO: need to handle this more generally when non-tensor attributes added to module | ||
| object.__setattr__(self, name, item) | ||
| elif isinstance(item, Parameter) or (isinstance(item, Module) and item is not self): | ||
| elif item is self: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't follow this code here. Not sure why it is deciding to ignore some things, and not others.
torch/jit/__init__.py
Outdated
| exported = tuple(mod.__torchscript_export__) | ||
| methods = methods + exported | ||
|
|
||
| if mod in _jit_internal.weak_modules: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is possible we no longer want to cache this. Should calling script twice on the same nn.Module return the same module? What if you mutate something in the original class (e.g. do some model surgery) and recall it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, it should make a fresh copy each time to stay with the idea of it being a "pure" operation. I think that should be in a follow up diff though.
torch/_jit_internal.py
Outdated
| return fn | ||
|
|
||
|
|
||
| def ignore(maybe_fn=None, *, drop_on_export=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kw only arg is not supported in python 2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unlanding now
Following on #19747, this implements most of the
torch.jit.script()changes laid out in #20939.Still to do:
ScriptMethod(so only@exported methods andforwardare compiled)forwardon a submodule doesn't workDifferential Revision: D15560490