-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[None][feat] Enable rms norm fusion for Nemotron MOE #8563
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: junq <[email protected]>
Signed-off-by: Lucas Liebenwein <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Suyog Gupta <[email protected]>
Signed-off-by: Suyog Gupta <[email protected]>
Signed-off-by: Suyog Gupta <[email protected]>
|
/bot run |
|
PR_Github #22201 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughThe pull request updates the default post-load fusion backend from flashinfer to triton, optimizes the MoE kernel for small batch sizes, enhances RMS norm weight handling with float32 precision, adds a new float32-weight pattern matching function for RMS norm, and restructures RMS norm fusion tests to support triton backend validation. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~22 minutes Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Warning Tools execution failed with the following error: Failed to run tools: Stream setup permanently failed: Ping-pong health check failed Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (3)
tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.py (1)
17-19: Consider expanding the docstring to follow Google-style format.While the current docstring documents the key behavior, consider adding parameter descriptions to align with coding guidelines for Google-style docstrings.
Example expansion:
- """Rms norm kernel. - Forces weights to be in float32 for the kernel. - """ + """RMS normalization kernel. + + Applies RMS normalization with weights forced to float32 precision during + computation for numerical accuracy. + + Args: + output: Output tensor pointer. + input: Input tensor pointer. + weights: Weight tensor pointer. + eps: Epsilon value for numerical stability. + M: Number of rows. + N_COLS: Number of columns. + BLOCK_N: Block size for Triton kernel. + """As per coding guidelines
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py (2)
29-29: Consider parameterizing the device.The device is hard-coded to
"cuda". While this is acceptable in a test file undersinglegpu, consider accepting adeviceparameter for consistency with theTestModelpattern where other modules usedevice="cuda"inline.
37-37: Consider adding a comment about float32 weight precision.The weight is explicitly cast to float32 during computation even though the module may be stored in float16. Adding an inline comment would clarify this intentional design choice.
For example:
- return (self.weight.to(torch.float32) * hidden_states).to(input_dtype) + # Use float32 for weights to match Nemotron-style RMSNorm computation + return (self.weight.to(torch.float32) * hidden_states).to(input_dtype)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
tensorrt_llm/_torch/auto_deploy/config/default.yaml(1 hunks)tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py(1 hunks)tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.py(2 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py(2 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py(3 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Use only spaces, no tabs; indent with 4 spaces.
Files:
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.pytensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.pytensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.pytests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.
Files:
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.pytensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.pytensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.pytests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).
Files:
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.pytensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.pytensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.pytests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py
🧠 Learnings (1)
📚 Learning: 2025-10-20T17:09:21.560Z
Learnt from: nvchenghaoz
PR: NVIDIA/TensorRT-LLM#8469
File: tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py:180-182
Timestamp: 2025-10-20T17:09:21.560Z
Learning: In tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py, the _gated_rmsnorm_replacement function does not need to cast the output of torch.ops.auto_deploy.torch_rmsnorm_gated back to the input dtype, even though the custom op returns fp32. The dtype handling is managed elsewhere or the fp32 output is acceptable for downstream consumers.
Applied to files:
tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.pytensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
🔇 Additional comments (7)
tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)
127-127: Verify Triton backend is production-ready for RMS norm fusion patterns.This configuration change switches the post_load_fusion backend from flashinfer to triton. While this aligns with the PR objective to enable rms norm fusion for Nemotron MOE via Triton, ensure that:
- The Triton backend is compatible with all RMS norm fusion patterns (especially Nemotron-style patterns mentioned in the PR)
- Related Triton kernel updates for float32 weight handling are properly integrated
- Unit tests validate this backend switch across different RMS norm patterns
- No performance regressions occur from the flashinfer → triton switch
Since only this configuration file is provided, please confirm that:
- Related Triton kernel and pattern-matching code changes are properly tested
- The RMS norm fusion tests cover both the float32 weight handling and Triton backend execution
- Any performance benchmarks comparing flashinfer vs. triton backend are acceptable
You can run the following to verify test coverage for RMS norm fusion with the new backend:
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py (1)
253-258: Verify variable scope and formula correctness.Three concerns with this optimization:
Variable scope:
top_kis referenced but not shown in the provided context. Verify it's in scope (function parameter or defined earlier).Tensor size inconsistency: Line 250 uses
sorted_token_ids.numel()while line 258 usessorted_token_ids.size(0). Ifsorted_token_idsis multidimensional, these return different values. Use consistent methods for correctness.Formula verification: The calculation
A.size(0) * top_k * config["BLOCK_SIZE_M"]seems large for an upper bound on expert assignments. Given that each ofA.size(0)tokens routes totop_kexperts, the natural upper bound would beA.size(0) * top_k, not scaled byBLOCK_SIZE_M. Please verify this formula is correct for the intended optimization.Run the following script to verify the formula and variable usage:
tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.py (1)
31-31: LGTM! Explicit float32 casting for weights aligns with PR objectives.The change ensures weights are explicitly cast to float32 before multiplication, maintaining precision during the computation while correctly casting back to the input dtype. This aligns with the PR objective to enable float32 weight handling for Nemotron MOE RMS norm fusion.
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py (2)
49-67: LGTM! New pattern function correctly implements float32 weight handling.The implementation properly matches Nemotron-style RMSNorm by explicitly casting weights to float32 before scaling. The apparent duplication with the existing
_rms_norm_patternis intentional and necessary—pattern matching requires exact computation graph matches, so two distinct patterns enable coverage of both weight handling approaches.
155-168: LGTM! Registration logic correctly handles both pattern variants.The two-tier loop structure properly registers both pattern functions across all input/weight dtype configurations, enabling comprehensive matching of both standard and Nemotron-style RMSNorm implementations.
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py (2)
89-105: LGTM! Well-structured test parameterization.The test restructuring effectively covers multiple scenarios:
- Two epsilon values (1e-2, 1e-6)
- Multiple backend variants (flashinfer, torch, triton)
- Dedicated test for NemotronH variant with Triton backend
The comment on line 103 helpfully clarifies the Triton backend requirement for the Nemotron H RMSNorm variant.
57-57: Add a docstring for the helper function.The
_run_testhelper function is missing a docstring. The coding guidelines require Google-style docstrings for functions.Apply this diff:
def _run_test(model, op, variant): + """Run RMSNorm fusion test for a given model and backend variant. + + Args: + model: The test model instance. + op: The expected operator after fusion. + variant: The backend variant name (e.g., 'flashinfer', 'torch', 'triton'). + """ def checker(gm): return any(is_op(n, op) for n in gm.graph.nodes)As per coding guidelines.
Likely an incorrect or invalid review comment.
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py
Show resolved
Hide resolved
📝 WalkthroughWalkthroughThis pull request implements a triton-based backend for RMSNorm fusion with explicit float32 weight handling. Changes include switching the backend configuration, optimizing triton kernels for small batches, adding float32 weight casting in kernel logic, extending pattern matching for float32-weight variants, and expanding test coverage with Nemotron variant support. Changes
Sequence DiagramsequenceDiagram
participant Config as Config Layer
participant PatternMatch as Pattern Matching
participant KernelFP32 as Float32<br/>RMSNorm Kernel
participant KernelDefault as Default<br/>RMSNorm Kernel
Config->>PatternMatch: Backend: triton
PatternMatch->>PatternMatch: Try _rms_norm_pattern_float32_weights
alt Float32 Weights Detected
PatternMatch->>KernelFP32: Cast weights to float32<br/>before multiplication
KernelFP32-->>PatternMatch: Optimized scaled output
else Default Weights
PatternMatch->>KernelDefault: Original pattern logic
KernelDefault-->>PatternMatch: Scaled output
end
PatternMatch-->>Config: Fusion complete
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~22 minutes The changes span multiple functional areas—configuration, kernel optimization, pattern matching logic, and test coverage—with moderate logic density per file. While individual edits are straightforward (type casting, batch size gating, pattern registration), the heterogeneity across components and introduction of the new float32-weight variant requires separate reasoning for each cohort. Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py (1)
1-1: Missing NVIDIA Apache-2.0 header (2025)Add the required header at the top of the file.
+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.py (1)
1-1: Missing NVIDIA Apache-2.0 header (2025)Add the required header at the top.
+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py (1)
1-1: Missing NVIDIA Apache-2.0 header (2025)Add the required header.
+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py (1)
1-1: Missing NVIDIA Apache-2.0 header (2025)Please add the standard header to this test file as well (policy applies to .py).
+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.
🧹 Nitpick comments (7)
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py (1)
253-258: EM upper-bound is safe; clarify assumption and guard when not capturing
- The bound EM = min(max_padded, Mtop_kBLOCK_SIZE_M) is a valid upper bound (actual padded tokens ≤ T*BLOCK_SIZE_M), so it won’t skip real work. The current comment about “top_ids … unique” is misleading.
- Suggest clearer comment and (optionally) a debug assert when not under CUDA graph capture.
Apply this diff to clarify and add a guarded assert:
- if A.size(0) < config["BLOCK_SIZE_M"]: - # optimize for small batch_size. - # We assume that top_ids of each token is unique, - # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, - # and we can skip some invalid blocks. - EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"]) + if A.size(0) < config["BLOCK_SIZE_M"]: + # Optimize for small batches by shrinking the launch grid. + # Let T = M * top_k. After per‑expert padding, the total tokens processed + # EM_actual ≤ T * BLOCK_SIZE_M and ≤ sorted_token_ids.size(0). + # Using EM = min(sorted_token_ids.size(0), T * BLOCK_SIZE_M) preserves correctness + # while cutting empty blocks when E » M. + EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"]) + # Optional safety check (avoid host sync during CUDA graph capture): + # if not torch.cuda.is_current_stream_capturing(): + # EM_actual = int(num_tokens_post_padded.item()) + # assert EM >= EM_actualPlease run a quick A/B to confirm numerical parity with and without this gate for small M (e.g., M ∈ {1, 2, 8}, various top_k and E).
tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)
127-127: Backend switched to Triton: ensure env/op availabilitySwitching fuse_rmsnorm to backend: triton is fine; please confirm torch.ops.auto_deploy.triton_rms_norm is registered across all build variants and CI images. Also update any user docs referencing FlashInfer as default.
tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.py (2)
17-19: Docstring good; minor clarityDocstring now reflects f32 weights. Consider noting output is cast back to input dtype.
31-31: Cast weight once; ensure eps in f32Tiny cleanup: cast w to f32 at load and multiply directly; also make eps explicitly f32 to avoid mixed-precision surprises.
- out = (w.to(tl.float32) * out).to(x.dtype) + w = w.to(tl.float32) + out = (w * out).to(x.dtype)Optionally:
- out = xf / tl.sqrt(var + eps) + out = xf / tl.sqrt(var + tl.full((), eps, tl.float32))tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py (2)
155-168: Avoid duplicate registrations and stale docstring
- Add skip_duplicates=True to prevent redundant pattern registrations across dtype variants.
- Top-of-file docstring still says “using FlashInfer”; update to mention Triton.
- search_fns = [ + search_fns = [ _rms_norm_pattern, _rms_norm_pattern_float32_weights, ] for search_fn in search_fns: for input_dtype, weight_dtype in configs: register_ad_pattern( search_fn=search_fn, replace_fn=partial(_rms_norm_replacement, backend=self.config.backend), patterns=patterns, dummy_args=dummy_args(input_dtype, weight_dtype), - op_ignore_types={}, + op_ignore_types={}, scalar_workaround={"eps": 1e-6}, + skip_duplicates=True, )
1-1: Docstring mentions only FlashInferUpdate to “FlashInfer or Triton” to reflect current behavior.
-"""Graph transform to optimize RMSNorm execution using FlashInfer.""" +"""Graph transform to optimize RMSNorm execution using FlashInfer or Triton."""tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py (1)
89-99: Also parametrize Triton on the baseline pathAdd (“triton”, torch.ops.auto_deploy.triton_rms_norm) to ensure non-Nemotron path is covered with the new default backend.
@pytest.mark.parametrize("eps", [1e-2, 1e-6]) @pytest.mark.parametrize( "variant, op", [ ("flashinfer", torch.ops.auto_deploy.flashinfer_rms_norm), ("torch", torch.ops.auto_deploy.torch_rmsnorm), + ("triton", torch.ops.auto_deploy.triton_rms_norm), ], ) def test_rmsnorm_fusion(eps, variant, op):
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
tensorrt_llm/_torch/auto_deploy/config/default.yaml(1 hunks)tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py(1 hunks)tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.py(2 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py(2 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py(3 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Use only spaces, no tabs; indent with 4 spaces.
Files:
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.pytensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.pytensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.pytests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.
Files:
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.pytensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.pytensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.pytests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).
Files:
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.pytensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.pytensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.pytests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py
🧠 Learnings (1)
📚 Learning: 2025-10-20T17:09:21.560Z
Learnt from: nvchenghaoz
PR: NVIDIA/TensorRT-LLM#8469
File: tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py:180-182
Timestamp: 2025-10-20T17:09:21.560Z
Learning: In tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py, the _gated_rmsnorm_replacement function does not need to cast the output of torch.ops.auto_deploy.torch_rmsnorm_gated back to the input dtype, even though the custom op returns fp32. The dtype handling is managed elsewhere or the fp32 output is acceptable for downstream consumers.
Applied to files:
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.pytensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/rms_norm.py
🧬 Code graph analysis (3)
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py (1)
tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (1)
register_ad_pattern(99-182)
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py (1)
cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h (1)
top_k(233-233)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py (2)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py (4)
forward(24-29)forward(39-43)forward(53-57)RMSNorm(16-29)tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py (1)
torch_rmsnorm(65-77)
🔇 Additional comments (5)
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py (1)
49-67: Float32-weight pattern looks correctPattern matches Nemotron-H style (f32 weight multiply) and returns to input dtype. LGTM.
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py (4)
26-38: NemotronH_RMSNorm test variant: LGTMCovers float32-weight path and dtype round-trip.
41-48: Model wiring: LGTMConditional Nemotron-H selection is clear and keeps baseline identical.
57-87: Test harness is cleanGood reuse via _run_test; numeric check thresholds look fine.
102-105: Nemotron-H Triton-only test: LGTMAsserts backend restriction and parity.
|
PR_Github #22201 [ run ] completed with state |
|
/bot reuse-pipeline |
|
PR_Github #22224 [ reuse-pipeline ] triggered by Bot. Commit: |
|
PR_Github #22224 [ reuse-pipeline ] completed with state |
|
/bot reuse-pipeline |
|
PR_Github #22228 [ reuse-pipeline ] triggered by Bot. Commit: |
|
PR_Github #22228 [ reuse-pipeline ] completed with state |
|
/bot reuse-pipeline |
|
PR_Github #22235 [ reuse-pipeline ] triggered by Bot. Commit: |
|
PR_Github #22235 [ reuse-pipeline ] completed with state |
Signed-off-by: Chenghao Zhang <[email protected]> Signed-off-by: junq <[email protected]> Signed-off-by: Lucas Liebenwein <[email protected]> Signed-off-by: Suyog Gupta <[email protected]> Co-authored-by: Chenghao Zhang <[email protected]> Co-authored-by: QI JUN <[email protected]> Co-authored-by: Lucas Liebenwein <[email protected]> Signed-off-by: yufeiwu-nv <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]> Signed-off-by: junq <[email protected]> Signed-off-by: Lucas Liebenwein <[email protected]> Signed-off-by: Suyog Gupta <[email protected]> Co-authored-by: Chenghao Zhang <[email protected]> Co-authored-by: QI JUN <[email protected]> Co-authored-by: Lucas Liebenwein <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]> Signed-off-by: junq <[email protected]> Signed-off-by: Lucas Liebenwein <[email protected]> Signed-off-by: Suyog Gupta <[email protected]> Co-authored-by: Chenghao Zhang <[email protected]> Co-authored-by: QI JUN <[email protected]> Co-authored-by: Lucas Liebenwein <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]> Signed-off-by: junq <[email protected]> Signed-off-by: Lucas Liebenwein <[email protected]> Signed-off-by: Suyog Gupta <[email protected]> Co-authored-by: Chenghao Zhang <[email protected]> Co-authored-by: QI JUN <[email protected]> Co-authored-by: Lucas Liebenwein <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]> Signed-off-by: junq <[email protected]> Signed-off-by: Lucas Liebenwein <[email protected]> Signed-off-by: Suyog Gupta <[email protected]> Co-authored-by: Chenghao Zhang <[email protected]> Co-authored-by: QI JUN <[email protected]> Co-authored-by: Lucas Liebenwein <[email protected]>
Add pattern matching for Nemotron style RMSnorm.
Update triton kernel to use float32 for weights
Update unit test
Summary by CodeRabbit
Performance
Improvements
Tests