-
Notifications
You must be signed in to change notification settings - Fork 26.3k
softshrink nan fixes #138421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
softshrink nan fixes #138421
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138421
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New FailureAs of commit 00f402a with merge base f4ee5a2 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Fixed it for mps device as well |
|
Added test to check nan outputs nan for softshrink. Would be happy to receive some feedback on this. It's my first time contributing so any feedback is welcome |
|
cc @mikaylagawarecki do you know who would be a good reviewer? |
|
Any updates on this? |
|
Format with clang-format? The indentation is wrong compared to before. Then it's hard to identify real changes. |
|
I couldn't run the clang-format. Got: But I fixed the indentation manually. If you could point me how I can get the clang-format(which version is needed or if I'm missing something) I can run it |
|
Use clang-format 17 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why multiply?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nan * 0 -> nan. Otherwise 0
I managed to run it and only ran it on |
|
any updates? |
|
@pytorchbot rebase |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
@cyyever Can you run the workflow? |
|
bump |
|
I guess we can merge? Anything else to be done from me? |
| self_val_t0 = (self_val > lambdVec) & (self_val - lambdVec); | ||
| self_val_t1 = (self_val < -lambd_val) & (self_val + lambdVec); | ||
| self_val_t0 = ((self_val > lambdVec) | (self_val.isnan())) & (self_val - lambdVec); | ||
| self_val_t1 = ((self_val < -lambd_val) | (self_val.isnan())) & (self_val + lambdVec); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think these changes can propagate nan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code will propagate NaN values correctly. Previously, the comparison self_val > lambdVec always returned False when the input was NaN, because any comparison with NaN evaluated to False. This meant self_val - lambdVec wasn't propagating NaN values and instead defaulted to 0. The mask now will properly detect NaN inputs allowing self_val - lambdVec to return NaN (since + or - op with NaN results in NaN).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your explanation makes sense, do you know where the changes to the vectorized path are tested?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are tests which compares the output of the compiled softshrink with the functional version(with just basic ops). Functional version is here, which I also changed with multiplication to get nan value propagated:
pytorch/torch/_refs/nn/functional/__init__.py
Lines 501 to 516 in 643b337
| @aten.softshrink.default.py_impl(DispatchKey.Autograd) | |
| @register_decomposition(aten.softshrink) | |
| @out_wrapper() | |
| def softshrink(a: TensorLikeType, lambd: float = 0.5): | |
| # Formula for reference, | |
| # softshrink(x) = x - lambd if x > lambd | |
| # = x + lambd if x < -lambd | |
| # = 0 otherwise | |
| torch._check( | |
| lambd >= 0, | |
| lambda: f"lambda must be greater or equal to 0, but found to be {lambd}", | |
| ) | |
| # We implement this in one torch.where to generate better code in the backward | |
| # see https://github.com/pytorch/pytorch/pull/107052#discussion_r1293748211 | |
| # If none of the expressions pass we multiply by 0 for dealing with nans and infs | |
| return torch.where(torch.abs(a) > lambd, a - torch.sign(a) * lambd, a * 0) |
From what I understood from reading the code, the functional version is run with different input shapes as well as different input devices and dtypes(selected by PYTORCH_OPINFO_SAMPLE_INPUT_INDEX). Shapes are defined in core.py of opinfo:
pytorch/torch/testing/_internal/opinfo/core.py
Lines 1987 to 1999 in 723498a
| shapes = ( | |
| # tensors with no elements | |
| (0,), | |
| (1, 0, 3), | |
| # zero dim (scalar) tensor | |
| (), | |
| # small 1D tensor | |
| (20,), | |
| # medium 1D tensor | |
| (812,), | |
| # large 2D tensor | |
| (1029, 917), | |
| ) |
One such test is:
PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=14 PYTORCH_TEST_WITH_ASAN=1 PYTORCH_TEST_WITH_UBSAN=1 python test/test_ops.py TestCommonCPU.test_python_ref_torch_fallback__refs_nn_functional_softshrink_cpu_float32
or for cuda and float16(just another example):
PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=14 python test/test_ops.py TestCommonCUDA.test_python_ref_torch_fallback__refs_nn_functional_softshrink_cuda_float16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing, thank you for clarifying, and to double check, these sample inputs have nans in them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes UnaryUfuncInfo class has argument handles_complex_extremal_values which is by default set to True, extremal values are(from comment on that same line):
# whether the op correctly handles extremal values (like nan/inf)
I added printing of the inputs while running this test for example:
PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=14 PYTORCH_TEST_WITH_ASAN=1 PYTORCH_TEST_WITH_UBSAN=1 python test/test_ops.py TestCommonCPU.test_python_ref__refs_nn_functional_softshrink_cpu_float16
|
Any other comments apart the one from above? |
| self_val_t0 = (self_val > lambdVec) & (self_val - lambdVec); | ||
| self_val_t1 = (self_val < -lambd_val) & (self_val + lambdVec); | ||
| self_val_t0 = ((self_val > lambdVec) | (self_val.isnan())) & (self_val - lambdVec); | ||
| self_val_t1 = ((self_val < -lambd_val) | (self_val.isnan())) & (self_val + lambdVec); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your explanation makes sense, do you know where the changes to the vectorized path are tested?
mikaylagawarecki
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: linux-binary-manywheel / manywheel-py3_9-cuda12_1-test / test Details for Dev Infra teamRaised by workflow job |
|
Hmm, not sure what to do with this failing check: Don't think it was introduced with this PR 🤔 |
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 1 checks: linux-binary-manywheel / manywheel-py3_9-cuda12_1-test / test Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes pytorch#138385 . Currently contains fixes for cpu and cuda. Will add fixes to mps as well soon if my mac can build it from source.(Had some issues with building it on my linux pc due to limited memory) Pull Request resolved: pytorch#138421 Approved by: https://github.com/mikaylagawarecki

Fixes #138385 .
Currently contains fixes for cpu and cuda. Will add fixes to mps as well soon if my mac can build it from source.(Had some issues with building it on my linux pc due to limited memory)
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10