-
Notifications
You must be signed in to change notification settings - Fork 26.3k
C++ Fold nn module #24160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
C++ Fold nn module #24160
Conversation
|
Why does this PR have |
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the PR @ShahriarSS! Overall it looks awesome. Once #23852 lands we can add the Fold module into the API parity whitelist, and then we can merge this PR.
test/cpp/api/modules.cpp
Outdated
| ASSERT_EQ(y.size(2), 4); | ||
| ASSERT_EQ(y.size(3), 5); | ||
|
|
||
| // TODO check numel of grad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we plan to implement this part as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I just forgot about it.
|
|
||
| /// Applies fold over a 3-D input. | ||
| /// See https://pytorch.org/docs/master/nn.html#torch.nn.Fold to learn about | ||
| /// the exact behavior of this module. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The documentation here looks great. We can experiment with porting all docs (including math formulas) from https://pytorch.org/docs/stable/nn.html#fold to here in this PR, or in a later PR if you prefer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll add the formulas here. We also need to do this for the other modules already implemented.
|
@ShahriarSS Regarding the |
|
@yf225 I removed the TODO from fold test because fold doesn't have weights of any kind. I think |
|
@ShahriarSS Thanks! I am working on adding API parity test for |
|
@pytorchbot rebase this please |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
| ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/conv.cpp | ||
| ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/dropout.cpp | ||
| ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/embedding.cpp | ||
| ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/fold.cpp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ShahriarSS the Facebook internal build seems to fail, and I am thinking we might need to add the fold.cpp entry into
pytorch/tools/build_variables.py
Lines 178 to 205 in 13292ec
| torch_cpp_srcs = [ | |
| "torch/csrc/api/src/cuda.cpp", # this just forwards stuff, no real CUDA | |
| "torch/csrc/api/src/data/datasets/mnist.cpp", | |
| "torch/csrc/api/src/data/samplers/distributed.cpp", | |
| "torch/csrc/api/src/data/samplers/random.cpp", | |
| "torch/csrc/api/src/data/samplers/sequential.cpp", | |
| "torch/csrc/api/src/data/samplers/stream.cpp", | |
| "torch/csrc/api/src/jit.cpp", | |
| "torch/csrc/api/src/nn/init.cpp", | |
| "torch/csrc/api/src/nn/module.cpp", | |
| "torch/csrc/api/src/nn/modules/batchnorm.cpp", | |
| "torch/csrc/api/src/nn/modules/conv.cpp", | |
| "torch/csrc/api/src/nn/modules/dropout.cpp", | |
| "torch/csrc/api/src/nn/modules/embedding.cpp", | |
| "torch/csrc/api/src/nn/modules/functional.cpp", | |
| "torch/csrc/api/src/nn/modules/linear.cpp", | |
| "torch/csrc/api/src/nn/modules/named_any.cpp", | |
| "torch/csrc/api/src/nn/modules/rnn.cpp", | |
| "torch/csrc/api/src/optim/adagrad.cpp", | |
| "torch/csrc/api/src/optim/adam.cpp", | |
| "torch/csrc/api/src/optim/lbfgs.cpp", | |
| "torch/csrc/api/src/optim/optimizer.cpp", | |
| "torch/csrc/api/src/optim/rmsprop.cpp", | |
| "torch/csrc/api/src/optim/serialize.cpp", | |
| "torch/csrc/api/src/optim/sgd.cpp", | |
| "torch/csrc/api/src/serialize/input-archive.cpp", | |
| "torch/csrc/api/src/serialize/output-archive.cpp", | |
| ] |
Summary:
This PR makes the following improvements to C++ API parity test harness:
1. Remove `options_args` since we can get the list of options from the Python module constructor args.
2. Add test for mapping `int` or `tuple` in Python module constructor args to `ExpandingArray` in C++ module options.
3. Use regex to split up e.g. `(1, {2, 3}, 4)` into `['1', '{2, 3}', '4']` for `cpp_default_constructor_args`.
4. Add options arg accessor tests in `_test_torch_nn_module_ctor_args`.
We will be able to merge #24160 and #24860 after these improvements.
Pull Request resolved: #25828
Differential Revision: D17266197
Pulled By: yf225
fbshipit-source-id: 96d0d4a2fcc4b47cd1782d4df2c9bac107dec3f9
|
@pytorchbot rebase this please |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
No description provided.