-
Notifications
You must be signed in to change notification settings - Fork 26.3k
C++ ModuleDict #28652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
C++ ModuleDict #28652
Conversation
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ShahriarSS Thanks a lot for the awesome work! I left some initial comments.
| /// its own. | ||
| void reset() override {} | ||
|
|
||
| /// Pretty prints the `ModuleList` module into the given `stream`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: ModuleList -> ModuleDict
|
|
||
| /// Pretty prints the `ModuleList` module into the given `stream`. | ||
| void pretty_print(std::ostream& stream) const override { | ||
| stream << "torch::nn::ModuleDict"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious would the actual pretty print function prints the content of the ModuleDict as well? For the Python version it seems to do so:
>>> import torch
>>> from torch import nn
>>> a = nn.ModuleDict({
... 'conv': nn.Conv2d(10, 10, 3),
... 'pool': nn.MaxPool2d(3)
... })
>>> a
ModuleDict(
(conv): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
(pool): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
)| } | ||
|
|
||
| /// Returns true if the `ModuleDict` contains no elements. | ||
| bool is_empty() const noexcept { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this function be called empty? I think ModuleDict's API should mirror std::map, and std::map has empty instead of is_empty. (We should probably deprecate ModuleList::is_empty and Sequential::is_empty as well and redirect them to the empty function, so that they are more in-line with the API of std::vector.)
| children_.erase(name); | ||
| return value; | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add keys and values?
def keys(self):
r"""Return an iterable of the ModuleDict keys.
"""
return self._modules.keys() def values(self):
r"""Return an iterable of the ModuleDict values.
"""
return self._modules.values()| /// `other` is already present in this `ModuleDict`, an exception is thrown. | ||
| void update(const ModuleDictImpl& other) { | ||
| children_.update(other.children_); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can combine these two update methods with
void update(ModuleDictImpl other) {
children_.update(std::move(other.children_));
}which would be optimal in both "copying" and "moving" cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also I believe we should call register_module to register all the newly added modules.
| const std::string& key, | ||
| const std::shared_ptr<Module>& value) { | ||
| return children_.insert(key, value); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same for this function: we can likely use
std::shared_ptr<Module>& insert(
const std::string& key,
std::shared_ptr<Module> value)and then do std::move inside the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also I believe we should call register_module to register the newly added module.
|
|
||
| const std::shared_ptr<Module>& operator[](const std::string& key) const { | ||
| return children_[key]; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit concerning that we have these two operator[] overloads for ModuleDict but not ModuleList or Sequential (they only have std::shared_ptr<Module> operator[](size_t index) const). I wondered if it would make sense to do a few things:
- Add
ModuleList::operator[]andSequential::operator[]overloads for bothstd::shared_ptr<Module>&andconst std::shared_ptr<Module>&return types - Add test for lvalue assignment for
Sequential/ModuleList/ModuleDict'soperator[]method, e.g.module_list[0] = SomeModule()
| TORCH_MODULE(ModuleDict); | ||
|
|
||
| } // namespace nn | ||
| } // namespace torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding the data structure for storing the modules, do you think we can use torch::OrderedDict? I think AnyModule is more about supporting calling forward on a module without needing to know its concrete type, and we don't have this requirement for ModuleDict.
| private: | ||
| // Friend classes. | ||
|
|
||
| friend class ModuleDictImpl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a specific reason we need to add this?
| OrderedDict<std::string, std::shared_ptr<Module>>::ConstIterator; | ||
|
|
||
| ModuleDictImpl() = default; | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be awesome to add a ModuleDictImpl constructor that takes a torch::OrderedDict as input, so that we can do something similar to:
torch::OrderedDict<torch::nn::Module> modules = {{"a", ModuleA()}, {"b", ModuleB()}};
auto moduledict = torch::nn::ModuleDict(modules);|
Is torch::nn::ModuleDict still being added to the C++ api? Thanks. |
|
Hi @ShahriarSS! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
This is
nn::ModuleDictfor the C++ front-end.@yf225 please take a look and tell me if the PR has problems. once you confirm the design I will add the documentations and the tests.
@yf225 is it possible to use
AnyModulehere? It would require a secondaryOrderedDictbesideschildren_.