-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[jit] Implement more of of the nn.Module API #28828
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This updates torch::script::Module to more closely match the behavior of nn.Module. In particular, it implements the (optionally recurisive) iterators that retrieve submodules, parameters, and buffers and makes their names match the python versions. This also removes the individual accessors for Parameter, Module, Buffer, etc. and replaces them with a single `attr` function which is equivalent to writing `a.foo` in Python (`setattr` emulates `a.foo = v`). As we build out the user-facing API for TorchScript values this will end up matching how an attribute is accessed on general objects. This PR preservers the python bindings for script::Module by emulating the old API at the binding level. A followup will clean up the usage to more directly match the C++ API.
ZolotukhinM
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have two comments:
- It seems that the old API for accessing attributes had checks for the type of the attribute (module/buffer/parameter), while the new API returns a generic one. We probably need to add these removed checks to the places where it was used to preserve the behavior (I commented in one of such sites, but there are probably more).
- The code for recursive iterators is hard to understand. I know what it's supposed to be doing but it's difficult to follow it even with that knowledge. Can we please add some comments (and update the old comments)? Some classes like
Policyet al would also benefit from brief comments.
| if (!module_.hasattr(key)) { | ||
| return false; | ||
| } | ||
| archive.module_ = module_.attr(key).toModule(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't we crash here if the attribute is there, but it's not a module? In the original code we returned false in this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I can make this more specific.
torch/csrc/jit/script/module.h
Outdated
| struct NameValue { | ||
| std::string name; | ||
| IValue value; | ||
| struct Frame { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name Frame is already overloaded in various contexts, are we sure we want to use it here as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ill change it to something less ambiguous.
[jit] Implement more of of the nn.Module API This updates torch::script::Module to more closely match the behavior of nn.Module. In particular, it implements the (optionally recurisive) iterators that retrieve submodules, parameters, and buffers and makes their names match the python versions. This also removes the individual accessors for Parameter, Module, Buffer, etc. and replaces them with a single `attr` function which is equivalent to writing `a.foo` in Python (`setattr` emulates `a.foo = v`). As we build out the user-facing API for TorchScript values this will end up matching how an attribute is accessed on general objects. This PR preservers the python bindings for script::Module by emulating the old API at the binding level. A followup will clean up the usage to more directly match the C++ API. gh-metadata: pytorch pytorch 28828 gh/zdevito/130/head
[jit] Implement more of of the nn.Module API This updates torch::script::Module to more closely match the behavior of nn.Module. In particular, it implements the (optionally recurisive) iterators that retrieve submodules, parameters, and buffers and makes their names match the python versions. This also removes the individual accessors for Parameter, Module, Buffer, etc. and replaces them with a single `attr` function which is equivalent to writing `a.foo` in Python (`setattr` emulates `a.foo = v`). As we build out the user-facing API for TorchScript values this will end up matching how an attribute is accessed on general objects. This PR preservers the python bindings for script::Module by emulating the old API at the binding level. A followup will clean up the usage to more directly match the C++ API. gh-metadata: pytorch pytorch 28828 gh/zdevito/130/head
[jit] Implement more of of the nn.Module API This updates torch::script::Module to more closely match the behavior of nn.Module. In particular, it implements the (optionally recurisive) iterators that retrieve submodules, parameters, and buffers and makes their names match the python versions. This also removes the individual accessors for Parameter, Module, Buffer, etc. and replaces them with a single `attr` function which is equivalent to writing `a.foo` in Python (`setattr` emulates `a.foo = v`). As we build out the user-facing API for TorchScript values this will end up matching how an attribute is accessed on general objects. This PR preservers the python bindings for script::Module by emulating the old API at the binding level. A followup will clean up the usage to more directly match the C++ API. gh-metadata: pytorch pytorch 28828 gh/zdevito/130/head
ZolotukhinM
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good, please find some comments inline.
| script::Module observer = observer_module.clone(); | ||
| std::string observer_name = "_observer_" + std::to_string(uid_++); | ||
| while (module.find_module(observer_name)) { | ||
| while (module.hasattr(observer_name)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ouch, there was a bug previously! Good thing that it's gonna be fixed now.
| // Queue submodules for processing | ||
| for (const script::NameModule& submodule : current.get_modules()) { | ||
| worklist.push(submodule.module); | ||
| for (const script::NameModule& submodule : current.named_children()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: we can probably use children instead of named_children here.
| InsertPrepackUnpack(graph); | ||
| for (script::NameModule m : module.get_modules()) { | ||
| InsertPrepackUnpack(m.module); | ||
| for (script::NameModule m : module.named_children()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto.
| FoldPrepackedWeightIntoModule( | ||
| module, method.name(), linear_params_module, conv_params_module); | ||
| for (script::NameModule m : module.get_modules()) { | ||
| for (script::NameModule m : module.named_children()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto.
| slot_dict_impl<Policy>(self.module_object()).setattr(name, std::move(value)); | ||
| } | ||
|
|
||
| static py::object get_generic(Module& self, const std::string& name) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this also be templatized by Policy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My next patch is going to remove all of those and replace them with attr, so I didn't bother with it here. All of these are private methods in our ScriptModule implementation that are already guarded.
torch/csrc/jit/script/module.h
Outdated
| IValue v) { | ||
| std::string name; | ||
| if (frames.size() == 1) { | ||
| name = (frames.back().i_ == -1) ? "" : nameFragment(frames.back()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really have to special case size==1? It looks like it can be handled just fine by the loop below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's to avoid the overhead of allocating a ostringstream and copying the string twice in the very common non-recursive case.
torch/csrc/jit/script/module.h
Outdated
| if (i > 0) { | ||
| ss << "."; | ||
| } | ||
| ss << nameFragment(frames[i]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if frames[i].i_ == -1? Should we assert that it's not the case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The getAttributeName is going to assert in that case, so I didn't add another one. It's not a bug because only the top-level frame can have this.
torch/csrc/jit/script/module.h
Outdated
| (type_ && module_.entity_type(i_) != *type_)) { | ||
| ++i_; | ||
| // return_module() is a corner case where instead of returning a submodule | ||
| // of root, we are return root itself, because we are iterating modules(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo: "are return"
torch/csrc/jit/script/module.h
Outdated
| // return_module() is a corner case where instead of returning a submodule | ||
| // of root, we are return root itself, because we are iterating modules(), | ||
| // which contains the root module itself. | ||
| // It is represented with a single Frame whose index is -1. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: since Frame was renamed to SlotCursor, we need to update all its references accordingly.
| } | ||
| // the last traversal action advanced beyond the number of slots in the | ||
| // module so continue the iteration in the parent. | ||
| if (top().i_ >= int64_t(top().module_.num_slots())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we do it while we're beyond the number of slots instead of if? I.e. shall we pop back until we reach a valid position or do we intentionally want to pop back only once?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would be correct too, but I went with this way to make the components easier to understand: next() does 1 step, while_not_valid is responsible for repeating.
[jit] Implement more of of the nn.Module API This updates torch::script::Module to more closely match the behavior of nn.Module. In particular, it implements the (optionally recurisive) iterators that retrieve submodules, parameters, and buffers and makes their names match the python versions. This also removes the individual accessors for Parameter, Module, Buffer, etc. and replaces them with a single `attr` function which is equivalent to writing `a.foo` in Python (`setattr` emulates `a.foo = v`). As we build out the user-facing API for TorchScript values this will end up matching how an attribute is accessed on general objects. This PR preservers the python bindings for script::Module by emulating the old API at the binding level. A followup will clean up the usage to more directly match the C++ API. gh-metadata: pytorch pytorch 28828 gh/zdevito/130/head
[jit] Implement more of of the nn.Module API This updates torch::script::Module to more closely match the behavior of nn.Module. In particular, it implements the (optionally recurisive) iterators that retrieve submodules, parameters, and buffers and makes their names match the python versions. This also removes the individual accessors for Parameter, Module, Buffer, etc. and replaces them with a single `attr` function which is equivalent to writing `a.foo` in Python (`setattr` emulates `a.foo = v`). As we build out the user-facing API for TorchScript values this will end up matching how an attribute is accessed on general objects. This PR preservers the python bindings for script::Module by emulating the old API at the binding level. A followup will clean up the usage to more directly match the C++ API. gh-metadata: pytorch pytorch 28828 gh/zdevito/130/head
CircleCI build failures summaryAs of commit 74bc5c4:
Here are the reasons each build failed. This comment was automatically generated by Dr. CI. Please report bugs/suggestions on the GitHub issue tracker. This comment has been revised 17 time(s). |
[jit] Implement more of of the nn.Module API This updates torch::script::Module to more closely match the behavior of nn.Module. In particular, it implements the (optionally recurisive) iterators that retrieve submodules, parameters, and buffers and makes their names match the python versions. This also removes the individual accessors for Parameter, Module, Buffer, etc. and replaces them with a single `attr` function which is equivalent to writing `a.foo` in Python (`setattr` emulates `a.foo = v`). As we build out the user-facing API for TorchScript values this will end up matching how an attribute is accessed on general objects. This PR preservers the python bindings for script::Module by emulating the old API at the binding level. A followup will clean up the usage to more directly match the C++ API. gh-metadata: pytorch pytorch 28828 gh/zdevito/130/head
[jit] Implement more of of the nn.Module API This updates torch::script::Module to more closely match the behavior of nn.Module. In particular, it implements the (optionally recurisive) iterators that retrieve submodules, parameters, and buffers and makes their names match the python versions. This also removes the individual accessors for Parameter, Module, Buffer, etc. and replaces them with a single `attr` function which is equivalent to writing `a.foo` in Python (`setattr` emulates `a.foo = v`). As we build out the user-facing API for TorchScript values this will end up matching how an attribute is accessed on general objects. This PR preservers the python bindings for script::Module by emulating the old API at the binding level. A followup will clean up the usage to more directly match the C++ API. gh-metadata: pytorch pytorch 28828 gh/zdevito/130/head
[jit] Implement more of of the nn.Module API This updates torch::script::Module to more closely match the behavior of nn.Module. In particular, it implements the (optionally recurisive) iterators that retrieve submodules, parameters, and buffers and makes their names match the python versions. This also removes the individual accessors for Parameter, Module, Buffer, etc. and replaces them with a single `attr` function which is equivalent to writing `a.foo` in Python (`setattr` emulates `a.foo = v`). As we build out the user-facing API for TorchScript values this will end up matching how an attribute is accessed on general objects. This PR preservers the python bindings for script::Module by emulating the old API at the binding level. A followup will clean up the usage to more directly match the C++ API. gh-metadata: pytorch pytorch 28828 gh/zdevito/130/head
[jit] Implement more of of the nn.Module API This updates torch::script::Module to more closely match the behavior of nn.Module. In particular, it implements the (optionally recurisive) iterators that retrieve submodules, parameters, and buffers and makes their names match the python versions. This also removes the individual accessors for Parameter, Module, Buffer, etc. and replaces them with a single `attr` function which is equivalent to writing `a.foo` in Python (`setattr` emulates `a.foo = v`). As we build out the user-facing API for TorchScript values this will end up matching how an attribute is accessed on general objects. This PR preservers the python bindings for script::Module by emulating the old API at the binding level. A followup will clean up the usage to more directly match the C++ API. gh-metadata: pytorch pytorch 28828 gh/zdevito/130/head
[jit] Implement more of of the nn.Module API This updates torch::script::Module to more closely match the behavior of nn.Module. In particular, it implements the (optionally recurisive) iterators that retrieve submodules, parameters, and buffers and makes their names match the python versions. This also removes the individual accessors for Parameter, Module, Buffer, etc. and replaces them with a single `attr` function which is equivalent to writing `a.foo` in Python (`setattr` emulates `a.foo = v`). As we build out the user-facing API for TorchScript values this will end up matching how an attribute is accessed on general objects. This PR preservers the python bindings for script::Module by emulating the old API at the binding level. A followup will clean up the usage to more directly match the C++ API. gh-metadata: pytorch pytorch 28828 gh/zdevito/130/head
[jit] Implement more of of the nn.Module API This updates torch::script::Module to more closely match the behavior of nn.Module. In particular, it implements the (optionally recurisive) iterators that retrieve submodules, parameters, and buffers and makes their names match the python versions. This also removes the individual accessors for Parameter, Module, Buffer, etc. and replaces them with a single `attr` function which is equivalent to writing `a.foo` in Python (`setattr` emulates `a.foo = v`). As we build out the user-facing API for TorchScript values this will end up matching how an attribute is accessed on general objects. This PR preservers the python bindings for script::Module by emulating the old API at the binding level. A followup will clean up the usage to more directly match the C++ API. gh-metadata: pytorch pytorch 28828 gh/zdevito/130/head
[jit] Implement more of of the nn.Module API This updates torch::script::Module to more closely match the behavior of nn.Module. In particular, it implements the (optionally recurisive) iterators that retrieve submodules, parameters, and buffers and makes their names match the python versions. This also removes the individual accessors for Parameter, Module, Buffer, etc. and replaces them with a single `attr` function which is equivalent to writing `a.foo` in Python (`setattr` emulates `a.foo = v`). As we build out the user-facing API for TorchScript values this will end up matching how an attribute is accessed on general objects. This PR preservers the python bindings for script::Module by emulating the old API at the binding level. A followup will clean up the usage to more directly match the C++ API. gh-metadata: pytorch pytorch 28828 gh/zdevito/130/head
Summary: Pull Request resolved: pytorch/pytorch#28828 This updates torch::script::Module to more closely match the behavior of nn.Module. In particular, it implements the (optionally recurisive) iterators that retrieve submodules, parameters, and buffers and makes their names match the python versions. This also removes the individual accessors for Parameter, Module, Buffer, etc. and replaces them with a single `attr` function which is equivalent to writing `a.foo` in Python (`setattr` emulates `a.foo = v`). As we build out the user-facing API for TorchScript values this will end up matching how an attribute is accessed on general objects. This PR preservers the python bindings for script::Module by emulating the old API at the binding level. A followup will clean up the usage to more directly match the C++ API. Test Plan: Imported from OSS Differential Revision: D18197611 Pulled By: zdevito fbshipit-source-id: 7ee4dcbb258605d1c988314b05d938423f1ccee5
|
Not an easy fix after all. Reverting #29208 now. |
Test Plan: revert-hammer Differential Revision: D18350353 Original commit changeset: 2026c8ab7650 fbshipit-source-id: 401f34cb276c3ea34a5439de4c3415969a04ab2a
Stack from ghstack:
This updates torch::script::Module to more closely match the behavior
of nn.Module. In particular, it implements the (optionally recurisive)
iterators that retrieve submodules, parameters, and buffers and makes
their names match the python versions.
This also removes the individual accessors for Parameter, Module, Buffer, etc.
and replaces them with a single
attrfunction which is equivalent towriting
a.fooin Python (setattremulatesa.foo = v).As we build out the user-facing API for TorchScript values this will end
up matching how an attribute is accessed on general objects.
This PR preservers the python bindings for script::Module by emulating the
old API at the binding level. A followup will clean up the usage to more
directly match the C++ API.
Differential Revision: D18197611