Skip to content

Conversation

@ShahriarRezghi
Copy link

@ShahriarRezghi ShahriarRezghi commented Oct 25, 2019

This is nn::ModuleDict for the C++ front-end.

@yf225 please take a look and tell me if the PR has problems. once you confirm the design I will add the documentations and the tests.
@yf225 is it possible to use AnyModule here? It would require a secondary OrderedDict besides children_.

Copy link
Contributor

@yf225 yf225 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ShahriarSS Thanks a lot for the awesome work! I left some initial comments.

/// its own.
void reset() override {}

/// Pretty prints the `ModuleList` module into the given `stream`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ModuleList -> ModuleDict


/// Pretty prints the `ModuleList` module into the given `stream`.
void pretty_print(std::ostream& stream) const override {
stream << "torch::nn::ModuleDict";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious would the actual pretty print function prints the content of the ModuleDict as well? For the Python version it seems to do so:

>>> import torch
>>> from torch import nn
>>> a = nn.ModuleDict({
...                         'conv': nn.Conv2d(10, 10, 3),
...                         'pool': nn.MaxPool2d(3)
...                 })
>>> a
ModuleDict(
  (conv): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
)

}

/// Returns true if the `ModuleDict` contains no elements.
bool is_empty() const noexcept {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this function be called empty? I think ModuleDict's API should mirror std::map, and std::map has empty instead of is_empty. (We should probably deprecate ModuleList::is_empty and Sequential::is_empty as well and redirect them to the empty function, so that they are more in-line with the API of std::vector.)

children_.erase(name);
return value;
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add keys and values?

    def keys(self):
        r"""Return an iterable of the ModuleDict keys.
        """
        return self._modules.keys()
    def values(self):
        r"""Return an iterable of the ModuleDict values.
        """
        return self._modules.values()

/// `other` is already present in this `ModuleDict`, an exception is thrown.
void update(const ModuleDictImpl& other) {
children_.update(other.children_);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can combine these two update methods with

  void update(ModuleDictImpl other) {
    children_.update(std::move(other.children_));
  }

which would be optimal in both "copying" and "moving" cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I believe we should call register_module to register all the newly added modules.

const std::string& key,
const std::shared_ptr<Module>& value) {
return children_.insert(key, value);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for this function: we can likely use

  std::shared_ptr<Module>& insert(
      const std::string& key,
      std::shared_ptr<Module> value)

and then do std::move inside the function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I believe we should call register_module to register the newly added module.


const std::shared_ptr<Module>& operator[](const std::string& key) const {
return children_[key];
}
Copy link
Contributor

@yf225 yf225 Oct 27, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit concerning that we have these two operator[] overloads for ModuleDict but not ModuleList or Sequential (they only have std::shared_ptr<Module> operator[](size_t index) const). I wondered if it would make sense to do a few things:

  1. Add ModuleList::operator[] and Sequential::operator[] overloads for both std::shared_ptr<Module>& and const std::shared_ptr<Module>& return types
  2. Add test for lvalue assignment for Sequential / ModuleList / ModuleDict's operator[] method, e.g. module_list[0] = SomeModule()

TORCH_MODULE(ModuleDict);

} // namespace nn
} // namespace torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the data structure for storing the modules, do you think we can use torch::OrderedDict? I think AnyModule is more about supporting calling forward on a module without needing to know its concrete type, and we don't have this requirement for ModuleDict.

private:
// Friend classes.

friend class ModuleDictImpl;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a specific reason we need to add this?

OrderedDict<std::string, std::shared_ptr<Module>>::ConstIterator;

ModuleDictImpl() = default;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be awesome to add a ModuleDictImpl constructor that takes a torch::OrderedDict as input, so that we can do something similar to:

torch::OrderedDict<torch::nn::Module> modules = {{"a", ModuleA()}, {"b", ModuleB()}};
auto moduledict = torch::nn::ModuleDict(modules);

@ezyang ezyang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 3, 2020
facebook-github-bot pushed a commit that referenced this pull request Jun 29, 2020
Summary:
This diff contains the implementation of C++ api for ParameterDict from #25883, refer to  #36904 and #28652
Pull Request resolved: #40654

Test Plan: Add unit test in this diff

Differential Revision: D22273265

Pulled By: glaringlee

fbshipit-source-id: 9134a92c95eacdd53d5b24470d5f7edbeb40a488
@meganset
Copy link
Contributor

Is torch::nn::ModuleDict still being added to the C++ api?

Thanks.

@meganset meganset mentioned this pull request Oct 6, 2020
@facebook-github-bot
Copy link
Contributor

Hi @ShahriarSS!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label May 28, 2022
@github-actions github-actions bot closed this Jun 27, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

open source Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants