-
Notifications
You must be signed in to change notification settings - Fork 26.3k
C++ MaxPool Module #24860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
C++ MaxPool Module #24860
Conversation
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ShahriarSS Thanks and it looks fantastic! I left some very minor comments. We might also want to wait until #23852 is merged, so that MaxPool can be the first module that achieves Python/C++ API parity =D
| TORCH_ARG(ExpandingArray<D>, padding) = 0; | ||
|
|
||
| /// a parameter that controls the stride of elements in the window | ||
| TORCH_ARG(ExpandingArray<D>, dialation) = 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: dialation -> dilation
|
|
||
| /// Applies maxpool over a 1-D input. | ||
| /// See https://pytorch.org/docs/master/nn.html#torch.nn.MaxPool1d to learn | ||
| /// about the exact behavior of this module. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might want to write the formulas here similar to https://pytorch.org/docs/stable/nn.html#maxpool1d (we can either do it in this PR or in a follow-up PR).
test/cpp/api/modules.cpp
Outdated
| torch::Tensor s = y.sum(); | ||
|
|
||
| s.backward(); | ||
| std::cout << y.sizes() << std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: need to remove print statement
| std::cout << y.sizes() << std::endl; | ||
| ASSERT_EQ(y.ndimension(), 3); | ||
| ASSERT_EQ(s.ndimension(), 0); | ||
| ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 2})); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Besides checking ndimension() and sizes(), can we also check that the values of y are what we expected? (we might need to set value of x in a specific way in order to test this)
| s.backward(); | ||
| ASSERT_EQ(y.ndimension(), 3); | ||
| ASSERT_EQ(s.ndimension(), 0); | ||
| ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 2, 2})); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto for checking value of y.
test/cpp/api/modules.cpp
Outdated
| ASSERT_EQ(s.ndimension(), 0); | ||
| for (auto i = 0; i < 3; i++) { | ||
| ASSERT_EQ(y.size(i), 2); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto for checking value of y.
test/cpp/api/modules.cpp
Outdated
| ASSERT_EQ(s.ndimension(), 0); | ||
| for (auto i = 0; i < 4; i++) { | ||
| ASSERT_EQ(y.size(i), 2); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto for checking value of y.
|
@yf225 There were some problems and one test wasn't implemented. I fixed them so please take a look. |
|
@pytorchbot rebase this please |
|
There's nothing to do! This branch is already up to date with master (4fac61a). (To learn more about this bot, see Bot commands.) |
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @ShahriarSS ! I will add the parity test in a follow-up PR.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary:
This PR makes the following improvements to C++ API parity test harness:
1. Remove `options_args` since we can get the list of options from the Python module constructor args.
2. Add test for mapping `int` or `tuple` in Python module constructor args to `ExpandingArray` in C++ module options.
3. Use regex to split up e.g. `(1, {2, 3}, 4)` into `['1', '{2, 3}', '4']` for `cpp_default_constructor_args`.
4. Add options arg accessor tests in `_test_torch_nn_module_ctor_args`.
We will be able to merge #24160 and #24860 after these improvements.
Pull Request resolved: #25828
Differential Revision: D17266197
Pulled By: yf225
fbshipit-source-id: 96d0d4a2fcc4b47cd1782d4df2c9bac107dec3f9
|
@pytorchbot rebase this please |
| options.stride_, | ||
| options.padding_, | ||
| options.dilation_, | ||
| options.ceil_mode_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ShahriarSS We might be missing options.return_indices_ here (and the forward calls for 2d and 3d), I will push a commit to add it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. But I don't think that torch::max_poolxd uses it. That's why I didn't include it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes @yf225 here is the documentation:
return_indices: if ``True``, will return the max indices along with the outputs.
Useful for :class:`torch.nn.MaxUnpool2d` later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah got it, thanks a lot for the catch!
This reverts commit 2311cd5.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
cc. @glaringlee |
No description provided.