[TRTLLM-11091][feat] Add tunable nvfp4 quantize with additional FlashInfer backend#12126
Conversation
📝 WalkthroughWalkthroughThe PR introduces a tunable FP4 quantization framework that automatically selects between TRTLLM and FlashInfer backends. It adds conditional dispatch logic, a kernel runner for profiling and tactic selection via AutoTuner, updates the public interface to use the tunable variant, and includes comprehensive test coverage validating outputs across both backends. Changes
Sequence DiagramsequenceDiagram
participant Input as Input Tensor
participant Tunable as tunable_fp4_quantize
participant AutoTuner as AutoTuner
participant Fp4Runner as Fp4QuantKernelRunner
participant Dispatch as _fp4_quantize_dispatch
participant TRTLLM as TRTLLM Backend
participant FlashInfer as FlashInfer Backend
participant Output as Output (Quantized)
Input->>Tunable: input, scale, vector_size, flag
Tunable->>AutoTuner: Select best tactic
AutoTuner->>Fp4Runner: Profile backends
Fp4Runner->>TRTLLM: Benchmark TRTLLM
Fp4Runner->>FlashInfer: Benchmark FlashInfer (if available)
Fp4Runner-->>AutoTuner: Return recommended tactic
AutoTuner-->>Tunable: Tactic selected
Tunable->>Dispatch: Dispatch with tactic
alt FlashInfer tactic selected
Dispatch->>FlashInfer: Execute quantization
FlashInfer-->>Dispatch: Result (FlashInfer format)
Dispatch->>Dispatch: Reshape to TRTLLM format
else TRTLLM tactic selected or fallback
Dispatch->>TRTLLM: Execute quantization
TRTLLM-->>Dispatch: Result (TRTLLM format)
end
Dispatch-->>Output: Quantized tensor
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/unittest/_torch/thop/parallel/test_fp4_quantize_flashinfer.py`:
- Around line 121-137: The tests call torch.ops.trtllm.tunable_fp4_quantize and
unpack act_fp4, act_sf but never validate act_sf; add assertions that act_sf
matches the reference ref_sf (from torch.ops.trtllm.fp4_quantize) including
shape, dtype/torch tensor type and element-wise equality (use
torch.testing.assert_close with atol=0, rtol=0 or appropriate tensor equality)
so regressions in scale values/layout are caught; update both occurrences (the
block using tunable_fp4_quantize and the similar block at lines 156-186) to
assert shape and value equality for act_sf vs ref_sf, referencing the
act_sf/ref_sf variables and ensuring consistency with
NVFP4LinearMethod._input_prepare forwarding into
nvfp4_gemm/nvfp4_gemm_allreduce.
- Around line 93-103: The test currently compares only a shared prefix of scale
factor tensors, allowing mismatched lengths to pass; change the check to first
assert the flattened shapes are equal (compare trtllm_sf_flat.shape ==
fi_sf_flat.shape) and then call torch.testing.assert_close on the entire tensors
(trtllm_sf_flat and fi_sf_flat) with atol=0 and rtol=0 so the full scale tensor
is validated rather than just a prefix; update the error message to mention both
shape and length mismatches and use the existing variable names trtllm_sf_flat
and fi_sf_flat to locate the assertion.
- Around line 34-35: The class-level `@pytest.mark.skipif` on
TestFp4QuantizeFlashinfer is hiding the TRTLLM fallback paths; remove the
class-level skip and instead apply the skip only to the tests that truly require
FlashInfer: add `@pytest.mark.skipif`(not HAS_FLASHINFER, reason="flashinfer not
available") to the test_tunable_fp4_quantize_op method and to the
FlashInfer-only branch inside test_tunable_fp4_quantize_with_autotune (or split
that branch into a separate test decorated with the skip); keep
TestFp4QuantizeFlashinfer as a normal unittest.TestCase so the default-to-TRTLLM
behavior (tested by test_tunable_fp4_quantize_with_autotune and its
non-FlashInfer branch) runs when FlashInfer is absent.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 7d97e1c3-5f72-4dc3-ac2f-1522a05fec91
📒 Files selected for processing (3)
tensorrt_llm/_torch/custom_ops/torch_custom_ops.pytensorrt_llm/_torch/modules/linear.pytests/unittest/_torch/thop/parallel/test_fp4_quantize_flashinfer.py
|
/bot run |
|
PR_Github #38645 [ run ] triggered by Bot. Commit: |
|
PR_Github #38645 [ run ] completed with state
|
|
/bot run |
|
PR_Github #38670 [ run ] triggered by Bot. Commit: |
|
PR_Github #38670 [ run ] completed with state
|
|
/bot run |
|
PR_Github #38885 [ run ] triggered by Bot. Commit: |
|
PR_Github #38885 [ run ] completed with state
|
|
/bot run |
1 similar comment
|
/bot run |
|
PR_Github #38893 [ run ] triggered by Bot. Commit: |
|
PR_Github #38893 [ run ] completed with state
|
|
/bot run |
|
PR_Github #38900 [ run ] triggered by Bot. Commit: |
|
PR_Github #38900 [ run ] completed with state
|
|
/bot run |
|
PR_Github #38910 [ run ] triggered by Bot. Commit: |
|
PR_Github #38910 [ run ] completed with state
|
|
/bot run |
|
PR_Github #39123 [ run ] triggered by Bot. Commit: |
|
/bot run |
|
PR_Github #43011 [ run ] triggered by Bot. Commit: |
|
/bot run --disable-fail-fast |
|
PR_Github #43016 [ run ] triggered by Bot. Commit: |
|
PR_Github #43011 [ run ] completed with state |
|
/bot run |
|
PR_Github #43083 [ run ] triggered by Bot. Commit: |
|
/bot run |
|
PR_Github #43112 [ run ] triggered by Bot. Commit: |
|
/bot run |
|
PR_Github #43134 [ run ] triggered by Bot. Commit: |
|
PR_Github #43134 [ run ] completed with state |
Signed-off-by: Zhenhua Wang <[email protected]>
|
/bot run --disable-fail-fast |
|
/bot reuse-pipeline |
Signed-off-by: Yiyun Lu <[email protected]>
|
PR_Github #43179 [ run ] triggered by Bot. Commit: |
|
PR_Github #43179 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #43180 [ reuse-pipeline ] triggered by Bot. Commit: |
|
PR_Github #43184 [ run ] triggered by Bot. Commit: |
|
PR_Github #43184 [ run ] completed with state |
|
PR_Github #43180 [ reuse-pipeline ] completed with state |
…Infer backend (NVIDIA#12126) Signed-off-by: Chang Liu <[email protected]> Signed-off-by: Yiyun Lu <[email protected]> Signed-off-by: Zhenhua Wang <[email protected]> Co-authored-by: Yiyun Lu <[email protected]> Co-authored-by: Zhenhua Wang <[email protected]>
Summary
Fp4QuantKernelRunnerandtrtllm::tunable_fp4_quantizecustom op that autoselects between TRTLLM CUDA and FlashInfer FP4 quantization kernels based on profiled performanceNVFP4LinearMethod._input_prepareto use the tunable op for activation quantizationDetails
The design follows the existing
Fp8QuantKernelRunnerpattern:register_fakefortorch.compilecompatibilityThe visual_gen pipeline benefits automatically since its warmup already runs within
autotune()context (pipeline_loader.py L229), which will profile the FP4 quantize kernel for NVFP4-quantized diffusion models.Test plan
pytest tests/unittest/_torch/thop/parallel/test_fp4_quantize_flashinfer.py -v— bitwise output comparison + tunable op + autotune context tests🤖 Generated with Claude Code
Summary by CodeRabbit
Release Notes
New Features
Tests