-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Cuda persistent softmax #20827
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cuda persistent softmax #20827
Conversation
|
this is really cool! |
|
Test failures are real. |
|
Yeah. They are intermittent. I ran one of the failed tests (test_softmax_dtype) two times in a row. It failed the first time and passed the second time. I'm looking into it. |
|
I'm pretty sure I introduced this bug when I did the code cleanup, the input arrays were no longer being properly initialized. I fixed it, hopefully the tests will pass now. |
ngimel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some small changes, overall looks good.
| } | ||
| } | ||
|
|
||
| constexpr uint32_t FULL_MASK = 0xffffffff; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need to define full mask, WARP_SHFL_XOR will by default use full mask.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will remove FULL_MASK.
| // Warp Softmax forward | ||
| //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
|
||
| // WARP_BATCH number of batches. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a misleading comment - it is number of samples?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kind of inherited the comments. I agree they are not very clear, I'll try to improve that.
|
|
||
| // WARP_BATCH number of batches. | ||
| // WARP_ITERATOINS The number of iterations required for one warp to iterate over all data. | ||
| // WARP_SIZE number of elements working on a single batch, has to be a power of two. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
number of threads working on a single sample?
| for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { | ||
| acc_t val[WARP_BATCH]; | ||
| #pragma unroll | ||
| for (int i = 0;i < WARP_BATCH;++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do 2 loops here provide perf benefits, or can they be fused like?
val = WARP_SHFL_XOR(..)
max_value[i] = <ternary operator>
If 2 loops are indeed necessary, comment might be in order.
|
|
||
| // reduction sum | ||
| #pragma unroll | ||
| for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would make sense to have a __ device __ function for warp reduce, as it is used a few times. It could also handle max reduction.
| kernel<<<blocks, threads>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements); | ||
| return true; | ||
| } | ||
| return false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might make sense to do TORCH_INTERNAL_ASSERT on softmax_elements <=1024.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's actually a better idea than asserting on the return value from dispatch_softmax. Will do.
| } else if (softmax_elements <= 1024) { | ||
| // compute function index. there's a function for each power of two size up to 1024. | ||
| int log2_elements = 0; | ||
| while ((1 << log2_elements) < softmax_elements) ++log2_elements; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| } | ||
|
|
||
| // reduction sum | ||
| constexpr uint32_t FULL_MASK = 0xffffffff; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment about FULL_MASK
| } | ||
|
|
||
| // use 128 threads per block to maximimize gpu utilization | ||
| constexpr int threads_per_block = 128; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like all these computationas are shared between forward/backward, so may be they can be abstracted away?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok.
|
|
||
| Tensor log_softmax_cuda(const Tensor &input, const int64_t dim, const bool half_to_float){ | ||
| return host_softmax<LogSoftMaxForwardEpilogue>(input, dim, half_to_float); | ||
| return host_softmax<LogSoftMaxForwardEpilogue,true>(input, dim, half_to_float); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not too happy with 2 template arguments saying the same thing (I'm log! I'm not!), but since I did not come up with a not too ugly way to get rid of it, I'll let it slide.
|
One of the softmax tests is still failing on accuracy, but it looks like that is due to a kluge in Pytorch. The test that fails runs softmax on a double precision tensor, but sees float like accuracy (i.e. relative error ~1e-8). The test was compiled with -D__HIP_PLATFORM_HCC__=1. My code uses WARP_SHFL_XOR for intrawarp reductions, if you look at the place where this is defined, there is a comment saying "HIP does not support double" and a specialization of WARP_SHFL_XOR that casts the value to a float before calling __shfl_xor. No test that uses WARP_SHFL_XOR will ever pass the double precision accuracy tests in test_nn.py because of this @iotamudelta. pytorch/aten/src/THC/THCDeviceUtils.cuh Lines 63 to 68 in 31e2d20
|
|
@thorjohnsen A double __shfl_xor(double, int, int) is defined for ROCm. This specialization should be removable. |
|
71 tests passed, 4 failed. After Inspection, I don't think any of the 4 failed tests are related to this PR, but please don't take my word for it. Removing the double specialization of WARP_SHFL_XOR like you suggested fixed the failing Softmax test and a couple of other unrelated unit tests are also passing now, so we might have accidentally fixed a bug @iotamudelta. I will push the new code with your suggested improvements tomorrow morning @ngimel, then I think we should be ready for the merge. |
|
My comments are addressed, thank you. |
|
The 3 failed tests appear unrelated to this PR. One test failed with ImportError torch.onnx.symbolic_helper, the 2nd test failed with ImportError undefined symbol: _ZN2at19NonVariableTypeMode10is_enabledEv and the 3rd test failed because of a java.io.IOException: Backing channel 'JNLP4-connect connection from 209.249.227.2/209.249.227.2:43242' is disconnected. I believe this PR is ready to be merged. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Adds persistent cuda kernels that speed up SoftMax applied over the fast dimension, i.e. torch.nn.Softmax(dim=-1) and torch.nn.LogSoftmax(dim=-1). When the size is <= 1024, this code is 2-10x faster than the current code, speedup is higher for smaller sizes. This code works for half, float and double tensors with 1024 or fewer elements in the fast dimension. Numerical accuracy is on par with the current code, i.e. relative error is ~1e-8 for float tensors and ~1e-17 for double tensors. Relative error was computed against the CPU code. The attached image shows kernel time in us for torch.nn.Softmax(dim=-1) applied to a half precision tensor of shape [16384,n], n is plotted along the horizontal axis. Similar uplifts can be seen for the backward pass and for LogSoftmax.  Pull Request resolved: pytorch/pytorch#20827 Differential Revision: D15582509 Pulled By: ezyang fbshipit-source-id: 65805db37487cebbc4ceefb1a1bd486d24745f80
Adds persistent cuda kernels that speed up SoftMax applied over the fast dimension, i.e. torch.nn.Softmax(dim=-1) and torch.nn.LogSoftmax(dim=-1). When the size is <= 1024, this code is 2-10x faster than the current code, speedup is higher for smaller sizes. This code works for half, float and double tensors with 1024 or fewer elements in the fast dimension. Numerical accuracy is on par with the current code, i.e. relative error is ~1e-8 for float tensors and ~1e-17 for double tensors. Relative error was computed against the CPU code.
The attached image shows kernel time in us for torch.nn.Softmax(dim=-1) applied to a half precision tensor of shape [16384,n], n is plotted along the horizontal axis. Similar uplifts can be seen for the backward pass and for LogSoftmax.