-
Notifications
You must be signed in to change notification settings - Fork 26.3k
cdist performance improvement for euclidean distance #25799
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Awesome! Just for completeness, here is a self contained version of our current distance function in gpytorch, which is about 2x faster than the original snippet I posted in that issue: def fast_cdist(x1, x2):
adjustment = x1.mean(-2, keepdim=True)
x1 = x1 - adjustment
x2 = x2 - adjustment # x1 and x2 should be identical in all dims except -2 at this point
# Compute squared distance matrix using quadratic expansion
# But be clever and do it with a single matmul call
x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
x1_pad = torch.ones_like(x1_norm)
x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)
x2_pad = torch.ones_like(x2_norm)
x1_ = torch.cat([-2. * x1, x1_norm, x1_pad], dim=-1)
x2_ = torch.cat([x2, x2_pad, x2_norm], dim=-1)
res = x1_.matmul(x2_.transpose(-2, -1))
# Zero out negative values
res.clamp_min_(1e-30).sqrt_()
return resThis uses the same basic strategy as the original method I'd written, but does everything in a single That being said, it has some drawbacks for inclusion upstream in pytorch proper over my original snippet. Namely (1) the additional working memory overhead, and (2) it's the same basic strategy, so a proper implementation of the first snippet I'd linked may erase the gap anyways. |
@jacobrgardner Thanks, let me check the performance metrics. |
|
Great work! FYI, I just added a reference to an implementation for the batched version in #15253 (comment) cc: @lerks |
aten/src/ATen/native/Distance.cpp
Outdated
| } else { | ||
| result = at::baddbmm(x2_norm.transpose(-2, -1), x1, x2.transpose(-2, -1), 1, -2); | ||
| } | ||
| result.add_(x1_norm); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implementations in #15253 had a clamp call filtering negative values, and it is necessary, it is possible to get negative distances with this approach because of floating point accuracy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, I changed this
`torch.cdist` has been a pain for a long time, it's buggy and slow. A more fundamental issue is that we use `torch.cdist(x1, x2).pow(2)` in the cdist code path: https://github.com/cornellius-gp/gpytorch/blob/master/gpytorch/kernels/kernel.py#L35 While the squared distance is differentiable if the points are the same, this is not true for the distance. But autograd doesn't know that, so it can completely mess up things if either there are repeated points, or if we test at a train point. For now, let's kill it for now until there is a reasonable implementation (including of the squared distances) available. There is work on improving perf here: pytorch/pytorch#25799, will add a request for returning the squared distances.
|
@ifedan, can we also add a While the squared distance is differentiable at zero, this is not true for the distance. But autograd doesn't know that, so doing |
`torch.cdist` has been a pain for a long time, it's buggy and slow. A more fundamental issue is that we use `torch.cdist(x1, x2).pow(2)` in the cdist code path: https://github.com/cornellius-gp/gpytorch/blob/master/gpytorch/kernels/kernel.py#L35 While the squared distance is differentiable if the points are the same, this is not true for the distance. But autograd doesn't know that, so it can completely mess up things if either there are repeated points, or if we test at a train point. For now, let's kill it for now until there is a reasonable implementation (including of the squared distances) available. There is work on improving perf here: pytorch/pytorch#25799, will add a request for returning the squared distances.
|
@pytorchbot retest this please |
|
@ifedan any thoughts on being able to return the squared (or more generally That is have an option to return This is not just as easy as squaring the result of cdist, as cdist is not differentiable at zero. |
|
Did you move cdist to functional.py because optional kwargs are not handled well in native_functions? |
houseroad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding an optional parameter with default value is backward compatible. Looks good to me.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@pytorchbot retest this please |
Summary: jacobrgardner pytorch/pytorch#15253 (comment) preposed a way to speedup euclidean distance calculation. This PR is implementation of this solution for normal and batch version. Also simonepri provided performance metrics pytorch/pytorch#15253 (comment)  Current implementation has speedup comparing to jacobrgardner approach  Pull Request resolved: pytorch/pytorch#25799 Differential Revision: D17964982 Pulled By: ifedan fbshipit-source-id: bf7bd0dbfca51fd39e667da55139347480f30a2f
Summary: jacobrgardner pytorch#15253 (comment) preposed a way to speedup euclidean distance calculation. This PR is implementation of this solution for normal and batch version. Also simonepri provided performance metrics pytorch#15253 (comment)  Current implementation has speedup comparing to jacobrgardner approach  Pull Request resolved: pytorch#25799 Differential Revision: D17964982 Pulled By: ifedan fbshipit-source-id: bf7bd0dbfca51fd39e667da55139347480f30a2f
|
What do you guys think about https://stackoverflow.com/questions/61241523/replacing-torch-cdist-function-to-eliminate-gpu-out-of-memory-runtime-error ? |
|
This (or very similar) was implemented in #31167 |
|
@ngimel are you saying that I should install the nightly version of pytorch for #31167 ? @jacobrgardner As for #25799 (comment) , how to replace
|
|
You can install either nightly or 1.5. |
@jacobrgardner #15253 (comment) preposed a way to speedup euclidean distance calculation. This PR is implementation of this solution for normal and batch version.
Also @simonepri provided performance metrics #15253 (comment)

Current implementation has speedup comparing to @jacobrgardner approach
