Skip to content

Conversation

@jjsjann123
Copy link
Collaborator

  1. faster atomicAdd trick for fp16 backward kernel
  2. better launch configs for backward kernel
  3. removed unnecessary buffer initialization for forward kernel

1. faster atomicAdd trick for fp16 backward kernel
2. better launch configs for backward kernel
3. removed unnecessary buffer initialization for forward kernel
@pytorchbot pytorchbot added module: cuda Related to torch.cuda, and CUDA support in general module: operators labels Jun 17, 2019
@jjsjann123
Copy link
Collaborator Author

image

Perf number measured similarly to #21694

  1. faster atomicAdd helps with fp16 performance a lot (comparing the backward speedup between fp16 to fp32)
  2. Updated launch configs performs better with larger NC and smaller spatial dimension.
  3. Removing unnecessary initialization helped with forward time.

@jjsjann123
Copy link
Collaborator Author

cc'ing @ngimel @ezyang

@jjsjann123
Copy link
Collaborator Author

PyLinter error on a python script that doesn't exist... Seems to be unrelated.

@ezyang ezyang requested review from ezyang and ngimel June 18, 2019 14:58
@ezyang
Copy link
Contributor

ezyang commented Jun 18, 2019

Same deal, deferring to @ngimel here

Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure grad_input is always contiguous (I think it is), in this case fast_atomic argument won't be needed.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Jun 19, 2019
Summary:
1. faster atomicAdd trick for fp16 backward kernel
2. better launch configs for backward kernel
3. removed unnecessary buffer initialization for forward kernel
Pull Request resolved: pytorch/pytorch#21879

Differential Revision: D15898680

Pulled By: ezyang

fbshipit-source-id: 1fc81e6c078f1538d82e4f36921b630499eb504f
@facebook-github-bot
Copy link
Contributor

@ezyang merged this pull request in 056a033.

facebook-github-bot pushed a commit that referenced this pull request Sep 15, 2020
… 32 bit aligned (#44642)

Summary:
For #44206 and #42218, I'd like to update trilinear interpolate backward and grid_sample backward to use `fastAtomicAdd`.

As a prelude, I spotted a UB risk in `fastAtomicAdd`.  I think existing code incurs a misaligned `__half2` atomicAdd when `index` is odd and `tensor` is not 32-bit aligned (`index % 2 == 1` and `(reinterpret_cast<std::uintptr_t>(tensor) % sizeof(__half2) == 1`). In this case we think we're `!low_bit` and go down the `!low_bit` code path, but in fact we are `low_bit`.  It appears the original [fastAtomicAdd PR](#21879 (comment) discussion did not consider that case explicitly.

I wanted to push my tentative fix for discussion ASAP. jjsjann123 and mkolod as original authors of `fastAtomicAdd`. (I'm also curious why we need to `reinterpret_cast<std::uintptr_t>(tensor...` for the address modding, but that's minor.)

Pull Request resolved: #44642

Reviewed By: mruberry

Differential Revision: D23699820

Pulled By: ngimel

fbshipit-source-id: 0db57150715ebb45e6a1fb36897e46f00d61defd
xuzhao9 pushed a commit that referenced this pull request Sep 18, 2020
… 32 bit aligned (#44642)

Summary:
For #44206 and #42218, I'd like to update trilinear interpolate backward and grid_sample backward to use `fastAtomicAdd`.

As a prelude, I spotted a UB risk in `fastAtomicAdd`.  I think existing code incurs a misaligned `__half2` atomicAdd when `index` is odd and `tensor` is not 32-bit aligned (`index % 2 == 1` and `(reinterpret_cast<std::uintptr_t>(tensor) % sizeof(__half2) == 1`). In this case we think we're `!low_bit` and go down the `!low_bit` code path, but in fact we are `low_bit`.  It appears the original [fastAtomicAdd PR](#21879 (comment) discussion did not consider that case explicitly.

I wanted to push my tentative fix for discussion ASAP. jjsjann123 and mkolod as original authors of `fastAtomicAdd`. (I'm also curious why we need to `reinterpret_cast<std::uintptr_t>(tensor...` for the address modding, but that's minor.)

Pull Request resolved: #44642

Reviewed By: mruberry

Differential Revision: D23699820

Pulled By: ngimel

fbshipit-source-id: 0db57150715ebb45e6a1fb36897e46f00d61defd
facebook-github-bot pushed a commit that referenced this pull request Dec 3, 2020
Summary:
Fixes #44206

This PR basically follows the diff in #21879 for upsampling bilinear.

For the script provided in #44206 , on my 2070 super GPU, the total timing I got (time in second)

| | non-amp | amp |
|---|---|---|
| before PR | 2.88 | 9.6 |
| after PR | 1.5 | 1.6 |

kernel time after PR
| | time | kernel |
| --- | --- | --- |
| non-amp | 0.37 ms | `void at::native::(anonymous namespace)::upsample_trilinear3d_backward_out_frame<float, float>(unsigned long, int, int, int, int, int, int, float, float, float, bool, float*, float const*) ` |
| amp | 0.61 ms | `void at::native::(anonymous namespace)::upsample_trilinear3d_backward_out_frame<c10::Half, float>(unsigned long, int, int, int, int, int, int, float, float, float, bool, c10::Half*, c10::Half const*)` |

Pull Request resolved: #48675

Reviewed By: bdhirsh

Differential Revision: D25284853

Pulled By: ngimel

fbshipit-source-id: 30f0d92e73050edd36013ce528d2e131effa3542
shaibagon pushed a commit to shaibagon/pytorch that referenced this pull request Dec 3, 2020
Summary:
Fixes pytorch#44206

This PR basically follows the diff in pytorch#21879 for upsampling bilinear.

For the script provided in pytorch#44206 , on my 2070 super GPU, the total timing I got (time in second)

| | non-amp | amp |
|---|---|---|
| before PR | 2.88 | 9.6 |
| after PR | 1.5 | 1.6 |

kernel time after PR
| | time | kernel |
| --- | --- | --- |
| non-amp | 0.37 ms | `void at::native::(anonymous namespace)::upsample_trilinear3d_backward_out_frame<float, float>(unsigned long, int, int, int, int, int, int, float, float, float, bool, float*, float const*) ` |
| amp | 0.61 ms | `void at::native::(anonymous namespace)::upsample_trilinear3d_backward_out_frame<c10::Half, float>(unsigned long, int, int, int, int, int, int, float, float, float, bool, c10::Half*, c10::Half const*)` |

Pull Request resolved: pytorch#48675

Reviewed By: bdhirsh

Differential Revision: D25284853

Pulled By: ngimel

fbshipit-source-id: 30f0d92e73050edd36013ce528d2e131effa3542
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: cuda Related to torch.cuda, and CUDA support in general open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants