Skip to content

[None][perf] Fuse sigmoid+mul+add shared-expert combine into one Trit…#14306

Merged
nv-guomingz merged 1 commit into
NVIDIA:mainfrom
nv-guomingz:user/guomingz/qwen3.5_fusion
May 20, 2026
Merged

[None][perf] Fuse sigmoid+mul+add shared-expert combine into one Trit…#14306
nv-guomingz merged 1 commit into
NVIDIA:mainfrom
nv-guomingz:user/guomingz/qwen3.5_fusion

Conversation

@nv-guomingz
Copy link
Copy Markdown
Collaborator

@nv-guomingz nv-guomingz commented May 19, 2026

…on kernel for qwen3.5

Replaces the three pointwise kernels used to combine the shared-expert
output back into the routed-expert output of Qwen3-Next / Qwen3.5 MoE
blocks with a single Triton kernel.

End-to-end serving impact (Qwen3.5-397B-A17B-NVFP4, 4x B300, TP=4 EP=4,
enable_attention_dp=true, max_seq_len=10240, trtllm-serve + benchmark_
serving, random ids, ignore-eos):

ISL/OSL | bs | Total tok/s | TPOT mean (ms)
--------+-----+---------------------+-------------------------
1k/1k | 256 | 19061 -> 19366 (+1.6%) | 25.12 -> 24.77 (-1.4%)
1k/1k | 32 | 3754 -> 3869 (+3.1%) | 16.54 -> 16.09 (-2.7%)
8k/1k | 256 | 49899 -> 50749 (+1.7%) | 39.28 -> 38.65 (-1.6%)
8k/1k | 32 | 14777 -> 15126 (+2.4%) | 18.07 -> 17.64 (-2.4%)

Summary by CodeRabbit

  • Improvements

    • Optimized Qwen3Next mixture-of-experts gating mechanism for improved inference performance through fused kernel operations.
  • Tests

    • Added comprehensive unit tests validating shared-expert gating across diverse tensor shapes, data types, and edge cases.

Review Change Stack

@nv-guomingz nv-guomingz requested review from a team as code owners May 19, 2026 12:19
@nv-guomingz nv-guomingz requested review from QiJune and syuoni May 19, 2026 12:19
@nv-guomingz nv-guomingz force-pushed the user/guomingz/qwen3.5_fusion branch from e4147c5 to c83f0c4 Compare May 19, 2026 12:20
@nv-guomingz
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 19, 2026

📝 Walkthrough

Walkthrough

This PR introduces a Triton-based fused kernel that combines sigmoid gating and shared-expert output merging into a single operation, integrates it into the Qwen3-Next MoE model, and validates it with comprehensive tests covering dtypes, memory semantics, and edge cases.

Changes

MoE Shared-Expert Fused Operation

Layer / File(s) Summary
Fused sigmoid-gate-mul-add Triton kernel
tensorrt_llm/_torch/modules/fused_shared_expert.py
Adds a Triton JIT kernel computing final_hidden_states + sigmoid(gate_logits) * shared_expert_output with masked loads, eviction hints, and fp32 intermediate math. Wrapper function validates tensor shapes/dtypes, supports in-place vs explicit output modes, enforces last-dimension contiguity, and launches the kernel over a (num_tokens, num_col_blocks) grid with power-of-two block sizes.
MoE shared-expert merge integration
tensorrt_llm/_torch/models/modeling_qwen3_next.py
Updates imports to add fused_sigmoid_gate_mul_add, modifies _compute_shared_output to return (shared_expert_output, shared_expert_gate_logits) as a tuple, and replaces manual gating logic with the fused kernel invocation. Handles tensor-parallel allreduce when required.
Test suite: correctness, semantics, and edge cases
tests/unittest/_torch/modules/test_fused_shared_expert.py
Provides fp32 reference implementation and validates kernel correctness across multiple token counts, hidden sizes, and dtypes (fp32/fp16/bf16). Tests in-place aliasing, explicit output buffers, non-contiguous strides, zero-token handling, edge-case hidden dimensions, and large-hidden multi-block paths.

Sequence Diagram

sequenceDiagram
  participant ModelMerge as Merge Logic
  participant SharedBranch as Shared Expert
  participant Fused as fused_sigmoid_gate_mul_add
  participant AllReduce
  ModelMerge->>SharedBranch: _compute_shared_output
  SharedBranch-->>ModelMerge: shared_output, gate_logits
  ModelMerge->>Fused: final_hidden_states + shared_output + gate_logits
  Fused-->>ModelMerge: merged result
  alt tp_size > 1 and not enable_attention_dp
    ModelMerge->>AllReduce: allreduce(merged result)
    AllReduce-->>ModelMerge: synchronized result
  end
  ModelMerge-->>ModelMerge: continue
Loading

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly identifies the main change: fusing sigmoid, multiply, and add operations into a single Triton kernel for Qwen3.5 shared-expert combining, which is the primary focus of the PR.
Description check ✅ Passed The PR description explains the change and provides quantitative impact metrics, but lacks explicit sections for Test Coverage and PR Checklist confirmation as specified in the template.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #49186 [ run ] triggered by Bot. Commit: c83f0c4 Link to invocation

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (3)
tests/unittest/_torch/modules/test_fused_shared_expert.py (2)

144-202: ⚡ Quick win

Cover the non-contiguous input copy path too.

fused_sigmoid_gate_mul_add() explicitly normalizes final_hidden_states and shared_expert_output with .contiguous() when their last dimension is strided, but every test here still passes contiguous inputs. A small case with strided inputs plus an explicit output= buffer would lock down that branch.

Suggested test shape
+def test_fused_sigmoid_gate_mul_add_non_contig_inputs():
+    torch.manual_seed(0)
+    num_tokens, hidden = 32, 4096
+    dtype = torch.bfloat16
+    device = "cuda"
+
+    final_base = torch.randn(num_tokens, hidden * 2, dtype=dtype, device=device)
+    shared_base = torch.randn(num_tokens, hidden * 2, dtype=dtype, device=device)
+    final = final_base[:, ::2]
+    shared = shared_base[:, ::2]
+    gate = torch.randn(num_tokens, 1, dtype=dtype, device=device)
+    output = torch.empty(num_tokens, hidden, dtype=dtype, device=device)
+
+    assert final.stride(-1) != 1
+    assert shared.stride(-1) != 1
+
+    expected = _reference(final, gate, shared)
+    out = fused_sigmoid_gate_mul_add(final, gate, shared, output=output)
+    torch.testing.assert_close(out, expected, atol=2e-2, rtol=2e-2)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unittest/_torch/modules/test_fused_shared_expert.py` around lines 144 -
202, Add a new unit test that exercises the non-contiguous input copy path for
fused_sigmoid_gate_mul_add by passing non-contiguous final_hidden_states and
shared_expert_output (e.g., create them via slicing or empty_strided so their
last dimension has stride != 1) along with an explicit non-contiguous output
buffer, then call fused_sigmoid_gate_mul_add(final, gate, shared, output=output)
and assert the returned tensor uses the provided output buffer (out.data_ptr()
== output.data_ptr()) and matches the reference (_reference(final, gate,
shared)) within the existing tolerances; this will trigger the code paths that
call .contiguous() on final_hidden_states and shared_expert_output and lock down
that branch.

42-218: Please make sure this lands in perf regression coverage too.

These unit tests are good for correctness, but the PR’s value is throughput/TPOT on Qwen3.5 serving. I don’t see a matching perf-list update in this diff, so please either add or point to the existing entry that exercises this fused path in tests/integration/test_lists/test-db/l0_perf.yml and the relevant tests/integration/test_lists/qa/llm_perf_*.yml.

As per coding guidelines, performance-sensitive paths should have pre-merge and QA perf coverage, and functional-only tests are not enough when a kernel rewrite can regress latency/throughput.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unittest/_torch/modules/test_fused_shared_expert.py` around lines 42 -
218, Add perf-regression coverage for the new fused kernel: update
tests/integration/test_lists/test-db/l0_perf.yml (or add a new entry) to include
a workload that exercises fused_sigmoid_gate_mul_add (the same shape used in
test_fused_sigmoid_gate_mul_add_qwen35_shape: hidden=4096, B=64) and add or
reference the corresponding QA perf list in tests/integration/test_lists/qa/
(e.g. llm_perf_*.yml) so the fused path is included in both pre-merge and QA
runs; ensure the perf entry names and payload mirror the unit test’s shape and
dtype (bfloat16) and reference the test module
tests/unittest/_torch/modules/test_fused_shared_expert.py or the
fused_sigmoid_gate_mul_add kernel so CI will run throughput/TPOT measurements.
tensorrt_llm/_torch/models/modeling_qwen3_next.py (1)

241-253: 🏗️ Heavy lift

Add a distributed regression test for this fused TP path.

The new unit suite validates the standalone kernel, but this branch is where the integration risk lives: allocate_output(...) produces a strided buffer, the fused kernel writes into it, and then the result is immediately all-reduced. A regression here would still pass all current tests.

As per coding guidelines, assess whether new/changed tests cover happy path, important edge cases, and failure modes relevant to the feature or fix.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/models/modeling_qwen3_next.py` around lines 241 - 253,
Add a distributed regression test that exercises the fused TP path where
allocate_output returns a strided buffer and the fused kernel writes into it
before an immediate all-reduce: create a test that instantiates the model with
mapping.tp_size > 1 and enable_attention_dp = False, invoke the branch that
calls torch.ops.trtllm.allocate_output, fused_sigmoid_gate_mul_add, and
self.allreduce (use shared_expert_outputs inputs so shared_expert_output and
shared_expert_gate_logits are non-trivial), and assert the final_hidden_states
after the allreduce matches a reference (e.g., CPU or non-fused fallback) across
ranks; include variants for different tp_size values and an edge case with
non-contiguous/strided output buffers to catch stride-handling bugs, and run the
test under torch.distributed (or the repo’s distributed test harness) so it
executes the full allocate->fused-kernel->allreduce integration.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@tensorrt_llm/_torch/models/modeling_qwen3_next.py`:
- Around line 241-253: Add a distributed regression test that exercises the
fused TP path where allocate_output returns a strided buffer and the fused
kernel writes into it before an immediate all-reduce: create a test that
instantiates the model with mapping.tp_size > 1 and enable_attention_dp = False,
invoke the branch that calls torch.ops.trtllm.allocate_output,
fused_sigmoid_gate_mul_add, and self.allreduce (use shared_expert_outputs inputs
so shared_expert_output and shared_expert_gate_logits are non-trivial), and
assert the final_hidden_states after the allreduce matches a reference (e.g.,
CPU or non-fused fallback) across ranks; include variants for different tp_size
values and an edge case with non-contiguous/strided output buffers to catch
stride-handling bugs, and run the test under torch.distributed (or the repo’s
distributed test harness) so it executes the full
allocate->fused-kernel->allreduce integration.

In `@tests/unittest/_torch/modules/test_fused_shared_expert.py`:
- Around line 144-202: Add a new unit test that exercises the non-contiguous
input copy path for fused_sigmoid_gate_mul_add by passing non-contiguous
final_hidden_states and shared_expert_output (e.g., create them via slicing or
empty_strided so their last dimension has stride != 1) along with an explicit
non-contiguous output buffer, then call fused_sigmoid_gate_mul_add(final, gate,
shared, output=output) and assert the returned tensor uses the provided output
buffer (out.data_ptr() == output.data_ptr()) and matches the reference
(_reference(final, gate, shared)) within the existing tolerances; this will
trigger the code paths that call .contiguous() on final_hidden_states and
shared_expert_output and lock down that branch.
- Around line 42-218: Add perf-regression coverage for the new fused kernel:
update tests/integration/test_lists/test-db/l0_perf.yml (or add a new entry) to
include a workload that exercises fused_sigmoid_gate_mul_add (the same shape
used in test_fused_sigmoid_gate_mul_add_qwen35_shape: hidden=4096, B=64) and add
or reference the corresponding QA perf list in tests/integration/test_lists/qa/
(e.g. llm_perf_*.yml) so the fused path is included in both pre-merge and QA
runs; ensure the perf entry names and payload mirror the unit test’s shape and
dtype (bfloat16) and reference the test module
tests/unittest/_torch/modules/test_fused_shared_expert.py or the
fused_sigmoid_gate_mul_add kernel so CI will run throughput/TPOT measurements.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 0cf9b3b2-bd0a-4e2c-a1d9-9d48e47020b9

📥 Commits

Reviewing files that changed from the base of the PR and between 989671b and c83f0c4.

📒 Files selected for processing (3)
  • tensorrt_llm/_torch/models/modeling_qwen3_next.py
  • tensorrt_llm/_torch/modules/fused_shared_expert.py
  • tests/unittest/_torch/modules/test_fused_shared_expert.py

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #49186 [ run ] completed with state SUCCESS. Commit: c83f0c4
/LLM/main/L0_MergeRequest_PR pipeline #38863 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@nv-guomingz
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #49275 [ run ] triggered by Bot. Commit: c83f0c4 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #49275 [ run ] completed with state FAILURE. Commit: c83f0c4
/LLM/main/L0_MergeRequest_PR pipeline #38940 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@nv-guomingz
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #49337 [ run ] triggered by Bot. Commit: c83f0c4 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #49337 [ run ] completed with state SUCCESS. Commit: c83f0c4
/LLM/main/L0_MergeRequest_PR pipeline #38994 completed with status: 'SUCCESS'

CI Report

Link to invocation

@nv-guomingz nv-guomingz enabled auto-merge (squash) May 20, 2026 07:57
Copy link
Copy Markdown
Collaborator

@QiJune QiJune left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@nv-guomingz nv-guomingz merged commit bb57a83 into NVIDIA:main May 20, 2026
11 checks passed
@nv-guomingz nv-guomingz deleted the user/guomingz/qwen3.5_fusion branch May 21, 2026 03:14
xxi-nv pushed a commit to xxi-nv/TensorRT-LLM that referenced this pull request May 22, 2026
bmarimuthu-nv pushed a commit to nv-auto-deploy/TensorRT-LLM that referenced this pull request May 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants