-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[inductor][cpp] Add BMM kernel template for autotuning #129772
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/129772
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (4 Unrelated Failures)As of commit 0f9e245 with merge base 16ea0dd ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
jgong5
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also share perf result?
26785ff to
6f058b2
Compare
d77c316 to
c459292
Compare
Thanks for the PR. Does the UT still fails in |
leslie-fang-intel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR. Add some comments, feel free to discuss.
| const {{micro_gemm.get_common_options()['input_t']}}* W, | ||
| {{micro_gemm.get_common_options()['input_t']}}* Y | ||
| {%- if is_dynamic_M %}, | ||
| const int64_t {{kernel.size(GemmOut, -2, unwrapped=True)}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May relate to https://github.com/pytorch/pytorch/pull/129772/files#r1659488160, here is just a formal parameter, right? Why we need to get the actual symbolic size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We just need the name of the symbol. The dynamic shapes variables ks0, ks1, etc. are declared in kernel.def_kernel(), so we don't know the name of these symbols beforehand. So I use the unwrapped calls to just get the name of the symbol.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it's a formal parameter of function threaded_mm, the name doesn't have to be ks0 or ks1, right? We only need to pass in actual parameter of ks0 when invoking this function? Doesn't it? Then I guess this unwrapped parameter is not needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is I'm trying to fit the CppBmmTemplate to work with minimal changes to CppPackedGemmTemplate. So in the GEMM template we have const int64_t M = {{kernel.size(GemmOut, 0)}}. This evaluates to const int64_t M = ks0 if only M is dynamic and const int64_t M = ks1 if M and B are both dynamic.
So the issue is not that we can't pass this as a parameter, because we could. I'm trying not to change the code existing in cpp_gemm_template.py, so having a version of the size() function that doesn't wrap everything in a static_cast c expression is helpful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused. The sizes should be automatically generated via kernel.def_kernel along with the buffers needed by epilogue fusions. I don't think you have to manually add them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Issue was that def_kernel was not being used to create the calls to sub-routines (single_thread_mm and threaded_mm, so I was constructing those manually. I added a function to CppTemplateKernel def_kernel_with_name that lets you create a function definition for any function, not just for the main kernel.
| epilogue_creator=epilogue_creator, | ||
| name=name, | ||
| ) | ||
| self.should_pack_weights = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does here mean pre-pack weight? I guess for better performance, we may still need to pack the weight (VNNI layout) of bmm inside the gemm template @jgong5?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Packing weights into VNNI layout is also for correctness but packing weights into blocked layout is not a must for correctness. I'm not sure if packing weights into blocked layout at runtime would be good for performance. If there is not much data reuse, perhaps packing weights into blocked layout is not optimal. I guess we need to explore and get more perf numbers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've done a big refactor so the weight packing/blocking is now divided into several functions within CppGemmTemplate (renamed from Packed), and some of those are overriden in CppBmmTemplate. Still need to explore with perf numbers when exactly we should be doing VNNI packing for non-constant weights. For now it's just being done outside of the gemm loops.
|
|
||
| def stride(self, node: ir.Buffer, dim: int) -> str: | ||
| return cexpr_index(self.rename_indexing(node.get_stride()[dim])) | ||
| stride = node.get_stride() if hasattr(node, 'get_stride') else node.layout.stride |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May I know in which case the node doesn't have the method of get_stride? get_stride is a method of ir.Buffer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The case is when an input is a ConcatKernel buffer containing multiple buffers. So the input to stride() is actually a SliceView, instead of an ir.Buffer. This occurs in hf_Reformer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One problem: node inputs are not ir.Buffer objects? They are ir.ReinterpretView or ir.SliceView. ReinterpretView does have get_stride implemented, but not because it's an ir.Buffer. We may want a complete overhaul of the typing hints in this class, since most things being passed are not actually ir.Buffer?
| start, end = parse_expr_with_index_symbols(_range) | ||
| sliced = L.slice_(sliced, dim, start, end, clamp=False) | ||
| assert isinstance(sliced.data, ir.ReinterpretView), sliced.data | ||
| assert isinstance(sliced.data, ir.ReinterpretView) or isinstance(sliced.data, ir.SliceView), sliced.data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feels like we need this change, because node of Matrix B is not realized... Haven't got time to confirm....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above.
84c01f4 to
c5d99c4
Compare
| start, end = parse_expr_with_index_symbols(_range) | ||
| sliced = L.slice_(sliced, dim, start, end, clamp=False) | ||
| assert isinstance(sliced.data, ir.ReinterpretView), sliced.data | ||
| if isinstance(sliced.data, ir.SliceView): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When does it happen?
| inputs[1] = padded_w | ||
|
|
||
| @staticmethod | ||
| def _pack_weight(inputs, layout_or_out, micro_gemm): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems a lot of duplication with cpp_gemm_template. Can we refactor/share code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did a significant refactor with pack_weight and blocking for VNNI and IRNodes separately.
| epilogue_creator=epilogue_creator, | ||
| name=name, | ||
| ) | ||
| self.should_pack_weights = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Packing weights into VNNI layout is also for correctness but packing weights into blocked layout is not a must for correctness. I'm not sure if packing weights into blocked layout at runtime would be good for performance. If there is not much data reuse, perhaps packing weights into blocked layout is not optimal. I guess we need to explore and get more perf numbers.
| const {{micro_gemm.get_common_options()['input_t']}}* W, | ||
| {{micro_gemm.get_common_options()['input_t']}}* Y | ||
| {%- if is_dynamic_M %}, | ||
| const int64_t {{kernel.size(GemmOut, -2, unwrapped=True)}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused. The sizes should be automatically generated via kernel.def_kernel along with the buffers needed by epilogue fusions. I don't think you have to manually add them.
a7835a7 to
e4e8fb6
Compare
jgong5
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix CI failures.
| placeholder = "<DEF_KERNEL>" | ||
| return self.def_function_with_name(self.kernel_name, placeholder, inputs, outputs, aliases) | ||
|
|
||
| def get_function_call(self, function_name: str, placeholder: str, indexer_dims: List[Any]=[], nodes=None) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add function spec specially for indexer_dims?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did a rewrite of this function and moved it into CppBmmTemplate since it was only used there. Args should be more clear now.
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Rebase failed due to Command |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
commit 19af839 Author: Mitchell, Frost <[email protected]> Date: Thu Dec 5 12:38:02 2024 -0800 Fix quantize amx bug commit f85cb03 Merge: 531a801 ca9aeed Author: Frost Mitchell <[email protected]> Date: Thu Dec 5 10:29:50 2024 -0500 Merge branch 'pytorch:main' into bmm_microkernel commit 531a801 Merge: dc24018 fd35be2 Author: Frost Mitchell <[email protected]> Date: Wed Dec 4 12:52:16 2024 -0500 Merge branch 'pytorch:main' into bmm_microkernel commit dc24018 Author: Mitchell, Frost <[email protected]> Date: Tue Dec 3 11:17:45 2024 -0800 Fix for modular gemm template commit 97d5a4f Merge: 866108a 9125e91 Author: Frost Mitchell <[email protected]> Date: Tue Dec 3 14:13:27 2024 -0500 Merge branch 'main' into bmm_microkernel commit 866108a Merge: 176814f 1af69ee Author: Frost Mitchell <[email protected]> Date: Fri Nov 22 08:31:50 2024 -0500 Merge branch 'main' into bmm_microkernel commit 176814f Author: Frost Mitchell <[email protected]> Date: Fri Nov 22 08:21:41 2024 -0500 Update torch/_inductor/codegen/cpp_gemm_template.py Co-authored-by: Jiong Gong <[email protected]> commit d491208 Author: Mitchell, Frost <[email protected]> Date: Wed Nov 20 11:58:50 2024 -0800 Streamline prep_weight and comments commit 6f1fa2f Author: Mitchell, Frost <[email protected]> Date: Tue Nov 19 06:06:56 2024 -0800 Lint, comments commit 9738fd2 Author: Mitchell, Frost <[email protected]> Date: Mon Nov 18 19:28:25 2024 -0800 Change weight prep and comments commit ff383c8 Author: Mitchell, Frost <[email protected]> Date: Fri Nov 15 11:12:17 2024 -0800 Fix amp, blocking with freezing commit 2763e0d Author: Mitchell, Frost <[email protected]> Date: Wed Nov 13 06:17:03 2024 -0800 Lint, fix BMM blocking commit a6a3d30 Author: Mitchell, Frost <[email protected]> Date: Tue Nov 12 18:17:04 2024 -0800 Change blocking for BMM when non-contiguous commit 9f1873d Author: Mitchell, Frost <[email protected]> Date: Tue Nov 12 11:29:08 2024 -0800 Comments commit 87a5b2d Author: Mitchell, Frost <[email protected]> Date: Wed Nov 6 11:22:38 2024 -0800 Fix AMX block bug commit 99d63c8 Author: Mitchell, Frost <[email protected]> Date: Wed Nov 6 07:35:28 2024 -0800 Lint commit 4b4cf95 Author: Mitchell, Frost <[email protected]> Date: Tue Nov 5 04:59:20 2024 -0800 Enable W transposed and fix epilogues commit da6f0ec Author: Mitchell, Frost <[email protected]> Date: Mon Oct 28 08:42:51 2024 -0700 Enable BMM with permuted W commit 0f58b3c Author: Mitchell, Frost <[email protected]> Date: Tue Oct 15 13:58:51 2024 -0700 Test for squaring matrix, single input commit cd1793b Author: Mitchell, Frost <[email protected]> Date: Tue Oct 15 13:07:09 2024 -0700 Simplify, support binary epilogues, disable reshaped inputs commit 273fe3c Author: Mitchell, Frost <[email protected]> Date: Tue Oct 8 13:34:05 2024 -0700 Fix T5 view errors commit 9b22e4b Author: Mitchell, Frost <[email protected]> Date: Mon Oct 7 11:35:35 2024 -0700 Add numthreads for blocking commit fb7d5d0 Author: Mitchell, Frost <[email protected]> Date: Fri Jul 26 13:27:13 2024 -0700 Add function caller commit 4e05d74 Author: Mitchell, Frost <[email protected]> Date: Fri Jul 26 12:31:32 2024 -0700 lint, fix tests commit 14f0e0c Author: Mitchell, Frost <[email protected]> Date: Fri Jul 26 11:50:33 2024 -0700 Refactor prep_weight() to separate VNNI packing and blocking commit 08cd48a Author: Mitchell, Frost <[email protected]> Date: Thu Jul 25 19:11:56 2024 -0700 BMM template using Cpp GEMM
fd6d3db to
473eec7
Compare
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 3 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 3 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR adds the Cpp template for BMM, for FP32, FP16, and BF16. See pytorch#125683 for more background. 1. Adds `CppBmmTemplate` class which inherits from `CppPackedGemmTemplate`. Given a number of worker threads `num_threads` and batch size `B`, execute the Gemm kernel. For the first `B - (B % num_threads)` batch inputs, run one sub-gemm problem per thread. Then for the remaining `B % num_threads` sub-gemms, we execute each subproblem using the parallelized Gemm kernel. To manage this code, the `GEMM_TEMPLATE` from `CppPackedGemmTemplate` is rendered two different times, one with a single thread and one which includes the parallel OMP pragma. 2. Adapts `CppPackedGemmTemplate` to allow for child class. The `GEMM_TEMPLATE` is separated into different strings to allow for rendering by the child class. Slicing/indexing are adapted to allow for 3D BMM inputs. Additional methods `get_options()` and `_get_params_for_choices()` are added to reduce code duplication. BMM within `dlrm` benchmark has a single input buffer which is used for but X and W inputs. This is currently not supported in this PR. ### Performance On Granite/Sapphire Rapids, cpp_bmm template code uses AMX which requires an expensive transpose operation so the BMM op is rarely selected as faster than the existing external bmm kernel. As a result, speedup on SPR is identical with and without BMM code. Pass rate matches the rates for main exactly. #### Test Summary on Granite Rapids Test Scenario | Comp Item | Date | Compiler | torchbench | huggingface | timm_models -- | -- | -- | -- | -- | -- | -- Single Socket Multi-Threads | Pass Rate | gemm autotune| inductor | 91%, 73/80 | 100%, 46/46 | 100%, 61/61 | | | bmm + gemm autotune | inductor | 91%, 73/80 | 100%, 46/46 | 100%, 61/61 | | Geomean Speedup | gemm autotune| inductor | 2.15x | 1.91x | 2.52x | | | bmm + gemm autotune | inductor | 2.15x | 1.96x | 2.53x Single Core Single-Thread | Pass Rate | gemm autotune | inductor | 91%, 73/80 | 100%, 46/46 | 100%, 61/61 | | | bmm + gemm autotune| inductor | 91%, 73/80 | 100%, 46/46 | 100%, 61/61 | | Geomean Speedup | inductor_locally_benchmark_586 | inductor | 2.43x | 1.56x | 2.60x | | | inductor_locally_benchmark_585 | inductor | 2.45x | 1.56x | 2.63x This is not the case on an older Skylake Xeon machine. For the BMM ops contained in torchbench models, bmm performance improves by 1.10-2.64x. #### BF16 28-core Skylake Xeon | Model | Inductor | GemmAutotune | Gemm+BMM Autotune | |--------|--------|--------|--------| | BERT_pytorch | 1.233x | 2.597x | 2.608x | | hf_DistilBert | 1.128x | 2.242x | 2.368x | | hf_Reformer | 1.124x | 1.419x | 1.590x | | hf_T5_base | 1.012x | 1.257x | 1.382x | | hf_T5_large | 1.085x | 2.228x | 2.345x | ## Example BMM Code ``` #include <c10/util/Unroll.h> #include <torch/csrc/inductor/aoti_torch/c/shim.h> template <bool accum> inline void cpp_bmm_micro_gemm_amx_kernel_32_2( AMXState& amx_state, const bfloat16* __restrict__ A, const bfloat16* __restrict__ B, float* __restrict__ C, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, uint8_t tilecfg_rows ) { // TODO(jgong5): add prefetch hint for A, B, C auto loadconfig = [](const amx_tilecfg& cfg) { _tile_loadconfig(&cfg); }; const auto last_k_offset = K / 32 * 32; const auto tail_k_size = K - last_k_offset; if C10_LIKELY (last_k_offset > 0) { amx_state.configure(tilecfg_rows, 64, 32 / 16, 2, loadconfig); } else { amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 2, loadconfig); } auto load_c = [&]() { _tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float)); _tile_loadd(1, C + 0 * ldc + 16, ldc * sizeof(float)); _tile_loadd(2, C + 16 * ldc + 0, ldc * sizeof(float)); _tile_loadd(3, C + 16 * ldc + 16, ldc * sizeof(float)); }; auto zero_c = [&]() { _tile_zero(0); _tile_zero(1); _tile_zero(2); _tile_zero(3); }; if constexpr (accum) { load_c(); } else { zero_c(); } auto compute = [&](int k) { _tile_stream_loadd(4, A + 0 * lda + k, lda * sizeof(bfloat16)); _tile_loadd(6, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16)); _tile_dpbf16ps(0, 4, 6); _tile_loadd(7, B + k * ldb + 32, ldb * 2 * sizeof(bfloat16)); _tile_dpbf16ps(1, 4, 7); _tile_stream_loadd(5, A + 16 * lda + k, lda * sizeof(bfloat16)); _tile_dpbf16ps(2, 5, 6); _tile_dpbf16ps(3, 5, 7); }; #pragma GCC unroll 4 for (int k = 0; k < last_k_offset; k += 32) { compute(k); } auto store_c = [&]() { // store to C _tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float)); _tile_stored(1, C + 0 * ldc + 16, ldc * sizeof(float)); _tile_stored(2, C + 16 * ldc + 0, ldc * sizeof(float)); _tile_stored(3, C + 16 * ldc + 16, ldc * sizeof(float)); }; // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead if C10_UNLIKELY (tail_k_size > 0) { if C10_LIKELY (last_k_offset > 0) { store_c(); amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 2, loadconfig); load_c(); } compute(last_k_offset); } store_c(); } template <bool accum> inline void cpp_bmm_micro_gemm_amx_kernel_16_2( AMXState& amx_state, const bfloat16* __restrict__ A, const bfloat16* __restrict__ B, float* __restrict__ C, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, uint8_t tilecfg_rows ) { // TODO(jgong5): add prefetch hint for A, B, C auto loadconfig = [](const amx_tilecfg& cfg) { _tile_loadconfig(&cfg); }; const auto last_k_offset = K / 32 * 32; const auto tail_k_size = K - last_k_offset; if C10_LIKELY (last_k_offset > 0) { amx_state.configure(tilecfg_rows, 64, 16 / 16, 2, loadconfig); } else { amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 2, loadconfig); } auto load_c = [&]() { _tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float)); _tile_loadd(1, C + 0 * ldc + 16, ldc * sizeof(float)); }; auto zero_c = [&]() { _tile_zero(0); _tile_zero(1); }; if constexpr (accum) { load_c(); } else { zero_c(); } auto compute = [&](int k) { _tile_stream_loadd(2, A + 0 * lda + k, lda * sizeof(bfloat16)); _tile_loadd(3, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16)); _tile_dpbf16ps(0, 2, 3); _tile_loadd(4, B + k * ldb + 32, ldb * 2 * sizeof(bfloat16)); _tile_dpbf16ps(1, 2, 4); }; #pragma GCC unroll 4 for (int k = 0; k < last_k_offset; k += 32) { compute(k); } auto store_c = [&]() { // store to C _tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float)); _tile_stored(1, C + 0 * ldc + 16, ldc * sizeof(float)); }; // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead if C10_UNLIKELY (tail_k_size > 0) { if C10_LIKELY (last_k_offset > 0) { store_c(); amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 2, loadconfig); load_c(); } compute(last_k_offset); } store_c(); } template <bool accum> inline void cpp_bmm_micro_gemm( AMXState& amx_state, const bfloat16* __restrict__ A, const bfloat16* __restrict__ B, float* __restrict__ C, int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc ) { AOTI_TORCH_CHECK(N % 32 == 0, "N dimension must be multiple of 32"); AOTI_TORCH_CHECK(K % 2 == 0, "K dimension must be multiple of 2"); // TODO(jgong5): loop unroll for M and N for (int64_t n = 0; n < N; n += 32) { for (int64_t m = 0; m < M; m += 32) { int64_t block_m = std::min<int64_t>(M - m, 32); int64_t m_tail = m; if (block_m >= 32) { cpp_bmm_micro_gemm_amx_kernel_32_2<accum>( amx_state, A + m * lda, B + n, C + m * ldc + n, K, lda, ldb, ldc, 16 ); block_m -= 32; m_tail += 32; } else if (block_m >= 16) { cpp_bmm_micro_gemm_amx_kernel_16_2<accum>( amx_state, A + m * lda, B + n, C + m * ldc + n, K, lda, ldb, ldc, 16 ); block_m -= 16; m_tail += 16; } if (block_m > 0) { cpp_bmm_micro_gemm_amx_kernel_16_2<accum>( amx_state, A + m_tail * lda, B + n, C + m_tail * ldc + n, K, lda, ldb, ldc, block_m ); } } } } void threaded_mm(const bfloat16* X, const bfloat16* W, bfloat16* Y, const int64_t ks_b_index) { constexpr int64_t num_threads = 48; constexpr int64_t N = 64; constexpr int64_t K = 96; constexpr int64_t Mr = 32; constexpr int64_t Nr = 32; constexpr int64_t Kr = 32; constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr; constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr; constexpr int64_t M = static_cast<int64_t>(384L); constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr; constexpr int64_t Mt_blocks = 1; constexpr int64_t Nt_blocks = 1; constexpr int64_t Kt_blocks = 3; constexpr int64_t Mc_blocks = 1; constexpr int64_t Nc_blocks = 1; constexpr int64_t Kc_blocks = 3; constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; // make sure all partitions are assigned AOTI_TORCH_CHECK( Mt_blocks * Nt_blocks * Kt_blocks * 48 >= Mr_blocks * Nr_blocks * Kr_blocks, "Not all partitions are assigned." ); #pragma omp parallel num_threads(48) { const int tid = omp_get_thread_num(); const int64_t k_group_id = tid / num_Kt_blocks; const int64_t k_slice_id = tid % num_Kt_blocks; const int64_t n_group_id = k_group_id / num_Nt_blocks; const int64_t n_slice_id = k_group_id % num_Nt_blocks; const int64_t k_block_start = k_slice_id * Kt_blocks; const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks); const int64_t n_block_start = n_slice_id * Nt_blocks; const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks); const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks); const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks); const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks; AMXState amx_state; auto _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf = _local_acc_buf.get(); for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) { const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread; const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks; const int64_t m_start = mc * Mr; const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); const int64_t m_size = m_end - m_start; for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { const int64_t n_start = nc * Nr; const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); const int64_t n_size = n_end - n_start; // NB: assume we pad N, nc_block_end won't exceed padded N here. const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end); if (_local_acc_buf == nullptr) { _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf = _local_acc_buf.get(); } for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { int64_t k_start = kc * Kr; int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K); for (int64_t nci = nc; nci < nc_block_end; nci++) { if (kc == k_block_start) { cpp_bmm_micro_gemm<static_cast<bool>(false)>( amx_state, &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]), &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]), &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]), static_cast<int64_t>(m_end + ((-1L)*m_start)), static_cast<int64_t>(Nr), static_cast<int64_t>(k_end + ((-1L)*k_start)), static_cast<int64_t>(96L), static_cast<int64_t>(32L), static_cast<int64_t>(Nc_blocks*Nr) ); } else { cpp_bmm_micro_gemm<static_cast<bool>(true)>( amx_state, &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]), &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]), &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]), static_cast<int64_t>(m_end + ((-1L)*m_start)), static_cast<int64_t>(Nr), static_cast<int64_t>(k_end + ((-1L)*k_start)), static_cast<int64_t>(96L), static_cast<int64_t>(32L), static_cast<int64_t>(Nc_blocks*Nr) ); } } } { { #pragma GCC ivdep for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L)) { for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1+=static_cast<int64_t>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(16)); auto tmp1 = at::vec::convert<bfloat16>(tmp0); tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(16)); } for(int64_t x1=static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=(static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))) == 0 ? 1 : static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))))) { auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))); auto tmp1 = at::vec::convert<bfloat16>(tmp0); tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))); } } } } } } amx_state.release([]() { _tile_release(); }); } } void single_thread_mm(const bfloat16* X, const bfloat16* W, bfloat16* Y, const int64_t ks_b_index) { constexpr int64_t num_threads = 1; constexpr int64_t N = 64; constexpr int64_t K = 96; constexpr int64_t Mr = 32; constexpr int64_t Nr = 32; constexpr int64_t Kr = 32; constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr; constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr; constexpr int64_t M = static_cast<int64_t>(384L); constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr; constexpr int64_t Mt_blocks = 12; constexpr int64_t Nt_blocks = 2; constexpr int64_t Kt_blocks = 3; constexpr int64_t Mc_blocks = 12; constexpr int64_t Nc_blocks = 1; constexpr int64_t Kc_blocks = 3; constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; // make sure all partitions are assigned AOTI_TORCH_CHECK( Mt_blocks * Nt_blocks * Kt_blocks * 1 >= Mr_blocks * Nr_blocks * Kr_blocks, "Not all partitions are assigned." ); { constexpr int tid = 0; constexpr int64_t k_group_id = 0; constexpr int64_t k_slice_id = 0; constexpr int64_t n_group_id = 0; constexpr int64_t n_slice_id = 0; constexpr int64_t m_block_start = 0; constexpr int64_t n_block_start = 0; constexpr int64_t n_block_end = Nr_blocks; constexpr int64_t k_block_start = 0; constexpr int64_t k_block_end = Kr_blocks; constexpr int64_t num_Mc_blocks_per_thread = num_Mc_blocks; constexpr int64_t m_block_end = Mr_blocks; AMXState amx_state; auto _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf = _local_acc_buf.get(); for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) { const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread; const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks; const int64_t m_start = mc * Mr; const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); const int64_t m_size = m_end - m_start; for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { const int64_t n_start = nc * Nr; const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); const int64_t n_size = n_end - n_start; // NB: assume we pad N, nc_block_end won't exceed padded N here. const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end); if (_local_acc_buf == nullptr) { _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf = _local_acc_buf.get(); } for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { int64_t k_start = kc * Kr; int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K); for (int64_t nci = nc; nci < nc_block_end; nci++) { if (kc == k_block_start) { cpp_bmm_micro_gemm<static_cast<bool>(false)>( amx_state, &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]), &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]), &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]), static_cast<int64_t>(m_end + ((-1L)*m_start)), static_cast<int64_t>(Nr), static_cast<int64_t>(k_end + ((-1L)*k_start)), static_cast<int64_t>(96L), static_cast<int64_t>(32L), static_cast<int64_t>(Nc_blocks*Nr) ); } else { cpp_bmm_micro_gemm<static_cast<bool>(true)>( amx_state, &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]), &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]), &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]), static_cast<int64_t>(m_end + ((-1L)*m_start)), static_cast<int64_t>(Nr), static_cast<int64_t>(k_end + ((-1L)*k_start)), static_cast<int64_t>(96L), static_cast<int64_t>(32L), static_cast<int64_t>(Nc_blocks*Nr) ); } } } { { #pragma GCC ivdep for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L)) { for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1+=static_cast<int64_t>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(16)); auto tmp1 = at::vec::convert<bfloat16>(tmp0); tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(16)); } for(int64_t x1=static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=(static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))) == 0 ? 1 : static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))))) { auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))); auto tmp1 = at::vec::convert<bfloat16>(tmp0); tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))); } } } } } } amx_state.release([]() { _tile_release(); }); } } extern "C" void cpp_bmm(const bfloat16* X, const bfloat16* W, bfloat16* Y) { const int64_t B = static_cast<int64_t>(5L); constexpr int64_t num_threads = 48; int64_t B_single_thread_block = (B / num_threads) * num_threads; #pragma omp parallel for num_threads(48) for (int64_t b_start = 0; b_start < B_single_thread_block; ++b_start) { single_thread_mm(X, W, Y, b_start); } for (int64_t b_start = B_single_thread_block; b_start < B; ++b_start) { threaded_mm(X, W, Y, b_start); } } ``` Pull Request resolved: pytorch#129772 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
This PR adds the Cpp template for BMM, for FP32, FP16, and BF16. See #125683 for more background.
CppBmmTemplateclass which inherits fromCppPackedGemmTemplate. Given a number of worker threadsnum_threadsand batch sizeB, execute the Gemm kernel. For the firstB - (B % num_threads)batch inputs, run one sub-gemm problem per thread. Then for the remainingB % num_threadssub-gemms, we execute each subproblem using the parallelized Gemm kernel.To manage this code, the
GEMM_TEMPLATEfromCppPackedGemmTemplateis rendered two different times, one with a single thread and one which includes the parallel OMP pragma.CppPackedGemmTemplateto allow for child class. TheGEMM_TEMPLATEis separated into different strings to allow for rendering by the child class. Slicing/indexing are adapted to allow for 3D BMM inputs. Additional methodsget_options()and_get_params_for_choices()are added to reduce code duplication.BMM within
dlrmbenchmark has a single input buffer which is used for but X and W inputs. This is currently not supported in this PR.Performance
On Granite/Sapphire Rapids, cpp_bmm template code uses AMX which requires an expensive transpose operation so the BMM op is rarely selected as faster than the existing external bmm kernel. As a result, speedup on SPR is identical with and without BMM code. Pass rate matches the rates for main exactly.
Test Summary on Granite Rapids
This is not the case on an older Skylake Xeon machine.
For the BMM ops contained in torchbench models, bmm performance improves by 1.10-2.64x.
BF16 28-core Skylake Xeon
Example BMM Code
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @gujinghui @PenghuiCheng @jianyuh @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen @snadampal @mcarilli @ptrblck @leslie-fang-intel @EikanWang @voznesenskym @penguinwu @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @LucasLLC @MeetVadakkanchery @mhorowitz @pradeepfn @XilunWu @rec