-
Notifications
You must be signed in to change notification settings - Fork 26.3k
fix cosine_similarity #18168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix cosine_similarity #18168
Conversation
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@pytorchbot retest this please |
|
@pytorchbot rebase this please |
ezyang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! I didn't know this was a trick you could do. Do you know how it's justified?
|
Oh, I see Sam's comment now. If you wanna be nice, put it in a comment :) |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
aten/src/ATen/native/Distance.cpp
Outdated
| return w12.div_((w1 * w2).clamp_min_(eps)); | ||
| Tensor w1 = at::sum(x1 * x1, dim); | ||
| Tensor w2 = at::sum(x2 * x2, dim); | ||
| Tensor n12 = (w1 * w2).sqrt_().clamp_min(eps); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
arguably we should use rsqrt here for better precision :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes the precision worse.
On CPU, rsqrt(x) is implemented as 1/sqrt(x) so now three operations with rounding instead of two. (There's no std::rsqrt in C++. The x86 rsqrt instructions are low-precision).
With CUDA it's a little different, but sqrt(x) is generally more precise on modern GPUs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @colesbury ! I will send a patch to fix this.
aten/src/ATen/native/Distance.cpp
Outdated
| return w12.div_((w1 * w2).clamp_min_(eps)); | ||
| Tensor w1 = at::sum(x1 * x1, dim); | ||
| Tensor w2 = at::sum(x2 * x2, dim); | ||
| Tensor n12 = (w1 * w2).rsqrt_().clamp_min(eps); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Although you probably now want to do either clamp_min(eps * eps) before rsqrt or clamp_max(1.0 / eps) after rsqrt.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah nice catch! Thanks!
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: fixes #18057 according to colesbury 's suggestion. Thanks! cc: ezyang Pull Request resolved: pytorch/pytorch#18168 Differential Revision: D14520953 Pulled By: ailzhang fbshipit-source-id: 970e6cfb482d857a81721ec1d0ee4a4df84a0450
fixes #18057 according to @colesbury 's suggestion. Thanks!
cc: @ezyang