[None][feat] Add fused allreduce+RMSNorm op and optional residual in …#12201
Conversation
…moe_finalize_allreduce Add AllReduceFusionOp.RMS_NORM (value=9) that fuses allreduce + RMSNorm in a single kernel without residual addition. This is useful for models where the residual connection is handled externally. Changes: - New kARRMSNorm fusion pattern in C++ allreduce kernels - launchResidualRmsNormKernel now dispatches on Residual template param - moe_finalize_allreduce accepts optional residual (Tensor? instead of Tensor) - MOE fused_op kernel skips residual load/add when residual_in is nullptr - Python AllReduceFusionOp.RMS_NORM enum and updated assertions - Unit tests for RMS_NORM pattern and moe_finalize with residual=None Signed-off-by: Fanrong Li <[email protected]>
|
/bot run |
|
PR_Github #38886 [ run ] triggered by Bot. Commit: |
|
PR_Github #38886 [ run ] completed with state
|
|
/bot run |
|
PR_Github #38958 [ run ] triggered by Bot. Commit: |
|
PR_Github #38958 [ run ] completed with state |
📝 WalkthroughWalkthroughChanges introduce a new RMS_NORM fusion operation for AllReduce with RMSNorm, adding corresponding enum values across kernel, C++, and Python layers. Residual input handling is made optional for RMS_NORM paths through updated dispatchers, kernels, and public API signatures. Test coverage is extended to validate the new RMS_NORM fusion path. Changes
Sequence DiagramsequenceDiagram
participant Python as Python API
participant AllReduceOp as AllReduceOp (C++)
participant Dispatcher as Kernel Dispatcher
participant KernelImpl as Kernel Implementation
participant Output as Result Buffer
Python->>AllReduceOp: call with fusion_op=RMS_NORM
AllReduceOp->>AllReduceOp: check residual presence
AllReduceOp->>Dispatcher: launchResidualRmsNormKernel<T, Residual>()
Dispatcher->>Dispatcher: template dispatch: Residual=true/false
Dispatcher->>KernelImpl: launch rms_norm_kernel_launcher<T, Bias, Residual, Weight>()
KernelImpl->>KernelImpl: compute norm_input = val + residual_in (if Residual=true)
KernelImpl->>KernelImpl: compute output = rms_norm(norm_input)
KernelImpl->>Output: write norm_out (and residual_out if Residual=true)
Output-->>AllReduceOp: return {norm_out} or {norm_out, residual_out}
AllReduceOp-->>Python: return fused result
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Comment Tip You can customize the tone of the review comments and chat replies.Configure the |
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
cpp/tensorrt_llm/thop/allreduceOp.cpp (1)
1678-1752:⚠️ Potential issue | 🟠 MajorValidate the optional-input shapes before sizing the MoE finalize launch.
num_tokensis now inferred from whichever optional tensor happens to be present, but the code never checks thatresidual,shared_expert_output,expert_scale_factor, andexpanded_idx_to_permuted_idxagree on thatm, or thatnorm_weight.size(0)matches the hidden dim of the other inputs. A mismatched caller will sizenorm_outandallreduce_fusion_params.sizefrom one shape and then hand the kernel raw pointers with another, which is an out-of-bounds risk instead of a cleanTORCH_CHECK.Suggested fix
- int hidden_dim = norm_weight.size(0); + TORCH_CHECK(norm_weight.dim() == 1, "norm_weight must be 1D"); + int hidden_dim = norm_weight.size(0); + TORCH_CHECK(input.size(-1) == hidden_dim, "input hidden dim must match norm_weight"); int top_k = expanded_idx_to_permuted_idx.size(-1); @@ - int num_tokens; + TORCH_CHECK(expanded_idx_to_permuted_idx.dim() == 2, "expanded_idx_to_permuted_idx must be 2D"); + int num_tokens; if (residual.has_value()) { + TORCH_CHECK(residual.value().dim() == 2, "residual must be 2D"); + TORCH_CHECK( + residual.value().size(1) == hidden_dim, "residual hidden dim must match norm_weight"); num_tokens = residual.value().size(0); } else if (shared_expert_output.has_value()) { + TORCH_CHECK(shared_expert_output.value().dim() == 2, "shared_expert_output must be 2D"); + TORCH_CHECK(shared_expert_output.value().size(1) == hidden_dim, + "shared_expert_output hidden dim must match norm_weight"); num_tokens = shared_expert_output.value().size(0); } else @@ + TORCH_CHECK( + expanded_idx_to_permuted_idx.size(0) == num_tokens, "expanded_idx_to_permuted_idx token dim mismatch"); + if (shared_expert_output.has_value()) + { + TORCH_CHECK( + shared_expert_output.value().size(0) == num_tokens, "shared_expert_output token dim mismatch"); + } + if (expert_scale_factor.has_value()) + { + TORCH_CHECK(expert_scale_factor.value().dim() == 2, "expert_scale_factor must be 2D"); + TORCH_CHECK(expert_scale_factor.value().size(0) == num_tokens + && expert_scale_factor.value().size(1) == top_k, + "expert_scale_factor must have shape [num_tokens, top_k]"); + } + // size: num_token * hidden_dim allreduce_fusion_params.size = num_tokens * hidden_dim;🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/thop/allreduceOp.cpp` around lines 1678 - 1752, Compute and validate a consistent num_tokens and hidden_dim before allocating outputs or populating allreduce_fusion_params: check that norm_weight.size(0) equals hidden_dim (input.size(1) or as used), and if residual.has_value() ensure residual.value().dim() and residual.value().size(0) match num_tokens and residual.value().size(1) == hidden_dim; if shared_expert_output.has_value() ensure its size(0) == num_tokens and size(1) == hidden_dim; if expert_scale_factor.has_value() validate its shape/length aligns with num_tokens (or is broadcastable) and use TORCH_CHECK to fail with clear messages; also verify expanded_idx_to_permuted_idx.size(0) matches num_tokens before using it to set allreduce_fusion_params.size and before creating norm_out/residual_out so the kernel receives consistent pointers (refer to variables/functions: num_tokens, hidden_dim, norm_weight, input, residual, shared_expert_output, expert_scale_factor, expanded_idx_to_permuted_idx, allreduce_fusion_params, moefinalize_allreduce_fusion_op).
🧹 Nitpick comments (2)
cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.cu (1)
741-743: Consider addingrms_gammacheck for consistency.The
moefinalize_allreduce_fusion_opvalidatesallreduce_in,expanded_idx_to_permuted_idx, andtop_k, but unlikemoereduction_allreduce_fusion_op(line 454), it doesn't explicitly checkrms_gamma.Since
fused_op(line 139) unconditionally dereferencesparams.rms_gamma, a nullrms_gammawould cause a crash. Ifrms_gammais guaranteed non-null by the caller, this is fine—but adding an explicit check would improve defensive robustness and consistency withmoereduction_allreduce_fusion_op.♻️ Suggested fix to add rms_gamma check
- TLLM_CHECK(params.allreduce_in && params.expanded_idx_to_permuted_idx && params.top_k); + TLLM_CHECK(params.allreduce_in && params.expanded_idx_to_permuted_idx && params.top_k && params.rms_gamma);🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.cu` around lines 741 - 743, The moefinalize_allreduce_fusion_op validation is missing a null check for params.rms_gamma (moereduction_allreduce_fusion_op performs this check), yet fused_op later dereferences params.rms_gamma; add a TLLM_CHECK(params.rms_gamma) (or equivalent null/assert) in moefinalize_allreduce_fusion_op alongside the existing checks to ensure rms_gamma is non-null before fused_op uses it.tests/unittest/_torch/multi_gpu/test_allreduce.py (1)
285-286: Cover the actualresidual=NoneRMS_NORMcontract.The new
RMS_NORMcase still goes through the common harness that always materializes and passes a residual tensor. That only proves this path ignores residual when present; it never exercises the API change this PR is enabling:RMS_NORMwithresidual=None.Also applies to: 301-313
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/multi_gpu/test_allreduce.py` around lines 285 - 286, Add a test that exercises the RMS_NORM API with residual=None instead of only using the common harness that always materializes a residual: update the param set to include a pytest.param for AllReduceFusionOp.RMS_NORM that triggers the path where no residual tensor is passed (or add a separate test that directly calls the allreduce harness/function with AllReduceFusionOp.RMS_NORM and residual=None), and ensure the harness invocation for that case does not create or pass a dummy residual so the code path for residual=None is actually executed and validated.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@cpp/tensorrt_llm/kernels/customAllReduceKernels.cu`:
- Around line 1980-1990: The wrapper launchResidualRmsNormKernel currently
derives whether to use the residual specialization from
params.fusion_params.residual_buffer, which allows RESIDUAL_RMS_PREPOST_NORM to
run without a residual; change it so that if fusionOp ==
RESIDUAL_RMS_PREPOST_NORM you require params.fusion_params.residual_buffer to be
non-null (emit an error/ASSERT/log and dispatch the <T, true> specialization),
otherwise proceed to choose based on the buffer; reference
launchResidualRmsNormKernel, RESIDUAL_RMS_PREPOST_NORM, and
params.fusion_params.residual_buffer when making the check and error-handling so
the residual-mandatory contract is preserved.
In `@tensorrt_llm/functional.py`:
- Around line 3953-3955: The assertion currently allows
AllReduceFusionOp.RMS_NORM with residual == None but create_allreduce_plugin
always reads all_reduce_params.residual.trt_tensor, causing a crash; update
create_allreduce_plugin to guard access to all_reduce_params.residual.trt_tensor
(only read/use it if all_reduce_params.residual is not None) and handle the
RMS_NORM path without a residual (e.g., skip adding the residual input or pass a
null/empty tensor placeholder as the plugin expects); reference
AllReduceFusionOp, create_allreduce_plugin, and
all_reduce_params.residual.trt_tensor when locating the change.
In `@tests/unittest/_torch/multi_gpu/test_allreduce.py`:
- Around line 692-696: The zip(...) call used to pack arguments for
mpi_pool_executor.map (wrapping run_moe_finalize_no_residual_single_rank and
run_moe_finalize_allreduce_no_residual_op with fc2_output, shared_expert_output,
expanded_idx_to_permuted_idx, scale) should be made strict to avoid silent
truncation and satisfy Ruff B905; update the zip invocation to pass strict=True
so mismatched iterator lengths raise immediately when calling
mpi_pool_executor.map with run_moe_finalize_no_residual_single_rank.
---
Outside diff comments:
In `@cpp/tensorrt_llm/thop/allreduceOp.cpp`:
- Around line 1678-1752: Compute and validate a consistent num_tokens and
hidden_dim before allocating outputs or populating allreduce_fusion_params:
check that norm_weight.size(0) equals hidden_dim (input.size(1) or as used), and
if residual.has_value() ensure residual.value().dim() and
residual.value().size(0) match num_tokens and residual.value().size(1) ==
hidden_dim; if shared_expert_output.has_value() ensure its size(0) == num_tokens
and size(1) == hidden_dim; if expert_scale_factor.has_value() validate its
shape/length aligns with num_tokens (or is broadcastable) and use TORCH_CHECK to
fail with clear messages; also verify expanded_idx_to_permuted_idx.size(0)
matches num_tokens before using it to set allreduce_fusion_params.size and
before creating norm_out/residual_out so the kernel receives consistent pointers
(refer to variables/functions: num_tokens, hidden_dim, norm_weight, input,
residual, shared_expert_output, expert_scale_factor,
expanded_idx_to_permuted_idx, allreduce_fusion_params,
moefinalize_allreduce_fusion_op).
---
Nitpick comments:
In `@cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.cu`:
- Around line 741-743: The moefinalize_allreduce_fusion_op validation is missing
a null check for params.rms_gamma (moereduction_allreduce_fusion_op performs
this check), yet fused_op later dereferences params.rms_gamma; add a
TLLM_CHECK(params.rms_gamma) (or equivalent null/assert) in
moefinalize_allreduce_fusion_op alongside the existing checks to ensure
rms_gamma is non-null before fused_op uses it.
In `@tests/unittest/_torch/multi_gpu/test_allreduce.py`:
- Around line 285-286: Add a test that exercises the RMS_NORM API with
residual=None instead of only using the common harness that always materializes
a residual: update the param set to include a pytest.param for
AllReduceFusionOp.RMS_NORM that triggers the path where no residual tensor is
passed (or add a separate test that directly calls the allreduce
harness/function with AllReduceFusionOp.RMS_NORM and residual=None), and ensure
the harness invocation for that case does not create or pass a dummy residual so
the code path for residual=None is actually executed and validated.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: dc020840-5c2c-4541-b944-1a845987a5a2
📒 Files selected for processing (9)
cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cucpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.hcpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.cucpp/tensorrt_llm/kernels/customAllReduceKernels.cucpp/tensorrt_llm/kernels/customAllReduceKernels.hcpp/tensorrt_llm/thop/allreduceOp.cpptensorrt_llm/_torch/distributed/ops.pytensorrt_llm/functional.pytests/unittest/_torch/multi_gpu/test_allreduce.py
…n, add rms_gamma check - Guard residual.trt_tensor access in create_allreduce_plugin for RMS_NORM - Re-add max_token assertion in MoEAllReduce with None guard for optional residual - Add TLLM_CHECK(params.rms_gamma) in moefinalize_allreduce_fusion_op Signed-off-by: Fanrong Li <[email protected]>
|
/bot run |
|
PR_Github #39083 [ run ] triggered by Bot. Commit: |
|
PR_Github #39083 [ run ] completed with state
|
|
/bot run |
|
PR_Github #39103 [ run ] triggered by Bot. Commit: |
|
PR_Github #39103 [ run ] completed with state
|
|
/bot run |
|
PR_Github #39148 [ run ] triggered by Bot. Commit: |
|
PR_Github #39148 [ run ] completed with state
|
|
/bot run |
|
PR_Github #39171 [ run ] triggered by Bot. Commit: |
|
PR_Github #39171 [ run ] completed with state
|
|
/bot run |
|
PR_Github #39201 [ run ] triggered by Bot. Commit: |
|
PR_Github #39201 [ run ] completed with state
|
|
/bot run |
|
PR_Github #39247 [ run ] triggered by Bot. Commit: |
|
PR_Github #39247 [ run ] completed with state
|
|
/bot run |
|
PR_Github #39268 [ run ] triggered by Bot. Commit: |
|
PR_Github #39268 [ run ] completed with state
|
|
/bot run |
|
PR_Github #39361 [ run ] triggered by Bot. Commit: |
|
PR_Github #39361 [ run ] completed with state
|
|
/bot run |
|
PR_Github #39402 [ run ] triggered by Bot. Commit: |
|
PR_Github #39402 [ run ] completed with state
|
|
/bot run |
|
PR_Github #39419 [ run ] triggered by Bot. Commit: |
|
PR_Github #39419 [ run ] completed with state
|
|
/bot run |
1 similar comment
|
/bot run |
|
PR_Github #39448 [ run ] triggered by Bot. Commit: |
|
PR_Github #39448 [ run ] completed with state |
NVIDIA#12201) Signed-off-by: Fanrong Li <[email protected]>
NVIDIA#12201) Signed-off-by: Fanrong Li <[email protected]>
Summary
Add
AllReduceFusionOp.RMS_NORM(value=9) that fuses allreduce + RMSNorm in a single kernel without residual addition. This is useful for models where the residual connection is handled externally (e.g., Mewtwo).Also makes the
residualparameter optional inmoe_finalize_allreduce, allowing MOE layers to skip residual addition when not needed.Changes
C++ Kernels
kARRMSNormfusion pattern in allreduce fusion kernelslaunchResidualRmsNormKernelDispatchwith abool Residualtemplate parameter to handle the residual-free path without runtime branchingmoe_finalize_allreducefused kernel conditionally skip residual load/add whenresidual_inisnullptrPython
AllReduceFusionOp.RMS_NORM = 9enum valuedistributed/ops.pyandfunctional.pyto allow RMS_NORM fusion without a residual tensorOp Registration
allreduceOp.cppto handleRMS_NORMdispatch and accept optional residual (Tensor?) in MOE finalize signatureTests
moe_finalize_allreducewithresidual=NoneDescription
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.Summary by CodeRabbit
Release Notes
New Features
Tests