-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[inductor][cpp][gemm] support k slicing for static shapes #130821
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/130821
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 7d7befe with merge base 1614891 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
| return False | ||
| if self.is_dynamic_M: | ||
| # TODO(jgong5): perhaps use size hint to decide? | ||
| return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since num_k_slices is 1 for dynamic M, so, anyway k-slicing will not work for this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I don't want to generate k-slicing related code when it is dynamic M for now.
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
This PR provides the initial support for k-slicing (i.e. parallel reduction along k-dim) of CPP GEMM template. Only static shapes are supported now. When k-slicing is enabled, there would be extra temporary buffers allocated to hold the intermediate results and an extra barrier after initial GEMM compute by each thread, i.e. each thread first stores the GEMM result to temporary accumulation buffers (pointed by
local_buf_ptrswhich is an array of pointers pointing to accumulation buffers), followed by a reduction along k-slices, epilogue computes and store to the final outputY. In each k-slicing thread group, the reduction along k-slices and epilogue computes are conducted in parallel along M-dim. The algorithm is designed to reduce the synchronization overhead as much as possible.The k-slicing is enabled when blocking on M and N is unable to occupy all threads. Since k-slicing doesn't always bring benefit, an extra configuration is added to enable it (disable by default). We need to identify a good heuristics in the future to enable k-slicing by default.
Performance numbers with 64x4096x64, 64x10000x64, 64x20000x64 as examples on 60-core SPR as examples. As you can see, the perf of k-slicing is only better than non-k-slicing when K is large enough.
Without k-slicing
AUTOTUNE linear_unary(64x4096, 64x4096, 64)
cpp_packed_gemm_0 0.0108 ms 100.0%
_linear_pointwise 0.0431 ms 25.1%
AUTOTUNE linear_unary(64x10000, 64x10000, 64)
cpp_packed_gemm_0 0.0272 ms 100.0%
_linear_pointwise 0.0892 ms 30.5%
AUTOTUNE linear_unary(64x20000, 64x20000, 64)
cpp_packed_gemm_0 0.0781 ms 100.0%
_linear_pointwise 0.1693 ms 46.1%
With k-slicing:
AUTOTUNE linear_unary(64x4096, 64x4096, 64)
cpp_packed_gemm_0 0.0260 ms 100.0%
_linear_pointwise 0.0444 ms 58.5%
AUTOTUNE linear_unary(64x10000, 64x10000, 64)
cpp_packed_gemm_0 0.0275 ms 100.0%
_linear_pointwise 0.0893 ms 30.8%
AUTOTUNE linear_unary(64x20000, 64x20000, 64)
cpp_packed_gemm_0 0.0284 ms 100.0%
_linear_pointwise 0.1686 ms 16.8%
cc @voznesenskym @penguinwu @EikanWang @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang