Skip to content

Conversation

@lezcano
Copy link
Collaborator

@lezcano lezcano commented Sep 7, 2021

Stack from ghstack:

Fixes #64256
It also fixes an inconsistent treatment of the case reduction = "mean"
when the whole target is equal to ignore_index. It now returns NaN
in this case, consistently with what it returns when computing the mean
over an empty tensor.

We add tests for all these cases.

Differential Revision: D31116297

cc @ezyang @gchanan

It also fixes an inconsistent treatment of the case `reduction = "mean"`
when the whole target is equal to `ignore_index`. It now returns `NaN`
in this case, consistently with what it returns when computing the mean
over an empty tensor.

We add tests for all these cases.

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Sep 7, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 3594ae9 (more details on the Dr. CI page):


  • 3/3 failures introduced in this PR

🕵️ 3 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_linux_xenial_py3_clang7_asan_test2 (1/3)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

Sep 30 12:13:18 SUMMARY: UndefinedBehaviorSanit.../jenkins/workspace/aten/src/ATen/Utils.cpp:20:3 in
Sep 30 12:13:18     #9 0x557157cfb8f2 in PyEval_EvalCode /home/builder/ktietz/cos6/ci_cos6/python_1622833237666/work/Python/ceval.c:731
Sep 30 12:13:18     #10 0x557157d63cd5 in run_mod /home/builder/ktietz/cos6/ci_cos6/python_1622833237666/work/Python/pythonrun.c:1025
Sep 30 12:13:18     #11 0x557157d65d5d in PyRun_StringFlags /home/builder/ktietz/cos6/ci_cos6/python_1622833237666/work/Python/pythonrun.c:949
Sep 30 12:13:18     #12 0x557157d65dbb in PyRun_SimpleStringFlags /home/builder/ktietz/cos6/ci_cos6/python_1622833237666/work/Python/pythonrun.c:445
Sep 30 12:13:18     #13 0x557157d66926 in run_command /home/builder/ktietz/cos6/ci_cos6/python_1622833237666/work/Modules/main.c:301
Sep 30 12:13:18     #14 0x557157d66926 in Py_Main /home/builder/ktietz/cos6/ci_cos6/python_1622833237666/work/Modules/main.c:749
Sep 30 12:13:18     #15 0x557157ca0196 in main /home/builder/ktietz/cos6/ci_cos6/python_1622833237666/work/Programs/python.c:69
Sep 30 12:13:18     #16 0x7f18f42bc83f in __libc_start_main /build/glibc-S7Ft5T/glibc-2.23/csu/../csu/libc-start.c:291
Sep 30 12:13:18     #17 0x557157d3033d in _start (/opt/conda/bin/python3.6+0x1a733d)
Sep 30 12:13:18 
Sep 30 12:13:18 SUMMARY: UndefinedBehaviorSanitizer: undefined-behavior /var/lib/jenkins/workspace/aten/src/ATen/Utils.cpp:20:3 in 
Sep 30 12:13:18 + retcode=1
Sep 30 12:13:18 + set -e
Sep 30 12:13:18 + return 1
Sep 30 12:13:18 + [[ pytorch-linux-xenial-py3-clang7-asan-test2 == *-NO_AVX-* ]]
Sep 30 12:13:18 + [[ '' == \n\o\g\p\u\_\N\O\_\A\V\X ]]
Sep 30 12:13:18 + [[ pytorch-linux-xenial-py3-clang7-asan-test2 == *-NO_AVX2-* ]]
Sep 30 12:13:18 + [[ '' == \n\o\g\p\u\_\N\O\_\A\V\X\2 ]]
Sep 30 12:13:18 + [[ pytorch-linux-xenial-py3-clang7-asan-test2 == *-NO_AVX512-* ]]
Sep 30 12:13:18 + [[ '' == \n\o\g\p\u\_\N\O\_\A\V\X\5\1\2 ]]
Sep 30 12:13:18 ++ mktemp

See CircleCI build pytorch_xla_linux_bionic_py3_6_clang9_test (2/3)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

Sep 30 13:05:16 AssertionError: False is not tr...erence with rtol=0.001 and atol=0.001 is only nan!
Sep 30 13:05:16   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 371, in instantiated_test
Sep 30 13:05:16     result = test(self, **param_kwargs)
Sep 30 13:05:16   File "/var/lib/jenkins/workspace/xla/test/../../test/test_nn.py", line 17343, in test_nll_loss_total_weight_is_zero
Sep 30 13:05:16     helper([2, 3])
Sep 30 13:05:16   File "/var/lib/jenkins/workspace/xla/test/../../test/test_nn.py", line 17340, in helper
Sep 30 13:05:16     self.assertEqual(F.nll_loss(input, target, weight, reduction="mean").item(), float("nan"))
Sep 30 13:05:16   File "/var/lib/jenkins/workspace/xla/test/pytorch_test_base.py", line 607, in assertEqual
Sep 30 13:05:16     return DeviceTypeTestBase.assertEqual(self, x, y, *args, **kwargs)
Sep 30 13:05:16   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_utils.py", line 1956, in assertEqual
Sep 30 13:05:16     super().assertTrue(result, msg=self._get_assert_msg(msg, debug_msg=debug_msg))
Sep 30 13:05:16 AssertionError: False is not true : Scalars failed to compare as equal! Comparing 0.0 and nan gives a difference of nan, but the allowed difference with rtol=0.001 and atol=0.001 is only nan!
Sep 30 13:05:16 
Sep 30 13:05:16 ----------------------------------------------------------------------
Sep 30 13:05:16 Ran 279 tests in 866.197s
Sep 30 13:05:16 
Sep 30 13:05:16 FAILED (failures=2, skipped=160)
Sep 30 13:05:16 
Sep 30 13:05:16 Generating XML reports...
Sep 30 13:05:16 Generated XML report: test-reports/python-unittest/test.......test.test_nn/TEST-TestNNDeviceTypeXLA-20210930125050.xml
Sep 30 13:05:16 + cleanup
Sep 30 13:05:16 + retcode=1

See GitHub Actions build linux-xenial-py3.6-gcc5.4 / test (backwards_compat, 1, 1, linux.2xlarge) (3/3)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2021-09-30T15:10:54.8026130Z The PR is introduc...m to confirm whether this change is wanted or not.
2021-09-30T15:10:54.8012147Z processing existing schema:  alltoall_base(__torch__.torch.classes.dist_c10d.ProcessGroup _0, Tensor _1, Tensor _2, int[] _3, int[] _4) -> (__torch__.torch.classes.dist_c10d.Work _0)
2021-09-30T15:10:54.8013571Z processing existing schema:  alltoall(__torch__.torch.classes.dist_c10d.ProcessGroup _0, Tensor[] _1, Tensor[] _2) -> (__torch__.torch.classes.dist_c10d.Work _0)
2021-09-30T15:10:54.8014939Z processing existing schema:  send(__torch__.torch.classes.dist_c10d.ProcessGroup _0, Tensor[] _1, int _2, int _3) -> (__torch__.torch.classes.dist_c10d.Work _0)
2021-09-30T15:10:54.8016450Z processing existing schema:  recv(__torch__.torch.classes.dist_c10d.ProcessGroup _0, Tensor[] _1, int _2, int _3) -> (__torch__.torch.classes.dist_c10d.Work _0)
2021-09-30T15:10:54.8017942Z processing existing schema:  recv_anysource(__torch__.torch.classes.dist_c10d.ProcessGroup _0, Tensor[] _1, int _2) -> (__torch__.torch.classes.dist_c10d.Work _0)
2021-09-30T15:10:54.8019263Z processing existing schema:  barrier(__torch__.torch.classes.dist_c10d.ProcessGroup _0) -> (__torch__.torch.classes.dist_c10d.Work _0)
2021-09-30T15:10:54.8020382Z processing existing schema:  __init__(__torch__.torch.classes.dist_c10d.frontend _0) -> (NoneType _0)
2021-09-30T15:10:54.8021863Z processing existing schema:  new_process_group_helper(__torch__.torch.classes.dist_c10d.frontend _0, int _1, int _2, int[] _3, str _4, __torch__.torch.classes.dist_c10d.Store _5, str? _6, int _7) -> (__torch__.torch.classes.dist_c10d.ProcessGroup _0)
2021-09-30T15:10:54.8023489Z processing existing schema:  get_process_group_by_name(__torch__.torch.classes.dist_c10d.frontend _0, str _1) -> (__torch__.torch.classes.dist_c10d.ProcessGroup _0)
2021-09-30T15:10:54.8024908Z processing existing schema:  get_name_of_process_group(__torch__.torch.classes.dist_c10d.frontend _0, __torch__.torch.classes.dist_c10d.ProcessGroup _1) -> (str _0)
2021-09-30T15:10:54.8026130Z The PR is introducing backward incompatible changes to the operator library. Please contact PyTorch team to confirm whether this change is wanted or not. 
2021-09-30T15:10:54.8026770Z 
2021-09-30T15:10:54.8027082Z Broken ops: [
2021-09-30T15:10:54.8028125Z 	aten::_log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, int input_dtype) -> (Tensor)
2021-09-30T15:10:54.8029363Z 	aten::_log_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, int input_dtype, *, Tensor(a!) out) -> (Tensor(a!))
2021-09-30T15:10:54.8030759Z 	aten::_softmax_backward_data(Tensor grad_output, Tensor output, int dim, int input_dtype) -> (Tensor)
2021-09-30T15:10:54.8031853Z 	aten::_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, int input_dtype, *, Tensor(a!) grad_input) -> (Tensor(a!))
2021-09-30T15:10:54.8032437Z ]
2021-09-30T15:10:54.8032771Z + cleanup
2021-09-30T15:10:54.8033070Z + retcode=1
2021-09-30T15:10:54.8033343Z + set +x

1 job timed out:

  • pytorch_linux_xenial_py3_clang7_asan_test2

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

lezcano added a commit that referenced this pull request Sep 7, 2021
It also fixes an inconsistent treatment of the case `reduction = "mean"`
when the whole target is equal to `ignore_index`. It now returns `NaN`
in this case, consistently with what it returns when computing the mean
over an empty tensor.

We add tests for all these cases.

ghstack-source-id: a88c848
Pull Request resolved: #64572
@lezcano lezcano changed the title Fixes https://github.com/pytorch/pytorch/issues/64256 Fix nll_backward for negative weights Sep 7, 2021
@lezcano lezcano requested a review from jbschlosser September 7, 2021 16:14
@lezcano
Copy link
Collaborator Author

lezcano commented Sep 8, 2021

This has an UBsan problem that I'll fix an explicit instantiation of a quiet_NaN and also an XLA test is failing, but it's basically ready for a review @jbschlosser

Copy link
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix! :)

Since the behavior here wrt reduction='mean' is different now, it's possible that some edge cases in the wild might break.

Example: a batch full of targets that have weight=0. Before, this would return 0, but now it will return NaN. According to the math in the docs, NaN is correct, but that math is contested in #61309, and it's possible returning NaN instead of 0 would be surprising for similar reasons that the way we calculate the mean is surprising to some.

Further, if we change the behavior, we'd need to make corresponding changes in XLA and deal with the UBSAN failure, as evidenced by the failing checks. These aren't hard to do; just requires some coordination with the XLA team and creating a NaN explicitly instead of doing 0/0.

However, rather than go through with all of this, I think I'd prefer to keep this PR to the minimal fix necessary to correctly handle negative weights for now. We can possibly address the BC-breaking parts in a future PR to keep things contained. Wdyt about this?

@lezcano
Copy link
Collaborator Author

lezcano commented Sep 10, 2021

About the BC-breaking. I don't think this is BC-breaking as much as a bug fix. As you mentioned, this is the behaviour in the docs. Even more, at the moment the behaviour is inconsistent across sizes / devices, so that is not something desirable. I say we fix this.

@jbschlosser
Copy link
Contributor

@lezcano Sounds good, I'm convinced, especially in light of the fact that the current behavior is inconsistent across sizes / devices. Let's move forward with the fix.

@JackCaoG FYI it looks like we need similar changes in XLA to get the tests passing.

@JackCaoG
Copy link
Collaborator

Hi @jbschlosser , sounds good. @lezcano Feel free to make the change on the pytroch side, xla CI in this pr should fail and youcan open an issue on the pt/xla side. We will follow up.

@jbschlosser
Copy link
Contributor

Otherwise, just need the UBSAN fix and we're good to go :)

Fixes #64256
It also fixes an inconsistent treatment of the case `reduction = "mean"`
when the whole target is equal to `ignore_index`. It now returns `NaN`
in this case, consistently with what it returns when computing the mean
over an empty tensor.

We add tests for all these cases.

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Sep 21, 2021
It also fixes an inconsistent treatment of the case `reduction = "mean"`
when the whole target is equal to `ignore_index`. It now returns `NaN`
in this case, consistently with what it returns when computing the mean
over an empty tensor.

We add tests for all these cases.

ghstack-source-id: 6b859cf
Pull Request resolved: #64572
@lezcano
Copy link
Collaborator Author

lezcano commented Sep 21, 2021

@jbschlosser I fixed the UBSAN failures. Now the only CI failures remaining are those from XLA (expected) and a couple that complain about the base of this branch. This should be ready.

@JackCaoG
Copy link
Collaborator

@lezcano pt/xla CI failure seems to be related to merge conflict. any chance you can rebase and surface the true error? If you could let me know which test you expect to fail I might be able to repo the failure myself too.

Fixes #64256
It also fixes an inconsistent treatment of the case `reduction = "mean"`
when the whole target is equal to `ignore_index`. It now returns `NaN`
in this case, consistently with what it returns when computing the mean
over an empty tensor.

We add tests for all these cases.

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Sep 21, 2021
It also fixes an inconsistent treatment of the case `reduction = "mean"`
when the whole target is equal to `ignore_index`. It now returns `NaN`
in this case, consistently with what it returns when computing the mean
over an empty tensor.

We add tests for all these cases.

ghstack-source-id: e20651a
Pull Request resolved: #64572
@lezcano
Copy link
Collaborator Author

lezcano commented Sep 21, 2021

@JackCaoG rebased.

@JackCaoG
Copy link
Collaborator

@lezcano pt/xla test passed, is the new behavior not being tested by the existing test?

@lezcano
Copy link
Collaborator Author

lezcano commented Sep 22, 2021

@JackCaoG This PR adds new tests, and note that they are failing in an XLA run now.
@jbschlosser all the other runs are passing, this should be ready to go.

Copy link
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks :)

@JackCaoG
Copy link
Collaborator

@lezcano Thanks, I will take care of it.

@jbschlosser
Copy link
Contributor

@jbschlosser has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@JackCaoG
Copy link
Collaborator

@jbschlosser I will update here when the xla change is ready

@jbschlosser
Copy link
Contributor

About the BC-breaking. I don't think this is BC-breaking as much as a bug fix. As you mentioned, this is the behaviour in the docs. Even more, at the moment the behaviour is inconsistent across sizes / devices, so that is not something desirable. I say we fix this.

Hm, I'm now running into internal tests that explicitly depend on the old behavior of returning 0.0 for the "all ignored" case, which makes me hesitate again at changing this behavior. This is for an internal loss function based on NLLLoss that explicitly claims to support missing labels - returning NaNs here doesn't seem nice for this case.

If PyTorch had a nice way to indicate "no gradients" when computing the mean of an empty set, we wouldn't have to choose between NaN (breaks training) and 0.0 (allows training to continue, but not mathematically correct and may have undesirable side effects).

cc @cpuhrsch: FYI NLLLoss / CrossEntropyLoss are a couple more cases where we allow ad-hoc "masking" via ignore_index.

I do agree it's undesirable to have inconsistency across sizes / devices - do you mind pointing out the current discrepancies and maybe we could at least fix those?

@lezcano
Copy link
Collaborator Author

lezcano commented Sep 23, 2021

Note that it supports missing labels, but if you give it something with no labels at all, then the reduction should return NaN, as the mean does not have an identity element.

Also, having "no gradients" would not solve this, as the NaNs appear in the forward. The backward may return whatever, as the gradient at NaN is not defined.

@jbschlosser jbschlosser added the module: bc-breaking Related to a BC-breaking change label Sep 23, 2021
lezcano added a commit that referenced this pull request Sep 29, 2021
It also fixes an inconsistent treatment of the case `reduction = "mean"`
when the whole target is equal to `ignore_index`. It now returns `NaN`
in this case, consistently with what it returns when computing the mean
over an empty tensor.

We add tests for all these cases.

ghstack-source-id: f154576
Pull Request resolved: #64572
@albanD
Copy link
Collaborator

albanD commented Sep 29, 2021

@albanD has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@albanD
Copy link
Collaborator

albanD commented Sep 29, 2021

The xla test failure seems real. Do we need a sister PR on the xla repo?

@JackCaoG
Copy link
Collaborator

@albanD I will work on a fix. Already have a prototype, just need to fix some additional failure on the xla end.

@albanD
Copy link
Collaborator

albanD commented Sep 29, 2021

I'm not very familiar with the joint landing process.
Does that mean that this one can be landed now (or soon when internal tests are done). Or should I wait for your green light?

@JackCaoG
Copy link
Collaborator

@albanD Please give me 1~2 days, I will update here when xla pr is ready. After that I will merge the xla pr when pytorch pr is merged.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Sep 29, 2021

@albanD @lezcano For nll_loss, should grad remain 0 if the forwarding pass return nan. Here is an example

m = nn.LogSoftmax(dim=1)
loss = nn.NLLLoss(ignore_index=1)

input = torch.randn(5, 2, requires_grad=True, device=device)
target = torch.ones(5, device=device).to(torch.int64)
output = loss(m(input), target)
print(output)
output.backward()
print(input.grad)

after my change, CPU returns

tensor(nan, grad_fn=<NllLossBackward0>)
tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]])

and xla returns

tensor(nan, device='xla:0', grad_fn=<NllLossBackward0>)
tensor([[nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan],
        [nan, nan]], device='xla:0')

I can update the xla behavior but want to make sure this is intended.

@lezcano
Copy link
Collaborator Author

lezcano commented Sep 30, 2021

I would say that a grad of zero is fine, as it is the fastest one to return. The function is not differentiable at these points (it's not even well-defined) so we are free to return anything we want. As such, we can choose to return whatever's fastest for the computations.

In this case, we can return zeros if all the elements in target were ignored, and inf or nan if the sum of the weights was zero (inf if the grad_output was not zero. I'll push an update making sure that this is what we do :)

Fixes #64256
It also fixes an inconsistent treatment of the case `reduction = "mean"`
when the whole target is equal to `ignore_index`. It now returns `NaN`
in this case, consistently with what it returns when computing the mean
over an empty tensor.

We add tests for all these cases.

Differential Revision: [D31116297](https://our.internmc.facebook.com/intern/diff/D31116297)

cc ezyang gchanan

[ghstack-poisoned]
lezcano added a commit that referenced this pull request Sep 30, 2021
It also fixes an inconsistent treatment of the case `reduction = "mean"`
when the whole target is equal to `ignore_index`. It now returns `NaN`
in this case, consistently with what it returns when computing the mean
over an empty tensor.

We add tests for all these cases.

ghstack-source-id: dcb4907
Pull Request resolved: #64572
@lezcano
Copy link
Collaborator Author

lezcano commented Sep 30, 2021

For what is worth, I have updated the code for the derivatives to make it more homogeneous. It should now be obvious that all the paths in both CPU and CUDA return what was described in #64572 (comment)

@albanD
Copy link
Collaborator

albanD commented Sep 30, 2021

Thanks!

@albanD
Copy link
Collaborator

albanD commented Sep 30, 2021

@albanD has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@albanD
Copy link
Collaborator

albanD commented Sep 30, 2021

@JackCaoG ok all the internal tests are good. So ready on my end. Let me know when it's good for you!

@JackCaoG
Copy link
Collaborator

JackCaoG commented Oct 1, 2021

@albanD xla pr is ready at pytorch/xla#3144 (review). I will merge it after this pr being merged.

@albanD
Copy link
Collaborator

albanD commented Oct 1, 2021

Thanks!
Just clicked land.

@facebook-github-bot facebook-github-bot deleted the gh/Lezcano/17/head branch October 5, 2021 14:17
wconstab added a commit that referenced this pull request Oct 12, 2021
Pytorch nll behavior has changed - now, where the sum of the weight
values is zero, the output will be NaN.  Our LTC tests must be updated
accordingly.

- pytorch/pytorch changed NLL behavior
  (#64572)
- pytorch/xla introduced a corresponding change to test utils
  (pytorch/xla#3144)
pytorchmergebot pushed a commit that referenced this pull request Sep 10, 2024
This PR allows results from `nn_loss` to be `nan`, which is the same behavior as with CUDA and CPU #64572 (comment).

Fixes #134431

Ref #64572 #119108
Pull Request resolved: #135434
Approved by: https://github.com/malfet
yushangdi pushed a commit that referenced this pull request Sep 12, 2024
This PR allows results from `nn_loss` to be `nan`, which is the same behavior as with CUDA and CPU #64572 (comment).

Fixes #134431

Ref #64572 #119108
Pull Request resolved: #135434
Approved by: https://github.com/malfet
tolleybot pushed a commit to tolleybot/pytorch that referenced this pull request Sep 14, 2024
This PR allows results from `nn_loss` to be `nan`, which is the same behavior as with CUDA and CPU pytorch#64572 (comment).

Fixes pytorch#134431

Ref pytorch#64572 pytorch#119108
Pull Request resolved: pytorch#135434
Approved by: https://github.com/malfet
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
This PR allows results from `nn_loss` to be `nan`, which is the same behavior as with CUDA and CPU pytorch#64572 (comment).

Fixes pytorch#134431

Ref pytorch#64572 pytorch#119108
Pull Request resolved: pytorch#135434
Approved by: https://github.com/malfet
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants