[TRTLLM-11547][feat] Add Qwen3.5 MTP support.#12646
Conversation
📝 WalkthroughWalkthroughThis pull request introduces Multi-TP (MTP) support for the Qwen3Next model by implementing checkpoint remapping for MTP layers, refactoring the Gated Delta Net implementation into a dedicated module, and enhancing speculative execution handling with improved position ID management for multi-axis RoPE variants. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/models/modeling_qwen3_next.py`:
- Around line 655-660: The code calls allgather(...) inside the
enable_lm_head_tp_in_adp branch (see enable_lm_head_tp_in_adp,
mapping_lm_head_tp, create_lm_head_tp_mapping) but allgather is not imported,
causing a NameError; fix by adding allgather to the import list from the
distributed module (the same place other distributed helpers are imported) so
that allgather is available when mapping_lm_head_tp and hidden_states are
all-gathered.
In `@tensorrt_llm/_torch/models/modeling_speculative.py`:
- Around line 798-800: The draft-mode code paths need the same Qwen3/Qwen3.5
handling as the single-engine path: update MTPDraftModel.__init__ to accept
"qwen3_5_moe_text" (in addition to "qwen3_next") and map it to the same MTP
class (Qwen3NextMTP imported from .modeling_qwen3_next), and extend
MTPDraftModelForCausalLM.load_weights to include the Qwen3/Qwen3.5 branch so the
matcher that handles qwen3_next also handles qwen3_5_moe_text; ensure any
spec_dec_mode.is_mtp_eagle() / two-engine checks treat both names symmetrically.
In `@tensorrt_llm/_torch/modules/mamba/gdn_mixer.py`:
- Around line 107-162: There is a duplicate Triton kernel definition named
fused_qkvzba_split_reshape_cat_kernel that overwrites the earlier one and causes
Ruff F811; remove the second definition (the entire function starting at the
later occurrence) so only the original fused_qkvzba_split_reshape_cat_kernel
remains; locate the duplicate by searching for the function name
fused_qkvzba_split_reshape_cat_kernel and delete the later block (including its
signature and body) to resolve the redefinition error and restore CI.
In `@tensorrt_llm/_torch/pyexecutor/_util.py`:
- Around line 1038-1049: The spec-layer extension currently appends spec entries
to hybrid_layer_mask/mamba_layer_mask without considering the caller-provided
layer_mask; update the logic so that if a layer_mask argument is provided you
first apply it to the existing hybrid_layer_mask and mamba_layer_mask (e.g.,
element-wise AND/zip) before computing num_layers and before extending with spec
layers, and ensure you align lengths (or pad/truncate) before extending with
get_num_spec_layers(spec_config) so the per-manager masks remain correct for
one-model separate draft KV mode.
In `@tensorrt_llm/_torch/pyexecutor/model_engine.py`:
- Around line 1525-1531: The slice assignment to inputs['position_ids'] using
previous_batch_tokens and previous_pos_id_offsets_cuda has closing bracket
indentation that triggers Flake8 E123; locate the block around
previous_batch_tokens > 0 in method/model where inputs['position_ids'] is
modified and reformat the expression so the bracketed index and the added value
align on the same indentation level (e.g., put the full slice [0,
num_ctx_tokens:num_ctx_tokens + previous_batch_tokens] on one line or keep the
opening bracket on the same column as the closing bracket) and ensure the
continuation of + (self.previous_pos_id_offsets_cuda[:previous_batch_tokens]) is
indented consistently; update the lines referencing previous_batch_tokens,
inputs['position_ids'], and previous_pos_id_offsets_cuda to satisfy E123.
In `@tensorrt_llm/_torch/speculative/mtp.py`:
- Around line 1038-1058: The reshape fails for padded 3D MRoPE batches because
the new branch in _select_mtp_position_ids/position_ids_gen assumes the last
axis is already truncated, but SpecDecOneEngineForCausalLM.forward still slices
position_ids as position_ids[:, :attn_metadata.num_tokens] (slicing the second
axis), leaving the time axis padded; fix by ensuring position_ids are truncated
on their last axis before this code runs — either update
SpecDecOneEngineForCausalLM.forward to slice position_ids on the final dimension
(e.g., position_ids[..., :attn_metadata.num_tokens]) or, inside the block that
computes position_ids_gen in _select_mtp_position_ids, explicitly slice/truncate
the last dimension to attn_metadata.num_tokens (or detect padded_num_tokens and
trim) before reshaping position_ids_gen and computing position_ids_delta so the
subsequent reshapes on position_ids_gen and position_ids_gen.flatten(...) cannot
encounter padded tokens.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 71fda09c-e57a-48bd-8259-effc37d874ac
📒 Files selected for processing (8)
tensorrt_llm/_torch/models/checkpoints/hf/qwen3_next_weight_mapper.pytensorrt_llm/_torch/models/modeling_qwen3_next.pytensorrt_llm/_torch/models/modeling_speculative.pytensorrt_llm/_torch/modules/mamba/gdn_mixer.pytensorrt_llm/_torch/pyexecutor/_util.pytensorrt_llm/_torch/pyexecutor/model_engine.pytensorrt_llm/_torch/speculative/mtp.pytests/unittest/_torch/speculative/test_mtp.py
340868c to
5b02d2a
Compare
|
/bot run |
|
PR_Github #41353 [ run ] triggered by Bot. Commit: |
|
PR_Github #41353 [ run ] completed with state
|
5b02d2a to
0d450fd
Compare
0d450fd to
83415bd
Compare
83415bd to
e5ffbad
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #41659 [ run ] triggered by Bot. Commit: |
e5ffbad to
4b09e78
Compare
|
PR_Github #41659 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #48731 [ run ] triggered by Bot. Commit: |
|
PR_Github #48731 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
a856878 to
7d1ce7f
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #48827 [ run ] triggered by Bot. Commit: |
|
PR_Github #48827 [ run ] completed with state |
7d1ce7f to
0ff516c
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #49128 [ run ] triggered by Bot. Commit: |
|
PR_Github #49128 [ run ] completed with state
|
0ff516c to
1e3e3d3
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #49319 [ run ] triggered by Bot. Commit: |
Signed-off-by: nv-guomingz <[email protected]>
1e3e3d3 to
1f83515
Compare
|
/bot run |
|
PR_Github #49374 [ run ] triggered by Bot. Commit: |
|
PR_Github #49319 [ run ] completed with state |
|
PR_Github #49374 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #49453 [ run ] triggered by Bot. Commit: |
|
PR_Github #49453 [ run ] completed with state |
Signed-off-by: nv-guomingz <[email protected]>
Signed-off-by: nv-guomingz <[email protected]>
Summary by CodeRabbit
Release Notes
New Features
Refactor
Tests
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.