-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[inductor pattern][cpu] add a new int8 woq mm pattern #131310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/131310
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f3b3961 with merge base c4bf400 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Why the first token latency even worse after this PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just curious, why the data type of scales has changed from bf16 to fp32? fp32 here caused overhead of 2 conversion and not sure how much overhead it will induce. Have you tested the UT add in TorchAO, can it hit the pattern matcher now?
|
Summary
|
This comment was marked as outdated.
This comment was marked as outdated.
This is the script to run on one node: |
|
After offline discussion with @Valentine233, it turns out that the runtime error I'm encountering at my end (`torch._dynamo.exc.InternalTorchDynamoError: 'PlainAQTLayout' object has no attribute 'layout_type') is similar to pytorch/ao#534 (comment). @Valentine233, is it possible for you to rebase your local PyTorch repo & verify whether or not you're also encountering the same problem? Thanks |
|
Update for first token regression: First token has a regression because First tokenFor the first token, the input shapes are x: [128, 4096], w: [4096, 4096], scale: [4096]. Next tokenFor the next token, the input shapes are x: [4, 4096], w: [4096, 4096], scale: [4096]. ReproduceYou could reproduce the result by running the UT added in this PR. |
This comment was marked as off-topic.
This comment was marked as off-topic.
|
While auto-tuning would be 1.5x faster for the first token, it too would've resulted in a regression for the first token. However, the overall performance order (higher is better) is: Auto-tuning int8 WoQ GEMM (another PR that needs this PR for LLaMA2) > ATen int8 WoQ GEMM kernel (this PR) > Dequantizing weights upfront |
## Summary As part of #125683, this PR modifies existing CPU GEMM cpp template & micro-kernel template to enable int8 WoQ GEMM auto-tuning with AVX2, AVX512 & AMX ISAs (the latter is only available on Xeon 4th generation & beyond). WoQ GEMM takes FP16/BF16 activations, int8 weights, and scale of the same dtype as activations. The operation is equivalent to `torch.nn.functional.linear(x, w.to(x.dtype)) * scale`, which is essentially what the ATen op `torch.ops.aten._weight_int8pack_mm` currently does (except that weights are not cached by it). Weights will be considered constant & cached, so this implementation is suitable for inference, and not QAT. `scale` is supported as a `mul` epilogue. Only BF16 activations have been supported in this PR because for FP16 & FP32, weight is dequantized during constant-folding pass of freezing, and then after auto-tuning, performance with a large `M` dimension may be better than either torch.ops.aten._weight_int8pack_mm, or the WoQ micro-kernel support introduced in this PR, which dequantizes `w` within the micro-kernel. While even BF16 activations with a large `M` dimension may benefit from dequantizing `w` beforehand, for now, they would use WoQ support in GEMM templates for auto-tuning, and then a subsequent PR would add logic for deciding whether or not to dequantize weights beforehand. ### Performance #### AMX Op-level speedup due to AMX micro-kernel (selected during auto-tuning) on 32 physical cores of Intel(R) Xeon(R) Platinum 8468H (of Xeon 4th generation series, codenamed Sapphire Rapids) vs. ATen kernel `torch.ops.aten._weight_int8pack_mm`. Intel OpenMP & tcmalloc were preloaded. In a few cases with an odd `K`, the implementation being added in this PR may not perform as well as the ATen kernel, which is unrelated to this PR, though, since `test_linear_amx` also exhibits similar datapoints. In those cases, the AMX micro-kernel might be slower than AVX512 micro-kernel, so if such sets of shapes are used for auto-tuning, either the AVX512 micro-kernel implementation, or the ATen kernel would be chosen instead. Benchmarked with unit-tests. Tabular data at https://gist.github.com/sanchitintel/294811a86c8ff6b867c668ae2107c405?permalink_comment_id=5142442#gistcomment-5142442 The AVX512 micro-kernel was disabled to collect data for AMX micro-kernel. #### AVX2/AVX512 micro-kernels Tabular data at at https://gist.github.com/sanchitintel/52b5fa9c66f791be19e48e2aa6423dc4?permalink_comment_id=5142437#gistcomment-5142437 ### Follow-up 1. int4 WoQ GEMM micro-kernel will also be added in a separate PR. 2. A subsequent PR would add logic for deciding whether or not to dequantize weights beforehand. E2E perf measurement should be done with #131310. Pull Request resolved: #131887 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
## Summary As part of #125683, this PR modifies existing CPU GEMM cpp template & micro-kernel template to enable int8 WoQ GEMM auto-tuning with AVX2, AVX512 & AMX ISAs (the latter is only available on Xeon 4th generation & beyond). WoQ GEMM takes FP16/BF16 activations, int8 weights, and scale of the same dtype as activations. The operation is equivalent to `torch.nn.functional.linear(x, w.to(x.dtype)) * scale`, which is essentially what the ATen op `torch.ops.aten._weight_int8pack_mm` currently does (except that weights are not cached by it). Weights will be considered constant & cached, so this implementation is suitable for inference, and not QAT. `scale` is supported as a `mul` epilogue. Only BF16 activations have been supported in this PR because for FP16 & FP32, weight is dequantized during constant-folding pass of freezing, and then after auto-tuning, performance with a large `M` dimension may be better than either torch.ops.aten._weight_int8pack_mm, or the WoQ micro-kernel support introduced in this PR, which dequantizes `w` within the micro-kernel. While even BF16 activations with a large `M` dimension may benefit from dequantizing `w` beforehand, for now, they would use WoQ support in GEMM templates for auto-tuning, and then a subsequent PR would add logic for deciding whether or not to dequantize weights beforehand. ### Performance #### AMX Op-level speedup due to AMX micro-kernel (selected during auto-tuning) on 32 physical cores of Intel(R) Xeon(R) Platinum 8468H (of Xeon 4th generation series, codenamed Sapphire Rapids) vs. ATen kernel `torch.ops.aten._weight_int8pack_mm`. Intel OpenMP & tcmalloc were preloaded. In a few cases with an odd `K`, the implementation being added in this PR may not perform as well as the ATen kernel, which is unrelated to this PR, though, since `test_linear_amx` also exhibits similar datapoints. In those cases, the AMX micro-kernel might be slower than AVX512 micro-kernel, so if such sets of shapes are used for auto-tuning, either the AVX512 micro-kernel implementation, or the ATen kernel would be chosen instead. Benchmarked with unit-tests. Tabular data at https://gist.github.com/sanchitintel/294811a86c8ff6b867c668ae2107c405?permalink_comment_id=5142442#gistcomment-5142442 The AVX512 micro-kernel was disabled to collect data for AMX micro-kernel. #### AVX2/AVX512 micro-kernels Tabular data at at https://gist.github.com/sanchitintel/52b5fa9c66f791be19e48e2aa6423dc4?permalink_comment_id=5142437#gistcomment-5142437 ### Follow-up 1. int4 WoQ GEMM micro-kernel will also be added in a separate PR. 2. A subsequent PR would add logic for deciding whether or not to dequantize weights beforehand. E2E perf measurement should be done with #131310. Pull Request resolved: #131887 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
|
Potential fix - since Xeon CPUs are used on machines with a large amount of RAM, we can optionally cache both quantized & dequantized weights for large values of In such a case, the auto-tuning GEMM template should allow a fallback that could be used for large values of |
…131887) ## Summary As part of pytorch#125683, this PR modifies existing CPU GEMM cpp template & micro-kernel template to enable int8 WoQ GEMM auto-tuning with AVX2, AVX512 & AMX ISAs (the latter is only available on Xeon 4th generation & beyond). WoQ GEMM takes FP16/BF16 activations, int8 weights, and scale of the same dtype as activations. The operation is equivalent to `torch.nn.functional.linear(x, w.to(x.dtype)) * scale`, which is essentially what the ATen op `torch.ops.aten._weight_int8pack_mm` currently does (except that weights are not cached by it). Weights will be considered constant & cached, so this implementation is suitable for inference, and not QAT. `scale` is supported as a `mul` epilogue. Only BF16 activations have been supported in this PR because for FP16 & FP32, weight is dequantized during constant-folding pass of freezing, and then after auto-tuning, performance with a large `M` dimension may be better than either torch.ops.aten._weight_int8pack_mm, or the WoQ micro-kernel support introduced in this PR, which dequantizes `w` within the micro-kernel. While even BF16 activations with a large `M` dimension may benefit from dequantizing `w` beforehand, for now, they would use WoQ support in GEMM templates for auto-tuning, and then a subsequent PR would add logic for deciding whether or not to dequantize weights beforehand. ### Performance #### AMX Op-level speedup due to AMX micro-kernel (selected during auto-tuning) on 32 physical cores of Intel(R) Xeon(R) Platinum 8468H (of Xeon 4th generation series, codenamed Sapphire Rapids) vs. ATen kernel `torch.ops.aten._weight_int8pack_mm`. Intel OpenMP & tcmalloc were preloaded. In a few cases with an odd `K`, the implementation being added in this PR may not perform as well as the ATen kernel, which is unrelated to this PR, though, since `test_linear_amx` also exhibits similar datapoints. In those cases, the AMX micro-kernel might be slower than AVX512 micro-kernel, so if such sets of shapes are used for auto-tuning, either the AVX512 micro-kernel implementation, or the ATen kernel would be chosen instead. Benchmarked with unit-tests. Tabular data at https://gist.github.com/sanchitintel/294811a86c8ff6b867c668ae2107c405?permalink_comment_id=5142442#gistcomment-5142442 The AVX512 micro-kernel was disabled to collect data for AMX micro-kernel. #### AVX2/AVX512 micro-kernels Tabular data at at https://gist.github.com/sanchitintel/52b5fa9c66f791be19e48e2aa6423dc4?permalink_comment_id=5142437#gistcomment-5142437 ### Follow-up 1. int4 WoQ GEMM micro-kernel will also be added in a separate PR. 2. A subsequent PR would add logic for deciding whether or not to dequantize weights beforehand. E2E perf measurement should be done with pytorch#131310. Pull Request resolved: pytorch#131887 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |




Add int8 woq mm pattern for llama, which successfully hits all the 675 woq linears.
Implementation
Differences with previous patterns:
Performance
Performance data on llama, with 1 numa node, freezing mode:
Before pattern match:
---------- Summary: ----------
inference-latency: 4.810 sec.
first-token-latency: 0.280 sec.
rest-token-latency: 0.146 sec.
P90-rest-token-latency: 0.148 sec.
After pattern match:
---------- Summary: ----------
inference-latency: 2.537 sec.
first-token-latency: 0.964 sec.
rest-token-latency: 0.051 sec.
P90-rest-token-latency: 0.052 sec.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov