Skip to content

Conversation

@jjsjann123
Copy link
Collaborator

Summary:
Update launch configs for TensorIterator gpu_reduce_kernel. Enable flexible
block dimension to improve efficiency for reduction cases with small fast
dimension.

Previously TensorIterator launches blocks with fixed 32x16 threads.
For cases like:

import torch
torch.randn(2**20, 4, device='cuda').sum(0)

The fixed launch config does handle coalesced memory access efficiently.

Updated launch configure enables flexible block dimension. Combining with
improved reduction scheme (using flexible vertical / horizontal reduction
instead of limited warp / block reduction in the old code), it ensures optimal
memory access pattern even with reduction on dimension with small stride.

Possible future improvements:

  1. Precise dynamic shared memory allocation.
  2. Using warp shuffle for vertical (block_y) reduction.

Summary:
  Update launch configs for TensorIterator gpu_reduce_kernel. Enable flexible
block dimension to improve efficiency for reduction cases with small fast
dimension.

Previously TensorIterator launches blocks with fixed 32x16 threads.
For cases like:

  import torch
  torch.randn(2**20, 4, device='cuda').sum(0)

The fixed launch config does handle coalesced memory access efficiently.

Updated launch configure enables flexible block dimension. Combining with
improved reduction scheme (using flexible vertical / horizontal reduction
instead of limited warp / block reduction in the old code), it ensures optimal
memory access pattern even with reduction on dimension with small stride.

Possible future improvements:
1. Precise dynamic shared memory allocation.
2. Using warp shuffle for vertical (block_y) reduction.
@jjsjann123
Copy link
Collaborator Author

This is my monkey benchmarking. Notice the perf gain for small fast changing dimensions.

perf

@jjsjann123
Copy link
Collaborator Author

For visibility @ngimel @umanwizard @colesbury

@umanwizard umanwizard self-requested a review January 24, 2019 21:09
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@umanwizard has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@umanwizard
Copy link
Contributor

We are getting this error on some internal builds:

aten/src/ATen/native/cuda/Reduce.cuh:34:12: error: right shift count >= width of type [-Werror=shift-count-overflow]
   n |= (n >> 32);

Copy link
Contributor

@umanwizard umanwizard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above comment

@jjsjann123
Copy link
Collaborator Author

My bad.
I thought I updated that two commits ago when I was changing the data type from int64 to int. Guess it didn't happen.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@umanwizard has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@apaszke
Copy link
Contributor

apaszke commented Jan 31, 2019

I'm somewhat worried about the slowdowns. Why are we getting rid of warp reductions? Aren't those supposed to be fast?

@jjsjann123
Copy link
Collaborator Author

jjsjann123 commented Jan 31, 2019

I'm not getting rid of warp reduction. Keeping them as necessary for the old launch config where block dimension as 32x16: https://github.com/pytorch/pytorch/pull/16224/files#diff-662693ef7b7f32fa32d7179b6614fc16R379

Renaming warp_reduce to block_x_reduce because as we have flexible block dimension. Given the cases where blockDim.x > 32, it requires a hybrid of shared memory reduction and warp reduction, because not all threads are in the same warp.

@jjsjann123
Copy link
Collaborator Author

Test failure doesn't seem to be relevant. Merging ToT to see if it goes away

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@soumith is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Feb 7, 2019
Summary:
Update launch configs for TensorIterator gpu_reduce_kernel. Enable flexible
block dimension to improve efficiency for reduction cases with small fast
dimension.

Previously TensorIterator launches blocks with fixed 32x16 threads.
For cases like:

  import torch
  torch.randn(2**20, 4, device='cuda').sum(0)

The fixed launch config does handle coalesced memory access efficiently.

Updated launch configure enables flexible block dimension. Combining with
improved reduction scheme (using flexible vertical / horizontal reduction
instead of limited warp / block reduction in the old code), it ensures optimal
memory access pattern even with reduction on dimension with small stride.

Possible future improvements:
1. Precise dynamic shared memory allocation.
2. Using warp shuffle for vertical (block_y) reduction.
Pull Request resolved: pytorch/pytorch#16224

Differential Revision: D13806753

Pulled By: soumith

fbshipit-source-id: 37e45c7767b5748cf9ecf894fad306e040e2f79f
@soumith
Copy link
Contributor

soumith commented Feb 8, 2019

just fyi @jjsjann123 this is being reverted, we are seeing "illegal memory exception" and RuntimeError: Creating MTGP constants failed. at caffe2/aten/src/THC/THCTensorRandom.cu:33 (when using torch.manual_seed) after this PR was landed in hard-to-repro internal workloads.

@jjsjann123
Copy link
Collaborator Author

I think I figured out the issue. It's me overlooking the old heuristics for inter-block vs inter-warp reduction. Surprised this doesn't get caught by CI tests.

Working on a fix. Will also add tests that covers it.

jjsjann123 added a commit to jjsjann123/pytorch that referenced this pull request Feb 13, 2019
update:
  1. global_reduce check for should_block_y_reduce first.
     This avoids the enabling global_reduce without block_y_reduce. Leading to
     accessing shared memory during global reduce without allocation.
  2. updating block_y_reduce heuristics. Improves perf on tiny tensors
  3. adding test case covering old cases where illegal memory access might occur

  TensorIterator cuda launch configs update (pytorch#16224)

    Summary:
    Update launch configs for TensorIterator gpu_reduce_kernel. Enable flexible
    block dimension to improve efficiency for reduction cases with small fast
    dimension.

    Previously TensorIterator launches blocks with fixed 32x16 threads.
    For cases like:

      import torch
      torch.randn(2**20, 4, device='cuda').sum(0)

    The fixed launch config does handle coalesced memory access efficiently.

    Updated launch configure enables flexible block dimension. Combining with
    improved reduction scheme (using flexible vertical / horizontal reduction
    instead of limited warp / block reduction in the old code), it ensures optimal
    memory access pattern even with reduction on dimension with small stride.

    Possible future improvements:
    1. Precise dynamic shared memory allocation.
    2. Using warp shuffle for vertical (block_y) reduction.
    Pull Request resolved: pytorch#16224
facebook-github-bot pushed a commit that referenced this pull request Feb 14, 2019
Summary:
update:
  1. global_reduce check for should_block_y_reduce first.
     This avoids the enabling global_reduce without block_y_reduce. Leading to
     accessing shared memory during global reduce without allocation.
  2. updating block_y_reduce heuristics. Improves perf on tiny tensors
  3. adding test case covering old cases where illegal memory access might occur

  TensorIterator cuda launch configs update (#16224)
    Update launch configs for TensorIterator gpu_reduce_kernel. Enable flexible
    block dimension to improve efficiency for reduction cases with small fast
    dimension.

    Previously TensorIterator launches blocks with fixed 32x16 threads.
    For cases like:

      import torch
      torch.randn(2**20, 4, device='cuda').sum(0)

    The fixed launch config does handle coalesced memory access efficiently.

    Updated launch configure enables flexible block dimension. Combining with
    improved reduction scheme (using flexible vertical / horizontal reduction
    instead of limited warp / block reduction in the old code), it ensures optimal
    memory access pattern even with reduction on dimension with small stride.

    Possible future improvements:
    1. Precise dynamic shared memory allocation.
    2. Using warp shuffle for vertical (block_y) reduction.
    Pull Request resolved: #16224
Pull Request resolved: #17040

Differential Revision: D14078295

Pulled By: umanwizard

fbshipit-source-id: ecc55054a5a4035e731f0196d633412225c3b06c
zdevito pushed a commit to zdevito/ATen that referenced this pull request Feb 14, 2019
Summary:
update:
  1. global_reduce check for should_block_y_reduce first.
     This avoids the enabling global_reduce without block_y_reduce. Leading to
     accessing shared memory during global reduce without allocation.
  2. updating block_y_reduce heuristics. Improves perf on tiny tensors
  3. adding test case covering old cases where illegal memory access might occur

  TensorIterator cuda launch configs update (#16224)
    Update launch configs for TensorIterator gpu_reduce_kernel. Enable flexible
    block dimension to improve efficiency for reduction cases with small fast
    dimension.

    Previously TensorIterator launches blocks with fixed 32x16 threads.
    For cases like:

      import torch
      torch.randn(2**20, 4, device='cuda').sum(0)

    The fixed launch config does handle coalesced memory access efficiently.

    Updated launch configure enables flexible block dimension. Combining with
    improved reduction scheme (using flexible vertical / horizontal reduction
    instead of limited warp / block reduction in the old code), it ensures optimal
    memory access pattern even with reduction on dimension with small stride.

    Possible future improvements:
    1. Precise dynamic shared memory allocation.
    2. Using warp shuffle for vertical (block_y) reduction.
    Pull Request resolved: pytorch/pytorch#16224
Pull Request resolved: pytorch/pytorch#17040

Differential Revision: D14078295

Pulled By: umanwizard

fbshipit-source-id: ecc55054a5a4035e731f0196d633412225c3b06c
zou3519 added a commit to zou3519/pytorch that referenced this pull request Mar 12, 2019
Summary:
update:
  1. global_reduce check for should_block_y_reduce first.
     This avoids the enabling global_reduce without block_y_reduce. Leading to
     accessing shared memory during global reduce without allocation.
  2. updating block_y_reduce heuristics. Improves perf on tiny tensors
  3. adding test case covering old cases where illegal memory access might occur

  TensorIterator cuda launch configs update (pytorch#16224)
    Update launch configs for TensorIterator gpu_reduce_kernel. Enable flexible
    block dimension to improve efficiency for reduction cases with small fast
    dimension.

    Previously TensorIterator launches blocks with fixed 32x16 threads.
    For cases like:

      import torch
      torch.randn(2**20, 4, device='cuda').sum(0)

    The fixed launch config does handle coalesced memory access efficiently.

    Updated launch configure enables flexible block dimension. Combining with
    improved reduction scheme (using flexible vertical / horizontal reduction
    instead of limited warp / block reduction in the old code), it ensures optimal
    memory access pattern even with reduction on dimension with small stride.

    Possible future improvements:
    1. Precise dynamic shared memory allocation.
    2. Using warp shuffle for vertical (block_y) reduction.
    Pull Request resolved: pytorch#16224
Pull Request resolved: pytorch#17040

Differential Revision: D14078295

Pulled By: umanwizard

fbshipit-source-id: ecc55054a5a4035e731f0196d633412225c3b06c
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants