Skip to content

[ARM CPU] Enable FP16 kernels for GQA op #23746

Merged
fajin-corp merged 11 commits intomainfrom
fajin/gqa-integrate
Feb 20, 2025
Merged

[ARM CPU] Enable FP16 kernels for GQA op #23746
fajin-corp merged 11 commits intomainfrom
fajin/gqa-integrate

Conversation

@fajin-corp
Copy link
Contributor

@fajin-corp fajin-corp commented Feb 19, 2025

Description

  • Enable hgemm and softmax fp16 kernels for GQA
  • add intra-loop parallelism to RoPE fp16 kernel

Benchmarking models

Note:

  • Both fp32 and fp16 models share the same model structure and operator settings.
  • GQA takes ~15% of the runtime.
  • prompt length 256, token generation length 512

Linux (ubuntu 24.04) Standard D16pls v5 (16 vcpus, 32 GiB memory)

fp32 (tps) old fp16 (tps) new fp16 (tps) new fp16 vs old fp16 new fp16 vs fp32
prompt processing 31.22 44.24 46.29 +4.6% +48.25%
token generation 4.75 7.2 7.95 +10.39% +67.43%

Motivation and Context

Speed up GQA on FP16

@fajin-corp fajin-corp requested a review from a team as a code owner February 19, 2025 00:02
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

@fajin-corp fajin-corp merged commit 2d33ee9 into main Feb 20, 2025
96 of 98 checks passed
@fajin-corp fajin-corp deleted the fajin/gqa-integrate branch February 20, 2025 17:38
guschmue pushed a commit that referenced this pull request Mar 6, 2025
### Description
 - Enable hgemm and softmax fp16 kernels for GQA
 - add intra-loop parallelism to RoPE fp16 kernel

__Benchmarking models__
- float32: [phi-3 cpu accuracy level
0](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/tree/main/cpu_and_mobile/cpu-int4-rtn-block-32)
- float16: [phi-3 gpu accuracy level
0](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/tree/main/cuda/cuda-int4-rtn-block-32)

Note: 
- Both fp32 and fp16 models share the same model structure and operator
settings.
- GQA takes ~15% of the runtime.
- prompt length 256, token generation length 512

Linux (ubuntu 24.04) Standard D16pls v5 (16 vcpus, 32 GiB memory)
| | fp32 (tps) | old fp16 (tps) | new fp16 (tps) | new fp16 vs old fp16
| new fp16 vs fp32 |
|--|--|--|--|--|--|
| prompt processing | 31.22 | 44.24 | 46.29 | +4.6% | +48.25% |
| token generation | 4.75  | 7.2 | 7.95 | +10.39% | +67.43% |

### Motivation and Context
Speed up GQA on FP16
snnn pushed a commit that referenced this pull request Mar 10, 2025
ashrit-ms pushed a commit that referenced this pull request Mar 17, 2025
### Description
 - Enable hgemm and softmax fp16 kernels for GQA
 - add intra-loop parallelism to RoPE fp16 kernel

__Benchmarking models__
- float32: [phi-3 cpu accuracy level
0](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/tree/main/cpu_and_mobile/cpu-int4-rtn-block-32)
- float16: [phi-3 gpu accuracy level
0](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/tree/main/cuda/cuda-int4-rtn-block-32)

Note: 
- Both fp32 and fp16 models share the same model structure and operator
settings.
- GQA takes ~15% of the runtime.
- prompt length 256, token generation length 512

Linux (ubuntu 24.04) Standard D16pls v5 (16 vcpus, 32 GiB memory)
| | fp32 (tps) | old fp16 (tps) | new fp16 (tps) | new fp16 vs old fp16
| new fp16 vs fp32 |
|--|--|--|--|--|--|
| prompt processing | 31.22 | 44.24 | 46.29 | +4.6% | +48.25% |
| token generation | 4.75  | 7.2 | 7.95 | +10.39% | +67.43% |

### Motivation and Context
Speed up GQA on FP16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants