Skip to content

Conversation

@ifedan
Copy link
Contributor

@ifedan ifedan commented Aug 19, 2019

max_pool2d_with_indices_cuda and max_pool2d_with_indices_backward_cuda should have channel last optimized kernels(#23815)

@pytorchbot pytorchbot added module: cuda Related to torch.cuda, and CUDA support in general module: internals Related to internal abstractions in c10 and ATen module: nn Related to torch.nn module: operators labels Aug 19, 2019
@ifedan
Copy link
Contributor Author

ifedan commented Aug 20, 2019

Forward timeimage

@ifedan
Copy link
Contributor Author

ifedan commented Aug 20, 2019

Backward timeimage

@jjsjann123
Copy link
Collaborator

Two quick thing on the benchmarking:

  1. Backwards looks really bad, are we measuring only the pooling backward kernel or the total backward time?
    Could try to hack the code path to remove that TensorIterator call that does the transpose/grad_accumulate, so we can get a fair comparison.
  2. NIT: The graph is a little hard to see clearly what's going on here. Maybe we could list relative speedup instead of the absolute time here.

I'm quite occupied until tomorrow afternoon. Let me take a closer look then.

@ifedan
Copy link
Contributor Author

ifedan commented Aug 20, 2019

  1. are we measuring only the pooling backward kernel or the total backward time?

image
image

@jjsjann123
Copy link
Collaborator

Oops, my bad, messed up the legend on backwards earlier :/ Inferring from the perf we must have already got rid of the permutation?

Looks like we are doing good except for certain problem size where we are getting destroyed by the stock kernel. Let me take another pass tomorrow afternoon.

@ifedan
Copy link
Contributor Author

ifedan commented Aug 20, 2019

Oops, my bad, messed up the legend on backwards earlier :/ Inferring from the perf we must have already got rid of the permutation?

Looks like we are doing good except for certain problem size where we are getting destroyed by the stock kernel. Let me take another pass tomorrow afternoon.

Yes, I think we can improve even more here, but the idea, for now, to have a stable version that not worse then NCHW

static const int BACKWARD_THREADS = 256;
template <typename scalar_t, typename accscalar_t>
C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS)
__global__ void MaxPoolForwardNHWC(const int nthreads, const scalar_t* bottom_data,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nthreads & num is not used any more. Same with MaxPoolBackwardNHWC

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

for(int c = threadIdx.x; c < channels; c+= blockDim.x) {
scalar_t val = ptr_input[c];
scalar_t maxval = out_cached[2 * c];
if ((ScalarConvert<scalar_t, accscalar_t>::to(val) > maxval) || THCNumerics<scalar_t>::isnan(val)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// DEPRECATED: use static_cast in kernels instead of scalar_cast
template <typename T, typename U>
__host__ __device__ T scalar_cast(U u) {
return ScalarConvert<U, T>::to(u);
}

We might just use static_cast

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

const scalar_t *ptr_input = bottom_data + ih * in_stride_h + iw * in_stride_w;
for(int c = threadIdx.x; c < channels; c+= blockDim.x) {
scalar_t val = ptr_input[c];
scalar_t maxval = out_cached[2 * c];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: maxval doesn't seem to be necessary here;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

scalar_t maxval = out_cached[2 * c];
if ((ScalarConvert<scalar_t, accscalar_t>::to(val) > maxval) || THCNumerics<scalar_t>::isnan(val)) {
out_cached[2 * c] = ScalarConvert<scalar_t, accscalar_t>::to(val);
out_cached[2 * c + 1] = ih * width + iw;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is dangerous: we should not convert index to scalar_t.
For fp16 kernel, mantissa is only 10 bit, so we have a range of 1023 (to be fair, on both side, but index is only going to be positive). Any index beyond that will give us error here.

We need to change the line earlier to

int *out_mask_cached = smem[];
scalar_t *out_cached = reinterpret_cast<scalar_t*>(&out_mask_cached[channels*height*width]);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will also require updating the allocation of dynamic shared memory in the launch config

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

for (int ih = istartH; ih < iendH; ih+=blockDim.z) {
for (int iw = istartW; iw < iendW; iw+=blockDim.y) {
int phstart, phend, pwstart, pwend;
if (stride_h == 1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Looks like we are using the same logic as with the NCHW kernel here. Maybe we want to combine the code and avoid two copies of the same code here. So later if we want to clean up the code, it would be easier.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ifedan ifedan requested a review from VitalyFedyunin August 22, 2019 21:28
output_data, indices_data);
break;
}
default: AT_ERROR("Unsupported memory format. Supports only ChannelsLast, Contiguous");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Deprecated alias; this alias was deprecated because it represents extra API
// surface that makes it hard for people to understand what macro to use.
// Use TORCH_CHECK(false, ...) or TORCH_INTERNAL_ASSERT(false, ...) to
// unconditionally fail at a line of code.
#define AT_ERROR(...)                                                         \
  do {                                                                        \
    ::c10::detail::deprecated_AT_ERROR();                                     \
    C10_EXPAND_MSVC_WORKAROUND(TORCH_CHECK(false, ::c10::str(__VA_ARGS__)));  \
  } while (false)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Copy link
Contributor

@VitalyFedyunin VitalyFedyunin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Not approving to avoid accidental merge (we want to benchmark everything first in the separate branch).

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@VitalyFedyunin
Copy link
Contributor

Please rebase we are getting ready to land it.

Copy link
Contributor

@VitalyFedyunin VitalyFedyunin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add double backward test

gradInput_data);
break;
}
default: AT_ERROR("Unsupported memory format. Supports only ChannelsLast, Contiguous");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TORCH_CHECK please

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

auto indices_view = indices.view(size);
return grad.contiguous().view(size).gather(-1, indices_view).view(indices.sizes());
const auto memory_format = indices.suggest_memory_format();
return grad.contiguous(memory_format).view(size).gather(-1, indices_view).view(indices.sizes());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add test for double backward

Copy link
Contributor Author

@ifedan ifedan Oct 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be checked through this magic param: check_with_channels_last=True

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@VitalyFedyunin VitalyFedyunin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to land at Monday as soon as all tests are green.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Oct 21, 2019
…e improvement] (#24872)

Summary:
max_pool2d_with_indices_cuda and max_pool2d_with_indices_backward_cuda should have channel last optimized kernels(pytorch/pytorch#23815)
Pull Request resolved: pytorch/pytorch#24872

Differential Revision: D16964577

Pulled By: ifedan

fbshipit-source-id: 296dfef8e511a7ae2ed423e34e902d5401b3becb
@facebook-github-bot
Copy link
Contributor

@ifedan merged this pull request in bc57967.

thiagocrepaldi pushed a commit to thiagocrepaldi/pytorch that referenced this pull request Feb 4, 2020
…e improvement] (pytorch#24872)

Summary:
max_pool2d_with_indices_cuda and max_pool2d_with_indices_backward_cuda should have channel last optimized kernels(pytorch#23815)
Pull Request resolved: pytorch#24872

Differential Revision: D16964577

Pulled By: ifedan

fbshipit-source-id: 296dfef8e511a7ae2ed423e34e902d5401b3becb
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: cuda Related to torch.cuda, and CUDA support in general module: internals Related to internal abstractions in c10 and ATen module: nn Related to torch.nn

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants