Skip to content

Conversation

@thorjohnsen
Copy link
Contributor

Adds persistent cuda kernels that speed up SoftMax applied over the fast dimension, i.e. torch.nn.Softmax(dim=-1) and torch.nn.LogSoftmax(dim=-1). When the size is <= 1024, this code is 2-10x faster than the current code, speedup is higher for smaller sizes. This code works for half, float and double tensors with 1024 or fewer elements in the fast dimension. Numerical accuracy is on par with the current code, i.e. relative error is ~1e-8 for float tensors and ~1e-17 for double tensors. Relative error was computed against the CPU code.

The attached image shows kernel time in us for torch.nn.Softmax(dim=-1) applied to a half precision tensor of shape [16384,n], n is plotted along the horizontal axis. Similar uplifts can be seen for the backward pass and for LogSoftmax.

image

@pytorchbot pytorchbot added module: cuda Related to torch.cuda, and CUDA support in general module: operators labels May 22, 2019
@soumith
Copy link
Contributor

soumith commented May 23, 2019

this is really cool!

@ngimel
Copy link
Collaborator

ngimel commented May 23, 2019

Test failures are real.

@thorjohnsen
Copy link
Contributor Author

Yeah. They are intermittent. I ran one of the failed tests (test_softmax_dtype) two times in a row. It failed the first time and passed the second time. I'm looking into it.

@thorjohnsen
Copy link
Contributor Author

I'm pretty sure I introduced this bug when I did the code cleanup, the input arrays were no longer being properly initialized. I fixed it, hopefully the tests will pass now.

Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some small changes, overall looks good.

}
}

constexpr uint32_t FULL_MASK = 0xffffffff;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to define full mask, WARP_SHFL_XOR will by default use full mask.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove FULL_MASK.

// Warp Softmax forward
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

// WARP_BATCH number of batches.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a misleading comment - it is number of samples?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kind of inherited the comments. I agree they are not very clear, I'll try to improve that.


// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
// WARP_SIZE number of elements working on a single batch, has to be a power of two.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

number of threads working on a single sample?

for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
acc_t val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do 2 loops here provide perf benefits, or can they be fused like?

val = WARP_SHFL_XOR(..)
max_value[i] = <ternary operator>

If 2 loops are indeed necessary, comment might be in order.


// reduction sum
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would make sense to have a __ device __ function for warp reduce, as it is used a few times. It could also handle max reduction.

kernel<<<blocks, threads>>>(dst, src, batch_count, softmax_elements_stride, softmax_elements);
return true;
}
return false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might make sense to do TORCH_INTERNAL_ASSERT on softmax_elements <=1024.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's actually a better idea than asserting on the return value from dispatch_softmax. Will do.

} else if (softmax_elements <= 1024) {
// compute function index. there's a function for each power of two size up to 1024.
int log2_elements = 0;
while ((1 << log2_elements) < softmax_elements) ++log2_elements;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang, @mcarilli the functions for next power of 2 are often needed - does it make sense to make those helper functions generally available?

}

// reduction sum
constexpr uint32_t FULL_MASK = 0xffffffff;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment about FULL_MASK

}

// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like all these computationas are shared between forward/backward, so may be they can be abstracted away?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.


Tensor log_softmax_cuda(const Tensor &input, const int64_t dim, const bool half_to_float){
return host_softmax<LogSoftMaxForwardEpilogue>(input, dim, half_to_float);
return host_softmax<LogSoftMaxForwardEpilogue,true>(input, dim, half_to_float);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not too happy with 2 template arguments saying the same thing (I'm log! I'm not!), but since I did not come up with a not too ugly way to get rid of it, I'll let it slide.

@thorjohnsen
Copy link
Contributor Author

One of the softmax tests is still failing on accuracy, but it looks like that is due to a kluge in Pytorch. The test that fails runs softmax on a double precision tensor, but sees float like accuracy (i.e. relative error ~1e-8). The test was compiled with -D__HIP_PLATFORM_HCC__=1. My code uses WARP_SHFL_XOR for intrawarp reductions, if you look at the place where this is defined, there is a comment saying "HIP does not support double" and a specialization of WARP_SHFL_XOR that casts the value to a float before calling __shfl_xor. No test that uses WARP_SHFL_XOR will ever pass the double precision accuracy tests in test_nn.py because of this @iotamudelta.

#ifdef __HIP_PLATFORM_HCC__
//To handle ambiguity, add a type double version.
__device__ __forceinline__ double WARP_SHFL_XOR(double value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) {
//(HIP doesn't support double)
return (double) __shfl_xor((float) value, laneMask, width);
}

@iotamudelta
Copy link
Contributor

@thorjohnsen A double __shfl_xor(double, int, int) is defined for ROCm. This specialization should be removable.

@thorjohnsen
Copy link
Contributor Author

71 tests passed, 4 failed. After Inspection, I don't think any of the 4 failed tests are related to this PR, but please don't take my word for it. Removing the double specialization of WARP_SHFL_XOR like you suggested fixed the failing Softmax test and a couple of other unrelated unit tests are also passing now, so we might have accidentally fixed a bug @iotamudelta. I will push the new code with your suggested improvements tomorrow morning @ngimel, then I think we should be ready for the merge.

@ngimel
Copy link
Collaborator

ngimel commented May 29, 2019

My comments are addressed, thank you.

@thorjohnsen
Copy link
Contributor Author

The 3 failed tests appear unrelated to this PR. One test failed with ImportError torch.onnx.symbolic_helper, the 2nd test failed with ImportError undefined symbol: _ZN2at19NonVariableTypeMode10is_enabledEv and the 3rd test failed because of a java.io.IOException: Backing channel 'JNLP4-connect connection from 209.249.227.2/209.249.227.2:43242' is disconnected. I believe this PR is ready to be merged.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request May 31, 2019
Summary:
Adds persistent cuda kernels that speed up SoftMax applied over the fast dimension, i.e. torch.nn.Softmax(dim=-1) and torch.nn.LogSoftmax(dim=-1). When the size is <= 1024, this code is 2-10x faster than the current code, speedup is higher for smaller sizes. This code works for half, float and double tensors with 1024 or fewer elements in the fast dimension. Numerical accuracy is on par with the current code, i.e. relative error is ~1e-8 for float tensors and ~1e-17 for double tensors. Relative error was computed against the CPU code.

The attached image shows kernel time in us for torch.nn.Softmax(dim=-1) applied to a half precision tensor of shape [16384,n], n is plotted along the horizontal axis. Similar uplifts can be seen for the backward pass and for LogSoftmax.

![image](https://user-images.githubusercontent.com/41591019/58212822-b63ebb00-7cb5-11e9-910d-1fc7d8585d58.png)
Pull Request resolved: pytorch/pytorch#20827

Differential Revision: D15582509

Pulled By: ezyang

fbshipit-source-id: 65805db37487cebbc4ceefb1a1bd486d24745f80
@facebook-github-bot
Copy link
Contributor

@ezyang merged this pull request in e098878.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: cuda Related to torch.cuda, and CUDA support in general open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants