-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Adding MSELoss, KLDivLoss and BCELoss to C++ front-end #27156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ShahriarSS Thanks a lot for the great work and my sincere apologies for the delay. I left some comments.
| // ============================================================================ | ||
|
|
||
| /// Options for a KLDiv loss module. | ||
| using KLDivLossOptions = L1LossOptions; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should define all the new loss options explicitly instead of aliasing L1LossOptions, because L1LossOptions could take a different default reduction value or get a new options arg one day, and we might not want those changes to affect the other loss options.
| const Tensor& target, | ||
| const MSELossOptions& options) { | ||
| return torch::mse_loss(self, target, options.reduction()); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be awesome to add tests for the new functionals as well :D
| // ============================================================================ | ||
|
|
||
| /// Options for a BCE loss module. | ||
| using BCELossOptions = L1LossOptions; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think BCELossOptions can take weight and reduction as constructor args, so it's not strictly equivalent to L1LossOptions.
| } | ||
|
|
||
| Tensor KLDivLossImpl::forward(const Tensor& input, const Tensor& target) { | ||
| return torch::kl_div(input, target, options.reduction()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should call F::kl_div here
| } | ||
|
|
||
| Tensor MSELossImpl::forward(const Tensor& input, const Tensor& target) { | ||
| return torch::mse_loss(input, target, options.reduction()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto for F::mse_loss
| namespace functional { | ||
|
|
||
| inline Tensor l1_loss( | ||
| const Tensor& self, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: input instead of self, to match Python version better :)
| } | ||
|
|
||
| inline Tensor kl_div( | ||
| const Tensor& self, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto: self -> input
| } | ||
|
|
||
| inline Tensor mse_loss( | ||
| const Tensor& self, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto: self -> input
| inline Tensor hinge_embedding_loss( | ||
| const Tensor& x1, | ||
| const Tensor& x2, | ||
| const Tensor& self, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto: self -> input
| const Tensor& x2, | ||
| const Tensor& self, | ||
| const Tensor& target, | ||
| const HingeEmbeddingLossOptions& options) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto for options = {}
Summary:
C++ API `Module::register_parameter` should accept undefined Tensor as parameter, which is equivalent to `module.register_parameter("param", None)` in Python API.
This unblocks #26082 and #27156.
Pull Request resolved: #27948
Differential Revision: D17931739
Pulled By: yf225
fbshipit-source-id: 21bdfc88e66e3dc39f3caf608a6a3de48c510fa9
|
@pytorchbot rebase this please |
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ShahriarSS Thanks a lot for the awesome work! I left some minor comments.
|
|
||
| /// Creates a criterion that measures the Binary Cross Entropy | ||
| /// between the target and the output. | ||
| struct TORCH_API BCELossImpl : Module { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think all Impl classes need to subclass from public Cloneable<ImplName> and implement the void reset() override method, otherwise module->clone() won't work on them.
| BCELossOptions( | ||
| Tensor weight = {}, | ||
| Reduction::Reduction reduction = Reduction::Mean) | ||
| : weight_(weight), reduction_(reduction) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think as a convention we should only provide a non-default constructor when the options has non-optional arguments or when the options has only one argument. For BCELossOptions I think we can follow the design of HingeEmbeddingLossOptions by providing defaults to weight and reduction and removing the non-default constructor.
| } | ||
|
|
||
| void BCELossImpl::pretty_print(std::ostream& stream) const { | ||
| stream << "torch::nn::BCELoss"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be fantastic to add tests for the pretty_prints as well :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To better match the Python version, we might need torch::nn::BCELoss()
| : options(options_) {} | ||
|
|
||
| void KLDivLossImpl::pretty_print(std::ostream& stream) const { | ||
| stream << "torch::nn::KLDivLoss"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To better match the Python version, we might need torch::nn::KLDivLoss()
| MSELossImpl::MSELossImpl(const MSELossOptions& options_) : options(options_) {} | ||
|
|
||
| void MSELossImpl::pretty_print(std::ostream& stream) const { | ||
| stream << "torch::nn::MSELoss"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To better match the Python version, we might need torch::nn::MSELoss()
|
@yf225 Should we do anything in |
| // ============================================================================ | ||
|
|
||
| BCELossImpl::BCELossImpl(const BCELossOptions& options_) : options(options_) { | ||
| register_parameter("weight", options.weight()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the catch! Yes I think we should move this into reset()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also it seems that it should be a buffer, not a parameter, based on the Python version:
self.register_buffer('weight', weight)
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the awesome work @ShahriarSS!
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
…#27948) Summary: C++ API `Module::register_parameter` should accept undefined Tensor as parameter, which is equivalent to `module.register_parameter("param", None)` in Python API. This unblocks pytorch#26082 and pytorch#27156. Pull Request resolved: pytorch#27948 Differential Revision: D17931739 Pulled By: yf225 fbshipit-source-id: 21bdfc88e66e3dc39f3caf608a6a3de48c510fa9
Summary: This PR adds ```MSELoss```, ```KLDivLoss``` and ```BCELoss```. The tests for ```BCELoss``` fail with the following error: ``` unknown file: Failure C++ exception with description "autograd_meta() INTERNAL ASSERT FAILED at /home/shahriar/Contrib/pytorch/c10/core/TensorImpl.h:533, please report a bug to PyTorch. set_requires_grad is not implemented for Tensor (set_requires_grad at /home/shahriar/Contrib/pytorch/c10/core/TensorImpl.h:533) ``` Pull Request resolved: pytorch#27156 Differential Revision: D17960323 Pulled By: yf225 fbshipit-source-id: 84b8431064f2f573679c03a8d7994e3e2f81a4d1
This PR adds
MSELoss,KLDivLossandBCELoss. The tests forBCELossfail with the following error: