-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[DO NOT REVIEW] Channels last perf test #25102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DO NOT REVIEW] Channels last perf test #25102
Conversation
Added cudnn nhwc support for: 1. batch norm 2. convolution 3. convolution_transpose
suggest_memory_format has ambiguous meaning for two cases: 1. tensor with NCHW where C = 1. we could use stride of C as a hint to tell the intended memory format. 2. tensor with NCHW where H == W == 1. there's no way to identify the intended memory format from strides. Currently we fallback to NCHW whenever we see contiguous tensor. Hence avoiding ambiguity for some of the special cases.
The old implementation assumed `is_channels_last_contiguous_` to be mutually exclusive to `is_contiguous_`, which is not true. Properly set the flag by checking strides.
…e_empty_supports_memory_format
Initial kernel support added for optimized NHWC tensor. TODO: currently backwards kernel spits out tensor with NHWC stride. Unfortunately autograd restores grad to contiguous (in either copy or add). This makes real perf tuning annoying to do. (since I cannot easily measure end-to-end time in my python script) My current kernel is blazing fast comparing to the original NCHW kernel in fp16, since I avoided atomicAdd. I'll finish perf tuning after we merged some future PR expanding NHWC support in the core.
This reverts commit c7ece81.
|
@jjsjann123 @ifedan FYI |
|
BTW, pytorch/torch/csrc/autograd/functions/accumulate_grad.cpp Lines 42 to 59 in fc7f4e4
This guy clones the grad if it's not contiguous. In the benchmarking for my PR, that permutation kills the perf gain from NHWC :/ Not sure how it works out now as our |
Previous kernel does not stride on Channel dimension, and the kernel uses shared memory to store temporary result (to break data dependency -> code paralellism) Resulted in requesting more resources than what's available. Fixing: added striding on C to reduce shmem usage per CTA.
Updated cudnn API for batchnorm. Enabling the Extended API which provides semi-persistent batchnorm kernel that has better performance on NHWC layout. TODO: I made adjustments to the API as well as BN in JIT IR. But I haven't fully tested the JIT part yet. I should verify that in the final PR.
This PR have combined changes of #23403 subtasks.