-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Implement C++ API torch::nn::MultiMarginLoss. #27424
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@CarMiranda Thanks so much for the great work! I left some very minor comments, and we should be able to merge it very soon :D
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the changes @CarMiranda! It seems that there is test failure for the newly added tests, we might need to debug it locally to see what causes the test failure.
|
@CarMiranda We can use these steps to run the tests:
|
|
It looks like the errors came from precision, I just replaced 0.3056 by 0.305556 and that did the trick! |
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the awesome work @CarMiranda! I will merge it after CI passes :D
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Hi yf225 , here is the C++ frontend API MultiMarginLoss implementation and tests pytorch#27198. Could you review it and tell me if it is okay? I am not entirely sure I used `c10::optional` correctly, but `options.weight()` resulted in a compilation error, so I went with `options.weight().value()` instead of `value_or()` to follow the logic in `torch.nn._WeightedLoss.register_buffer` (where one can pass a `None` value). Oh, and are the tests supposed to be skipped or did I do something wrong? I ran `pytest test/test_cpp_api_parity.py -k Loss -v` , and the `L1Loss` test passed but the others were skipped... Thank you for the review in any case! Pull Request resolved: pytorch#27424 Differential Revision: D17839963 Pulled By: yf225 fbshipit-source-id: f4b6012590cf22d56d42751c214df80cce717cb8
Hi @yf225 , here is the C++ frontend API MultiMarginLoss implementation and tests #27198. Could you review it and tell me if it is okay?
I am not entirely sure I used
c10::optionalcorrectly, butoptions.weight()resulted in a compilation error, so I went withoptions.weight().value()instead ofvalue_or()to follow the logic intorch.nn._WeightedLoss.register_buffer(where one can pass aNonevalue).Oh, and are the tests supposed to be skipped or did I do something wrong? I ran
pytest test/test_cpp_api_parity.py -k Loss -v, and theL1Losstest passed but the others were skipped...Thank you for the review in any case!