-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Fix the issue when NHWC Tensor has height or width larger then max cuda grid #28931
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), max_threads / block_y / block_z)); | ||
| const dim3 block(block_x, block_y, block_z); | ||
| int grid_x = nbatch; | ||
| int grid_y = cuda::ATenCeilDiv(safe_downcast<int, int64_t>(outputWidth), block_y*BLOCK_STRIDE); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't need multiple kernel launching, as we are striding along width/height in the kernel already.
I would discard all changes and simply add change this line (and grid_z as well)
int grid_y = std::min<int>(
at::cuda::getCurrentDeviceProperties()->maxGridSize[1],
cuda::ATenCeilDiv(safe_downcast<int, int64_t>(outputWidth), block_y*BLOCK_STRIDE));
Please do NOT do the same thing for grid_x, as the kernel does not loop over batch dimension. The fix you have here could be used for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated grid, and removed slicing, but I still see the performance degradation for backward function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm, perf regression doesn't really make any sense here. It's basically the same kernel.
Not sure if it's related to that typo I pointed out in the other comment.
Changing the grid size would affect the occupancy and maybe cache hit? I need to think about it further before I can make any decisive call. But let's fix the typo and hope the problem goes away
| maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), max_threads / block_y / block_z)); | ||
| const dim3 block(block_x, block_y, block_z); | ||
| int grid_x = nbatch; | ||
| int grid_y = cuda::ATenCeilDiv(safe_downcast<int, int64_t>(inputWidth), block_y*BLOCK_STRIDE); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same thing as I commented on forward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
| cuda::ATenCeilDiv(safe_downcast<int, int64_t>(outputWidth), block_y*BLOCK_STRIDE)); | ||
| int grid_z = std::min<int>( | ||
| at::cuda::getCurrentDeviceProperties()->maxGridSize[2], | ||
| cuda::ATenCeilDiv(safe_downcast<int, int64_t>(outputHeight), block_y*BLOCK_STRIDE)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo here! block_y should be block_z
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
| cuda::ATenCeilDiv(safe_downcast<int, int64_t>(inputWidth), block_y*BLOCK_STRIDE)); | ||
| int grid_z = std::min<int>( | ||
| at::cuda::getCurrentDeviceProperties()->maxGridSize[2], | ||
| cuda::ATenCeilDiv(safe_downcast<int, int64_t>(inputHeight), block_y*BLOCK_STRIDE)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
block_y -> block_z
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
|
LGTM. Did that few line change actually give a hit on performance? Or is there a high variance between runs? |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
…da grid (#28931) Summary: When NHWC Tensor has height or width larger then max CUDA grid size, max_pool fails with error code 0 The example is: pytorch/pytorch#28714 This change should limit grid size to the CUDA max possible size and chunk the input to be able to process it. Pull Request resolved: pytorch/pytorch#28931 Differential Revision: D18358892 Pulled By: ifedan fbshipit-source-id: 2fd65448bd644f1588a0e208edaaea5bcb6a7d52


When NHWC Tensor has height or width larger then max CUDA grid size, max_pool fails with error code 0
The example is: #28714
This change should limit grid size to the CUDA max possible size and chunk the input to be able to process it.