Skip to content

Conversation

@akyrola
Copy link

@akyrola akyrola commented Jun 3, 2019

When optimizing sparse tensor coalesce, I recognized that this kernel was taking bulk of the time (see PR #21214). It is used (at least) in the sparse tensor constructor to validate that the index tensor min/max indices are valid.

This PR rewrites the kernel by using CUB reduction ,achieving about 16x speedup. With my benchmark for coalesce, before nvprof showed:

#  GPU activities:   45.47%  2.42669s       101  24.027ms  23.862ms  28.968ms  void kernelTransformReduceInnermostDimIndex<long, long, MinValuePair<long, long>>(long*, long*, long*, unsigned int, unsigned int, thrust::pair<long, long>, long)
#                    45.41%  2.42386s       101  23.999ms  23.857ms  28.944ms  void kernelTransformReduceInnermostDimIndex<long, long, MaxValuePair<long, long>>(long*, long*, long*, unsigned int, unsigned int, thrust::pair<long, long>, long)

... after this:

 GPU activities:   19.50%  154.92ms       101  1.5338ms  1.5285ms  1.5987ms  void kernelTransformReduceInnermostDimIndex<long, long, MinValuePair<long, long>>(long*, long*, long*, unsigned int, unsigned int, thrust::pair<long, long>, long)
                   19.45%  154.52ms       101  1.5299ms  1.5247ms  1.5933ms  void kernelTransformReduceInnermostDimIndex<long, long, MaxValuePair<long, long>>(long*, long*, long*, unsigned int, unsigned int, thrust::pair<long, long>, long)

Test: test/torch.py and test/sparse.py pass.

@pytorchbot pytorchbot added module: cuda Related to torch.cuda, and CUDA support in general module: operators labels Jun 3, 2019
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@akyrola has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@akyrola akyrola force-pushed the akyrola/cubreduce branch from 9a4d78b to 00c3710 Compare June 3, 2019 20:46
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@akyrola has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ezyang
Copy link
Contributor

ezyang commented Jun 5, 2019

Pasting some prior discussion here:


@wesolwsk: if you can point me to a workflow using this operator we can check the SM occupancy before and after


Incidentally, the macro for num threads per block is set much too high for pyTorch (1024). I previously found that for Caffe2 lowering it from 512 to 128 had a positive effect on perf for use cases with small to medium sized inputs so it may be worth trying for PyTorch


@akyrola: I don't have a workflow, this is just random point optimization from a bootcamp task. But below, I run nvprof to get "achieved occupancy", which I think is the same thing?

Here is my benchmark script:

import torch
from random import *
import time
import timeit

n = 1000000
I = torch.tensor([[randint(0, 200) for _ in range(3)] for _ in range(n)]).t().cuda()
V = torch.randn(n).cuda()
size = torch.Size([2000, 1000, 1000])

def fn():
    torch.cuda.synchronize();
    S = torch.sparse_coo_tensor(I, V, size)
    torch.cuda.synchronize();

t = timeit.repeat(fn, repeat=100, number=1)
print(min(t))
nvprof --metrics achieved_occupancy python sparse_bench.py

Before:

0.05
==2041189== Profiling application: python sparse_bench.py
==2041189== Profiling result:
==2041189== Metric result:
Invocations                               Metric Name                        Metric Description         Min         Max         Avg
Device "Tesla M40 (0)"
    Kernel: void kernelTransformReduceInnermostDimIndex<long, long, MaxValuePair<long, long>>(long*, long*, long*, unsigned int, unsigned int, thrust::pair<long, long>, long)
        100                        achieved_occupancy                        Achieved Occupancy    0.249997    0.249998    0.249998
    Kernel: void kernelTransformReduceInnermostDimIndex<long, long, MinValuePair<long, long>>(long*, long*, long*, unsigned int, unsigned int, thrust::pair<long, long>, long)
        100                        achieved_occupancy                        Achieved Occupancy    0.249998    0.249998    0.249998
times:
            Type  Time(%)      Time     Calls       Avg       Min       Max  Name
 GPU activities:   50.02%  2.53547s       100  25.355ms  25.243ms  29.036ms  void kernelTransformReduceInnermostDimIndex<long, long, MinValuePair<long, long>>(long*, long*, long*, unsigned int, unsigned int, thrust::pair<long, long>, long)
                   49.92%  2.53042s       100  25.304ms  25.223ms  28.981ms  void kernelTransformReduceInnermostDimIndex<long, long, MaxValuePair<long, long>>(long*, long*, long*, unsigned int, unsigned int, thrust::pair<long, long>, long)

After:

0.03
==2032478== Profiling application: python sparse_bench.py
==2032478== Profiling result:
==2032478== Metric result:
Invocations                               Metric Name                        Metric Description         Min         Max         Avg
Device "Tesla M40 (0)"
    Kernel: void kernelTransformReduceInnermostDimIndex<long, long, MaxValuePair<long, long>>(long*, long*, long*, unsigned int, unsigned int, thrust::pair<long, long>, long)
        100                        achieved_occupancy                        Achieved Occupancy    0.124988    0.124990    0.124989
    Kernel: void kernelTransformReduceInnermostDimIndex<long, long, MinValuePair<long, long>>(long*, long*, long*, unsigned int, unsigned int, thrust::pair<long, long>, long)
        100                        achieved_occupancy                        Achieved Occupancy    0.124988    0.124999    0.124989
times:
            Type  Time(%)      Time     Calls       Avg       Min       Max  Name
 GPU activities:   49.63%  164.96ms       100  1.6496ms  1.5968ms  2.0850ms  void kernelTransformReduceInnermostDimIndex<long, long, MinValuePair<long, long>>(long*, long*, long*, unsigned int, unsigned int, thrust::pair<long, long>, long)
                   49.47%  164.45ms       100  1.6445ms  1.5961ms  1.8330ms  void kernelTransformReduceInnermostDimIndex<long, long, MaxValuePair<long, long>>(long*, long*, long*, unsigned int, unsigned int, thrust::pair<long, long>, long)

Interestingly the occupancy is lower on the new version, but it runs 15x faster. Occupancy in isolation is not a good metric because very inefficient algorithm can have a high occupancy.
Let me do some additional testing to ensure the correctness, although I think the other tests cover it well.


I tried with 128 threads, and the performance is 2x slower on my (admittedly too simple benchmark). I'll keep it in 256 but introduce a define.


Reading your post (very impressive result!), indeed 128 would be better for smaller inputs... i guess it would be best to vary the number based on the input size. but that's too much work for this case.


@wesolwsk Which input sizes have you tested? Can you try a variety of input sizes (e.g. 100, 10000, 1000000). For the last, scaling the number of blocks with number of rows may work better than limiting number of blocks to 1024.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From @wesolwsk: can we assume anything about the size of the innermost dimension compared to the other dimensions? Optimal implementations will be different depending on which is larger.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From @wesolwsk: can you try without this limit? 1024 blocks yields occupancy of about 13 on Volta for short rows. Maybe you can get better performance for large input sizes if you just set the number of blocks to the row count. It would also allow you to get rid of the external for loop, which could help a little more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know how long this test takes to run? If it's more than a few seconds, we should mark it slowTest

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is fast, some tens of millis

@ngimel
Copy link
Collaborator

ngimel commented Jun 5, 2019

I'm generally uneasy about improvements to legacy reductions. Ideally all reduction operators should be ported to TensorIterator reduction that is achieving close to SOL speed on a wide variety of parameters. max/min/max_index/min_index currently don't go through TensorIterator, but they should.
Having multiple reduction paths, some using cub, some not, is hard to maintain.

@ezyang ezyang added facebook and removed facebook labels Jun 5, 2019
@akyrola
Copy link
Author

akyrola commented Jun 5, 2019

Screen Shot 2019-06-05 at 9 55 57 AM

Indeed @wesolwsk, the speed depends a lot of the shape. Old version is faster if there are many rows. i was doing this in the context of sparse tensors, where the number of rows for the indices is always very small (as it is the number of tensor dimensions). But definitely should develop alternative kernel for the "skinny" tensors.

But given what @ngimel says, perhaps I should drop this? We though have a lot of CUB usage in the caffe2 side.

@wesolwsk
Copy link
Contributor

wesolwsk commented Jun 5, 2019

@akyrola, nice perf comparison. Maybe we can keep both implementations for now but select based on input dimensions (looks like the new one is better for rows of size > 1000. If TensorIterator reductions can handle both cases well then that would be even better.

Summary:
When optimizing sparse tensor coalesce, I recognized that this kernel was taking bulk of the time (see PR pytorch#21214). It is used (at least) in the sparse tensor constructor to validate that the index tensor min/max indices are valid.

This PR rewrites the kernel by using CUB reduction ,achieving about 16x speedup. With my benchmark for coalesce, before nvprof showed:
```
#  GPU activities:   45.47%  2.42669s       101  24.027ms  23.862ms  28.968ms  void kernelTransformReduceInnermostDimIndex<long, long, MinValuePair<long, long>>(long*, long*, long*, unsigned int, unsigned int, thrust::pair<long, long>, long)
#                    45.41%  2.42386s       101  23.999ms  23.857ms  28.944ms  void kernelTransformReduceInnermostDimIndex<long, long, MaxValuePair<long, long>>(long*, long*, long*, unsigned int, unsigned int, thrust::pair<long, long>, long)
```

... after this:

```
 GPU activities:   19.50%  154.92ms       101  1.5338ms  1.5285ms  1.5987ms  void kernelTransformReduceInnermostDimIndex<long, long, MinValuePair<long, long>>(long*, long*, long*, unsigned int, unsigned int, thrust::pair<long, long>, long)
                   19.45%  154.52ms       101  1.5299ms  1.5247ms  1.5933ms  void kernelTransformReduceInnermostDimIndex<long, long, MaxValuePair<long, long>>(long*, long*, long*, unsigned int, unsigned int, thrust::pair<long, long>, long)
```
Pull Request resolved: pytorch#21295

Differential Revision: D15606873

fbshipit-source-id: e5bc86933efa44c36c3b2942114a04c20abd7700
@akyrola akyrola force-pushed the akyrola/cubreduce branch from 00c3710 to 1c8c84b Compare June 5, 2019 18:42
@ngimel
Copy link
Collaborator

ngimel commented Jun 5, 2019

TensorIterator sum reduction is achieving higher bandwidth everywhere than current min/max reduction (and looking at the speed-up numbers, even than improved min reduction. min/max numbers should be similar to sum if min/max are moved to use TensorIterator. Benchmarking script and output on V100:

import torch
def bench(size, fn):
   x=torch.ones(size, device='cuda', dtype = torch.long)
   torch.cuda.synchronize()
   import time
   nrep = 100
   start = time.time()
   for i in range(nrep):
      fn(x,1)
   torch.cuda.synchronize()
   end = time.time()
   return ((end-start)/nrep)


print("rows elems bwmin bwsum timemin timesum speedup")
for row in [1,3,10,100]:
    for nelem in [1000,10000,100000,1000000]:
        size = [row,nelem]
        timemin = bench(size, torch.min)
        bwmin = row*nelem*8/timemin*1e-9
        timesum = bench(size, torch.sum)
        bwsum = row*nelem*8/timesum*1e-9
        print('{:8d} {:8d} {:.2f} {:.2f} {:3e} {:3e} {:.2f}'.format(row, nelem, bwmin, bwsum, timemin,  timesum,  timemin/timesum))
for row in [1000,10000,100000,1000000]:
    for nelem in [1,3,10,100]:
        size = [row,nelem]
        timemin = bench(size, torch.min)
        bwmin = row*nelem*8/timemin*1e-9
        timesum = bench(size, torch.sum)
        bwsum = row*nelem*8/timesum*1e-9
        print('{:8d} {:8d} {:.2f} {:.2f} {:3e} {:3e} {:.2f}'.format(row, nelem, bwmin, bwsum, timemin,  timesum,  timemin/timesum))
rows elems bwmin bwsum timemin timesum speedup
       1     1000 0.46 0.90 1.726866e-05 8.919239e-06 1.94
       1    10000 0.64 10.52 1.256108e-04 7.603168e-06 16.52
       1   100000 0.71 39.73 1.130674e-03 2.013445e-05 56.16
       1  1000000 0.41 460.28 1.960913e-02 1.738071e-05 1128.21
       3     1000 1.89 3.19 1.270533e-05 7.522106e-06 1.69
       3    10000 2.24 32.11 1.073050e-04 7.474422e-06 14.36
       3   100000 2.28 117.60 1.052678e-03 2.040863e-05 51.58
       3  1000000 1.18 594.66 2.036736e-02 4.035950e-05 504.65
      10     1000 5.94 10.71 1.346827e-05 7.469654e-06 1.80
      10    10000 7.12 106.45 1.124167e-04 7.514954e-06 14.96
      10   100000 3.76 255.15 2.127333e-03 3.135443e-05 67.85
      10  1000000 3.76 733.80 2.126475e-02 1.090217e-04 195.05
     100     1000 52.95 105.55 1.510859e-05 7.579327e-06 1.99
     100    10000 32.83 528.50 2.436876e-04 1.513720e-05 16.10
     100   100000 33.53 750.09 2.385638e-03 1.066542e-04 22.37
     100  1000000 33.25 774.75 2.405822e-02 1.032586e-03 23.30
    1000        1 0.93 1.08 8.587837e-06 7.379055e-06 1.16
    1000        3 2.89 3.19 8.299351e-06 7.529259e-06 1.10
    1000       10 9.60 10.44 8.330345e-06 7.660389e-06 1.09
    1000      100 95.98 105.15 8.335114e-06 7.607937e-06 1.10
   10000        1 9.59 10.65 8.344650e-06 7.510185e-06 1.11
   10000        3 28.67 32.11 8.370876e-06 7.474422e-06 1.12
   10000       10 96.09 106.76 8.325577e-06 7.493496e-06 1.11
   10000      100 508.86 508.40 1.572132e-05 1.573563e-05 1.00
  100000        1 26.06 17.80 3.070354e-05 4.495382e-05 0.68
  100000        3 83.52 310.79 2.873421e-05 7.722378e-06 3.72
  100000       10 251.74 429.03 3.177881e-05 1.864672e-05 1.70
  100000      100 665.75 722.70 1.201653e-04 1.106954e-04 1.09
 1000000        1 28.58 18.26 2.799034e-04 4.380512e-04 0.64
 1000000        3 81.95 487.73 2.928448e-04 4.920721e-05 5.95
 1000000       10 270.94 610.78 2.952671e-04 1.309800e-04 2.25
 1000000      100 660.15 759.30 1.211839e-03 1.053598e-03 1.15

@akyrola
Copy link
Author

akyrola commented Jun 5, 2019

@ngimel current min/max returns also the index (i.e are both min and argmin), are these kind of reductions supported by tensor iterations already?

@akyrola
Copy link
Author

akyrola commented Jun 5, 2019

Btw, the benchmark script you copypasted might be suspectible for startup overheads? I think the first time kernel is run, it takes much longer. I was using the "timeit" module to avoid this, by taking the minimum example.

My tests were done on an old M40.

@ngimel
Copy link
Collaborator

ngimel commented Jun 5, 2019

@ngimel current min/max returns also the index (i.e are both min and argmin), are these kind of reductions supported by tensor iterations already?

Yes, it's pretty flexible, you just have to provide the necessary functors to it (e.g. I'm pretty sure it is currently returning both mean and std from a single reduction pass, using the right reduction functor).
I don't think there are significant overheads associated just with the first kernel being run (it's the other way around, likely subsequent repetitions show lower time because of spurious cache hits), and cuda context is initialized and input data is created outside of timing runs, but I did not specifically check it.

@ngimel
Copy link
Collaborator

ngimel commented Jun 5, 2019

https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/SharedReduceOps.h can be useful to look at.

@akyrola
Copy link
Author

akyrola commented Jun 5, 2019

Ok, let me work on the iterator approach perhaps later this week. Cannot promise though, so if someone wants to do it instead, please do.

@facebook-github-bot
Copy link
Contributor

Hi @akyrola!

Thank you for your pull request.

We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@pytorchbot
Copy link
Collaborator

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
Stale pull requests will automatically be closed 30 days after being marked Stale

@pytorchbot pytorchbot added Stale and removed Stale labels Apr 12, 2022
@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jun 11, 2022
@github-actions github-actions bot closed this Jul 11, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed module: cuda Related to torch.cuda, and CUDA support in general open source Stale

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants