-
Notifications
You must be signed in to change notification settings - Fork 26.3k
C++ API parity: Upsample #28413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
C++ API parity: Upsample #28413
Conversation
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jon-tow Thanks so much for the awesome work and I really appreciated it. :D I left some minor comments
I feel that we can probably relax the requirement on this and allow using |
…uild_variables.py`
|
As always, thank you for the help and guidance @yf225 :) . pytorch/torch/csrc/api/src/nn/modules/upsampling.cpp Lines 25 to 35 in 276873b
|
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jon-tow Thanks so much for the awesome work! I left some minor comments :D It would be awesome if we could mark Implementation Parity as YES for for Upsample in test/cpp_api_parity/parity-tracker.md as well. Thanks again!
| InterpolateOptions() | ||
| .size(options.size()) | ||
| .scale_factor(options.scale_factor()) | ||
| .mode(decltype(InterpolateOptions().mode())(options.mode())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious does .mode(options.mode()) work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem here is that UpsampleOptions::mode_t and InterpolateOptions::mode_t differ by one variant member: torch::kArea (in Python "area"). According to the Python torch.nn.Upsample documentation, "area" is not an acceptable mode option, as "area" is commonly used as a decimation/downsampling mechanism. It happens to be passable because it is not checked in the module's __init__.
Would you like me to add torch::kArea as a variant member in Upsample::mode_t? This would make your suggestion possible. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah that makes sense, thanks for the catch! One thing that's a bit worrying to me is that I am not sure if the conversion from UpsampleOptions::mode_t to InterpolateOptions::mode_t is always well-defined :( Maybe we can create a temporary variable and use if-branching to set the value to be passed into F::interpolate explicitly :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: The TORCH_ARG-macro carries the private: access modifier over to the next line so that any following non-TORCH_ARG declaration gets "secretly" privatized. In this case, it hides
typedef c10::variant<...> mode_t (and similarly other option variants such as the reduction_ts) from users.
I had to place the InterpolateOptions::mode_t type alias declaration before any TORCH_ARG to expose the type publicly. This allowed me to create a temporary InterpolateOptions::mode_t mode; and ergonomically make the necessary update based on your suggestion.
Just thought I'd mention it because it may be useful to let users have access to these types: modes, reductions, etc. :D. Otherwise, decltype seems to be the only way to get this type information. Let me know what you think! :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for the catch! Yes I think we should establish the convention throughout the codebase that typedef c10::variant<...> mode_t should always happen in the first line of any Options that uses c10::variant (It is not enforced everywhere in the codebase right now, I opened a PR for PadOptions in #28760) :D
|
We might need to resolve the conflict with the changes in upstream |
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jon-tow Thanks so much for the awesome work! :D
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Based on the discussion in #28413 (comment), putting anything that's not tagged as `public:` under a `TORCH_ARG` line would hide it under `private:`. To get around this problem, we should move the `mode_t` declaration at the top of the PadOptions declaration. Pull Request resolved: #28760 Differential Revision: D18165117 Pulled By: yf225 fbshipit-source-id: cf39c0a893822264cd6a64cd887729afcd84dbd0
|
No problem, @yf225! Thanks for your time and advice. Glad to lend a hand :) |
Adds
interpolatefunctional andUpsamplemodule support for the C++ API.Issue: #25883
Reviewer: @yf225