-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[CPU] Improve QMoE kernel #25822
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CPU] Improve QMoE kernel #25822
Conversation
This reverts commit a0251ea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
onnxruntime/contrib_ops/cpu/quantization/moe_quantization_cpu.cc
Outdated
Show resolved
Hide resolved
onnxruntime/contrib_ops/cpu/quantization/moe_quantization_cpu.cc
Outdated
Show resolved
Hide resolved
onnxruntime/contrib_ops/cpu/quantization/moe_quantization_cpu.cc
Outdated
Show resolved
Hide resolved
onnxruntime/contrib_ops/cpu/quantization/moe_quantization_cpu.cc
Outdated
Show resolved
Hide resolved
onnxruntime/contrib_ops/cpu/quantization/moe_quantization_cpu.cc
Outdated
Show resolved
Hide resolved
onnxruntime/contrib_ops/cpu/quantization/moe_quantization_cpu.cc
Outdated
Show resolved
Hide resolved
onnxruntime/contrib_ops/cpu/quantization/moe_quantization_cpu.cc
Outdated
Show resolved
Hide resolved
onnxruntime/contrib_ops/cpu/quantization/moe_quantization_cpu.cc
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
a39580c to
27c1c05
Compare
c9cdf68 to
a7978f8
Compare
tianleiwu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good.
Future performance improvements:
(1) For decoding (number of tokens is 1), we shall have an optimized kernel to increase throughput.
(2) Consider to have a provider option to enable prepacking for dequantization.
(3) Use block quantization and MLAS n-bit gemm kernel to speed up.
This pull request introduces several improvements and refactorings to the quantized Mixture-of-Experts (QMoE) operator in ONNX Runtime, focusing on enhanced support for FP32 mode, improved SwiGLU activation handling, and better test coverage. The most important changes are grouped below by theme. - Added explicit registration and support for `QMoE` operator with both `MLFloat16` and `float` data types, enabling FP32 (non-quantized) mode in addition to quantized modes. This includes updates to kernel registration and schema/type constraints. [[1]](diffhunk://#diff-fd949b2a9885f634c37c2048da9e35d227ed20adf1d7baf5de488f304a78bde9L109-R110) [[2]](diffhunk://#diff-fd949b2a9885f634c37c2048da9e35d227ed20adf1d7baf5de488f304a78bde9L275-R277) [[3]](diffhunk://#diff-81f57d9adc2cce94f85a2949a895b7ff82efcc13d05e23ee6567661f0fecb7c0L1467-R1467) [[4]](diffhunk://#diff-81f57d9adc2cce94f85a2949a895b7ff82efcc13d05e23ee6567661f0fecb7c0L1548-R1548) - Refactored `ApplySwiGLUActivation` to accept configurable `activation_alpha` and `activation_beta` parameters, matching CUDA behavior and allowing flexibility in activation function tuning. Also, dropped support for non-interleaved memory layouts (now not implemented). [[1]](diffhunk://#diff-4e4afb8dcdade0abe18bd8bea68b148b4090cd86d60a1b1422c049960231737dR49-R60) [[2]](diffhunk://#diff-edb344a38502bba9a0083ab98e274ec1b5b2606639a61df7be474a600a7b99d2L29-R61) [[3]](diffhunk://#diff-f85806c745243652a0336da094126687a6c0d14b19fe760abe73df1d940dc4cbL12-R13) - Now reads `activation_alpha` and `activation_beta` attributes from operator parameters, defaulting to values appropriate for SwiGLU. - Refactored the QMoE operator to clarify separation between quantized and FP32 implementations, and restructured internal methods for better maintainability. Added template parameterization for data types and improved handling of expert weights and biases. [[1]](diffhunk://#diff-e54124baa488af74400fae0f0dbd5cf7d4f1e307c0a5ba0e9dc79622e1315cd5R13-R35) [[2]](diffhunk://#diff-e54124baa488af74400fae0f0dbd5cf7d4f1e307c0a5ba0e9dc79622e1315cd5L38-R55) [[3]](diffhunk://#diff-e54124baa488af74400fae0f0dbd5cf7d4f1e307c0a5ba0e9dc79622e1315cd5L58-L59) - Removed legacy shape/layout support in QMoE input validation, enforcing only the new memory layout for expert weights and improving consistency and forward compatibility. - Updated unit tests for QMoE to use correct zero-point values for quantized weights (e.g., 0x88 for int4, 128 for int8), ensuring that test cases accurately reflect expected zero-output behavior for zero weights. Also clarified comments and expected outputs for SwiGLU and quantized scenarios. [[1]](diffhunk://#diff-27ea1ef8d40401d116e653d6b935304a7ad68ee8300d04ea98e814c585abee75L1340-R1349) [[2]](diffhunk://#diff-27ea1ef8d40401d116e653d6b935304a7ad68ee8300d04ea98e814c585abee75L1379-R1380) [[3]](diffhunk://#diff-27ea1ef8d40401d116e653d6b935304a7ad68ee8300d04ea98e814c585abee75L1404-R1413) [[4]](diffhunk://#diff-27ea1ef8d40401d116e653d6b935304a7ad68ee8300d04ea98e814c585abee75L1525-R1538) These changes collectively improve the flexibility, correctness, and maintainability of the QMoE operator in ONNX Runtime. Unit test result ``` sRunning test: batch_size=1, sequence_length=8, quant_bits=4, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 4-bit: max_diff = 0.000372 .Running test: batch_size=1, sequence_length=8, quant_bits=8, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 8-bit: max_diff = 0.000392 .Running test: batch_size=1, sequence_length=32, quant_bits=4, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 4-bit: max_diff = 0.000470 .Running test: batch_size=1, sequence_length=32, quant_bits=8, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 8-bit: max_diff = 0.000442 .Running test: batch_size=4, sequence_length=8, quant_bits=4, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 4-bit: max_diff = 0.000470 .Running test: batch_size=4, sequence_length=8, quant_bits=8, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 8-bit: max_diff = 0.000442 .Running test: batch_size=4, sequence_length=32, quant_bits=4, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 4-bit: max_diff = 0.000609 .Running test: batch_size=4, sequence_length=32, quant_bits=8, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 8-bit: max_diff = 0.000702 . ---------------------------------------------------------------------- Ran 9 tests in 46.754s OK (skipped=1) ``` --------- Co-authored-by: Tianlei Wu <[email protected]>
This pull request introduces several improvements and refactorings to the quantized Mixture-of-Experts (QMoE) operator in ONNX Runtime, focusing on enhanced support for FP32 mode, improved SwiGLU activation handling, and better test coverage. The most important changes are grouped below by theme. ### Operator Registration and Type Support - Added explicit registration and support for `QMoE` operator with both `MLFloat16` and `float` data types, enabling FP32 (non-quantized) mode in addition to quantized modes. This includes updates to kernel registration and schema/type constraints. [[1]](diffhunk://#diff-fd949b2a9885f634c37c2048da9e35d227ed20adf1d7baf5de488f304a78bde9L109-R110) [[2]](diffhunk://#diff-fd949b2a9885f634c37c2048da9e35d227ed20adf1d7baf5de488f304a78bde9L275-R277) [[3]](diffhunk://#diff-81f57d9adc2cce94f85a2949a895b7ff82efcc13d05e23ee6567661f0fecb7c0L1467-R1467) [[4]](diffhunk://#diff-81f57d9adc2cce94f85a2949a895b7ff82efcc13d05e23ee6567661f0fecb7c0L1548-R1548) ### SwiGLU Activation Improvements - Refactored `ApplySwiGLUActivation` to accept configurable `activation_alpha` and `activation_beta` parameters, matching CUDA behavior and allowing flexibility in activation function tuning. Also, dropped support for non-interleaved memory layouts (now not implemented). [[1]](diffhunk://#diff-4e4afb8dcdade0abe18bd8bea68b148b4090cd86d60a1b1422c049960231737dR49-R60) [[2]](diffhunk://#diff-edb344a38502bba9a0083ab98e274ec1b5b2606639a61df7be474a600a7b99d2L29-R61) [[3]](diffhunk://#diff-f85806c745243652a0336da094126687a6c0d14b19fe760abe73df1d940dc4cbL12-R13) - Now reads `activation_alpha` and `activation_beta` attributes from operator parameters, defaulting to values appropriate for SwiGLU. ### QMoE Operator Implementation Refactor - Refactored the QMoE operator to clarify separation between quantized and FP32 implementations, and restructured internal methods for better maintainability. Added template parameterization for data types and improved handling of expert weights and biases. [[1]](diffhunk://#diff-e54124baa488af74400fae0f0dbd5cf7d4f1e307c0a5ba0e9dc79622e1315cd5R13-R35) [[2]](diffhunk://#diff-e54124baa488af74400fae0f0dbd5cf7d4f1e307c0a5ba0e9dc79622e1315cd5L38-R55) [[3]](diffhunk://#diff-e54124baa488af74400fae0f0dbd5cf7d4f1e307c0a5ba0e9dc79622e1315cd5L58-L59) ### Shape Checking and Layout - Removed legacy shape/layout support in QMoE input validation, enforcing only the new memory layout for expert weights and improving consistency and forward compatibility. ### Test and Documentation Updates - Updated unit tests for QMoE to use correct zero-point values for quantized weights (e.g., 0x88 for int4, 128 for int8), ensuring that test cases accurately reflect expected zero-output behavior for zero weights. Also clarified comments and expected outputs for SwiGLU and quantized scenarios. [[1]](diffhunk://#diff-27ea1ef8d40401d116e653d6b935304a7ad68ee8300d04ea98e814c585abee75L1340-R1349) [[2]](diffhunk://#diff-27ea1ef8d40401d116e653d6b935304a7ad68ee8300d04ea98e814c585abee75L1379-R1380) [[3]](diffhunk://#diff-27ea1ef8d40401d116e653d6b935304a7ad68ee8300d04ea98e814c585abee75L1404-R1413) [[4]](diffhunk://#diff-27ea1ef8d40401d116e653d6b935304a7ad68ee8300d04ea98e814c585abee75L1525-R1538) These changes collectively improve the flexibility, correctness, and maintainability of the QMoE operator in ONNX Runtime. Unit test result ``` sRunning test: batch_size=1, sequence_length=8, quant_bits=4, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 4-bit: max_diff = 0.000372 .Running test: batch_size=1, sequence_length=8, quant_bits=8, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 8-bit: max_diff = 0.000392 .Running test: batch_size=1, sequence_length=32, quant_bits=4, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 4-bit: max_diff = 0.000470 .Running test: batch_size=1, sequence_length=32, quant_bits=8, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 8-bit: max_diff = 0.000442 .Running test: batch_size=4, sequence_length=8, quant_bits=4, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 4-bit: max_diff = 0.000470 .Running test: batch_size=4, sequence_length=8, quant_bits=8, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 8-bit: max_diff = 0.000442 .Running test: batch_size=4, sequence_length=32, quant_bits=4, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 4-bit: max_diff = 0.000609 .Running test: batch_size=4, sequence_length=32, quant_bits=8, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 8-bit: max_diff = 0.000702 . ---------------------------------------------------------------------- Ran 9 tests in 46.754s OK (skipped=1) ``` --------- Co-authored-by: Tianlei Wu <[email protected]>
This pull request introduces several improvements and refactorings to the quantized Mixture-of-Experts (QMoE) operator in ONNX Runtime, focusing on enhanced support for FP32 mode, improved SwiGLU activation handling, and better test coverage. The most important changes are grouped below by theme. ### Operator Registration and Type Support - Added explicit registration and support for `QMoE` operator with both `MLFloat16` and `float` data types, enabling FP32 (non-quantized) mode in addition to quantized modes. This includes updates to kernel registration and schema/type constraints. [[1]](diffhunk://#diff-fd949b2a9885f634c37c2048da9e35d227ed20adf1d7baf5de488f304a78bde9L109-R110) [[2]](diffhunk://#diff-fd949b2a9885f634c37c2048da9e35d227ed20adf1d7baf5de488f304a78bde9L275-R277) [[3]](diffhunk://#diff-81f57d9adc2cce94f85a2949a895b7ff82efcc13d05e23ee6567661f0fecb7c0L1467-R1467) [[4]](diffhunk://#diff-81f57d9adc2cce94f85a2949a895b7ff82efcc13d05e23ee6567661f0fecb7c0L1548-R1548) ### SwiGLU Activation Improvements - Refactored `ApplySwiGLUActivation` to accept configurable `activation_alpha` and `activation_beta` parameters, matching CUDA behavior and allowing flexibility in activation function tuning. Also, dropped support for non-interleaved memory layouts (now not implemented). [[1]](diffhunk://#diff-4e4afb8dcdade0abe18bd8bea68b148b4090cd86d60a1b1422c049960231737dR49-R60) [[2]](diffhunk://#diff-edb344a38502bba9a0083ab98e274ec1b5b2606639a61df7be474a600a7b99d2L29-R61) [[3]](diffhunk://#diff-f85806c745243652a0336da094126687a6c0d14b19fe760abe73df1d940dc4cbL12-R13) - Now reads `activation_alpha` and `activation_beta` attributes from operator parameters, defaulting to values appropriate for SwiGLU. ### QMoE Operator Implementation Refactor - Refactored the QMoE operator to clarify separation between quantized and FP32 implementations, and restructured internal methods for better maintainability. Added template parameterization for data types and improved handling of expert weights and biases. [[1]](diffhunk://#diff-e54124baa488af74400fae0f0dbd5cf7d4f1e307c0a5ba0e9dc79622e1315cd5R13-R35) [[2]](diffhunk://#diff-e54124baa488af74400fae0f0dbd5cf7d4f1e307c0a5ba0e9dc79622e1315cd5L38-R55) [[3]](diffhunk://#diff-e54124baa488af74400fae0f0dbd5cf7d4f1e307c0a5ba0e9dc79622e1315cd5L58-L59) ### Shape Checking and Layout - Removed legacy shape/layout support in QMoE input validation, enforcing only the new memory layout for expert weights and improving consistency and forward compatibility. ### Test and Documentation Updates - Updated unit tests for QMoE to use correct zero-point values for quantized weights (e.g., 0x88 for int4, 128 for int8), ensuring that test cases accurately reflect expected zero-output behavior for zero weights. Also clarified comments and expected outputs for SwiGLU and quantized scenarios. [[1]](diffhunk://#diff-27ea1ef8d40401d116e653d6b935304a7ad68ee8300d04ea98e814c585abee75L1340-R1349) [[2]](diffhunk://#diff-27ea1ef8d40401d116e653d6b935304a7ad68ee8300d04ea98e814c585abee75L1379-R1380) [[3]](diffhunk://#diff-27ea1ef8d40401d116e653d6b935304a7ad68ee8300d04ea98e814c585abee75L1404-R1413) [[4]](diffhunk://#diff-27ea1ef8d40401d116e653d6b935304a7ad68ee8300d04ea98e814c585abee75L1525-R1538) These changes collectively improve the flexibility, correctness, and maintainability of the QMoE operator in ONNX Runtime. Unit test result ``` sRunning test: batch_size=1, sequence_length=8, quant_bits=4, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 4-bit: max_diff = 0.000372 .Running test: batch_size=1, sequence_length=8, quant_bits=8, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 8-bit: max_diff = 0.000392 .Running test: batch_size=1, sequence_length=32, quant_bits=4, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 4-bit: max_diff = 0.000470 .Running test: batch_size=1, sequence_length=32, quant_bits=8, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 8-bit: max_diff = 0.000442 .Running test: batch_size=4, sequence_length=8, quant_bits=4, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 4-bit: max_diff = 0.000470 .Running test: batch_size=4, sequence_length=8, quant_bits=8, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 8-bit: max_diff = 0.000442 .Running test: batch_size=4, sequence_length=32, quant_bits=4, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 4-bit: max_diff = 0.000609 .Running test: batch_size=4, sequence_length=32, quant_bits=8, use_swiglu=True, swiglu_interleaved=True Parity check - SwiGLU(interleaved=True) 8-bit: max_diff = 0.000702 . ---------------------------------------------------------------------- Ran 9 tests in 46.754s OK (skipped=1) ``` --------- Co-authored-by: Tianlei Wu <[email protected]>
### Description Cherry-pick the following PRs: #25943 #25937 #25917 #25909 #25898 #25897 #25888 #25881 #25830 #25619 #25575 #25572 #25558 #25530 #25474 #25455 #25110 Also two dependent PRs for qMoE cpu: #25877 #25822 --------- Co-authored-by: xiaomsft <[email protected]> Co-authored-by: Xiaoyan Hu <[email protected]> Co-authored-by: Akshay Sonawane <[email protected]> Co-authored-by: Kunal Vaishnavi <[email protected]> Co-authored-by: Pradeep Sakhamoori <[email protected]> Co-authored-by: mingyue <[email protected]> Co-authored-by: Maximilian Müller <[email protected]> Co-authored-by: Adrian Lizarraga <[email protected]> Co-authored-by: Dmitri Smirnov <[email protected]> Co-authored-by: Emmanuel <[email protected]> Co-authored-by: Emmanuel Assumang <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: praneshgo <[email protected]> Co-authored-by: Hariharan Seshadri <[email protected]> Co-authored-by: Jing Fang <[email protected]> Co-authored-by: Ishwar Raut <[email protected]>
This pull request introduces several improvements and refactorings to the quantized Mixture-of-Experts (QMoE) operator in ONNX Runtime, focusing on enhanced support for FP32 mode, improved SwiGLU activation handling, and better test coverage. The most important changes are grouped below by theme.
Operator Registration and Type Support
QMoEoperator with bothMLFloat16andfloatdata types, enabling FP32 (non-quantized) mode in addition to quantized modes. This includes updates to kernel registration and schema/type constraints. [1] [2] [3] [4]SwiGLU Activation Improvements
ApplySwiGLUActivationto accept configurableactivation_alphaandactivation_betaparameters, matching CUDA behavior and allowing flexibility in activation function tuning. Also, dropped support for non-interleaved memory layouts (now not implemented). [1] [2] [3]activation_alphaandactivation_betaattributes from operator parameters, defaulting to values appropriate for SwiGLU.QMoE Operator Implementation Refactor
Shape Checking and Layout
Test and Documentation Updates
These changes collectively improve the flexibility, correctness, and maintainability of the QMoE operator in ONNX Runtime.
Unit test result