-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[C++ API parity] Smooth L1 loss #27661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for the awesome work @CarMiranda! I left some minor comments.
| /// 'none': no reduction will be applied, 'mean': the sum of the output will | ||
| /// be divided by the number of elements in the output, 'sum': the output will | ||
| /// be summed. Default: 'mean' | ||
| TORCH_ARG(Reduction::Reduction, reduction) = Reduction::Mean; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For options that only takes one argument, It would be awesome to use a similar design as https://github.com/pytorch/pytorch/pull/27435/files#diff-1e16803d96151e8f6eb7a327394b6584R88-R94 (i.e. adding an implicit constructor and using the constructor to set the default for reduction).
| inline Tensor smooth_l1_loss( | ||
| const Tensor& input, | ||
| const Tensor& target, | ||
| const SmoothL1LossOptions& options) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be awesome to give a default value {} to options (and add a test case for the "no options" use case), similar to the other PRs :D
| } | ||
|
|
||
| Tensor SmoothL1LossImpl::forward(const Tensor& input, const Tensor& target) { | ||
| return torch::smooth_l1_loss(input, target, options.reduction()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be awesome to call the functional form to match the Python version even better:
| return torch::smooth_l1_loss(input, target, options.reduction()); | |
| return F::smooth_l1_loss(input, target, options); |
| } | ||
| } else { | ||
| std::vector<Tensor> expanded_tensors = torch::broadcast_tensors({input, target}); | ||
| ret = torch::smooth_l1_loss(expanded_tensors[0], expanded_tensors[1], options.reduction()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it would be awesome to align the indentation with the first branch of the if block :D
| const Tensor& input, | ||
| const Tensor& target, | ||
| const SmoothL1LossOptions& options) { | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we could remove the new line here
| const torch::nn::SmoothL1LossOptions& options_) : options(options_) {} | ||
|
|
||
| void SmoothL1LossImpl::pretty_print(std::ostream& stream) const { | ||
| stream << "torch::nn::SmoothL1Loss(reduction=" << options.reduction() << ")"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be awesome to add a test for this as well :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As part of the pretty_print test, we can also test that SmoothL1Loss(/*reduction=*/Reduction::Mean) (i.e. the implicit constructor of SmoothL1LossOptions) works :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
EDIT: Actually, I am just now realizing (while working with the python API) that reductions are never printed... So I guess I will just remove this.
Hi @yf225, first of all, thanks for the reviews!
I had deliberately left the pretty_print tests out, since there were no such tests for the L1Loss, which has the same and unique parameter (reduction) as SmoothL1Loss, MultiMarginLoss and SoftMarginLoss. I guessed this was due to #25883, where you pointed out
there are still a lot of preparation work that needs to be done (e.g. fixing how we handle enums in C++ API #15149).
and these losses use the Reduction::Reduction enum... Should I use a switch statement like the one below?
void SmoothL1LossImpl::pretty_print(std::ostream& stream) const {
stream << "torch::nn::SmoothL1Loss(reduction=";
switch(options.reduction()) {
case Reduction::None:
stream << "Reduction::None";
break;
case Reduction::Mean:
stream << "Reduction::Mean";
break;
case Reduction::Sum:
stream << "Reduction::Sum";
break;
}
stream << ")";
}
test/cpp/api/functional.cpp
Outdated
| auto input = torch::tensor({{0.1}, {1.2}, {4.7}}, torch::requires_grad()); | ||
| auto target = torch::tensor({{0.}, {1.}, {5.}}, torch::kFloat); | ||
| auto output = | ||
| F::smooth_l1_loss(input, target, SmoothL1LossOptions()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can remove SmoothL1LossOptions() in this line to test the "no options" use case :D
| const torch::nn::SmoothL1LossOptions& options_) : options(options_) {} | ||
|
|
||
| void SmoothL1LossImpl::pretty_print(std::ostream& stream) const { | ||
| stream << "torch::nn::SmoothL1Loss(reduction=" << options.reduction() << ")"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As part of the pretty_print test, we can also test that SmoothL1Loss(/*reduction=*/Reduction::Mean) (i.e. the implicit constructor of SmoothL1LossOptions) works :D
test/cpp/api/modules.cpp
Outdated
| } | ||
|
|
||
| TEST_F(ModulesTest, SmoothL1LossNoReduction) { | ||
| SmoothL1Loss loss(SmoothL1LossOptions().reduction(Reduction::None)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could write this as
| SmoothL1Loss loss(SmoothL1LossOptions().reduction(Reduction::None)); | |
| SmoothL1Loss loss(/*reduction=*/Reduction::None); |
to also test the implicit conversion for SmoothL1LossOptions :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @yf225! Taking your previous reviews into account, I wrote these 4 tests for each single-optional-parameter-loss module/functional pair, covering explicit instantiation (in modules.cpp, *DefaultOptions), explicit instantiation + changing parameter value (in modules.cpp, *NoReduction, this test case), omitting optional parameter for functional (in functional.cpp, *DefaultOptions) and implicit instantiation as a parameter in functional (in functional.cpp, *NoReduction).
Do you see any other test that should be run for these modules/functionals and their options? (I ask this here because you already corrected the others hehe)
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@CarMiranda Thanks so much for the great work! I think the tests we have now are perfect, and I will merge this after CI passes :D
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
I will rebase this on top of #28523 after it's merged, to get around the gcc warning problem. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
1 similar comment
|
@CarMiranda Wondering are you currently working on the RNNCell / LSTMCell / GRUCell modules? Thanks a lot for your help! |
In accordance with #25883, I added the
SmoothL1Lossmodule andsmooth_l1_lossfunctional.