Conversation
Signed-off-by: Liqun Fu <[email protected]>
Signed-off-by: liqunfu <[email protected]>
Signed-off-by: liqunfu <[email protected]>
Signed-off-by: liqunfu <[email protected]>
Signed-off-by: liqunfu <[email protected]>
Signed-off-by: liqunfu <[email protected]>
Signed-off-by: liqunfu <[email protected]>
What is the unit of Mean Duration? Is it in milliseconds (ms)? |
…rios, clear interleaved attribute Signed-off-by: liqunfu <[email protected]>
in microseconds, The results was from another PR that I am working on but it cannot mixed with this one due to timing of the release. |
Added mlas test for RoPE. It only runs in X64 build because that is the only machine I have. Tests for other scenarios can be enabled later. |
Signed-off-by: liqunfu <[email protected]>
Signed-off-by: liqunfu <[email protected]>
Signed-off-by: liqunfu <[email protected]>
…com/microsoft/onnxruntime into liqun/Intel-ROPE-kernel-to-use-AVX2
Signed-off-by: liqunfu <[email protected]>
yihonglyu
left a comment
There was a problem hiding this comment.
Could you add a microbenchmark under onnxruntime\test\mlas\bench or another suitable location and collect the performance metrics with and without the patch?
Signed-off-by: liqunfu <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
added mlas_rope benchmark |
Signed-off-by: liqunfu <[email protected]>
onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp
Outdated
Show resolved
Hide resolved
onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2_fp32.cpp
Outdated
Show resolved
Hide resolved
Signed-off-by: liqunfu <[email protected]>
### Description <!-- Describe your changes. --> Credit to [chethanpk](https://github.com/chethanpk) who provided with Rope Embedding in a patch. The patch is in the first commit of this PR. I have been confirming perf improvement with this code change. My analysis is based on phi-3-mini-4k-instruct-int4-int8-blklen32. Benchmark from onnxruntim-genai does not show clear improvement. this is because GQA only takes a small portion of the whole model (<10%) and Rope within GQA only take small portion of the whole GQA (12%). The following is the profile with and without avx2 we see cost of RoPE dropped from 82.42 to 18.86. Therefore I still recommend to merge this PR. with avx2 RoPE: Name: GroupQueryAttention_rotary, Mean Duration: 18.86, Percentage: 3.16% plain c++ RoPE: Name: GroupQueryAttention_rotary, Mean Duration: 82.42, Percentage: 12.20% mlas benchmark: dim|interleaved|baseline|new -|-|-|- 128 |false|735|18.1 256 |false|1470|31.7 512 |false|2938|59.2 1024 |false|5876|81.5 128 |true|368|23.1 256 |true|735|34.3 512 |true|1470|62.0 1024 |true|2937|125 --------- Signed-off-by: Liqun Fu <[email protected]> Signed-off-by: liqunfu <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
### Description <!-- Describe your changes. --> Credit to [chethanpk](https://github.com/chethanpk) who provided with Rope Embedding in a patch. The patch is in the first commit of this PR. I have been confirming perf improvement with this code change. My analysis is based on phi-3-mini-4k-instruct-int4-int8-blklen32. Benchmark from onnxruntim-genai does not show clear improvement. this is because GQA only takes a small portion of the whole model (<10%) and Rope within GQA only take small portion of the whole GQA (12%). The following is the profile with and without avx2 we see cost of RoPE dropped from 82.42 to 18.86. Therefore I still recommend to merge this PR. with avx2 RoPE: Name: GroupQueryAttention_rotary, Mean Duration: 18.86, Percentage: 3.16% plain c++ RoPE: Name: GroupQueryAttention_rotary, Mean Duration: 82.42, Percentage: 12.20% mlas benchmark: dim|interleaved|baseline|new -|-|-|- 128 |false|735|18.1 256 |false|1470|31.7 512 |false|2938|59.2 1024 |false|5876|81.5 128 |true|368|23.1 256 |true|735|34.3 512 |true|1470|62.0 1024 |true|2937|125 --------- Signed-off-by: Liqun Fu <[email protected]> Signed-off-by: liqunfu <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This fixes an issue that _mm256_maskload_ps intrinsic used in remainder-handling logic introduced in #23694. The core of the problem is that _mm256_maskload_ps (and its store equivalent) can read beyond the masked elements. Even if mask correctly specifies that you only want to load, for example, 3 floats, the intrinsic may still read the full 32 bytes (8 floats) from the provided memory address. The invalid access occurs when one of buffers (input, sin_data, or cos_data) ends near the boundary of a memory page, and the part of the 32-byte read that you don't care about (i.e., the masked-off part) falls onto an unmapped page. This will cause a segmentation fault (invalid access). The Solution: Use a Scalar Remainder Loop The simplest, safest, and most robust solution is to replace the masked AVX remainder logic with a simple scalar loop. This is the exact strategy already used by your RopeKernel_Avx2_fp16_Impl functions, which are safe from this bug. The performance impact of this change will be negligible, as this loop only processes the final 1-15 elements. --------- Co-authored-by: Copilot <[email protected]>
This fixes an issue that _mm256_maskload_ps intrinsic used in remainder-handling logic introduced in microsoft#23694. The core of the problem is that _mm256_maskload_ps (and its store equivalent) can read beyond the masked elements. Even if mask correctly specifies that you only want to load, for example, 3 floats, the intrinsic may still read the full 32 bytes (8 floats) from the provided memory address. The invalid access occurs when one of buffers (input, sin_data, or cos_data) ends near the boundary of a memory page, and the part of the 32-byte read that you don't care about (i.e., the masked-off part) falls onto an unmapped page. This will cause a segmentation fault (invalid access). The Solution: Use a Scalar Remainder Loop The simplest, safest, and most robust solution is to replace the masked AVX remainder logic with a simple scalar loop. This is the exact strategy already used by your RopeKernel_Avx2_fp16_Impl functions, which are safe from this bug. The performance impact of this change will be negligible, as this loop only processes the final 1-15 elements. --------- Co-authored-by: Copilot <[email protected]>
Description
Credit to chethanpk who provided with Rope Embedding in a patch. The patch is in the first commit of this PR.
I have been confirming perf improvement with this code change. My analysis is based on phi-3-mini-4k-instruct-int4-int8-blklen32.
Benchmark from onnxruntim-genai does not show clear improvement. this is because GQA only takes a small portion of the whole model (<10%) and Rope within GQA only take small portion of the whole GQA (12%). The following is the profile with and without avx2
we see cost of RoPE dropped from 82.42 to 18.86. Therefore I still recommend to merge this PR.
with avx2 RoPE:
Name: GroupQueryAttention_rotary, Mean Duration: 18.86, Percentage: 3.16%
plain c++ RoPE:
Name: GroupQueryAttention_rotary, Mean Duration: 82.42, Percentage: 12.20%
mlas benchmark: