[None][perf] Integrate the flashinfer gdn prefill kernel for qwen3.5#13644
Conversation
7fc33eb to
e6f5624
Compare
|
/bot run --disable-fail-fast |
e6f5624 to
f73c47a
Compare
|
PR_Github #47115 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughThis PR introduces a FlashInfer-backed adapter for the Gated Delta Net (GDN) chunk attention operator. A new ChangesFlashInfer GDN Adapter and Integration
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (5)
tests/unittest/_torch/modules/mamba/test_flashinfer_chunk_gdn.py (3)
10-12: 💤 Low valueDrop
from __future__and use the Python 3.10+listbuilt-in.
from typing import Listshould be replaced by the built-inlisttype, andfrom __future__ import annotationsis unnecessary.♻️ Suggested change
-from __future__ import annotations - -from typing import List - import pytest import torch`@torch.no_grad`() def _make_inputs( - seq_lens: List[int], + seq_lens: list[int],Based on learnings: Python 3.10+ is required throughout the codebase and
from __future__ import annotationsis not needed. As per coding guidelines: "Prefer using built-in typeslist,dict,tupleinstead of legacytyping.List."Also applies to: 39-39
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unittest/_torch/modules/mamba/test_flashinfer_chunk_gdn.py` around lines 10 - 12, Remove the unnecessary future import line "from __future__ import annotations" and replace any use of the typing alias "List" with the built-in "list" type; specifically delete the import "from typing import List" and update all type annotations in this test (including the other occurrence around the original line 39) from "List[...]" to "list[...]" so the file uses Python 3.10+ built-ins and no future import.
1-30: Missing perf test coverage for a[perf]-tagged kernel change.This PR swaps the prefill attention kernel path (Triton → FlashInfer) for Qwen3.5 GDN, which is explicitly performance-sensitive. The added tests are all unit/parity tests and do not assert any throughput or latency improvement. Per QA guidelines for PRs touching attention kernels, please verify:
- Is there an entry in
tests/integration/test_lists/test-db/l0_perf.yml(or the appropriate per-GPUl0_*.yml) that will catch a FlashInfer GDN prefill regression pre-merge?- If no such entry exists, consider adding a perf test in
tests/integration/defs/perf/test_perf_sanity.pyto establish a latency baseline for Qwen3.5 prefill underTLLM_USE_FLASHINFER_GDN_PREFILL=1vs=0.QA list updates to
llm_function_core.txtare not required for these unit tests alone, but the absence of any performance assertion means a future regression in the FlashInfer path would not be caught in CI.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unittest/_torch/modules/mamba/test_flashinfer_chunk_gdn.py` around lines 1 - 30, Add perf coverage for the FlashInfer GDN prefill path: either add an entry for the new FlashInfer-prefill case to the appropriate L0 perf test list (tests/integration/test_lists/test-db/l0_perf.yml or per-GPU l0_*.yml) that will run with TLLM_USE_FLASHINFER_GDN_PREFILL=1, or add a simple latency baseline test in tests/integration/defs/perf/test_perf_sanity.py that measures Qwen3.5 prefill latency for Qwen3NextGatedDeltaNet.forward_extend with TLLM_USE_FLASHINFER_GDN_PREFILL toggled between 1 and 0; ensure the new perf test targets the same input shapes exercised by the unit tests so regressions in the FlashInfer prefill path are caught in CI.
60-65: 💤 Low valueFix ruff RUF005: prefer iterable unpacking over list concatenation.
♻️ Suggested fix
cu = torch.tensor( - [0] + list(torch.tensor(seq_lens).cumsum(0).tolist()), + [0, *torch.tensor(seq_lens).cumsum(0).tolist()], dtype=torch.int64, device=device, )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unittest/_torch/modules/mamba/test_flashinfer_chunk_gdn.py` around lines 60 - 65, Replace the list concatenation used to build cu with iterable unpacking: instead of torch.tensor([0] + list(torch.tensor(seq_lens).cumsum(0).tolist()), ...), construct cu via torch.tensor((0, *torch.tensor(seq_lens).cumsum(0).tolist()), dtype=torch.int64, device=device). Update the expression that creates cu (the variable returned alongside q, k, v, g, beta) to use the tuple unpacking form to satisfy RUF005.tensorrt_llm/_torch/modules/fla/flashinfer_chunk.py (2)
35-35: 💤 Low valueUse Python 3.10+ built-in type syntax per coding guidelines.
from typing import Optional, TupleandTuple[torch.Tensor, Optional[torch.Tensor]]should use the modern style. Thefrom __future__ import annotationsimport is also redundant for Python 3.10+.♻️ Suggested change
-from __future__ import annotations - -from typing import Optional, Tuple - import torch from tensorrt_llm._torch.modules.fla.l2norm import l2norm_fwddef chunk_gated_delta_rule( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, - scale: Optional[float] = None, - initial_state: Optional[torch.Tensor] = None, - initial_state_indices: Optional[torch.Tensor] = None, + scale: float | None = None, + initial_state: torch.Tensor | None = None, + initial_state_indices: torch.Tensor | None = None, inplace_indexed_state_update: bool = False, output_final_state: bool = False, - cu_seqlens: Optional[torch.Tensor] = None, + cu_seqlens: torch.Tensor | None = None, head_first: bool = False, use_qk_l2norm_in_kernel: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[torch.Tensor, torch.Tensor | None]:As per coding guidelines: "Prefer using built-in types
list,dict,tupleinstead of legacytyping.List,typing.Dict,typing.Tuple; use|syntax instead oftyping.Union."Also applies to: 61-61
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/modules/fla/flashinfer_chunk.py` at line 35, Replace legacy typing usage and redundant future import: remove "from typing import Optional, Tuple" (and any "from __future__ import annotations") and update type annotations that use Tuple[...] and Optional[...] to modern Python 3.10+ syntax, e.g., change "Tuple[torch.Tensor, Optional[torch.Tensor]]" to "tuple[torch.Tensor, torch.Tensor | None]" (or use "torch.Tensor | None" for optional parts). Update all occurrences (e.g., the annotation referenced around the function/method using that tuple return type) to use built-in "tuple" and the "|" union operator.
133-134: 💤 Low value
state_bufpre-allocation assumesD_k == D_v.
head_size = q3.shape[2]captures the key head dimension (D_k), but the last two dims of FlashInfer's state are(D_v, D_k). Usinghead_sizefor both silently produces a wrong buffer shape whenD_k ≠ D_v.For Qwen3.5
D_k == D_v == 128, so this is currently safe — but it's a silent fragility worth addressing:♻️ Suggested fix
+v_head_size = v3.shape[2] head_size = q3.shape[2] num_seqs = cu_seqlens.shape[0] - 1 output_buf = q3.new_empty(total_seq_len, num_o_heads, v_head_size) -state_buf = q3.new_empty(num_seqs, num_o_heads, head_size, head_size, dtype=torch.float32) +state_buf = q3.new_empty(num_seqs, num_o_heads, v_head_size, head_size, dtype=torch.float32)
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@tensorrt_llm/_torch/modules/fla/flashinfer_chunk.py`:
- Line 35: Replace legacy typing usage and redundant future import: remove "from
typing import Optional, Tuple" (and any "from __future__ import annotations")
and update type annotations that use Tuple[...] and Optional[...] to modern
Python 3.10+ syntax, e.g., change "Tuple[torch.Tensor, Optional[torch.Tensor]]"
to "tuple[torch.Tensor, torch.Tensor | None]" (or use "torch.Tensor | None" for
optional parts). Update all occurrences (e.g., the annotation referenced around
the function/method using that tuple return type) to use built-in "tuple" and
the "|" union operator.
In `@tests/unittest/_torch/modules/mamba/test_flashinfer_chunk_gdn.py`:
- Around line 10-12: Remove the unnecessary future import line "from __future__
import annotations" and replace any use of the typing alias "List" with the
built-in "list" type; specifically delete the import "from typing import List"
and update all type annotations in this test (including the other occurrence
around the original line 39) from "List[...]" to "list[...]" so the file uses
Python 3.10+ built-ins and no future import.
- Around line 1-30: Add perf coverage for the FlashInfer GDN prefill path:
either add an entry for the new FlashInfer-prefill case to the appropriate L0
perf test list (tests/integration/test_lists/test-db/l0_perf.yml or per-GPU
l0_*.yml) that will run with TLLM_USE_FLASHINFER_GDN_PREFILL=1, or add a simple
latency baseline test in tests/integration/defs/perf/test_perf_sanity.py that
measures Qwen3.5 prefill latency for Qwen3NextGatedDeltaNet.forward_extend with
TLLM_USE_FLASHINFER_GDN_PREFILL toggled between 1 and 0; ensure the new perf
test targets the same input shapes exercised by the unit tests so regressions in
the FlashInfer prefill path are caught in CI.
- Around line 60-65: Replace the list concatenation used to build cu with
iterable unpacking: instead of torch.tensor([0] +
list(torch.tensor(seq_lens).cumsum(0).tolist()), ...), construct cu via
torch.tensor((0, *torch.tensor(seq_lens).cumsum(0).tolist()), dtype=torch.int64,
device=device). Update the expression that creates cu (the variable returned
alongside q, k, v, g, beta) to use the tuple unpacking form to satisfy RUF005.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 5b4de1b8-1c30-4d25-8206-5bf3a3e7ba90
📒 Files selected for processing (3)
tensorrt_llm/_torch/modules/fla/flashinfer_chunk.pytensorrt_llm/_torch/modules/mamba/gdn_mixer.pytests/unittest/_torch/modules/mamba/test_flashinfer_chunk_gdn.py
|
PR_Github #47115 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #47203 [ run ] triggered by Bot. Commit: |
|
PR_Github #47203 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #47286 [ run ] triggered by Bot. Commit: |
|
PR_Github #47286 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #47394 [ run ] triggered by Bot. Commit: |
|
PR_Github #47394 [ run ] completed with state
|
f73c47a to
23bf4bf
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #47459 [ run ] triggered by Bot. Commit: |
23bf4bf to
626e893
Compare
|
/bot reuse-pipeline |
|
PR_Github #50140 [ reuse-pipeline ] triggered by Bot. Commit: |
|
PR_Github #50140 [ reuse-pipeline ] completed with state |
7d89656 to
3f32586
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #50179 [ run ] triggered by Bot. Commit: |
Signed-off-by: nv-guomingz <[email protected]>
3f32586 to
e0df0be
Compare
|
PR_Github #50179 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #50220 [ run ] triggered by Bot. Commit: |
|
PR_Github #50220 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #50239 [ run ] triggered by Bot. Commit: |
|
PR_Github #50239 [ run ] completed with state |
…VIDIA#13644) Signed-off-by: nv-guomingz <[email protected]>
Summary by CodeRabbit
New Features
Tests
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.