-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Fix FP16 fastAtomicAdd for one case where tensor start address is not 32 bit aligned #44642
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
jjsjann123
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch!
Codecov Report
@@ Coverage Diff @@
## master #44642 +/- ##
==========================================
- Coverage 67.98% 67.98% -0.01%
==========================================
Files 384 384
Lines 49567 49567
==========================================
- Hits 33697 33696 -1
- Misses 15870 15871 +1
Continue to review full report at Codecov.
|
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
… 32 bit aligned (#44642) Summary: For #44206 and #42218, I'd like to update trilinear interpolate backward and grid_sample backward to use `fastAtomicAdd`. As a prelude, I spotted a UB risk in `fastAtomicAdd`. I think existing code incurs a misaligned `__half2` atomicAdd when `index` is odd and `tensor` is not 32-bit aligned (`index % 2 == 1` and `(reinterpret_cast<std::uintptr_t>(tensor) % sizeof(__half2) == 1`). In this case we think we're `!low_bit` and go down the `!low_bit` code path, but in fact we are `low_bit`. It appears the original [fastAtomicAdd PR](#21879 (comment) discussion did not consider that case explicitly. I wanted to push my tentative fix for discussion ASAP. jjsjann123 and mkolod as original authors of `fastAtomicAdd`. (I'm also curious why we need to `reinterpret_cast<std::uintptr_t>(tensor...` for the address modding, but that's minor.) Pull Request resolved: #44642 Reviewed By: mruberry Differential Revision: D23699820 Pulled By: ngimel fbshipit-source-id: 0db57150715ebb45e6a1fb36897e46f00d61defd
For #44206 and #42218, I'd like to update trilinear interpolate backward and grid_sample backward to use
fastAtomicAdd.As a prelude, I spotted a UB risk in
fastAtomicAdd. I think existing code incurs a misaligned__half2atomicAdd whenindexis odd andtensoris not 32-bit aligned (index % 2 == 1and(reinterpret_cast<std::uintptr_t>(tensor) % sizeof(__half2) == 1). In this case we think we're!low_bitand go down the!low_bitcode path, but in fact we arelow_bit. It appears the original fastAtomicAdd PR's discussion did not consider that case explicitly.I wanted to push my tentative fix for discussion ASAP. @jjsjann123 and @mkolod as original authors of
fastAtomicAdd. (I'm also curious why we need toreinterpret_cast<std::uintptr_t>(tensor...for the address modding, but that's minor.)