-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[None][feat] MNNVLAllreduce Kernel Refactor #8018
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Shiyu Li <[email protected]>
Signed-off-by: Shiyu Li <[email protected]>
Signed-off-by: Shiyu Li <[email protected]>
Signed-off-by: Shiyu Li <[email protected]>
Signed-off-by: Shiyu Li <[email protected]>
Signed-off-by: Shiyu Li <[email protected]>
Signed-off-by: Shiyu Li <[email protected]>
Signed-off-by: Shiyu Li <[email protected]>
Signed-off-by: Shiyu Li <[email protected]>
Signed-off-by: Shiyu Li <[email protected]>
Signed-off-by: Shiyu Li <[email protected]>
Signed-off-by: Shiyu Li <[email protected]>
Signed-off-by: Shiyu Li <[email protected]>
Signed-off-by: Shiyu Li <[email protected]>
Signed-off-by: Shiyu Li <[email protected]>
📝 WalkthroughWalkthroughIntroduces Lamport-based synchronization utilities. Replaces legacy two-shot allreduce and RMSNorm paths with fused oneshot/twoshot allreduce entry points and a unified parameter struct. Updates Torch C++ op, Python custom op registration, distributed workspace management, and tests. Removes old two-shot CUDA file. Minor binding/formatting edits. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Host as Host (thop::mnnvlFusionAllReduce)
participant Kern as oneshot_allreduce_fusion_op
participant Dev as CUDA Grid
participant Lam as LamportFlags
Host->>Kern: Build AllReduceFusionParams
Kern->>Dev: Launch fused oneshot kernel (dtype, WORLD_SIZE)
rect rgba(221,238,255,0.4)
note over Dev: Stage-less fused allreduce
Dev->>Lam: cta_arrive()
Lam-->>Dev: flag advanced / current buf index
Dev->>Dev: Allreduce + optional RMSNorm (if enabled)
Dev->>Lam: wait_and_update()
end
Dev-->>Host: outputs (output, residual_out)
sequenceDiagram
autonumber
participant Host as Host (thop::mnnvlFusionAllReduce)
participant Kern as twoshot_allreduce_fusion_op
participant Grid1 as Kernel: SCATTER
participant Grid2 as Kernel: BROADCAST
participant Lam as LamportFlags
participant RMS as (optional) RMSNorm Kernel
Host->>Kern: Params (incl. rmsnorm_fusion, buffer_flags)
Kern->>Grid1: Launch SCATTER
rect rgba(232,246,221,0.5)
note over Grid1,Lam: Two-shot Stage 0
Grid1->>Lam: cta_arrive()
Lam-->>Grid1: current/dirty buf
Grid1->>Grid1: Scatter + partial reduce
Grid1->>Lam: wait_and_update()
end
Kern->>Grid2: Launch BROADCAST
rect rgba(255,243,205,0.5)
note over Grid2,Lam: Two-shot Stage 1
Grid2->>Lam: cta_arrive()
Grid2->>Grid2: Broadcast + finalize
Grid2->>Lam: wait_and_update()
end
alt rmsnorm_fusion
Kern->>RMS: Launch RMSNorm fused path
RMS-->>Host: output, residual_out
else
Grid2-->>Host: output, residual_out
end
sequenceDiagram
autonumber
participant Py as Python (distributed.ops)
participant WS as Workspace Manager
participant Cpp as Torch C++ Op
participant Kern as Fused Kernel(s)
Py->>WS: get_or_scale_allreduce_mnnvl_workspace(dtype, size?)
WS-->>Py: {mcast_buffer, comm_buffer, buffer_flags, size}
Py->>Cpp: mnnvl_fusion_allreduce(input, residual?, gamma?, epsilon?, buffers, flags, rmsnorm)
Cpp->>Kern: oneshot or twoshot dispatch
Kern-->>Cpp: [output, residual_out]
Cpp-->>Py: [output, residual_out]
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60–90 minutes Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 10
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (7)
cpp/tensorrt_llm/runtime/mcastDeviceMemory.cpp (2)
158-162: Fix OOB pointer arithmetic in fabric-handle exchange (double-scaling by sizeof).Pointer arithmetic on
exphndlmultiplies bysizeof(CUmemFabricHandle)twice, causing out-of-bounds writes and incorrect allgather offsets. Use element indexing or byte pointers.Apply this diff:
- cudaMallocHost(&exphndl, mGroupSize * sizeof(CUmemFabricHandle)); - memcpy(exphndl + mGroupRank * sizeof(CUmemFabricHandle), &myhndl, sizeof(CUmemFabricHandle)); - mGroupComm.allgather( - exphndl + mGroupRank * sizeof(CUmemFabricHandle), exphndl, sizeof(CUmemFabricHandle), mpi::MpiType::kCHAR); + cudaMallocHost(&exphndl, mGroupSize * sizeof(CUmemFabricHandle)); + memcpy(&exphndl[mGroupRank], &myhndl, sizeof(CUmemFabricHandle)); + mGroupComm.allgather( + reinterpret_cast<char*>(&exphndl[mGroupRank]), + reinterpret_cast<char*>(exphndl), + sizeof(CUmemFabricHandle), + mpi::MpiType::kCHAR);
143-149: Avoid power-of-two-only rounding for driver-provided granularities.
roundUp()assumes power-of-two granularity;alloc_granularityandmc_granularityare not guaranteed to be powers of two. This can under/over-allocate and break address mapping.Suggested safe rounding helper (outside this range):
// Replace roundUp() with arithmetic rounding that doesn't assume power-of-two. inline size_t roundUp(size_t val, size_t gran) { if (gran == 0) { return val; } return ((val + gran - 1) / gran) * gran; }Then the existing calls here remain correct.
cpp/tensorrt_llm/pybind/runtime/bindings.cpp (3)
200-206: Compile error: pybind class helper typo (py::classh→py::class_).
py::classhis invalid and will not compile.- py::classh<tr::ITensor, PyITensor>(m, "ITensor").def(py::init()); + py::class_<tr::ITensor, PyITensor>(m, "ITensor").def(py::init());
224-227: Compile error: pybind class helper typo (py::classh→py::class_).- py::classh<tr::BufferManager>(m, "BufferManager") + py::class_<tr::BufferManager>(m, "BufferManager")
229-239: Compile error: pybind class helper typo (py::classh→py::class_).- py::classh<tr::TllmRuntime>(m, "TllmRuntime") + py::class_<tr::TllmRuntime>(m, "TllmRuntime")cpp/tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.h (2)
2-2: Update copyright year.Header should reflect current year per guidelines.
- * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.As per coding guidelines
16-16: Use include guards instead of pragma once.Project guideline requires TRTLLM_* include guards for headers.
-#pragma once +#ifndef TRTLLM_MNNVLALLREDUCEKERNELS_H +#define TRTLLM_MNNVLALLREDUCEKERNELS_HAnd add at file end:
-} // namespace tensorrt_llm::kernels::mnnvl +} // namespace tensorrt_llm::kernels::mnnvl + +#endif // TRTLLM_MNNVLALLREDUCEKERNELS_HAs per coding guidelines
🧹 Nitpick comments (17)
cpp/tensorrt_llm/common/lamportUtils.cuh (6)
34-56: Type naming and constants style.
- Type
fp16_bit_castshould beFp16BitCast(type names CamelCase).- Constant
NEGZERO_FP16should bekNEGZERO_FP16(constants use k-prefixed UPPER_SNAKE_CASE).As per coding guidelines
Example:
-constexpr uint16_t NEGZERO_FP16 = 0x8000U; +constexpr uint16_t kNEGZERO_FP16 = 0x8000U; -template <typename T> -union fp16_bit_cast +template <typename T> +union Fp16BitCast { ... };And update usages accordingly.
58-75: Use renamed helpers and ensureconstexprpaths compile on host.Replace
fp16_bit_castand updated constant names innegZero(); also ensure host compilation of half/bfloat unions is allowed in your toolchain.Example:
- return fp16_bit_cast<T>(NEGZERO_FP16).fp; + return Fp16BitCast<T>(kNEGZERO_FP16).fp;
222-229: Optional: remove unusedgridin CGA path.
cg::grid_group grid = cg::this_grid();is unused incta_arrive(). Consider removing to avoid warnings.
255-261: Avoid magic number 3; use the declared constant for buffer count.Hardcoding
3weakens maintainability.- flag_ptr[0] = {(mCurrentIndex + 1) % 3, // Current index + flag_ptr[0] = {(mCurrentIndex + 1) % LamportBufferLayout::num_lamport_buffers, // Current index mCurrentIndex, // Dirty index mCurBufferLayout.bytes_per_buffer, // Buffer size mCurBufferLayout.num_stages}; // Dirty - Number of stages
128-141: Double-checkgetTotalBytes()math for non-divisible stage sizes.
bytes_per_buffer / num_stagestruncates. Ifbytes_per_bufferisn’t divisible bynum_stages, total size becomes inconsistent. Consider storing and using the per-stage size explicitly.
273-285: Minor: naming and typos.
clear_boundry→clear_boundary.- Consider
constexpr auto elems = sizeof(PackedType)/sizeof(float);if specializinggetPackedLamportInitto a matching element type is required elsewhere.tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py (1)
112-114: Scope env var to test to avoid bleed‑through.Setting TLLM_TEST_MNNVL globally can impact later tests in the same process. Prefer a scoped set/unset around the call or use a fixture to restore the env.
Apply this minimal change:
- elif strategy == AllReduceStrategy.MNNVL: - os.environ["TLLM_TEST_MNNVL"] = "1" + elif strategy == AllReduceStrategy.MNNVL: + prev = os.environ.get("TLLM_TEST_MNNVL") + os.environ["TLLM_TEST_MNNVL"] = "1" + try: + pass # set just before barrier; restored after + finally: + if prev is None: + os.environ.pop("TLLM_TEST_MNNVL", None) + else: + os.environ["TLLM_TEST_MNNVL"] = prevcpp/tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.cu (3)
24-25: Remove unused NVML include.nvml.h is not used in this TU and adds an unnecessary dependency.
-#include <nvml.h>
308-314: Drop unused variable.clear_token_stride is computed but never used, causing warnings.
- int clear_token_stride = gridDim.x;
696-894: RMSNorm Lamport kernel looks correct; consider minor readability nits.The volatile single-lane validity check assumes 16B atomic writes; acceptable per design. Consider a brief comment near value.x validity to document the 16B write assumption.
cpp/tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.h (1)
27-57: Document new public types and params.Enum and struct should have Doxygen comments (//!, //!<) to clarify stage semantics and buffer/flag fields.
-enum MNNVLTwoShotStage : uint8_t +//! \brief Two-shot pipeline stages for Lamport-managed buffers. +enum MNNVLTwoShotStage : uint8_t { - SCATTER = 0, - BROADCAST = 1, - NUM_STAGES = 2, + SCATTER = 0, //!< Stage 0: rank-scatter to destination shards + BROADCAST = 1, //!< Stage 1: broadcast reduced shards + NUM_STAGES = 2 //!< Total number of stages }; -struct AllReduceFusionParams +//! \brief Parameters for MNNVL fused all-reduce (oneshot/twoshot) and optional RMSNorm fusion. +struct AllReduceFusionParams { - // Environmental/Aux data + //! \name Environment/Aux + //! @{ int nranks; int rank; nvinfer1::DataType dtype; int num_tokens; int token_dim; - void** buffer_ptrs_dev; + void** buffer_ptrs_dev; //!< Device array of per-rank UC pointers void* buffer_ptr_local; void* multicast_ptr; - uint32_t* buffer_flags; - bool rmsnorm_fusion; + uint32_t* buffer_flags; //!< Lamport flags buffer + bool rmsnorm_fusion; //!< Enable residual+RMSNorm fusion + //! @} - // Input and output data + //! \name IO + //! @{ void const* input; void const* residual_in; void const* gamma; double epsilon; void* residual_out; void* output; cudaStream_t stream; + //! @} };As per coding guidelines
cpp/tensorrt_llm/thop/allreduceOp.cpp (2)
1159-1164: Validate contiguity of optional inputs when fused.When rmsnorm_fusion is true, also check gamma and residual_in contiguity to prevent unexpected strided reads.
- TORCH_CHECK(input.is_contiguous(), "[mnnvlFusionAllReduce] input must be contiguous"); + TORCH_CHECK(input.is_contiguous(), "[mnnvlFusionAllReduce] input must be contiguous"); + if (rmsnorm_fusion) { + TORCH_CHECK(gamma->is_contiguous(), "[mnnvlFusionAllReduce] gamma must be contiguous"); + TORCH_CHECK(residual_in->is_contiguous(), "[mnnvlFusionAllReduce] residual must be contiguous"); + }
1203-1211: Name the oneshot threshold to avoid magic numbers.Helps keep Python and C++ in sync.
- // FIXME: Find a better heuristic - if (num_tokens * hidden_dim * allreduce_params.nranks * input.itemsize() <= 64 * 1024 * 8) + // FIXME: Find a better heuristic + constexpr size_t kOneShotBytesThreshold = 64 * 1024 * 8; // keep in sync with Python helper + if (num_tokens * hidden_dim * allreduce_params.nranks * input.itemsize() <= kOneShotBytesThreshold)Please add a shared constant (header) if this heuristic stabilizes.
tensorrt_llm/_torch/distributed/ops.py (4)
118-120: MPI communicator leak (minor).Split communicator is never freed; can accumulate in long-lived processes.
- comm.Barrier() + comm.Barrier() + try: + pass + finally: + try: + comm.Free() + except Exception: + pass
124-131: Document buffer_flags layout; consider named struct.The packed uint32 layout is opaque. At minimum add a brief docstring or constants for indices.
- buffer_flags = torch.tensor( - [0, 2, buffer_size_bytes, 0, *num_bytes_to_clear, 0], + # Layout: [cur_idx, dirty_idx, bytes_per_buf, dirty_num_stages, numBytesToClear[4], access_count_ptr] + buffer_flags = torch.tensor( + [0, 2, buffer_size_bytes, 0, *num_bytes_to_clear, 0], dtype=torch.uint32, device=torch.device("cuda", mapping.local_rank), )Based on learnings
375-381: Shorten exception message (Ruff TRY003).Keep messages concise to satisfy lint and readability.
- raise ValueError( - f"MNNVL all reduce only supports dtype {MNNVLAllReduce.get_supported_dtypes()} and without cp." - ) + raise ValueError("MNNVL allreduce supports dtypes " + f"{MNNVLAllReduce.get_supported_dtypes()} and requires cp disabled.")
443-447: Shorten overflow error (Ruff TRY003).Trim the error message; include key numbers only.
- raise ValueError( - f"[MNNVL AllReduce] Shard ({num_tokens}, {hidden_dim}), TP Size {self.mapping.tp_size}: Required Workspace size {workspace_size_bytes} bytes is too large!" - ) + raise ValueError(f"[MNNVL AllReduce] Required workspace {workspace_size_bytes} bytes exceeds uint32 limits " + f"for shard ({num_tokens}, {hidden_dim}), TP {self.mapping.tp_size}.")
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (11)
cpp/tensorrt_llm/common/lamportUtils.cuh(1 hunks)cpp/tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.cu(1 hunks)cpp/tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.h(1 hunks)cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu(0 hunks)cpp/tensorrt_llm/nanobind/runtime/bindings.cpp(2 hunks)cpp/tensorrt_llm/pybind/runtime/bindings.cpp(2 hunks)cpp/tensorrt_llm/runtime/mcastDeviceMemory.cpp(1 hunks)cpp/tensorrt_llm/thop/allreduceOp.cpp(3 hunks)tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py(1 hunks)tensorrt_llm/_torch/distributed/ops.py(17 hunks)tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py(3 hunks)
💤 Files with no reviewable changes (1)
- cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu
🧰 Additional context used
📓 Path-based instructions (8)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh}: Namespace closing braces must include a trailing comment with the namespace name (e.g., '} // namespace foo').
Prefer const or constexpr variables over #define for constants.
Declare variables that are not modified after initialization as const.
Avoid magic literals in code; except for 0, nullptr, true, false. Use named constants for comparisons and logic.
Use Allman brace style for formatting.
Place the semicolon of an empty for/while loop on a new line.
Bodies of switch/while/do-while/for must be compound statements (brace-delimited), and if/else must always be followed by brace-delimited statements.
Type names (e.g., classes) must be CamelCase starting with an uppercase letter (e.g., FooBar).
Local variables, methods, and namespaces use lowerCamelCase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not in an anonymous namespace must be lowerCamelCase prefixed with 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number globals that are static or in an anonymous namespace use lowerCamelCase prefixed with 's' (e.g., sMutableStaticGlobal).
Locally visible static variables use lowerCamelCase with 's' prefix (e.g., static std::once_flag sFlag).
Private/protected member variables use 'm' prefix with CamelCase (e.g., mNbFooValues). Public members may omit, but 'm' is encouraged for clarity.
Constants (enums, global constants, static constants, and function-scope magic/literal constants) use uppercase SNAKE_CASE with 'k' prefix (e.g., kDIGIT_NUM).
Function-scope constants that are not magic numbers or literals are named like non-constant variables (e.g., bool const pass = a && b).
If macros are necessary, name them in UPPER_SNAKE_CASE (e.g., FOO_VERSION) and prefer constants over #define.
Use LLVM clang-format; wrap lines at a maximum of 120 columns; use '// clang-format off/on' sparingly with justification.
Use smart pointers for heap allocations; prefer unique_ptr for sole ownership, shared_ptr for shared...
Files:
cpp/tensorrt_llm/runtime/mcastDeviceMemory.cppcpp/tensorrt_llm/thop/allreduceOp.cppcpp/tensorrt_llm/common/lamportUtils.cuhcpp/tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.cucpp/tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.hcpp/tensorrt_llm/pybind/runtime/bindings.cppcpp/tensorrt_llm/nanobind/runtime/bindings.cpp
**/*.{cpp,cxx,cc,cu,h,hpp,hh,hxx,cuh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
C++ filenames should be lowerCamelCase (first letter lowercase) and must be case-insensitive unique within a compilation target.
Files:
cpp/tensorrt_llm/runtime/mcastDeviceMemory.cppcpp/tensorrt_llm/thop/allreduceOp.cppcpp/tensorrt_llm/common/lamportUtils.cuhcpp/tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.cucpp/tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.hcpp/tensorrt_llm/pybind/runtime/bindings.cppcpp/tensorrt_llm/nanobind/runtime/bindings.cpp
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Use only spaces, no tabs; indent with 4 spaces.
Files:
cpp/tensorrt_llm/runtime/mcastDeviceMemory.cpptests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.pycpp/tensorrt_llm/thop/allreduceOp.cppcpp/tensorrt_llm/common/lamportUtils.cuhcpp/tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.cucpp/tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.hcpp/tensorrt_llm/pybind/runtime/bindings.cpptensorrt_llm/_torch/custom_ops/cpp_custom_ops.pytensorrt_llm/_torch/distributed/ops.pycpp/tensorrt_llm/nanobind/runtime/bindings.cpp
**/*.{h,hpp,hh,hxx,cpp,cxx,cc}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc}: Prefer anonymous namespaces over 'static' for internal linkage of functions.
All templates (class/function/member/static) must be instantiated at least once; non-POD classes should have private data members.
Files:
cpp/tensorrt_llm/runtime/mcastDeviceMemory.cppcpp/tensorrt_llm/thop/allreduceOp.cppcpp/tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.hcpp/tensorrt_llm/pybind/runtime/bindings.cppcpp/tensorrt_llm/nanobind/runtime/bindings.cpp
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).
Files:
cpp/tensorrt_llm/runtime/mcastDeviceMemory.cpptests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.pycpp/tensorrt_llm/thop/allreduceOp.cppcpp/tensorrt_llm/common/lamportUtils.cuhcpp/tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.cucpp/tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.hcpp/tensorrt_llm/pybind/runtime/bindings.cpptensorrt_llm/_torch/custom_ops/cpp_custom_ops.pytensorrt_llm/_torch/distributed/ops.pycpp/tensorrt_llm/nanobind/runtime/bindings.cpp
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.
Files:
tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.pytensorrt_llm/_torch/custom_ops/cpp_custom_ops.pytensorrt_llm/_torch/distributed/ops.py
**/*.{h,hpp,hh,hxx,cuh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Use include guards named 'TRTLLM_<FILE_NAME_IN_CAPS_WITH_UNDERSCORES>_H' (no leading or trailing underscore; directory names excluded).
Files:
cpp/tensorrt_llm/common/lamportUtils.cuhcpp/tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.h
**/*.{h,hpp,hh,hxx}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Document new class interfaces and function prototypes with Doxygen; use //! for single-line and //!< for members.
Files:
cpp/tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.h
🧠 Learnings (10)
📚 Learning: 2025-09-23T15:01:00.070Z
Learnt from: nv-lschneider
PR: NVIDIA/TensorRT-LLM#7910
File: cpp/tensorrt_llm/kernels/nccl_device/config.cu:15-17
Timestamp: 2025-09-23T15:01:00.070Z
Learning: In TensorRT-LLM NCCL device kernels, the <sstream> header is not needed as an explicit include in config.cu because it's provided transitively through other headers. Local compilation testing confirms this works without the explicit include.
Applied to files:
cpp/tensorrt_llm/runtime/mcastDeviceMemory.cppcpp/tensorrt_llm/thop/allreduceOp.cpp
📚 Learning: 2025-09-23T15:13:48.819Z
Learnt from: nv-lschneider
PR: NVIDIA/TensorRT-LLM#7910
File: cpp/tensorrt_llm/kernels/nccl_device/multimem.h:20-30
Timestamp: 2025-09-23T15:13:48.819Z
Learning: TRT-LLM targets modern CUDA toolkits that support FP8 datatypes, so cuda_fp8.h can be included unconditionally without version guards in TRT-LLM code.
Applied to files:
cpp/tensorrt_llm/runtime/mcastDeviceMemory.cppcpp/tensorrt_llm/thop/allreduceOp.cppcpp/tensorrt_llm/common/lamportUtils.cuh
📚 Learning: 2025-09-23T15:01:00.070Z
Learnt from: nv-lschneider
PR: NVIDIA/TensorRT-LLM#7910
File: cpp/tensorrt_llm/kernels/nccl_device/config.cu:15-17
Timestamp: 2025-09-23T15:01:00.070Z
Learning: In TensorRT-LLM NCCL device kernels (cpp/tensorrt_llm/kernels/nccl_device/config.cu), std::ostringstream is used but <sstream> doesn't need to be explicitly included because it's provided transitively through other headers like tensorrt_llm/common/cudaUtils.h or config.h. Local compilation testing confirms this works without the explicit include.
Applied to files:
cpp/tensorrt_llm/runtime/mcastDeviceMemory.cppcpp/tensorrt_llm/thop/allreduceOp.cpp
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.
Applied to files:
cpp/tensorrt_llm/runtime/mcastDeviceMemory.cpp
📚 Learning: 2025-08-08T05:06:31.596Z
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp:36-36
Timestamp: 2025-08-08T05:06:31.596Z
Learning: CUTLASS extension files (under cpp/tensorrt_llm/cutlass_extensions/) follow CUTLASS coding style conventions, including using #pragma once instead of TRTLLM_ prefixed header guards, even though they are .hpp files.
Applied to files:
cpp/tensorrt_llm/runtime/mcastDeviceMemory.cpp
📚 Learning: 2025-09-23T15:12:38.312Z
Learnt from: nv-lschneider
PR: NVIDIA/TensorRT-LLM#7910
File: cpp/tensorrt_llm/thop/allreduceOp.cpp:352-446
Timestamp: 2025-09-23T15:12:38.312Z
Learning: In TensorRT-LLM NCCL device implementation, NCCL version 2.28+ requirements are handled at runtime in the nccl_device/config layer rather than with compile-time guards. This allows the allreduceOp to remain version-agnostic and delegates version compatibility validation to the appropriate lower-level components that can gracefully handle unsupported configurations.
Applied to files:
cpp/tensorrt_llm/runtime/mcastDeviceMemory.cppcpp/tensorrt_llm/thop/allreduceOp.cpp
📚 Learning: 2025-09-23T15:12:38.312Z
Learnt from: nv-lschneider
PR: NVIDIA/TensorRT-LLM#7910
File: cpp/tensorrt_llm/thop/allreduceOp.cpp:352-446
Timestamp: 2025-09-23T15:12:38.312Z
Learning: In TensorRT-LLM NCCL device allreduce implementation (cpp/tensorrt_llm/thop/allreduceOp.cpp), the goto pattern in runNCCLAllReduceDeviceFusion is intentionally used for future extensibility, allowing multiple switch cases to fallback to the default handler. While not aesthetically ideal, this pattern supports adding more fusion cases later that can reuse the same fallback logic.
Applied to files:
cpp/tensorrt_llm/thop/allreduceOp.cppcpp/tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.cutensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
📚 Learning: 2025-08-14T06:36:40.701Z
Learnt from: timlee0212
PR: NVIDIA/TensorRT-LLM#6886
File: tensorrt_llm/_torch/models/modeling_deepseekv3.py:0-0
Timestamp: 2025-08-14T06:36:40.701Z
Learning: In DeepSeek V3 model (tensorrt_llm/_torch/models/modeling_deepseekv3.py), the disagreement between AllReduce.__init__ guard and _compute_mlp_tp_size logic for MNNVL usage is expected by design. The AllReduce component and MLP TP-size computation intentionally use different criteria for MNNVL availability decisions.
Applied to files:
cpp/tensorrt_llm/thop/allreduceOp.cppcpp/tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.cucpp/tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.htensorrt_llm/_torch/distributed/ops.py
📚 Learning: 2025-09-16T09:30:09.716Z
Learnt from: tongyuantongyu
PR: NVIDIA/TensorRT-LLM#7763
File: cpp/tensorrt_llm/CMakeLists.txt:297-301
Timestamp: 2025-09-16T09:30:09.716Z
Learning: In the TensorRT-LLM project, NCCL libraries are loaded earlier by PyTorch libraries or the bindings library, so the main shared library doesn't need NCCL paths in its RPATH - the libraries will already be available in the process address space when needed.
Applied to files:
cpp/tensorrt_llm/thop/allreduceOp.cpp
📚 Learning: 2025-09-23T14:58:05.372Z
Learnt from: nv-lschneider
PR: NVIDIA/TensorRT-LLM#7910
File: cpp/tensorrt_llm/kernels/nccl_device/config.cu:42-49
Timestamp: 2025-09-23T14:58:05.372Z
Learning: In TensorRT-LLM NCCL device kernels (cpp/tensorrt_llm/kernels/nccl_device/), the token partitioning intentionally uses ceil-like distribution (same token_per_rank for all ranks) to ensure all ranks launch the same number of blocks. This is required for optimal NCCL device API barrier performance, even though it may launch extra blocks for non-existent tokens on later ranks. Runtime bounds checking in the kernel (blockID validation) handles the overshoot cases.
Applied to files:
cpp/tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.cu
🧬 Code graph analysis (6)
tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py (2)
cpp/tensorrt_llm/thop/allreduceOp.cpp (2)
strategy(691-722)strategy(691-691)tensorrt_llm/functional.py (1)
AllReduceStrategy(3876-3885)
cpp/tensorrt_llm/thop/allreduceOp.cpp (1)
cpp/tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.cu (4)
oneshot_allreduce_fusion_op(459-557)oneshot_allreduce_fusion_op(459-459)twoshot_allreduce_fusion_op(896-1051)twoshot_allreduce_fusion_op(896-896)
cpp/tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.cu (4)
cpp/include/tensorrt_llm/common/cudaUtils.h (1)
getMultiProcessorCount(399-406)cpp/include/tensorrt_llm/common/dataType.h (1)
getDTypeSize(26-44)cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu (1)
__syncthreads(127-144)cpp/tensorrt_llm/thop/allreduceOp.cpp (14)
input(159-191)input(159-162)input(207-285)input(207-209)input(287-306)input(287-289)input(308-339)input(308-310)input(341-406)input(341-343)input(408-590)input(408-412)input(592-658)input(592-595)
cpp/tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.h (1)
cpp/tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.cu (4)
oneshot_allreduce_fusion_op(459-557)oneshot_allreduce_fusion_op(459-459)twoshot_allreduce_fusion_op(896-1051)twoshot_allreduce_fusion_op(896-896)
tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py (1)
cpp/tensorrt_llm/thop/allreduceOp.cpp (14)
input(159-191)input(159-162)input(207-285)input(207-209)input(287-306)input(287-289)input(308-339)input(308-310)input(341-406)input(341-343)input(408-590)input(408-412)input(592-658)input(592-595)
tensorrt_llm/_torch/distributed/ops.py (3)
tensorrt_llm/mapping.py (2)
is_multi_node(424-425)local_rank(403-404)cpp/tensorrt_llm/runtime/mcastGPUBuffer.h (1)
McastGPUBuffer(26-68)tensorrt_llm/_mnnvl_utils.py (2)
MnnvlMemory(53-338)supports_mnnvl(332-338)
🪛 Ruff (0.13.1)
tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
60-60: Unused function argument: gamma
(ARG001)
60-60: Unused function argument: epsilon
(ARG001)
60-60: Unused function argument: buffer
(ARG001)
60-60: Unused function argument: buffer_flags
(ARG001)
61-61: Unused function argument: rmsnorm_fusion
(ARG001)
tensorrt_llm/_torch/distributed/ops.py
378-380: Avoid specifying long messages outside the exception class
(TRY003)
445-447: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (11)
cpp/tensorrt_llm/nanobind/runtime/bindings.cpp (1)
304-306: LGTM: decoder_stream binding retains semantics.The property continues to return a reference to the underlying CudaStream with correct RV policy.
cpp/tensorrt_llm/pybind/runtime/bindings.cpp (1)
398-400: LGTM: decoder_stream binding keeps the same return policy and semantics.No behavior change.
cpp/tensorrt_llm/common/lamportUtils.cuh (1)
1-15: Header license/style: OK.Apache-2.0 header present; formatting consistent.
tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py (3)
171-179: Good parameter coverage for runtime paths.The added seq_len sets exercise oneshot/twoshot switching and workspace reallocation. Nice.
Please confirm CI timing remains acceptable given longer sequences; otherwise, consider marking the largest case as slow.
182-183: Hidden-size edge cases are covered.Including 8, 7176, and 8192 validates packing divisibility and max-dim constraints.
186-187: Restricting to MNNVL strategy is appropriate.Keeps the test focused on the new fused path.
cpp/tensorrt_llm/kernels/communicationKernels/mnnvlAllreduceKernels.cu (2)
292-455: Kernel synchronization and bounds look correct.LamportFlags usage, PDL attributes, CGA cluster reduction, and OOB early-return are coherent. Negative zero sentinel handling aligns with the flags API. No functional blockers spotted.
Please ensure lamportUtils.cuh guarantees 16B atomicity preconditions for float4 writes on all supported dtypes.
459-557: Launch heuristics and dtype/world-size dispatch are sound.The elts_per_thread derivation, grid config, and world-size gating cover the supported cases and fail fast otherwise.
cpp/tensorrt_llm/thop/allreduceOp.cpp (1)
1163-1169: Hidden-dim divisibility check is correct.Using elts_per_load derived from float4 aligns with kernel vectorization assumptions.
tensorrt_llm/_torch/distributed/ops.py (2)
456-459: Shape/view usage is correct.Reinterpreting UC buffer into (3, -1) per dtype to expose three Lamport buffers is appropriate.
460-483: Fused/non‑fused dispatch to torch op looks correct.Argument wiring matches C++ op signature; output reshaping is consistent.
a669187 to
664a3bd
Compare
Signed-off-by: Shiyu Li <[email protected]>
664a3bd to
e58fc23
Compare
|
/bot run --add-multi-gpu-test --disable-fail-fast |
|
PR_Github #20123 [ run ] triggered by Bot |
|
PR_Github #20123 [ run ] completed with state |
Signed-off-by: Shiyu Li <[email protected]>
|
/bot run --add-multi-gpu-test --disable-fail-fast |
|
PR_Github #22255 [ run ] completed with state |
|
/bot run --add-multi-gpu-test --disable-fail-fast |
|
PR_Github #22313 [ run ] triggered by Bot. Commit: |
Signed-off-by: Shiyu Li <[email protected]>
|
PR_Github #22313 [ run ] completed with state |
|
/bot run --add-multi-gpu-test --disable-fail-fast |
|
PR_Github #22666 [ run ] triggered by Bot. Commit: |
|
PR_Github #22666 [ run ] completed with state |
Co-authored-by: Jin Li <[email protected]> Signed-off-by: Shiyu Li <[email protected]>
Signed-off-by: Shiyu Li <[email protected]>
|
/bot run --add-multi-gpu-test --disable-fail-fast |
|
PR_Github #23426 [ run ] triggered by Bot. Commit: |
|
PR_Github #23426 [ run ] completed with state |
|
/bot run --add-multi-gpu-test --disable-fail-fast |
|
PR_Github #23538 [ run ] triggered by Bot. Commit: |
|
PR_Github #23538 [ run ] completed with state |
Signed-off-by: Shiyu Li <[email protected]> Co-authored-by: Jin Li <[email protected]> Signed-off-by: FredricZ-2007 <[email protected]>
<!-- .github/pull_request_template.md --> ## 📌 Description This PR porting all changes in [TensorRT-LLM#8018](NVIDIA/TensorRT-LLM#8018) into Flashinfer. Apart from the changes mentioned in the original PR, this PR also introduce new API interface as `trtllm_mnnvl_allreduce` and `trtllm_mnnvl_fused_allreduce_add_rmsnorm` to replace the original ones. The workspace allocation is wrapped as an entire class with a given buffer size and the user does not need to worry about the details inside. This PR adds support for IPC Socket based mcast device memory bootstrap so that it can run on DGX machine that does not support fabric handle. @wenscarl This PR also incorporate the changes made in #2056 and should be able to replace that PR. A bcast interface is added to the comm backend as this is needed during the handle transfer. The old API is tagged as deprecated and redirected to the new APIs. The user of the old API should not need to make any changes. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Fused all‑reduce with optional RMSNorm fusion and selectable one‑shot/two‑shot strategies; new Python APIs and workspace utilities; IPC-based handle exchange and bcast support. * **Improvements** * Pluggable handle‑exchange backends (Fabric/POSIX), stricter I/O and shape validation, renamed/standardized fusion entry points and parameter surfaces, cached CUDA SM count for tuning, and safer lifecycle/cleanup. * **Tests** * MPI‑aware tests for fused and legacy flows, workspace-based runs, synchronization, and expanded sequence/hidden‑size coverage. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
Summary by CodeRabbit
Description
This PR refactors the MNNVLAllreduce implementation in TRTLLM:
Limitations:
4096x8192 TP32, twoshot128K x 8192)Test Coverage
pytest tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.pyPR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.