-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
Context
We would like to add the following two APIs to the C++ frontend:
torch::nn::MultiMarginLoss, which is the equivalent of Python APItorch.nn.MultiMarginLoss.torch::nn::functional::multi_margin_loss, which is the equivalent of Python APItorch.nn.functional.multi_margin_loss.
Steps
- Add
torch::nn::MultiMarginLossOptionstotorch/csrc/api/include/torch/nn/options/loss.h(add this file if it doesn’t exist), which should include the following parameters (based on https://pytorch.org/docs/stable/nn.html#torch.nn.MultiMarginLoss)TORCH_ARG(int64_t, p) = 1TORCH_ARG(double, margin) = 1.0TORCH_ARG(Tensor, weight) = Tensor()TORCH_ARG(Reduction::Reduction, reduction) = Reduction::Mean- NOTE: please make sure to add the same comments for the parameters as in the Python version. For example, for parameter
pwe should sayHas a default value of 1. 1 and 2 are the only supported values.
- Add
torch::nn::functional::multi_margin_loss(...)intorch/csrc/api/include/torch/nn/functional/loss.h(add this file if it doesn’t exist). The function should have the following signature:
namespace torch {
namespace nn {
namespace functional {
inline Tensor multi_margin_loss(
const Tensor& input,
const Tensor& target,
const MultiMarginLossOptions& options = {}) {
...
}
} // namespace functional
} // namespace nn
} // namespace torch- Add
torch::nn::MultiMarginLossintorch/csrc/api/include/torch/nn/modules/loss.h(add this file if it doesn’t exist). The module’s constructor should taketorch::nn::MultiMarginLossOptionsas input (and store it in its internal options field), the module’s reset() method should registerweightas a buffer (same logic as), and the module’s forward method should have the following signature:pytorch/torch/nn/modules/loss.py
Line 20 in 5cac738
self.register_buffer('weight', weight)
Tensor forward(
const Tensor& input,
const Tensor& target) {
// Should call torch::nn::functional::multi_margin_loss with all inputs and the module's options
}-
Add test for
torch::nn::MultiMarginLossin test/cpp/api/modules.cpp. We can just check that the output value for some simple cases are as expected. -
Add test for
torch::nn::functional::multi_margin_loss(...)in test/cpp/api/functional.cpp, using essentially the same tests astorch::nn::MultiMarginLoss.
Helpful Resources
There are quite a few PRs for adding new functionals / new modules for the C++ API (the list of PRs is in #25883), which can serve as great references. Also please ping @yf225 on this issue if you encounter any problems.
How do I claim this feature request?
Please comment in this issue if you are interested in working on it.
cc @yf225