Skip to content

Conversation

@emcastillo
Copy link
Collaborator

@emcastillo emcastillo commented Dec 12, 2019

Fixes an issue with cdist backward calculation for large inputs for the euclidean case.

The grid size when launching the kernel exceeded the 2^16 limit for the second dimension, resulting in RuntimeError: CUDA error: invalid configuration argument

Code to reproduce:

h, w, d = 800, 1216, 12
n = 133
A = torch.randn(n, d).cuda()
B = torch.randn(h, w, d).cuda()
A.requires_grad = True
B.requires_grad = True

B = B.reshape(-1, d).contiguous()
dist = torch.cdist(A, B)
loss = dist.sum()
loss.backward()

Thanks to @tkerola for the bug report, reproduction and suggesting a solution.

@emcastillo emcastillo changed the title Change cdist kernel grid to avoid CUDA error Change cdist kernel grid parameter to avoid CUDA invalid configuration error Dec 12, 2019
@tkerola
Copy link

tkerola commented Dec 12, 2019

I think this will solve #27209 as well.

Copy link

@tkerola tkerola left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some small comments.

@kostmo
Copy link
Member

kostmo commented Dec 12, 2019

💊 CircleCI build failures summary and remediations

As of commit 7ab0f56:

  • 2/2 failures introduced in this PR

Detailed failure analysis

One may explore the probable reasons each build failed interactively on the Dr. CI website.

🕵️ 2 new failures recognized by patterns

The following build failures do not appear to be due to upstream breakage:

See CircleCI build pytorch_linux_xenial_py3_6_gcc5_4_test (1/2)

Step: "Test" (full log | pattern match details)

Feb 25 03:09:12 RuntimeError: test_quantization failed!
Feb 25 03:09:12 Ran 36 tests in 57.272s 
Feb 25 03:09:12  
Feb 25 03:09:12 FAILED (errors=1, skipped=1) 
Feb 25 03:09:12  
Feb 25 03:09:12 Generating XML reports... 
Feb 25 03:09:12 Traceback (most recent call last): 
Feb 25 03:09:12   File "test/run_test.py", line 486, in <module> 
Feb 25 03:09:12     main() 
Feb 25 03:09:12   File "test/run_test.py", line 479, in main 
Feb 25 03:09:12     raise RuntimeError(message) 
Feb 25 03:09:12 RuntimeError: test_quantization failed! 
Feb 25 03:09:12 + cleanup 
Feb 25 03:09:12 + retcode=1 
Feb 25 03:09:12 + set +x 
Feb 25 03:09:12 =================== sccache compilation log =================== 
Feb 25 03:09:12 =========== If your build fails, please take a look at the log above for possible reasons =========== 
Feb 25 03:09:12 Compile requests                  7 
Feb 25 03:09:12 Compile requests executed         6 
Feb 25 03:09:12 Cache hits                        0 
Feb 25 03:09:12 Cache misses                      6 
Feb 25 03:09:12 Cache timeouts                    0 

See CircleCI build pytorch_linux_xenial_cuda10_1_cudnn7_py3_NO_AVX_NO_AVX2_test (2/2)

Step: "Test" (full log | pattern match details)

Feb 25 04:56:42 RuntimeError: test_quantization failed!
Feb 25 04:56:42 Ran 36 tests in 59.569s 
Feb 25 04:56:42  
Feb 25 04:56:42 FAILED (errors=1, skipped=1) 
Feb 25 04:56:42  
Feb 25 04:56:42 Generating XML reports... 
Feb 25 04:56:42 Traceback (most recent call last): 
Feb 25 04:56:42   File "test/run_test.py", line 486, in <module> 
Feb 25 04:56:42     main() 
Feb 25 04:56:42   File "test/run_test.py", line 479, in main 
Feb 25 04:56:42     raise RuntimeError(message) 
Feb 25 04:56:42 RuntimeError: test_quantization failed! 
Feb 25 04:56:43 + cleanup 
Feb 25 04:56:43 + retcode=1 
Feb 25 04:56:43 + set +x 
Feb 25 04:56:43 =================== sccache compilation log =================== 
Feb 25 04:56:43 =========== If your build fails, please take a look at the log above for possible reasons =========== 
Feb 25 04:56:43 Compile requests                32 
Feb 25 04:56:43 Compile requests executed       11 
Feb 25 04:56:43 Cache hits                       1 
Feb 25 04:56:43 Cache misses                    10 
Feb 25 04:56:43 Cache timeouts                   0 

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

This comment has been revised 36 times.

@emcastillo emcastillo force-pushed the fix_cdist_backward branch 3 times, most recently from 108d618 to bc353c4 Compare December 17, 2019 09:07
@ngimel ngimel self-requested a review December 20, 2019 03:06
@ngimel ngimel added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 20, 2019
Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a test for the case you are fixing. Also, this will still break for m>65K168 which is approx 32 million, but that's better than before.

@ptrblck
Copy link
Collaborator

ptrblck commented Dec 24, 2019

Similar issue in pdist as reported here: #31593 (comment)

@emcastillo Let me know, if you want to fix both methods in the same run or if I should take care of pdist.

@emcastillo
Copy link
Collaborator Author

emcastillo commented Dec 25, 2019

I am trying to write the test, but the required matrix sizes for it to fail are quite big, resulting in the test failing with an out-of-memory error when running the functional checks. How should I proceed in this case?

@ptrblck you can take care of pdist :)

@ngimel
Copy link
Collaborator

ngimel commented Dec 26, 2019

@emcastillo it looks like you are hitting #24345, and it looks like it was never resolved. You can either get back to whatever cdist implementation was before pytorch 1.2 (you'd need to add batching support, because it did not exist before pytorch 1.2) that supposedly did not use as much memory, or you can, at least for euclidian distance, let pytorch figure that backward pass itself, and call the necessary matrix multiplies (#31599), that would take care of most practical cases. Non-euclidian distances would still through an error.

@emcastillo
Copy link
Collaborator Author

Thanks for the advice!
I will try to let pytorch to do the backward pass itself for euclidian distances.
I am still pretty much new to the PyTorch codebase so I guess it will take me a while to figure it out, so I am sorry for the time it is likely going to take.

Happy new year

@ngimel
Copy link
Collaborator

ngimel commented Jan 6, 2020

Happy New Year! Similar thing is done for adaptive_avg_pooling https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp#L324-L340 - in some cases it can have device-independent differentiable implementation, and in this case it does not have device-specific dispatch and gradient formula, and in the general case _adaptive_avg_pool2d has device-specific dispatch and device-specific backward formulas (also look in native_functions.yaml for adaptive_avg_pool2d and _adaptive_avg_pool2d).

@Zhaoyi-Yan
Copy link

Zhaoyi-Yan commented Jan 6, 2020

I met this problem in pytorch v1.3.1. For my case, I need to compute the similarity between a matrix (N, 2500, 256) with another matrix (N, 2500, 256), does this will deal with it ? Also I am not sure whether it will be cherry-picked in v1.4. Or there exists some workaround for this.

Edit: N is a small number, eg. 8.

@emcastillo
Copy link
Collaborator Author

@ngimel, sorry for the delay I can finally start working on this.

What I understood from reading the code and the links you pointed me out is that I need to define a generic new function for cdist and this function should be registered in aten/src/ATen/native/native_functions.yaml without dispatch logic and it should not be defined in tools/autograd/derivatives.yaml

This function should be the main one that is called when cdist is executed and call the current cdist for non-euclidean distances which have backward mapped and autograd should call the specific implementations later, or the current matrix-mult based approach which won't have a backward function defined so autograd can do itself the backward pass.

Please correct me if I am wrong (which I most likely am 😂)

@ngimel
Copy link
Collaborator

ngimel commented Feb 24, 2020

@emcastillo please address @ailzhang's comment and rebase. Thanks!

@ailzhang
Copy link
Contributor

ailzhang commented Feb 25, 2020 via email

@emcastillo
Copy link
Collaborator Author

emcastillo commented Feb 25, 2020

Check removed and rebased!
Thanks for all the help!!

@emcastillo
Copy link
Collaborator Author

@ngimel I think that the failures are not related to my changes. Can you confirm please?
Thanks

Copy link
Contributor

@ailzhang ailzhang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ngimel
Copy link
Collaborator

ngimel commented Feb 25, 2020

Thanks, xla failures are probably not related to your changes, but @ailzhang would know more.

@ailzhang
Copy link
Contributor

@ngimel The current failed tests are quantization tests ;)

@emcastillo emcastillo deleted the fix_cdist_backward branch February 26, 2020 02:42
@facebook-github-bot
Copy link
Contributor

@ngimel merged this pull request in a836c4c.

hczhu pushed a commit that referenced this pull request Feb 28, 2020
Summary:
Fixes an issue with `cdist` backward calculation for large inputs for the euclidean case.

The grid size when launching the kernel exceeded the 2^16 limit for the second dimension, resulting in `RuntimeError: CUDA error: invalid configuration argument`

Code to reproduce:

```
h, w, d = 800, 1216, 12
n = 133
A = torch.randn(n, d).cuda()
B = torch.randn(h, w, d).cuda()
A.requires_grad = True
B.requires_grad = True

B = B.reshape(-1, d).contiguous()
dist = torch.cdist(A, B)
loss = dist.sum()
loss.backward()
```

Thanks to tkerola for the bug report, reproduction and suggesting a solution.
Pull Request resolved: #31167

Differential Revision: D20035605

Pulled By: ngimel

fbshipit-source-id: ae28ba4b549ee07a8bd937bb1de2438dc24eaa17
ttumiel pushed a commit to ttumiel/pytorch that referenced this pull request Mar 4, 2020
Summary:
Fixes an issue with `cdist` backward calculation for large inputs for the euclidean case.

The grid size when launching the kernel exceeded the 2^16 limit for the second dimension, resulting in `RuntimeError: CUDA error: invalid configuration argument`

Code to reproduce:

```
h, w, d = 800, 1216, 12
n = 133
A = torch.randn(n, d).cuda()
B = torch.randn(h, w, d).cuda()
A.requires_grad = True
B.requires_grad = True

B = B.reshape(-1, d).contiguous()
dist = torch.cdist(A, B)
loss = dist.sum()
loss.backward()
```

Thanks to tkerola for the bug report, reproduction and suggesting a solution.
Pull Request resolved: pytorch#31167

Differential Revision: D20035605

Pulled By: ngimel

fbshipit-source-id: ae28ba4b549ee07a8bd937bb1de2438dc24eaa17
@connorlee77
Copy link

How can I update my version of torch to get this change?

@ngimel
Copy link
Collaborator

ngimel commented Mar 4, 2020

You can get nightly packages following instructions on pytorch.org.

@RuABraun
Copy link

RuABraun commented Feb 2, 2021

I'm still getting an error with pytorch 1.7.1 and nightly

Hard to reproduce though (it definitely is because of cdist, but I can't reproduce it when I create a 5-line example)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.