Skip to content

Conversation

@Valentine233
Copy link
Collaborator

@Valentine233 Valentine233 commented Jul 22, 2024

Add int8 woq mm pattern for llama, which successfully hits all the 675 woq linears.

Implementation

Differences with previous patterns:

  • scale is fp32 instead of bf16.
  • no reshape for x and output.

Performance

Performance data on llama, with 1 numa node, freezing mode:

  • Before pattern match:
    ---------- Summary: ----------
    inference-latency: 4.810 sec.
    first-token-latency: 0.280 sec.
    rest-token-latency: 0.146 sec.
    P90-rest-token-latency: 0.148 sec.

  • After pattern match:
    ---------- Summary: ----------
    inference-latency: 2.537 sec.
    first-token-latency: 0.964 sec.
    rest-token-latency: 0.051 sec.
    P90-rest-token-latency: 0.052 sec.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/131310

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f3b3961 with merge base c4bf400 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@leslie-fang-intel
Copy link
Collaborator

Why the first token latency even worse after this PR?

Copy link
Collaborator

@leslie-fang-intel leslie-fang-intel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, why the data type of scales has changed from bf16 to fp32? fp32 here caused overhead of 2 conversion and not sure how much overhead it will induce. Have you tested the UT add in TorchAO, can it hit the pattern matcher now?

@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 23, 2024
@Valentine233
Copy link
Collaborator Author

Summary

  1. The fp32-related discussions above are expected to be solved by [int8 woq] make the scale type the same as input for bf16 autocast ao#534.
  2. The regression of first token needs a further analysis.

@sanchitintel

This comment was marked as outdated.

@Valentine233
Copy link
Collaborator Author

Hi @Valentine233, can you please add info about the precise benchmark you used, perhaps from your bash history? I used the following command, but the benchmark goes on for a long time, and only posts E2E latency. I'm guessing computing per-token latency requires modifying the transformers library. Please confirm, thanks!

 numactl --membind=0 --cpunodebind=0 -C 0-31 python run_llm.py -m "meta-llama/Llama-2-7b-hf" --device cpu --dtype bf16 --output_dir$(pwd)/tmp --batch-size 256 --ws-total-cores 32 --ws-cores-per-instance 32 --weight-dtype INT8 --benchmark --torchao --inductor --num-iter 10 --num-warmup 1 --profile

I'm also seeing this warning. Please advise if it's normal. Thanks!

W0804 17:20:34.220000 2951501 torch/_dynamo/convert_frame.py:828] [0/8] torch._dynamo hit config.cache_size_limit (8)
W0804 17:20:34.220000 2951501 torch/_dynamo/convert_frame.py:828] [0/8]    function: 'forward' (transformers/src/transformers/models/llama/modeling_llama.py:639)
W0804 17:20:34.220000 2951501 torch/_dynamo/convert_frame.py:828] [0/8]    last reason: 0/0: ___check_obj_id(L['past_key_values'], 94511674602464)
W0804 17:20:34.220000 2951501 torch/_dynamo/convert_frame.py:828] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0804 17:20:34.220000 2951501 torch/_dynamo/convert_frame.py:828] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.

This is the script to run on one node: numactl -C 56-111 -m 1 python ../../../../../../models/language_modeling/pytorch/llama/inference/cpu/run_llm.py --benchmark --num-warmup 1 --num-iter 2 --token-latency --dtype 'bf16' -m 'meta-llama/Llama-2-7b-hf' --max-new-tokens 32 --input-tokens 32 --batch-size 1 --weight-only-quant --torchao --weight-dtype INT8. Can't remember about the warning.

@sanchitintel
Copy link
Collaborator

After offline discussion with @Valentine233, it turns out that the runtime error I'm encountering at my end (`torch._dynamo.exc.InternalTorchDynamoError: 'PlainAQTLayout' object has no attribute 'layout_type') is similar to pytorch/ao#534 (comment).

@Valentine233, is it possible for you to rebase your local PyTorch repo & verify whether or not you're also encountering the same problem? Thanks

@Valentine233
Copy link
Collaborator Author

Update for first token regression:

First token has a regression because aten::_weight_int8pack_mm is slower than the decomposed one, given certain input shapes. cc @mingfeima @jgong5 @leslie-fang-intel

First token

For the first token, the input shapes are x: [128, 4096], w: [4096, 4096], scale: [4096].
Profiling
With woq pattern:
image
Without woq pattern:
image

Next token

For the next token, the input shapes are x: [4, 4096], w: [4096, 4096], scale: [4096].
Profiling
With woq pattern:
image
Without woq pattern:
image

Reproduce

You could reproduce the result by running the UT added in this PR.

@sanchitintel

This comment was marked as off-topic.

@sanchitintel
Copy link
Collaborator

sanchitintel commented Aug 8, 2024

While auto-tuning would be 1.5x faster for the first token, it too would've resulted in a regression for the first token.
x: [128, 4096], w: [4096, 4096], scale: [4096] seems to do better if weights are dequantized upfront.

However, the overall performance order (higher is better) is:

Auto-tuning int8 WoQ GEMM (another PR that needs this PR for LLaMA2) > ATen int8 WoQ GEMM kernel (this PR) > Dequantizing weights upfront

pytorchmergebot pushed a commit that referenced this pull request Aug 10, 2024
## Summary

As part of #125683, this PR modifies existing CPU GEMM cpp template & micro-kernel template to enable int8 WoQ GEMM auto-tuning with AVX2, AVX512 & AMX ISAs (the latter is only available on Xeon 4th generation & beyond).

WoQ GEMM takes FP16/BF16 activations, int8 weights, and scale of the same dtype as activations.
The operation is equivalent to `torch.nn.functional.linear(x, w.to(x.dtype)) * scale`, which is essentially what the ATen op `torch.ops.aten._weight_int8pack_mm` currently does (except that weights are not cached by it). Weights will be considered constant & cached, so this implementation is suitable for inference, and not QAT. `scale` is supported as a `mul` epilogue.

Only BF16 activations have been supported in this PR because for FP16 & FP32, weight is dequantized during constant-folding pass of freezing, and then after auto-tuning, performance with a large `M` dimension may be better than either torch.ops.aten._weight_int8pack_mm, or the WoQ micro-kernel support introduced in this PR, which dequantizes `w` within the micro-kernel.
While even BF16 activations with a large `M` dimension may benefit from dequantizing `w` beforehand, for now, they would  use WoQ support in GEMM templates for auto-tuning, and then a subsequent PR would add logic for deciding whether or not to dequantize weights beforehand.

### Performance
#### AMX
Op-level speedup due to AMX micro-kernel (selected during auto-tuning) on 32 physical cores of Intel(R) Xeon(R) Platinum 8468H (of Xeon 4th generation series, codenamed Sapphire Rapids) vs. ATen kernel `torch.ops.aten._weight_int8pack_mm`. Intel OpenMP & tcmalloc were preloaded.

In a few cases with an odd `K`, the implementation being added in this PR may not perform as well as the ATen kernel, which is unrelated to this PR, though, since `test_linear_amx` also exhibits similar datapoints. In those cases, the AMX micro-kernel might be slower than AVX512 micro-kernel, so if such sets of shapes are used for auto-tuning, either the AVX512 micro-kernel implementation, or the ATen kernel would be chosen instead.

Benchmarked with unit-tests.

Tabular data at https://gist.github.com/sanchitintel/294811a86c8ff6b867c668ae2107c405?permalink_comment_id=5142442#gistcomment-5142442

The AVX512 micro-kernel was disabled to collect data for AMX micro-kernel.

#### AVX2/AVX512 micro-kernels

Tabular data at at https://gist.github.com/sanchitintel/52b5fa9c66f791be19e48e2aa6423dc4?permalink_comment_id=5142437#gistcomment-5142437

### Follow-up
1. int4 WoQ GEMM micro-kernel will also be added in a separate PR.
2. A subsequent PR would add logic for deciding whether or not to dequantize weights beforehand.

E2E perf measurement should be done with #131310.

Pull Request resolved: #131887
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
pytorchmergebot pushed a commit that referenced this pull request Aug 14, 2024
## Summary

As part of #125683, this PR modifies existing CPU GEMM cpp template & micro-kernel template to enable int8 WoQ GEMM auto-tuning with AVX2, AVX512 & AMX ISAs (the latter is only available on Xeon 4th generation & beyond).

WoQ GEMM takes FP16/BF16 activations, int8 weights, and scale of the same dtype as activations.
The operation is equivalent to `torch.nn.functional.linear(x, w.to(x.dtype)) * scale`, which is essentially what the ATen op `torch.ops.aten._weight_int8pack_mm` currently does (except that weights are not cached by it). Weights will be considered constant & cached, so this implementation is suitable for inference, and not QAT. `scale` is supported as a `mul` epilogue.

Only BF16 activations have been supported in this PR because for FP16 & FP32, weight is dequantized during constant-folding pass of freezing, and then after auto-tuning, performance with a large `M` dimension may be better than either torch.ops.aten._weight_int8pack_mm, or the WoQ micro-kernel support introduced in this PR, which dequantizes `w` within the micro-kernel.
While even BF16 activations with a large `M` dimension may benefit from dequantizing `w` beforehand, for now, they would  use WoQ support in GEMM templates for auto-tuning, and then a subsequent PR would add logic for deciding whether or not to dequantize weights beforehand.

### Performance
#### AMX
Op-level speedup due to AMX micro-kernel (selected during auto-tuning) on 32 physical cores of Intel(R) Xeon(R) Platinum 8468H (of Xeon 4th generation series, codenamed Sapphire Rapids) vs. ATen kernel `torch.ops.aten._weight_int8pack_mm`. Intel OpenMP & tcmalloc were preloaded.

In a few cases with an odd `K`, the implementation being added in this PR may not perform as well as the ATen kernel, which is unrelated to this PR, though, since `test_linear_amx` also exhibits similar datapoints. In those cases, the AMX micro-kernel might be slower than AVX512 micro-kernel, so if such sets of shapes are used for auto-tuning, either the AVX512 micro-kernel implementation, or the ATen kernel would be chosen instead.

Benchmarked with unit-tests.

Tabular data at https://gist.github.com/sanchitintel/294811a86c8ff6b867c668ae2107c405?permalink_comment_id=5142442#gistcomment-5142442

The AVX512 micro-kernel was disabled to collect data for AMX micro-kernel.

#### AVX2/AVX512 micro-kernels

Tabular data at at https://gist.github.com/sanchitintel/52b5fa9c66f791be19e48e2aa6423dc4?permalink_comment_id=5142437#gistcomment-5142437

### Follow-up
1. int4 WoQ GEMM micro-kernel will also be added in a separate PR.
2. A subsequent PR would add logic for deciding whether or not to dequantize weights beforehand.

E2E perf measurement should be done with #131310.

Pull Request resolved: #131887
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
@sanchitintel
Copy link
Collaborator

sanchitintel commented Aug 19, 2024

Potential fix - since Xeon CPUs are used on machines with a large amount of RAM, we can optionally cache both quantized & dequantized weights for large values of M (based on some heuristic), and lower the corresponding FX pattern to a custom function (like the way quantized_decomposed custom functions have been defined) that accepts both quantized & dequantized weights. Then that custom function should be lowered to a template-based auto-tuning implementation of WoQ GEMM.

In such a case, the auto-tuning GEMM template should allow a fallback that could be used for large values of M, and use the cached dequantized weights if a large value of M would be encountered at runtime.

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 26, 2024
…131887)

## Summary

As part of pytorch#125683, this PR modifies existing CPU GEMM cpp template & micro-kernel template to enable int8 WoQ GEMM auto-tuning with AVX2, AVX512 & AMX ISAs (the latter is only available on Xeon 4th generation & beyond).

WoQ GEMM takes FP16/BF16 activations, int8 weights, and scale of the same dtype as activations.
The operation is equivalent to `torch.nn.functional.linear(x, w.to(x.dtype)) * scale`, which is essentially what the ATen op `torch.ops.aten._weight_int8pack_mm` currently does (except that weights are not cached by it). Weights will be considered constant & cached, so this implementation is suitable for inference, and not QAT. `scale` is supported as a `mul` epilogue.

Only BF16 activations have been supported in this PR because for FP16 & FP32, weight is dequantized during constant-folding pass of freezing, and then after auto-tuning, performance with a large `M` dimension may be better than either torch.ops.aten._weight_int8pack_mm, or the WoQ micro-kernel support introduced in this PR, which dequantizes `w` within the micro-kernel.
While even BF16 activations with a large `M` dimension may benefit from dequantizing `w` beforehand, for now, they would  use WoQ support in GEMM templates for auto-tuning, and then a subsequent PR would add logic for deciding whether or not to dequantize weights beforehand.

### Performance
#### AMX
Op-level speedup due to AMX micro-kernel (selected during auto-tuning) on 32 physical cores of Intel(R) Xeon(R) Platinum 8468H (of Xeon 4th generation series, codenamed Sapphire Rapids) vs. ATen kernel `torch.ops.aten._weight_int8pack_mm`. Intel OpenMP & tcmalloc were preloaded.

In a few cases with an odd `K`, the implementation being added in this PR may not perform as well as the ATen kernel, which is unrelated to this PR, though, since `test_linear_amx` also exhibits similar datapoints. In those cases, the AMX micro-kernel might be slower than AVX512 micro-kernel, so if such sets of shapes are used for auto-tuning, either the AVX512 micro-kernel implementation, or the ATen kernel would be chosen instead.

Benchmarked with unit-tests.

Tabular data at https://gist.github.com/sanchitintel/294811a86c8ff6b867c668ae2107c405?permalink_comment_id=5142442#gistcomment-5142442

The AVX512 micro-kernel was disabled to collect data for AMX micro-kernel.

#### AVX2/AVX512 micro-kernels

Tabular data at at https://gist.github.com/sanchitintel/52b5fa9c66f791be19e48e2aa6423dc4?permalink_comment_id=5142437#gistcomment-5142437

### Follow-up
1. int4 WoQ GEMM micro-kernel will also be added in a separate PR.
2. A subsequent PR would add logic for deciding whether or not to dequantize weights beforehand.

E2E perf measurement should be done with pytorch#131310.

Pull Request resolved: pytorch#131887
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Oct 18, 2024
@github-actions github-actions bot closed this Nov 17, 2024
@github-actions github-actions bot deleted the woq_mm_cpu branch December 18, 2024 02:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor module: inductor open source Stale topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants