-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[C++ API parity] Multi-Label Soft Margin loss #27669
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the fantastic work @CarMiranda! I left some very minor comments.
| // ============================================================================ | ||
|
|
||
| MultiLabelSoftMarginLossImpl::MultiLabelSoftMarginLossImpl( | ||
| const torch::nn::MultiLabelSoftMarginLossOptions& options_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clang-tidy seems to complain about this line, and we can do the following to get around it:
| const torch::nn::MultiLabelSoftMarginLossOptions& options_) | |
| const torch::nn::MultiLabelSoftMarginLossOptions& options_) // NOLINT(modernize-pass-by-value) |
|
@pytorchbot rebase this please |
| /// Creates a criterion that optimizes a multi-label one-versus-all | ||
| /// loss based on max-entropy, between input :math:`x` and target :math:`y` of size | ||
| /// :math:`(N, C)`. | ||
| struct TORCH_API MultiLabelSoftMarginLossImpl : Module { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My sincere apologies for missing this earlier: for all torch::nn layers we will need to subclass it from Cloneable in the following way:
| struct TORCH_API MultiLabelSoftMarginLossImpl : Module { | |
| struct TORCH_API MultiLabelSoftMarginLossImpl : public torch::nn::Cloneable<MultiLabelSoftMarginLossImpl> { |
otherwise module->clone() won't work on them. I think for all currently open PRs we should change this, and I will fix the modules that are already in the codebase (e.g. L1Loss).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, no problem! Thanks for the review!
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: In accordance with pytorch#25883, I added the `MultiLabelSoftMarginLoss` module and `multilabel_soft_margin_loss` functional. It looks like there isn't a C++ ATen implementation of `multilabel_soft_margin_loss`, so I translated the python version, which does not rely on a C/C++ backend either. Pull Request resolved: pytorch#27669 Differential Revision: D17907608 Pulled By: yf225 fbshipit-source-id: ccb02951e009973c2adbe604593ce929f10c39eb
In accordance with #25883, I added the
MultiLabelSoftMarginLossmodule andmultilabel_soft_margin_lossfunctional.It looks like there isn't a C++ ATen implementation of
multilabel_soft_margin_loss, so I translated the python version, which does not rely on a C/C++ backend either.