[TRTLLM-11127][feat] add W4A8_MXFP4_FP8 MoE unit test support#13401
Conversation
📝 WalkthroughWalkthroughThe changes introduce support for a new quantization algorithm Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/unittest/_torch/modules/moe/quantize_utils.py`:
- Around line 1683-1700: The test's GatedMLP reference path is supposed to use
bf16 but currently passes dtype=dtype into super().__init__ and later calls
.to(dtype=self.dtype,...), which lets fp16 tests reuse the fp16 path; change the
constructor call to force bfloat16 (e.g., pass dtype=torch.bfloat16 or
equivalent bf16 constant) so the experts are built in bf16 (keep ModelConfig()
without quant_config), and ensure any subsequent .to(...) that currently
converts expert submodules does not convert them back to the test dtype (instead
convert only the outer wrapper or skip converting the expert weights). Apply the
same change at the other occurrence noted (lines 1729-1743).
In `@tests/unittest/_torch/modules/moe/test_moe_module.py`:
- Around line 1125-1145: Move the known-bad predicate that skips TRTLLM +
W4A8_MXFP4_FP8 large configs out of the local block guarding
test_configurable_moe_single_gpu() and into the shared test-parameter filtering
used by generate_multi_gpu_test_params(); specifically, extract the conditional
referencing quant_algo == QuantAlgo.W4A8_MXFP4_FP8, moe_backend ==
MoeBackendType.TRTLLM.value, model_config.num_experts >= 60, and
model_config.intermediate_size >= 1408 and apply it in the centralized/shared
skip/filter function so both single-GPU and multi-GPU matrices (including tests
produced by generate_multi_gpu_test_params()) skip this known-bad configuration.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 4651f86b-e826-4310-90dd-918dd86a56ec
📒 Files selected for processing (4)
tests/unittest/_torch/modules/moe/moe_test_utils.pytests/unittest/_torch/modules/moe/quantize_utils.pytests/unittest/_torch/modules/moe/test_moe_backend.pytests/unittest/_torch/modules/moe/test_moe_module.py
The kernel-faithful MXFP4FP8RefGatedMLPFusedMoE (static FP8 round-trip on FC1/FC2 inputs) brings the generic top_k=1 cases and gpt-oss-style top_k>=2 cases inside tolerance for the TRTLLM Gen backend. The single remaining gap is the triple intersection W4A8_MXFP4_FP8 + top_k=1 + swiglu_gptoss_style=True: with only one routed token per expert, the load-time static per-tensor FP8 scales no longer cover the FC2 activation range that the gpt-oss SwiGLU '*(linear + 1)' term produces, and ref vs kernel diverge ~92-94%. Mirror the existing W4A8_MXFP4_MXFP8 + gpt-oss + top_k=1 skip and the original PR NVIDIA#13401 design-limitation rationale, but only on this narrow intersection. The broader top_k=1 coverage and the gpt-oss top_k>=2 coverage stay enabled. Signed-off-by: xxi <[email protected]>
6c66f1b to
360cf31
Compare
384be56 to
393947b
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #48507 [ run ] triggered by Bot. Commit: |
|
PR_Github #48507 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #48794 [ run ] triggered by Bot. Commit: |
|
PR_Github #48794 [ run ] completed with state |
Adds W4A8_MXFP4_FP8 coverage to the PyTorch MoE unit tests for both
the CUTLASS and TRTLLM-Gen fused MoE methods.
Test infrastructure
- quantize_utils.py:
* Extend get_test_quant_params for W4A8_MXFP4_FP8 with backend-specific
alignment (CUTLASS 128/128, TRTLLM 128/512).
* Loosen MXFP4MXFP8QuantizeUtil.create_weights to accept both
W4A8_MXFP4_MXFP8 and W4A8_MXFP4_FP8.
* Add MXFP4FP8QuantizeUtil for per-expert per-tensor input-scale
population.
* Add a dedicated MXFP4FP8RefGatedMLPFusedMoE reference: the original
GatedMLP+W4A8MXFP4FP8LinearMethod path was off by ~50-70x because
W4A8MXFP4FP8LinearMethod.apply passes a dynamic per-tensor FP8
input_scale into trtllm::w4a8_mxfp4_fp8_gemm (wired to
FP4GemmType.W4A8_MXFP4_MXFP8, expects per-block scales). The new
reference dequantizes MXFP4 weights at load time via
trtllm.mxfp4_dequantize_unswizzled and emulates the kernel's static
per-tensor FP8 activation round-trip on FC1 / FC2 inputs, mirroring
W4A8MXFP4FP8TRTLLMGenFusedMoEMethod.load_quant_scales /
W4A8MXFP4FP8CutlassFusedMoEMethod.quantize_input.
- moe_test_utils.py:
* Wire W4A8_MXFP4_FP8 into trtllm_gen_quant_algos and the
is_mxfp4_variant auto-pad set; extend the 128-alignment skip to
both MXFP4 A8 variants.
* Add a parallel_mode kwarg to should_skip_trtllm so DEP can be
distinguished from TEP (both have moe_tp_size=1) for kernel-bug
skips.
- test_moe_backend.py / test_moe_module.py: register the new algo in
QUANT_ALGOS_TO_TEST / QUANT_ALGOS, extend the MXFP4 weight-prep path
in prepare_weights_from_backend, and forward swiglu_gptoss_style /
parallel_mode to should_skip_trtllm in the multi-GPU param generator.
Documented kernel-bug skips (each tied to an NVBug; remove once fixed)
- NVBUG-6178914 - W4A8_MXFP4_FP8 + swiglu_gptoss_style on TRTLLMGen
produces output ~600-800x smaller than the bf16 reference across all
top_k / shapes / parallel modes; same kernel matches reference under
default SwiGLU and CUTLASS matches reference under gpt-oss SwiGLU,
isolating the bug to the TRTLLM-Gen gpt-oss SwiGLU epilogue scale
path (10 cases).
- NVBUG-6178915 - W4A8_MXFP4_FP8 + parallel_mode==DEP +
NVLINK_ONE_SIDED / NVLINK_TWO_SIDED MoeAllReduce produces ~97%
mismatch under default SwiGLU; same kernel + quant passes under
TTP+NVLINK and DEP+DEEPEP / DEP+DEEPEPLOWLATENCY, isolating the bug
to the DEP+NVLINK MoeAllReduce combine path consuming the FP8
per-tensor-scale output (4 cases). Excludes the gpt-oss subset
already covered by NVBUG-6178914.
Verified on GB200 (OCI): backend test 4 passed / 230 deselected;
module test 16 skipped / 634 deselected / 0 failed (subsequent runs
under the new skips: 0 failed across single + multi-GPU matrices).
Signed-off-by: xxi <[email protected]>
898c373 to
fc0fa7b
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #48913 [ run ] triggered by Bot. Commit: |
|
PR_Github #48913 [ run ] completed with state
|
|
/bot run --disable-fail-fas |
|
/bot run --disable-fail-fast |
|
PR_Github #48995 [ run ] triggered by Bot. Commit: |
|
PR_Github #48996 [ run ] triggered by Bot. Commit: |
|
PR_Github #48995 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #48999 [ run ] triggered by Bot. Commit: |
|
PR_Github #48996 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #49137 [ run ] triggered by Bot. Commit: |
|
PR_Github #49134 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #49138 [ run ] triggered by Bot. Commit: |
|
PR_Github #49137 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #49141 [ run ] triggered by Bot. Commit: |
|
PR_Github #49138 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #49144 [ run ] triggered by Bot. Commit: |
|
PR_Github #49141 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #49145 [ run ] triggered by Bot. Commit: |
|
PR_Github #49144 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #49148 [ run ] triggered by Bot. Commit: |
|
PR_Github #49145 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #49149 [ run ] triggered by Bot. Commit: |
|
PR_Github #49148 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #49152 [ run ] triggered by Bot. Commit: |
|
PR_Github #49149 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #49155 [ run ] triggered by Bot. Commit: |
|
PR_Github #49152 [ run ] completed with state |
|
PR_Github #49155 [ run ] completed with state |
Add W4A8_MXFP4_FP8 coverage to both test_moe_backend.py and test_moe_module.py, supporting the CUTLASS and TRTLLM-gen fused MoE methods.
Key pieces
Documented skips (each with rationale in code)
Verified on GB200 (OCI): backend test 4 passed / 230 deselected; module test 16 skipped / 634 deselected / 0 failed.
Summary by CodeRabbit
Release Notes
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.