Skip to content

[TRTLLM-11091][feat] Add tunable nvfp4 quantize with additional FlashInfer backend#12126

Merged
luyiyun1021 merged 10 commits into
NVIDIA:mainfrom
chang-l:feat/tunable-fp4-quantize-flashinfer
Apr 14, 2026
Merged

[TRTLLM-11091][feat] Add tunable nvfp4 quantize with additional FlashInfer backend#12126
luyiyun1021 merged 10 commits into
NVIDIA:mainfrom
chang-l:feat/tunable-fp4-quantize-flashinfer

Conversation

@chang-l
Copy link
Copy Markdown
Collaborator

@chang-l chang-l commented Mar 11, 2026

Summary

  • Add Fp4QuantKernelRunner and trtllm::tunable_fp4_quantize custom op that autoselects between TRTLLM CUDA and FlashInfer FP4 quantization kernels based on profiled performance
  • Update NVFP4LinearMethod._input_prepare to use the tunable op for activation quantization
  • Add unit tests verifying bitwise-identical output between backends and correct behavior under autotune context

Details

The design follows the existing Fp8QuantKernelRunner pattern:

  • During autotune warmup, both TRTLLM and FlashInfer FP4 quantize kernels are profiled for actual input shapes
  • The AutoTuner caches the fastest backend per shape with zero inference-time overhead
  • Falls back to TRTLLM when FlashInfer is unavailable or outside autotune context
  • Includes register_fake for torch.compile compatibility

The visual_gen pipeline benefits automatically since its warmup already runs within autotune() context (pipeline_loader.py L229), which will profile the FP4 quantize kernel for NVFP4-quantized diffusion models.

Test plan

  • pytest tests/unittest/_torch/thop/parallel/test_fp4_quantize_flashinfer.py -v — bitwise output comparison + tunable op + autotune context tests
  • Existing NVFP4 tests should pass unchanged (backward compatible — defaults to TRTLLM without autotune)
  • Visual gen NVFP4 pipeline warmup exercises the new tunable path

🤖 Generated with Claude Code

Summary by CodeRabbit

Release Notes

  • New Features

    • FP4 quantization now supports automatic backend selection between TRTLLM and FlashInfer backends for optimized performance.
    • Introduced tunable FP4 quantization framework that dynamically selects the best-performing backend at runtime.
  • Tests

    • Added comprehensive test suite validating FP4 quantization across multiple shapes, data types, and backends.

@chang-l chang-l requested review from a team as code owners March 11, 2026 23:23
@chang-l chang-l requested review from HuiGao-NV and yizhang-nv March 11, 2026 23:23
@chang-l chang-l changed the title [None][feat] Add tunable FP4 quantize with FlashInfer backend [TRTLLM-11091][feat] Add tunable FP4 quantize with FlashInfer backend Mar 11, 2026
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 11, 2026

📝 Walkthrough

Walkthrough

The PR introduces a tunable FP4 quantization framework that automatically selects between TRTLLM and FlashInfer backends. It adds conditional dispatch logic, a kernel runner for profiling and tactic selection via AutoTuner, updates the public interface to use the tunable variant, and includes comprehensive test coverage validating outputs across both backends.

Changes

Cohort / File(s) Summary
FP4 Quantization Tunable Framework
tensorrt_llm/_torch/custom_ops/torch_custom_ops.py, tensorrt_llm/_torch/modules/linear.py
Introduces conditional FP4 quantization dispatch between TRTLLM and FlashInfer, adds Fp4QuantKernelRunner for backend profiling and tactic selection, exposes tunable_fp4_quantize custom op with AutoTuner integration, provides fake implementation for torch.compile compatibility, and updates linear.py to use the new tunable variant.
FP4 Quantization Test Suite
tests/unittest/_torch/thop/parallel/test_fp4_quantize_flashinfer.py
New test module validating FP4 quantization outputs across TRTLLM and FlashInfer backends with parameterized shapes and dtypes, verifying exact equality of packed outputs, testing tunable_fp4_quantize operation, and asserting autotune context compatibility with graceful fallback if FlashInfer unavailable.

Sequence Diagram

sequenceDiagram
    participant Input as Input Tensor
    participant Tunable as tunable_fp4_quantize
    participant AutoTuner as AutoTuner
    participant Fp4Runner as Fp4QuantKernelRunner
    participant Dispatch as _fp4_quantize_dispatch
    participant TRTLLM as TRTLLM Backend
    participant FlashInfer as FlashInfer Backend
    participant Output as Output (Quantized)

    Input->>Tunable: input, scale, vector_size, flag
    Tunable->>AutoTuner: Select best tactic
    AutoTuner->>Fp4Runner: Profile backends
    Fp4Runner->>TRTLLM: Benchmark TRTLLM
    Fp4Runner->>FlashInfer: Benchmark FlashInfer (if available)
    Fp4Runner-->>AutoTuner: Return recommended tactic
    AutoTuner-->>Tunable: Tactic selected
    Tunable->>Dispatch: Dispatch with tactic
    alt FlashInfer tactic selected
        Dispatch->>FlashInfer: Execute quantization
        FlashInfer-->>Dispatch: Result (FlashInfer format)
        Dispatch->>Dispatch: Reshape to TRTLLM format
    else TRTLLM tactic selected or fallback
        Dispatch->>TRTLLM: Execute quantization
        TRTLLM-->>Dispatch: Result (TRTLLM format)
    end
    Dispatch-->>Output: Quantized tensor
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 61.54% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description check ✅ Passed The PR description is well-structured, clearly explaining the feature, design rationale, and test plan with all required template sections.
Title check ✅ Passed The title accurately and specifically describes the main change: adding a tunable FP4 quantization operator that supports both TRTLLM and FlashInfer backends.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/unittest/_torch/thop/parallel/test_fp4_quantize_flashinfer.py`:
- Around line 121-137: The tests call torch.ops.trtllm.tunable_fp4_quantize and
unpack act_fp4, act_sf but never validate act_sf; add assertions that act_sf
matches the reference ref_sf (from torch.ops.trtllm.fp4_quantize) including
shape, dtype/torch tensor type and element-wise equality (use
torch.testing.assert_close with atol=0, rtol=0 or appropriate tensor equality)
so regressions in scale values/layout are caught; update both occurrences (the
block using tunable_fp4_quantize and the similar block at lines 156-186) to
assert shape and value equality for act_sf vs ref_sf, referencing the
act_sf/ref_sf variables and ensuring consistency with
NVFP4LinearMethod._input_prepare forwarding into
nvfp4_gemm/nvfp4_gemm_allreduce.
- Around line 93-103: The test currently compares only a shared prefix of scale
factor tensors, allowing mismatched lengths to pass; change the check to first
assert the flattened shapes are equal (compare trtllm_sf_flat.shape ==
fi_sf_flat.shape) and then call torch.testing.assert_close on the entire tensors
(trtllm_sf_flat and fi_sf_flat) with atol=0 and rtol=0 so the full scale tensor
is validated rather than just a prefix; update the error message to mention both
shape and length mismatches and use the existing variable names trtllm_sf_flat
and fi_sf_flat to locate the assertion.
- Around line 34-35: The class-level `@pytest.mark.skipif` on
TestFp4QuantizeFlashinfer is hiding the TRTLLM fallback paths; remove the
class-level skip and instead apply the skip only to the tests that truly require
FlashInfer: add `@pytest.mark.skipif`(not HAS_FLASHINFER, reason="flashinfer not
available") to the test_tunable_fp4_quantize_op method and to the
FlashInfer-only branch inside test_tunable_fp4_quantize_with_autotune (or split
that branch into a separate test decorated with the skip); keep
TestFp4QuantizeFlashinfer as a normal unittest.TestCase so the default-to-TRTLLM
behavior (tested by test_tunable_fp4_quantize_with_autotune and its
non-FlashInfer branch) runs when FlashInfer is absent.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 7d97e1c3-5f72-4dc3-ac2f-1522a05fec91

📥 Commits

Reviewing files that changed from the base of the PR and between 7479423 and d7e7372.

📒 Files selected for processing (3)
  • tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
  • tensorrt_llm/_torch/modules/linear.py
  • tests/unittest/_torch/thop/parallel/test_fp4_quantize_flashinfer.py

Comment thread tests/unittest/_torch/thop/parallel/test_fp4_quantize_flashinfer.py Outdated
Comment thread tests/unittest/_torch/thop/parallel/test_fp4_quantize_flashinfer.py Outdated
Comment thread tests/unittest/_torch/thop/parallel/test_fp4_quantize_flashinfer.py
@chang-l
Copy link
Copy Markdown
Collaborator Author

chang-l commented Mar 12, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #38645 [ run ] triggered by Bot. Commit: 88ac6eb Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #38645 [ run ] completed with state SUCCESS. Commit: 88ac6eb
/LLM/main/L0_MergeRequest_PR pipeline #29975 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@chang-l
Copy link
Copy Markdown
Collaborator Author

chang-l commented Mar 12, 2026

/bot run

@chang-l chang-l changed the title [TRTLLM-11091][feat] Add tunable FP4 quantize with FlashInfer backend [TRTLLM-11091][feat] Add tunable nvfp4 quantize with additional FlashInfer backend Mar 12, 2026
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #38670 [ run ] triggered by Bot. Commit: 2a31200 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #38670 [ run ] completed with state SUCCESS. Commit: 2a31200
/LLM/main/L0_MergeRequest_PR pipeline #29994 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@chang-l
Copy link
Copy Markdown
Collaborator Author

chang-l commented Mar 13, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #38885 [ run ] triggered by Bot. Commit: 2a31200 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #38885 [ run ] completed with state SUCCESS. Commit: 2a31200
/LLM/main/L0_MergeRequest_PR pipeline #30193 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@chang-l
Copy link
Copy Markdown
Collaborator Author

chang-l commented Mar 13, 2026

/bot run

1 similar comment
@chang-l
Copy link
Copy Markdown
Collaborator Author

chang-l commented Mar 13, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #38893 [ run ] triggered by Bot. Commit: 2a31200 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #38893 [ run ] completed with state FAILURE. Commit: 2a31200
/LLM/main/L0_MergeRequest_PR pipeline #30201 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@chang-l
Copy link
Copy Markdown
Collaborator Author

chang-l commented Mar 13, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #38900 [ run ] triggered by Bot. Commit: 2a31200 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #38900 [ run ] completed with state FAILURE. Commit: 2a31200
/LLM/main/L0_MergeRequest_PR pipeline #30207 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@chang-l
Copy link
Copy Markdown
Collaborator Author

chang-l commented Mar 13, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #38910 [ run ] triggered by Bot. Commit: 2a31200 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #38910 [ run ] completed with state FAILURE. Commit: 2a31200
/LLM/main/L0_MergeRequest_PR pipeline #30218 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@chang-l
Copy link
Copy Markdown
Collaborator Author

chang-l commented Mar 16, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39123 [ run ] triggered by Bot. Commit: c508220 Link to invocation

@luyiyun1021
Copy link
Copy Markdown
Collaborator

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43011 [ run ] triggered by Bot. Commit: 89aa60c Link to invocation

@luyiyun1021
Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@luyiyun1021 luyiyun1021 requested a review from QiJune April 13, 2026 09:49
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43016 [ run ] triggered by Bot. Commit: 89aa60c Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43011 [ run ] completed with state ABORTED. Commit: 89aa60c

Link to invocation

Comment thread tensorrt_llm/_torch/visual_gen/config.py
@chang-l
Copy link
Copy Markdown
Collaborator Author

chang-l commented Apr 13, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43083 [ run ] triggered by Bot. Commit: 89aa60c Link to invocation

@chang-l
Copy link
Copy Markdown
Collaborator Author

chang-l commented Apr 13, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43112 [ run ] triggered by Bot. Commit: 89aa60c Link to invocation

@chang-l
Copy link
Copy Markdown
Collaborator Author

chang-l commented Apr 14, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43134 [ run ] triggered by Bot. Commit: 89aa60c Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43134 [ run ] completed with state SUCCESS. Commit: 89aa60c
/LLM/main/L0_MergeRequest_PR pipeline #33767 completed with status: 'SUCCESS'

CI Report

Link to invocation

Comment thread tensorrt_llm/_torch/modules/linear.py
@luyiyun1021
Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@zhenhuaw-me
Copy link
Copy Markdown
Member

/bot reuse-pipeline

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43179 [ run ] triggered by Bot. Commit: aaee20b Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

@luyiyun1021
Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43180 [ reuse-pipeline ] triggered by Bot. Commit: 96d4942 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43184 [ run ] triggered by Bot. Commit: 96d4942 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43180 [ reuse-pipeline ] completed with state SUCCESS. Commit: 96d4942
Reusing PR_Github #43134 for commit 96d4942

Link to invocation

@luyiyun1021 luyiyun1021 merged commit 1480140 into NVIDIA:main Apr 14, 2026
5 checks passed
chienchunhung pushed a commit to chienchunhung/TensorRT-LLM that referenced this pull request Apr 16, 2026
…Infer backend (NVIDIA#12126)

Signed-off-by: Chang Liu <[email protected]>
Signed-off-by: Yiyun Lu <[email protected]>
Signed-off-by: Zhenhua Wang <[email protected]>
Co-authored-by: Yiyun Lu <[email protected]>
Co-authored-by: Zhenhua Wang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants