[webgpu] Use components for VxAttentionScore#23726
Conversation
For phi3.5-gqa-static sum_long(>1000 tokens) on meteor lake. Before: 300 tokens in 27.0sec, e2e:11.1 tps, prompt: 212.4 tps, gen: 14.2 tps, ttft: 5.85 sec After: 300 tokens in 23.0sec, e2e:13.0 tps, prompt: 248.9 tps, gen: 16.6 tps, ttft: 4.99 sec
|
@sushraja-msft @guschmue This PR applies the flash attention's logic so that softmax can be merged to the last stage shader. It's used to optimize the generation shader. Since the https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc#L413 hasn't been applied to GQA, this PR can also benefit the prefill time. Next step, I can move/reuse CopyKVCache so that the generation shader can also combine QKT into one shader. |
|
@sushraja-msft @guschmue, have restored this PR to only add the components support. And will put the fa support to a separate PR. Please take a look, thanks. |
|
@guschmue It seems that the 2 failing CIs are not related with my changes. Can this PR be merged? |
|
yeap |
For phi3.5-gqa-static sum_long(>1000 tokens) on meteor lake. Before: 300 tokens in 27.0sec, e2e:11.1 tps, prompt: 212.4 tps, gen: 14.2 tps, ttft: 5.85 sec After: 300 tokens in 23.0sec, e2e:13.0 tps, prompt: 248.9 tps, gen: 16.6 tps, ttft: 4.99 sec
For phi3.5-gqa-static sum_long(>1000 tokens) on meteor lake. Before: 300 tokens in 27.0sec, e2e:11.1 tps, prompt: 212.4 tps, gen: 14.2 tps, ttft: 5.85 sec After: 300 tokens in 23.0sec, e2e:13.0 tps, prompt: 248.9 tps, gen: 16.6 tps, ttft: 4.99 sec
For phi3.5-gqa-static sum_long(>1000 tokens) on meteor lake.
Before:
300 tokens in 27.0sec, e2e:11.1 tps, prompt: 212.4 tps, gen: 14.2 tps, ttft: 5.85 sec
After:
300 tokens in 23.0sec, e2e:13.0 tps, prompt: 248.9 tps, gen: 16.6 tps, ttft: 4.99 sec