-
Notifications
You must be signed in to change notification settings - Fork 26.3k
max_pool2d cuda should have channel last optimized kernels[Performance improvement] #24872
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This reverts commit c7ece81.
|
Two quick thing on the benchmarking:
I'm quite occupied until tomorrow afternoon. Let me take a closer look then. |
|
Oops, my bad, messed up the legend on backwards earlier :/ Inferring from the perf we must have already got rid of the permutation? Looks like we are doing good except for certain problem size where we are getting destroyed by the stock kernel. Let me take another pass tomorrow afternoon. |
Yes, I think we can improve even more here, but the idea, for now, to have a stable version that not worse then NCHW |
| static const int BACKWARD_THREADS = 256; | ||
| template <typename scalar_t, typename accscalar_t> | ||
| C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS) | ||
| __global__ void MaxPoolForwardNHWC(const int nthreads, const scalar_t* bottom_data, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nthreads & num is not used any more. Same with MaxPoolBackwardNHWC
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
| for(int c = threadIdx.x; c < channels; c+= blockDim.x) { | ||
| scalar_t val = ptr_input[c]; | ||
| scalar_t maxval = out_cached[2 * c]; | ||
| if ((ScalarConvert<scalar_t, accscalar_t>::to(val) > maxval) || THCNumerics<scalar_t>::isnan(val)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pytorch/aten/src/THC/THCNumerics.cuh
Lines 403 to 407 in b2f6e2b
| // DEPRECATED: use static_cast in kernels instead of scalar_cast | |
| template <typename T, typename U> | |
| __host__ __device__ T scalar_cast(U u) { | |
| return ScalarConvert<U, T>::to(u); | |
| } |
We might just use static_cast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
| const scalar_t *ptr_input = bottom_data + ih * in_stride_h + iw * in_stride_w; | ||
| for(int c = threadIdx.x; c < channels; c+= blockDim.x) { | ||
| scalar_t val = ptr_input[c]; | ||
| scalar_t maxval = out_cached[2 * c]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: maxval doesn't seem to be necessary here;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
| scalar_t maxval = out_cached[2 * c]; | ||
| if ((ScalarConvert<scalar_t, accscalar_t>::to(val) > maxval) || THCNumerics<scalar_t>::isnan(val)) { | ||
| out_cached[2 * c] = ScalarConvert<scalar_t, accscalar_t>::to(val); | ||
| out_cached[2 * c + 1] = ih * width + iw; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is dangerous: we should not convert index to scalar_t.
For fp16 kernel, mantissa is only 10 bit, so we have a range of 1023 (to be fair, on both side, but index is only going to be positive). Any index beyond that will give us error here.
We need to change the line earlier to
int *out_mask_cached = smem[];
scalar_t *out_cached = reinterpret_cast<scalar_t*>(&out_mask_cached[channels*height*width]);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will also require updating the allocation of dynamic shared memory in the launch config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
| for (int ih = istartH; ih < iendH; ih+=blockDim.z) { | ||
| for (int iw = istartW; iw < iendW; iw+=blockDim.y) { | ||
| int phstart, phend, pwstart, pwend; | ||
| if (stride_h == 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: Looks like we are using the same logic as with the NCHW kernel here. Maybe we want to combine the code and avoid two copies of the same code here. So later if we want to clean up the code, it would be easier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
| output_data, indices_data); | ||
| break; | ||
| } | ||
| default: AT_ERROR("Unsupported memory format. Supports only ChannelsLast, Contiguous"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// Deprecated alias; this alias was deprecated because it represents extra API
// surface that makes it hard for people to understand what macro to use.
// Use TORCH_CHECK(false, ...) or TORCH_INTERNAL_ASSERT(false, ...) to
// unconditionally fail at a line of code.
#define AT_ERROR(...) \
do { \
::c10::detail::deprecated_AT_ERROR(); \
C10_EXPAND_MSVC_WORKAROUND(TORCH_CHECK(false, ::c10::str(__VA_ARGS__))); \
} while (false)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
VitalyFedyunin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Not approving to avoid accidental merge (we want to benchmark everything first in the separate branch).
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
Please rebase we are getting ready to land it. |
VitalyFedyunin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add double backward test
| gradInput_data); | ||
| break; | ||
| } | ||
| default: AT_ERROR("Unsupported memory format. Supports only ChannelsLast, Contiguous"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TORCH_CHECK please
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
| auto indices_view = indices.view(size); | ||
| return grad.contiguous().view(size).gather(-1, indices_view).view(indices.sizes()); | ||
| const auto memory_format = indices.suggest_memory_format(); | ||
| return grad.contiguous(memory_format).view(size).gather(-1, indices_view).view(indices.sizes()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to add test for double backward
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will be checked through this magic param: check_with_channels_last=True
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
VitalyFedyunin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to land at Monday as soon as all tests are green.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
…e improvement] (#24872) Summary: max_pool2d_with_indices_cuda and max_pool2d_with_indices_backward_cuda should have channel last optimized kernels(pytorch/pytorch#23815) Pull Request resolved: pytorch/pytorch#24872 Differential Revision: D16964577 Pulled By: ifedan fbshipit-source-id: 296dfef8e511a7ae2ed423e34e902d5401b3becb
…e improvement] (pytorch#24872) Summary: max_pool2d_with_indices_cuda and max_pool2d_with_indices_backward_cuda should have channel last optimized kernels(pytorch#23815) Pull Request resolved: pytorch#24872 Differential Revision: D16964577 Pulled By: ifedan fbshipit-source-id: 296dfef8e511a7ae2ed423e34e902d5401b3becb




max_pool2d_with_indices_cuda and max_pool2d_with_indices_backward_cuda should have channel last optimized kernels(#23815)