-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Implement C++ API torch::nn::TripletMarginLoss #27525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@PyExtreme Thanks so much for the awesome work! I left some minor comments.
| void TripletMarginLossImpl::pretty_print(std::ostream& stream) const { | ||
| stream << "torch::nn::TripletMarginLoss(p=" << options.p() << | ||
| ", margin=" << options.margin() << | ||
| ", weight=" << options.swap() << |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean to write swap= instead of weight=? (We will need std::boolalpha for printing boolean value in C++.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also it would be awesome to print in margin=..., p=..., eps=..., swap=..., reduction=... order :D
| stream << "torch::nn::TripletMarginLoss(p=" << options.p() << | ||
| ", margin=" << options.margin() << | ||
| ", weight=" << options.swap() << | ||
| ", reduction=" << options.reduction() << ")"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be awesome to add a test for the pretty_print result :D
| const Tensor& positive, | ||
| const Tensor& negative); | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can remove one extra line here
| const Tensor& anchor, | ||
| const Tensor& positive, | ||
| const Tensor& negative, | ||
| const TripletMarginLossOptions& options) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since all TripletMarginLossOptions's arguments are optional, we can provide an empty default value for options here:
| const TripletMarginLossOptions& options) { | |
| const TripletMarginLossOptions& options = {}) { |
to support the triplet_margin_loss(anchor, positive, negative) use case. (We should add a test case for that as well.)
Hi @yf225 , Here is the C++ frontend API TripletMarginLoss implementation and tests #27197 . Could you please review it?
Secondly, the tests got skipped. I ran pytest test/test_cpp_api_parity.py -k Loss -v , and the L1Loss test passed but the others were skipped...
Thanks