Skip to content

Conversation

@suyoggupta
Copy link
Collaborator

@suyoggupta suyoggupta commented Oct 22, 2025

Add pattern matching for Nemotron style RMSnorm.
Update triton kernel to use float32 for weights
Update unit test

Summary by CodeRabbit

  • Performance

    • Optimized Mixture-of-Experts kernel for improved small batch performance.
  • Improvements

    • Enhanced RMS normalization numerical stability by enforcing float32 precision for weights during computation.
    • Updated post-load fusion backend to Triton.
  • Tests

    • Added expanded RMS normalization test coverage including new variant validation with Triton backend.

nvchenghaoz and others added 15 commits October 17, 2025 18:27
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Suyog Gupta <[email protected]>
@suyoggupta suyoggupta marked this pull request as ready for review October 22, 2025 04:57
@suyoggupta suyoggupta requested a review from a team as a code owner October 22, 2025 04:57
@suyoggupta suyoggupta requested a review from MrGeva October 22, 2025 04:57
@suyoggupta
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22201 [ run ] triggered by Bot. Commit: c741482

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 22, 2025

📝 Walkthrough

Walkthrough

The pull request updates the default post-load fusion backend from flashinfer to triton, optimizes the MoE kernel for small batch sizes, enhances RMS norm weight handling with float32 precision, adds a new float32-weight pattern matching function for RMS norm, and restructures RMS norm fusion tests to support triton backend validation.

Changes

Cohort / File(s) Summary
Configuration update
tensorrt_llm/_torch/auto_deploy/config/default.yaml
Changes post-load fusion transform backend from flashinfer to triton
Kernel optimizations
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py, tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.py
MoE kernel adds small-batch optimization to adjust EM calculation; RMS norm kernel enforces float32 precision for weights during computation
Transform pattern matching
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
Introduces _rms_norm_pattern_float32_weights pattern function for float32-weight RMS norm scenarios; restructures pattern registration to support multiple search functions
Test restructuring
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py
Adds NemotronH\_RMSNorm variant, introduces _run_test helper, restructures tests into explicit test cases with triton backend validation; removes previous parameterized test structure

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~22 minutes

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "[None][feat] Enable rms norm fusion for Nemotron MOE" directly describes the primary objective of the changeset. The PR introduces a new RMS normalization pattern function for Nemotron-style implementations with float32 weight handling, updates the Triton kernel to enforce float32 precision for weights, changes the backend configuration from flashinfer to triton for post-load fusion, and adds comprehensive tests including test_rmsnorm_fusion_nemotron_h. All these changes are cohesively directed toward enabling RMS norm fusion specifically for Nemotron MOE, making the title fully aligned with the substantial changes in the PR.
Docstring Coverage ✅ Passed No functions found in the changes. Docstring coverage check skipped.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Warning

Tools execution failed with the following error:

Failed to run tools: Stream setup permanently failed: Ping-pong health check failed


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.py (1)

17-19: Consider expanding the docstring to follow Google-style format.

While the current docstring documents the key behavior, consider adding parameter descriptions to align with coding guidelines for Google-style docstrings.

Example expansion:

-    """Rms norm kernel.
-    Forces weights to be in float32 for the kernel.
-    """
+    """RMS normalization kernel.
+    
+    Applies RMS normalization with weights forced to float32 precision during
+    computation for numerical accuracy.
+    
+    Args:
+        output: Output tensor pointer.
+        input: Input tensor pointer.
+        weights: Weight tensor pointer.
+        eps: Epsilon value for numerical stability.
+        M: Number of rows.
+        N_COLS: Number of columns.
+        BLOCK_N: Block size for Triton kernel.
+    """

As per coding guidelines

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py (2)

29-29: Consider parameterizing the device.

The device is hard-coded to "cuda". While this is acceptable in a test file under singlegpu, consider accepting a device parameter for consistency with the TestModel pattern where other modules use device="cuda" inline.


37-37: Consider adding a comment about float32 weight precision.

The weight is explicitly cast to float32 during computation even though the module may be stored in float16. Adding an inline comment would clarify this intentional design choice.

For example:

-        return (self.weight.to(torch.float32) * hidden_states).to(input_dtype)
+        # Use float32 for weights to match Nemotron-style RMSNorm computation
+        return (self.weight.to(torch.float32) * hidden_states).to(input_dtype)
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between df689f8 and c741482.

📒 Files selected for processing (5)
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py (2 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py (3 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py
🧠 Learnings (1)
📚 Learning: 2025-10-20T17:09:21.560Z
Learnt from: nvchenghaoz
PR: NVIDIA/TensorRT-LLM#8469
File: tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py:180-182
Timestamp: 2025-10-20T17:09:21.560Z
Learning: In tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py, the _gated_rmsnorm_replacement function does not need to cast the output of torch.ops.auto_deploy.torch_rmsnorm_gated back to the input dtype, even though the custom op returns fp32. The dtype handling is managed elsewhere or the fp32 output is acceptable for downstream consumers.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
🔇 Additional comments (7)
tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)

127-127: Verify Triton backend is production-ready for RMS norm fusion patterns.

This configuration change switches the post_load_fusion backend from flashinfer to triton. While this aligns with the PR objective to enable rms norm fusion for Nemotron MOE via Triton, ensure that:

  1. The Triton backend is compatible with all RMS norm fusion patterns (especially Nemotron-style patterns mentioned in the PR)
  2. Related Triton kernel updates for float32 weight handling are properly integrated
  3. Unit tests validate this backend switch across different RMS norm patterns
  4. No performance regressions occur from the flashinfer → triton switch

Since only this configuration file is provided, please confirm that:

  • Related Triton kernel and pattern-matching code changes are properly tested
  • The RMS norm fusion tests cover both the float32 weight handling and Triton backend execution
  • Any performance benchmarks comparing flashinfer vs. triton backend are acceptable

You can run the following to verify test coverage for RMS norm fusion with the new backend:

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py (1)

253-258: Verify variable scope and formula correctness.

Three concerns with this optimization:

  1. Variable scope: top_k is referenced but not shown in the provided context. Verify it's in scope (function parameter or defined earlier).

  2. Tensor size inconsistency: Line 250 uses sorted_token_ids.numel() while line 258 uses sorted_token_ids.size(0). If sorted_token_ids is multidimensional, these return different values. Use consistent methods for correctness.

  3. Formula verification: The calculation A.size(0) * top_k * config["BLOCK_SIZE_M"] seems large for an upper bound on expert assignments. Given that each of A.size(0) tokens routes to top_k experts, the natural upper bound would be A.size(0) * top_k, not scaled by BLOCK_SIZE_M. Please verify this formula is correct for the intended optimization.

Run the following script to verify the formula and variable usage:

tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.py (1)

31-31: LGTM! Explicit float32 casting for weights aligns with PR objectives.

The change ensures weights are explicitly cast to float32 before multiplication, maintaining precision during the computation while correctly casting back to the input dtype. This aligns with the PR objective to enable float32 weight handling for Nemotron MOE RMS norm fusion.

tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py (2)

49-67: LGTM! New pattern function correctly implements float32 weight handling.

The implementation properly matches Nemotron-style RMSNorm by explicitly casting weights to float32 before scaling. The apparent duplication with the existing _rms_norm_pattern is intentional and necessary—pattern matching requires exact computation graph matches, so two distinct patterns enable coverage of both weight handling approaches.


155-168: LGTM! Registration logic correctly handles both pattern variants.

The two-tier loop structure properly registers both pattern functions across all input/weight dtype configurations, enabling comprehensive matching of both standard and Nemotron-style RMSNorm implementations.

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py (2)

89-105: LGTM! Well-structured test parameterization.

The test restructuring effectively covers multiple scenarios:

  • Two epsilon values (1e-2, 1e-6)
  • Multiple backend variants (flashinfer, torch, triton)
  • Dedicated test for NemotronH variant with Triton backend

The comment on line 103 helpfully clarifies the Triton backend requirement for the Nemotron H RMSNorm variant.


57-57: Add a docstring for the helper function.

The _run_test helper function is missing a docstring. The coding guidelines require Google-style docstrings for functions.

Apply this diff:

 def _run_test(model, op, variant):
+    """Run RMSNorm fusion test for a given model and backend variant.
+    
+    Args:
+        model: The test model instance.
+        op: The expected operator after fusion.
+        variant: The backend variant name (e.g., 'flashinfer', 'torch', 'triton').
+    """
     def checker(gm):
         return any(is_op(n, op) for n in gm.graph.nodes)

As per coding guidelines.

Likely an incorrect or invalid review comment.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 22, 2025

📝 Walkthrough

Walkthrough

This pull request implements a triton-based backend for RMSNorm fusion with explicit float32 weight handling. Changes include switching the backend configuration, optimizing triton kernels for small batches, adding float32 weight casting in kernel logic, extending pattern matching for float32-weight variants, and expanding test coverage with Nemotron variant support.

Changes

Cohort / File(s) Summary
Backend Configuration
tensorrt_llm/_torch/auto_deploy/config/default.yaml
Changed fuse_rmsnorm backend from flashinfer to triton within transforms.post_load_fusion
Triton Kernel Optimizations
tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.py, tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py
rms_norm.py: Added explicit float32 coercion for weights before multiplication with scaled output. triton_moe.py: Added batch size optimization gate in _invoke_kernel to reduce wasted computation when batch size is smaller than BLOCK_SIZE_M
Transform Pattern Matching
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
Added _rms_norm_pattern_float32_weights() to match RMSNorm patterns with float32 weights; refactored pattern registration loop to register both original and float32-weight variants across all dtype configurations
Test Harness Expansion
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py
Added NemotronH_RMSNorm module; extended TestModel with use_nemotron_h flag; refactored test structure with _run_test() helper; added test_rmsnorm_fusion_nemotron_h() to cover Nemotron variant with triton backend

Sequence Diagram

sequenceDiagram
    participant Config as Config Layer
    participant PatternMatch as Pattern Matching
    participant KernelFP32 as Float32<br/>RMSNorm Kernel
    participant KernelDefault as Default<br/>RMSNorm Kernel
    
    Config->>PatternMatch: Backend: triton
    PatternMatch->>PatternMatch: Try _rms_norm_pattern_float32_weights
    alt Float32 Weights Detected
        PatternMatch->>KernelFP32: Cast weights to float32<br/>before multiplication
        KernelFP32-->>PatternMatch: Optimized scaled output
    else Default Weights
        PatternMatch->>KernelDefault: Original pattern logic
        KernelDefault-->>PatternMatch: Scaled output
    end
    PatternMatch-->>Config: Fusion complete
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~22 minutes

The changes span multiple functional areas—configuration, kernel optimization, pattern matching logic, and test coverage—with moderate logic density per file. While individual edits are straightforward (type casting, batch size gating, pattern registration), the heterogeneity across components and introduction of the new float32-weight variant requires separate reasoning for each cohort.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 25.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The pull request title "Enable rms norm fusion for Nemotron MOE" accurately captures the main intent of the changeset. The changes add support for RMS norm fusion with float32 weights (a new pattern _rms_norm_pattern_float32_weights), switch the backend to triton, and include a dedicated NemotronH_RMSNorm test class with specific test coverage for the Nemotron variant. While the changeset also includes a batch size optimization in the MOE kernel and other supporting changes, the primary focus is enabling RMS norm fusion functionality for Nemotron MOE models that previously wasn't supported. The title is concise, specific, and clearly communicates the feature being enabled and the target model.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py (1)

1-1: Missing NVIDIA Apache-2.0 header (2025)

Add the required header at the top of the file.

+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.py (1)

1-1: Missing NVIDIA Apache-2.0 header (2025)

Add the required header at the top.

+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py (1)

1-1: Missing NVIDIA Apache-2.0 header (2025)

Add the required header.

+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py (1)

1-1: Missing NVIDIA Apache-2.0 header (2025)

Please add the standard header to this test file as well (policy applies to .py).

+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
🧹 Nitpick comments (7)
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py (1)

253-258: EM upper-bound is safe; clarify assumption and guard when not capturing

  • The bound EM = min(max_padded, Mtop_kBLOCK_SIZE_M) is a valid upper bound (actual padded tokens ≤ T*BLOCK_SIZE_M), so it won’t skip real work. The current comment about “top_ids … unique” is misleading.
  • Suggest clearer comment and (optionally) a debug assert when not under CUDA graph capture.

Apply this diff to clarify and add a guarded assert:

-    if A.size(0) < config["BLOCK_SIZE_M"]:
-        # optimize for small batch_size.
-        # We assume that top_ids of each token is unique,
-        # so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
-        # and we can skip some invalid blocks.
-        EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
+    if A.size(0) < config["BLOCK_SIZE_M"]:
+        # Optimize for small batches by shrinking the launch grid.
+        # Let T = M * top_k. After per‑expert padding, the total tokens processed
+        # EM_actual ≤ T * BLOCK_SIZE_M and ≤ sorted_token_ids.size(0).
+        # Using EM = min(sorted_token_ids.size(0), T * BLOCK_SIZE_M) preserves correctness
+        # while cutting empty blocks when E » M.
+        EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
+        # Optional safety check (avoid host sync during CUDA graph capture):
+        # if not torch.cuda.is_current_stream_capturing():
+        #     EM_actual = int(num_tokens_post_padded.item())
+        #     assert EM >= EM_actual

Please run a quick A/B to confirm numerical parity with and without this gate for small M (e.g., M ∈ {1, 2, 8}, various top_k and E).

tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)

127-127: Backend switched to Triton: ensure env/op availability

Switching fuse_rmsnorm to backend: triton is fine; please confirm torch.ops.auto_deploy.triton_rms_norm is registered across all build variants and CI images. Also update any user docs referencing FlashInfer as default.

tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.py (2)

17-19: Docstring good; minor clarity

Docstring now reflects f32 weights. Consider noting output is cast back to input dtype.


31-31: Cast weight once; ensure eps in f32

Tiny cleanup: cast w to f32 at load and multiply directly; also make eps explicitly f32 to avoid mixed-precision surprises.

-    out = (w.to(tl.float32) * out).to(x.dtype)
+    w = w.to(tl.float32)
+    out = (w * out).to(x.dtype)

Optionally:

-    out = xf / tl.sqrt(var + eps)
+    out = xf / tl.sqrt(var + tl.full((), eps, tl.float32))
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py (2)

155-168: Avoid duplicate registrations and stale docstring

  • Add skip_duplicates=True to prevent redundant pattern registrations across dtype variants.
  • Top-of-file docstring still says “using FlashInfer”; update to mention Triton.
-        search_fns = [
+        search_fns = [
             _rms_norm_pattern,
             _rms_norm_pattern_float32_weights,
         ]
         for search_fn in search_fns:
             for input_dtype, weight_dtype in configs:
                 register_ad_pattern(
                     search_fn=search_fn,
                     replace_fn=partial(_rms_norm_replacement, backend=self.config.backend),
                     patterns=patterns,
                     dummy_args=dummy_args(input_dtype, weight_dtype),
-                    op_ignore_types={},
+                    op_ignore_types={},
                     scalar_workaround={"eps": 1e-6},
+                    skip_duplicates=True,
                 )

1-1: Docstring mentions only FlashInfer

Update to “FlashInfer or Triton” to reflect current behavior.

-"""Graph transform to optimize RMSNorm execution using FlashInfer."""
+"""Graph transform to optimize RMSNorm execution using FlashInfer or Triton."""
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py (1)

89-99: Also parametrize Triton on the baseline path

Add (“triton”, torch.ops.auto_deploy.triton_rms_norm) to ensure non-Nemotron path is covered with the new default backend.

 @pytest.mark.parametrize("eps", [1e-2, 1e-6])
 @pytest.mark.parametrize(
     "variant, op",
     [
         ("flashinfer", torch.ops.auto_deploy.flashinfer_rms_norm),
         ("torch", torch.ops.auto_deploy.torch_rmsnorm),
+        ("triton", torch.ops.auto_deploy.triton_rms_norm),
     ],
 )
 def test_rmsnorm_fusion(eps, variant, op):
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2b4e812 and d50402e.

📒 Files selected for processing (5)
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py (2 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py (3 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py
🧠 Learnings (1)
📚 Learning: 2025-10-20T17:09:21.560Z
Learnt from: nvchenghaoz
PR: NVIDIA/TensorRT-LLM#8469
File: tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py:180-182
Timestamp: 2025-10-20T17:09:21.560Z
Learning: In tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py, the _gated_rmsnorm_replacement function does not need to cast the output of torch.ops.auto_deploy.torch_rmsnorm_gated back to the input dtype, even though the custom op returns fp32. The dtype handling is managed elsewhere or the fp32 output is acceptable for downstream consumers.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.py
🧬 Code graph analysis (3)
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py (1)
tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (1)
  • register_ad_pattern (99-182)
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py (1)
cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h (1)
  • top_k (233-233)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py (2)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py (4)
  • forward (24-29)
  • forward (39-43)
  • forward (53-57)
  • RMSNorm (16-29)
tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py (1)
  • torch_rmsnorm (65-77)
🔇 Additional comments (5)
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py (1)

49-67: Float32-weight pattern looks correct

Pattern matches Nemotron-H style (f32 weight multiply) and returns to input dtype. LGTM.

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py (4)

26-38: NemotronH_RMSNorm test variant: LGTM

Covers float32-weight path and dtype round-trip.


41-48: Model wiring: LGTM

Conditional Nemotron-H selection is clear and keeps baseline identical.


57-87: Test harness is clean

Good reuse via _run_test; numeric check thresholds look fine.


102-105: Nemotron-H Triton-only test: LGTM

Asserts backend restriction and parity.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22201 [ run ] completed with state SUCCESS. Commit: c741482
/LLM/main/L0_MergeRequest_PR pipeline #16740 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@suyoggupta
Copy link
Collaborator Author

/bot reuse-pipeline

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22224 [ reuse-pipeline ] triggered by Bot. Commit: d3f3b99

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22224 [ reuse-pipeline ] completed with state SUCCESS. Commit: d3f3b99
Reusing PR_Github #22201 for commit d3f3b99

@suyoggupta
Copy link
Collaborator Author

/bot reuse-pipeline

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22228 [ reuse-pipeline ] triggered by Bot. Commit: 7ea21cb

@suyoggupta suyoggupta enabled auto-merge (squash) October 23, 2025 02:23
@tensorrt-cicd
Copy link
Collaborator

PR_Github #22228 [ reuse-pipeline ] completed with state SUCCESS. Commit: 7ea21cb
Reusing PR_Github #22201 for commit 7ea21cb

@suyoggupta
Copy link
Collaborator Author

/bot reuse-pipeline

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22235 [ reuse-pipeline ] triggered by Bot. Commit: 804736e

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22235 [ reuse-pipeline ] completed with state SUCCESS. Commit: 804736e
Reusing PR_Github #22201 for commit 804736e

@suyoggupta suyoggupta merged commit 2956978 into NVIDIA:main Oct 23, 2025
5 checks passed
yufeiwu-nv pushed a commit to yufeiwu-nv/TensorRT-LLM that referenced this pull request Oct 24, 2025
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: junq <[email protected]>
Signed-off-by: Lucas Liebenwein <[email protected]>
Signed-off-by: Suyog Gupta <[email protected]>
Co-authored-by: Chenghao Zhang <[email protected]>
Co-authored-by: QI JUN <[email protected]>
Co-authored-by: Lucas Liebenwein <[email protected]>
Signed-off-by: yufeiwu-nv <[email protected]>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Nov 1, 2025
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: junq <[email protected]>
Signed-off-by: Lucas Liebenwein <[email protected]>
Signed-off-by: Suyog Gupta <[email protected]>
Co-authored-by: Chenghao Zhang <[email protected]>
Co-authored-by: QI JUN <[email protected]>
Co-authored-by: Lucas Liebenwein <[email protected]>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Nov 3, 2025
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: junq <[email protected]>
Signed-off-by: Lucas Liebenwein <[email protected]>
Signed-off-by: Suyog Gupta <[email protected]>
Co-authored-by: Chenghao Zhang <[email protected]>
Co-authored-by: QI JUN <[email protected]>
Co-authored-by: Lucas Liebenwein <[email protected]>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Nov 3, 2025
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: junq <[email protected]>
Signed-off-by: Lucas Liebenwein <[email protected]>
Signed-off-by: Suyog Gupta <[email protected]>
Co-authored-by: Chenghao Zhang <[email protected]>
Co-authored-by: QI JUN <[email protected]>
Co-authored-by: Lucas Liebenwein <[email protected]>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Nov 3, 2025
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: junq <[email protected]>
Signed-off-by: Lucas Liebenwein <[email protected]>
Signed-off-by: Suyog Gupta <[email protected]>
Co-authored-by: Chenghao Zhang <[email protected]>
Co-authored-by: QI JUN <[email protected]>
Co-authored-by: Lucas Liebenwein <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants