Skip to content

Conversation

@ptrblck
Copy link
Collaborator

@ptrblck ptrblck commented Jun 27, 2019

This PR activates faster depthwise convolution kernels for Volta and Turing GPUs using cudnn >= 7600.
The script to benchmark the current PyTorch master branch and this PR branch can be found here.
(50 warmup iterations, 1000 iterations for timing)

I've used #3265 to create a similar benchmark and added a few additional setups.
Since the results are quite long, I've uploaded them in a spreadsheet here.
Times are given in ms per iteration.
We've benchmarked this PR on a DGX1 using V100 GPUs.

The current workload check in check_cudnn_depthwise_workload is quite long and can be moved to another file, if wanted.

CC @ngimel (Thanks for the support while benchmarking it ;) )

int w = input.size(3); // same as h
int ch = input.size(1);
int bs = input.size(0);
int k = weight.size(2); // kernel size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you never use k in this function

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right! That's some dead code and I'll remove it.

@li-roy
Copy link
Contributor

li-roy commented Jun 28, 2019

@pytorchbot rebase this please

@soumith
Copy link
Contributor

soumith commented Jun 29, 2019

@pytorchbot rebase this please

@zhangguanheng66
Copy link
Contributor

@pytorchbot rebase this please

@pytorchbot
Copy link
Collaborator

Sorry, only maintainers are authorized to rebase other people's PRs. Feel free to try again on one of your PRs!

(To learn more about this bot, see Bot commands.)

@zhangguanheng66
Copy link
Contributor

@pytorchbot retest this please

@ngimel
Copy link
Collaborator

ngimel commented Jul 1, 2019

@ptrblck windows build failure looks real, it does not like "and" apparently.

@ptrblck
Copy link
Collaborator Author

ptrblck commented Jul 2, 2019

@ngimel Thanks for the information!
I wasn't aware that alternative tokens might cause trouble in Windows.
I've updated it to &&.

@ptrblck
Copy link
Collaborator Author

ptrblck commented Jul 2, 2019

@pytorchbot retest this please

@ngimel
Copy link
Collaborator

ngimel commented Jul 3, 2019

@pytorchbot rebase this please

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@izdeby has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

auto ConvParams::use_cudnn_depthwise(
const at::Tensor& input, const at::Tensor& weight) const -> bool {
#if AT_CUDNN_ENABLED()
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should not be calling getCurrentDeviceProperties() directly here, instead add CUDAHooks::supportsDepthwiseConvolutionWithCuDNN to cuda/detail/CUDAHooks.cpp, like it's currently done for CUDAHooks::supportsDilatedConvolutionWithCuDNN(). That would also allow you to not use AT_CUDNN_ENABLED macro.

@pytorchbot pytorchbot added the merge-this-please Was marked for merge with @pytorchbot merge this please label Jul 8, 2019
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Jul 9, 2019
Summary:
This PR activates faster depthwise convolution kernels for Volta and Turing GPUs using cudnn >= 7600.
The script to benchmark the current PyTorch master branch and this PR branch can be found [here](https://gist.github.com/ptrblck/4590cf20721d8f43296c9903abd4a774).
(50 warmup iterations, 1000 iterations for timing)

I've used pytorch/pytorch#3265 to create a similar benchmark and added a few additional setups.
Since the results are quite long, I've uploaded them in a spreadsheet [here](https://docs.google.com/spreadsheets/d/13ByXcqg7LQUr3DVG3XpLwnJ-CXg3GUZJ3puyTMw9n2I/edit?usp=sharing).
Times are given in ms per iteration.
We've benchmarked this PR on a DGX1 using V100 GPUs.

The current workload check in `check_cudnn_depthwise_workload` is quite long and can be moved to another file, if wanted.

CC ngimel (Thanks for the support while benchmarking it ;) )
Pull Request resolved: pytorch/pytorch#22302

Differential Revision: D16115057

Pulled By: ezyang

fbshipit-source-id: bad184658518e73b4d6b849d77e408f5a7a757de
@facebook-github-bot
Copy link
Contributor

@ezyang merged this pull request in a3346e1.

@ptrblck ptrblck deleted the cudnn branch July 14, 2019 22:44
@yaysummeriscoming
Copy link

Hi I’m very excited this is in :) . I’m able to reproduce the individual depthwise convolution tests as presented in the spreadsheet - very impressive gains of up to 400%.

I decided to test this with MobileNet V2, however unfortunately I’m only seeing speedups of ~10%. My understanding is that depthwise convolution is the slowest link when training such lightweight networks, so this seems quite low to me?

I’ve modified the test script to use MobileNet V2 here:
https://gist.github.com/yaysummeriscoming/88ae59bc5b7ba8581ea396d8ce87d28f

I had a go at profiling with the autograd profiler, but that doesn’t delineate between point wise & depthwise convolution.

Any pointers? Could it be that cuDNN isn’t optimised for pointwise convolutions with large input/output channel ratios?

(I’m assuming this is the right place to ask this, please correct me if not)

@jph00
Copy link

jph00 commented Oct 25, 2019

@ptrblck thanks for this great PR and helpful benchmarking. I've created a couple of summary tables of the benchmarks that might be helpful. Here's speedup by kernel size and stride, by height/width:

image

And here's the details of h/w by num channels, for just the stride one and kernel size 3 rows:

image

Have you tried benchmarking 5x5 convs? They are used a lot in efficientnet so would be great if they're fast...

cc @ngimel

@jph00
Copy link

jph00 commented Oct 26, 2019

I just tried 5x5 convs and it appears they are not optimized for tensor cores - they ran at about the same speed for fp16 vs fp32.

@bearpelican
Copy link

@ngimel
Copy link
Collaborator

ngimel commented Oct 26, 2019

IIRC, cudnn only had fast implementations for 1x1 and 3x3, so just enabling it for 5x5 is unlikely to dramatically speed things up.

@jph00
Copy link

jph00 commented Oct 26, 2019

Thanks @ngimel . Is there any plan to add a 5x5 implementation? If not, could I twist your arm to create such a plan... ;)

@ngimel
Copy link
Collaborator

ngimel commented Oct 26, 2019

Not mine, I'm not with nvidia anymore :-)

@jph00
Copy link

jph00 commented Oct 26, 2019

Oh yes so I see! Welcome to Facebook then :)

@andravin
Copy link
Contributor

Hi @ptrblck and @ngimel , just to clarify a point that is causing some confusion: although this patch uses cuDNN kernels for depthwise convolution on Volta, do those kernels actually use tensor cores?

Depthwise convolution is basically planar convolution nested inside of diagonal matrix multiplication. I do not see how that could be made faster with 8m x 8n x 4k matrix multiplication fragments. It seems like one of the input matrices would be diagonal, and (at most) 1/8th of the matrix elements would be nonzero, reducing the effective arithmetic throughput of the tensor cores to the level of fp32 core arithmetic throughput.

@ngimel
Copy link
Collaborator

ngimel commented Feb 21, 2020

@andravin you are right, those kernels don't use tensor cores.

@andravin
Copy link
Contributor

@ngimel do they use hfma2? fp16 accumulation might be OK for 3x3 depth-wise.

Otherwise I am stumped why these kernels are Volta only. Also, P100 had hfma2.

facebook-github-bot pushed a commit that referenced this pull request Jun 22, 2020
Summary:
Follow up of #38044. Thanks ptrblck, mcarilli for the help on discussing the changes!

Could fix #37725 by skipping the depthwise-workload check introduced in #22302. This PR also relaxed dilated convolution for channels-last.

The testing script is https://gist.github.com/xwang233/82a707f69bb710cb612349280a2c5f41. About 387k conv arguments were tested and no cudnn exception was thrown.

cc ngimel VitalyFedyunin ptrblck mcarilli
Pull Request resolved: #38904

Differential Revision: D22155797

Pulled By: VitalyFedyunin

fbshipit-source-id: 81b5736cec67ea263029121521c6acafd9dddba6
facebook-github-bot pushed a commit that referenced this pull request Oct 23, 2021
Summary:
There are multiple improvement of depthwise convolution speed in cudnn between 7.6 and 8.2, since #22302.
This PR aim to harvest all the new improvement by enable more cudnn kernel. The workload checking logic can also be simplified now.
To keep the change simple, I kept things before cudnn 8.2 unchanged.

Similar to #22302, I used a script [here](https://gist.github.com/FDecaYed/e8ba98a95cd33697df2ace86fdb44897) to benchmark. Both run are using cudnn 8.2
One enhancement I did to the script is switch to event based timing. With warmup kernels to fill the launch queue ahead, this should give us accurate kernel timing even in CPU launch bound cases.

Here is A100 and V100 result sorted by speedup.
[Book1.xlsx](https://github.com/pytorch/pytorch/files/6530371/Book1.xlsx)

Result highlights:
Newly turned on 5x5 cudnn kernel show up to 6x speedup.
Close to half of test sizes show >10% speedup.
Fixed some corner cases that previously caused 15-20x slowdown.
Only slowdown a handful of cases(~10 out of >1000)

Pull Request resolved: #58749

Reviewed By: bdhirsh

Differential Revision: D31613199

Pulled By: ngimel

fbshipit-source-id: 883b58facad67ccd51dc9ab539368b4738d40398
langong347 pushed a commit that referenced this pull request Oct 25, 2021
Summary:
There are multiple improvement of depthwise convolution speed in cudnn between 7.6 and 8.2, since #22302.
This PR aim to harvest all the new improvement by enable more cudnn kernel. The workload checking logic can also be simplified now.
To keep the change simple, I kept things before cudnn 8.2 unchanged.

Similar to #22302, I used a script [here](https://gist.github.com/FDecaYed/e8ba98a95cd33697df2ace86fdb44897) to benchmark. Both run are using cudnn 8.2
One enhancement I did to the script is switch to event based timing. With warmup kernels to fill the launch queue ahead, this should give us accurate kernel timing even in CPU launch bound cases.

Here is A100 and V100 result sorted by speedup.
[Book1.xlsx](https://github.com/pytorch/pytorch/files/6530371/Book1.xlsx)

Result highlights:
Newly turned on 5x5 cudnn kernel show up to 6x speedup.
Close to half of test sizes show >10% speedup.
Fixed some corner cases that previously caused 15-20x slowdown.
Only slowdown a handful of cases(~10 out of >1000)

Pull Request resolved: #58749

Reviewed By: bdhirsh

Differential Revision: D31613199

Pulled By: ngimel

fbshipit-source-id: 883b58facad67ccd51dc9ab539368b4738d40398
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

merge-this-please Was marked for merge with @pytorchbot merge this please Merged module: cuda Related to torch.cuda, and CUDA support in general module: internals Related to internal abstractions in c10 and ATen open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.