-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add nn::Flatten to C++ Frontend #28072
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mrsalehi Thanks a lot for the great work! I left some comments.
| } | ||
|
|
||
| Tensor FlattenImpl::forward(const Tensor& input) { | ||
| return input.flatten(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In Python version, torch.nn.Flatten takes start_dim and end_dim as optional constructor arguments, and we'd need to do the same for the C++ version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We would also need to add tests for the optional constructor arguments start_dim and end_dim as well. Thanks a lot for your help!
|
@yf225 Sorry for my delay in making the changes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mrsalehi Thanks a lot for the update! I left some comments. For running the tests of a C++ module, we can run ./build/bin/test_api --gtest_filter=ModulesTest* --gtest_stack_trace_depth=10 --gmock_verbose=info in the PyTorch root folder after building PyTorch locally :D
|
|
||
| /// Options for the `Flatten` module. | ||
| struct TORCH_API FlattenOptions { | ||
| FlattenOptions(int64_t start_dim, int64_t end_dim); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think for options that has two or more optional arguments, we usually don't provide the non-default constructor (e.g. CosineEmbeddingLossOptions), and we should likely remove the constructor here
| } | ||
|
|
||
| Tensor FlattenImpl::forward(const Tensor& input) { | ||
| return torch::flatten(input, options.start_dim(), options.end_dim()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might be able to call
| return torch::flatten(input, options.start_dim(), options.end_dim()); | |
| return input.flatten(input, options.start_dim(), options.end_dim()); |
to match the Python version even better :D
| /// A placeholder for Flatten operator | ||
| class TORCH_API FlattenImpl : public Cloneable<FlattenImpl> { | ||
| public: | ||
| FlattenImpl(int64_t start_dim, int64_t end_dim) : FlattenImpl(FlattenOptions(start_dim, end_dim)) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can remove this constructor and only keep explicit FlattenImpl(const FlattenOptions& options_);
|
Hi @yf225! |
I believe this should be what we are looking for: pytorch/torch/csrc/api/include/torch/nn/modules/loss.h Lines 184 to 186 in 5fbec1f
Thanks a lot for catching the issue! Yes I think the pretty print is not consistent now and we are thinking about whether we should print the full set of options or strictly follow the Python version (which does not print the full set of options in all cases). For now I think we can print |
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mrsalehi Thanks so much for the awesome work! :D
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary:
Adds torch::nn::Flatten module support for the C++ API.
Issue: #25883
Reviewer: @yf225