-
Notifications
You must be signed in to change notification settings - Fork 26.3k
add OpInfo for torch.nn.functional.nll_loss
#63854
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 93bd050 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
|
Failures on CUDA look real: Failures also happen in slow mode, i.e. setting |
|
@pmeier what does the error look like with |
|
The same as the middle part of the message above: Although here the differences are visible before the 5th decimal place. |
|
I think the failures are due to non-determinism that stems from reducing the output to a scalar. Setting |
| # (shape_2d, dict()), | ||
| # ((*shape_2d, 3, 3), dict()), | ||
| # (shape_2d, dict(weight=True)), | ||
| # (shape_2d, dict(ignore_index=1)), | ||
| # (shape_2d, dict(reduction="mean")), | ||
| # (shape_2d, dict(reduction="sum")), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Enabling any of these sample inputs leads to gradcheck failures. They all have in common that reduction="mean" is the default value and thus a reduction is performed. reduction="none" uses a different code path and works fine. cc @albanD
Codecov Report
@@ Coverage Diff @@
## master #63854 +/- ##
==========================================
- Coverage 66.85% 66.84% -0.01%
==========================================
Files 695 695
Lines 90722 90748 +26
==========================================
+ Hits 60649 60664 +15
- Misses 30073 30084 +11 |
|
Closed in favor of #64203. |
Addresses pytorch/functorch#78.
cc @albanD @mruberry @jbschlosser @VitalyFedyunin @walterddr