Skip to content

Conversation

@jmydurant
Copy link
Collaborator

@jmydurant jmydurant commented Sep 2, 2025

Summary by CodeRabbit

  • New Features

    • Enable saving softmax stats for FP8/E4M3 attention (formerly blocked).
    • Add configurable chunked prefill buffer batch size; exposed in Python APIs and used in execution.
    • Introduce global-offset–based chunked prefill with per-loop max length for more efficient KV loading.
    • Broaden support to SM90/SM120 for FP8 MLA, KV cache reuse, and chunked context.
    • Adjust workspace sizing to scale with chunked prefill buffer batch size.
  • Tests

    • Expand FP8 coverage in accuracy tests and update H100 test lists to exercise FP8 chunked prefill paths.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@jmydurant jmydurant requested review from a team as code owners September 2, 2025 08:35
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 2, 2025

📝 Walkthrough

Walkthrough

Enables FP8/E4M3 softmax-stats saving and broadens MLA context support. Refactors chunked prefill to use per-loop global offsets and max lengths, adds buffer-batch sizing, and updates kernels, launchers, bindings, and Python/Torch paths. Expands kernel spec generation, adjusts workspace sizing, and updates tests and test lists to include FP8 and chunked prefill.

Changes

Cohort / File(s) Summary
FMHA test and driver
cpp/kernels/fmha_v2/fmha_test.py, cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp
Allow -save-softmax for E4M3 in tests and main driver; remove E4M3 guard.
Kernel spec generation (qGMMA flash warps)
cpp/kernels/fmha_v2/setup.py
Add return_softmax to combinations; gate normal/MLA kernel generation; fix output_dtype usage; enumerate E4M3 with BF16 output; propagate return_softmax_stats.
FMHA Tile normalizer (E4M3)
cpp/kernels/fmha_v2/src/fmha/fragment.h
Add specialization exposing aliases/constants and implement final_update for dequantized softmax sum scaling.
Epilogue softmax stats
cpp/kernels/fmha_v2/src/fmha/warpspec/epilogue.h
When RETURN_SOFTMAX_STATS, store dequantized global_sum using SOFTMAX_FP_DEQUANT_SCALE.
AttentionOp runtime sizing
cpp/tensorrt_llm/common/attentionOp.cpp, cpp/tensorrt_llm/common/attentionOp.h
Scale FP8 K/V context MLA buffers by mChunkPrefillBufferBatchSize; add public mChunkPrefillBufferBatchSize.
FMHA runner (MLA gating)
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
Replace BF16-only MLA check with Hopper MLA (BF16/E4M3); update supportReturnSoftmaxStats gating per layout.
MLA chunked prefill kernel/API
cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu, cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh
Replace chunked_size/idx with chunked_ld_global_offset and max_seq_len; rework indexing/loading; update launches and instantiations.
Torch/thop bindings
cpp/tensorrt_llm/nanobind/thop/bindings.cpp, cpp/tensorrt_llm/pybind/thop/bindings.cpp
Add optional chunk_prefill_buffer_batch_size argument to attention binding.
thop AttentionOp plumbing
cpp/tensorrt_llm/thop/attentionOp.h, cpp/tensorrt_llm/thop/attentionOp.cpp
Add optional chunk_prefill_buffer_batch_size to API; enable FP8 Context MLA on SM90/SM120; set mChunkPrefillBufferBatchSize.
MLA preprocess (Torch op)
cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp
Switch MLA chunked KV load to chunked_ld_global_offset and max_seq_len; update helper, public op, and TORCH_LIBRARY signatures/calls.
Unit tests (MLA chunked prefill)
cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
Convert to offset-based chunking; add metadata prep, new buffers, and API changes; adjust kernels and copy helpers.
Torch attention backend (interfaces)
tensorrt_llm/_torch/attention_backend/interface.py
Add chunk_buffer_batch_size (runtime feature) and MLAParams.chunk_prefill_buffer_batch_size.
Torch attention backend (TRT-LLM wrapper)
tensorrt_llm/_torch/attention_backend/trtllm.py
Introduce chunk_prefill_buffer_batch_size; compute/persist chunked_ld_global_offset and max_chunk_len_per_loop; update plan/forward/load_chunked_kv calls and kernel invocation.
High-level module integration
tensorrt_llm/_torch/modules/attention.py
Pass chunked_ld_global_offset and chunked_max_seq_len to load_chunked_kv_cache_for_mla; forward chunked_prefill_buffer_batch_size into attention calls.
Executor gating
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Enable KV-cache reuse and chunked context for SM120.
Integration tests
tests/integration/defs/accuracy/test_llm_api_pytorch.py, tests/integration/test_lists/test-db/l0_h100.yml
Add FP8 to quant_dtype params; update H100 test list to run FP8 chunked prefill cases.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Py as PyModule (attention.py)
  participant Wrap as TrtllmAttentionWrapper
  participant TRT as TrtllmAttention
  participant Thop as thop.attention
  participant MLA as MLA Chunked KV Loader (CUDA)
  participant Kern as FMHA Kernel

  Py->>Wrap: plan(chunk_size, chunk_prefill_buffer_batch_size)
  Wrap->>Wrap: compute total_chunk_size, chunked_loop_num
  Wrap->>Wrap: pre_process_for_chunked_prefill(..., chunked_ld_global_offset, max_chunk_len_per_loop)

  loop For each prefill loop
    Py->>TRT: load_chunked_kv_cache_for_mla(chunked_ld_global_offset[loop], max_chunk_len_per_loop[loop], ...)
    TRT->>MLA: invokeMLALoadChunkedKV(..., chunked_ld_global_offset, max_seq_len)
    MLA-->>TRT: output_kv, output_k_pe

    Py->>TRT: run(..., chunked_prefill_buffer_batch_size)
    TRT->>Thop: attention(..., chunk_prefill_buffer_batch_size)
    Thop->>Kern: launch FMHA (MLA/FP8 or BF16)
    Kern-->>Thop: results (optionally return_softmax stats)
    Thop-->>TRT: outputs
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested labels

KV-Cache Management

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore or @coderabbit ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai or @coderabbitai title anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)

345-351: Replace magic SM literals with a named constant.

Use a single allowed set for MLA KV reuse (and reuse below for chunked prefill) to avoid drift.

Apply within this block:

-        if executor_config.kv_cache_config.enable_block_reuse and sm_version not in [
-                90, 100, 120
-        ]:
+        if executor_config.kv_cache_config.enable_block_reuse and sm_version not in K_ALLOWED_MLA_SM:

Add once near imports (outside this hunk):

# Module-level guard for MLA features
K_ALLOWED_MLA_SM = {90, 100, 120}
cpp/kernels/fmha_v2/setup.py (1)

2757-2760: Bug: limit_v_fragments flag never set (limit_qk_fragments used twice).

This mis-encodes flags in the kernel-traits printer and metadata generation.

-        if kspec.limit_qk_fragments:
-            flags |= 256
+        if kspec.limit_v_fragments:
+            flags |= 256
cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu (1)

240-258: Cast KV block pointer to the cache type (or byte pointer) to avoid aliasing/type-size hazards.

Using T* here is incorrect if T != TCache; prefer TCache* (or char*).

-        auto* kvSrc = reinterpret_cast<T*>(kv_cache.getKBlockPtr(batch_idx, token_idx_in_kv_cache));
+        // Use the actual cache element type to avoid size/alignment issues.
+        auto* kvSrc = reinterpret_cast<TCache*>(kv_cache.getKBlockPtr(batch_idx, token_idx_in_kv_cache));

Optionally use char* and compute offsets in bytes if that matches KVBlockArray’s contract better.

cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp (1)

251-266: Add missing input validation for chunked path.

  • Unlike the paged path, chunked loader doesn’t check num_ctx_cached_tokens > 0.
  • No dtype/shape checks for chunked_ld_global_offset. Kernel expects int64 vector (1D, >= num_contexts+0).

Add checks to fail fast.

 std::vector<torch::Tensor> loadChunkedKVCacheForMLA(torch::ScalarType out_dtype, int64_t const num_contexts,
     int64_t const num_ctx_cached_tokens, torch::Tensor const& cu_ctx_chunked_kv_lens,
-    torch::Tensor const& chunked_ld_global_offset, torch::Tensor const& kv_cache_block_offsets,
+    torch::Tensor const& chunked_ld_global_offset, torch::Tensor const& kv_cache_block_offsets,
     torch::Tensor const& host_kv_cache_pool_pointers, torch::Tensor const& host_kv_cache_pool_mapping,
     torch::optional<torch::Tensor> kv_scale_orig_quant, torch::optional<torch::Tensor> kv_scale_quant_orig,
     int64_t const layer_idx, int64_t const lora_size, int64_t const rope_size, int64_t const tokens_per_block,
     int64_t const max_seq_len, int64_t const attention_window_size, int64_t const sink_token_length,
     int64_t const beam_width, int64_t const quant_mode)
 {
@@
-    TLLM_CHECK(num_contexts > 0);
-    CHECK_INPUT(cu_ctx_chunked_kv_lens, torch::kInt64);
+    TLLM_CHECK(num_contexts > 0);
+    TORCH_CHECK(num_ctx_cached_tokens > 0);
+    CHECK_INPUT(cu_ctx_chunked_kv_lens, torch::kInt64);
     TORCH_CHECK(cu_ctx_chunked_kv_lens.dim() == 1);
     TORCH_CHECK(cu_ctx_chunked_kv_lens.size(0) >= num_contexts + 1);
+    CHECK_INPUT(chunked_ld_global_offset, torch::kInt64);
+    TORCH_CHECK(chunked_ld_global_offset.dim() == 1);
+    TORCH_CHECK(chunked_ld_global_offset.size(0) >= num_contexts);
🧹 Nitpick comments (25)
cpp/tensorrt_llm/common/attentionOp.h (1)

416-416: Prefer fixed-width int over size_t for serialized config.

size_t varies by platform and can complicate stable serialization. Consider int32_t (or int64_t) for mChunkPrefillBufferBatchSize; if keeping size_t, always cast in serialization (as above).

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)

28-36: Duplicate imports of is_mla.

Two is_mla definitions are imported; one shadows the other and invites confusion. Keep one (or alias explicitly).

-from ._util import (KvCacheCreator, _adjust_torch_mem_fraction,
-                    create_py_executor_instance, instantiate_sampler, is_mla)
+from ._util import (KvCacheCreator, _adjust_torch_mem_fraction,
+                    create_py_executor_instance, instantiate_sampler)
-from .config_utils import is_mla
+from .config_utils import is_mla
tests/integration/defs/accuracy/test_llm_api_pytorch.py (1)

1713-1713: Update stale comment contradicting new support.

Comment says FP8/NVFP4 chunked prefill isn’t supported, but you just enabled it. Remove or restate with SM gating.

Apply:

-    # currently, chunked prefill is not supported for fp8 and nvfp4
+    # Chunked prefill supports fp8/nvfp4 on SM90/SM100/SM120.
cpp/kernels/fmha_v2/src/fmha/fragment.h (2)

1741-1752: Avoid redeclaring traits/geometry already inherited from Base.

This specialization publicly derives from Tile_o_normalizer_fp32 but then redeclares Fragment_accu, Mma_tile, MMAS_M, MMAS_N, ROWS_PER_THREAD, REGS_PER_THREAD, WARPS_{M,N,K}, BYTES_PER_ELEMENT. It’s redundant and risks drift if Base changes.

Prefer bringing what you need into scope or referencing Base::… directly.

Example (minimal):

 struct Tile_o_normalizer<Ada_qmma_e4m3_fp32_traits, Cta_tile>
     : public Tile_o_normalizer_fp32<Ada_qmma_e4m3_fp32_traits, Cta_tile>
 {
-    using Fragment_accu = Fragment_accumulator<Traits>;
-    using Mma_tile = typename Traits::template Mma_tile<Cta_tile>;
-    enum { MMAS_M = Mma_tile::MMAS_M };
-    enum { MMAS_N = Mma_tile::VALID_MMAS_N };
-    enum { ROWS_PER_THREAD = 2 * MMAS_M };
-    enum { REGS_PER_THREAD = 8 };
-    enum { WARPS_M = Cta_tile::WARPS_M };
-    enum { WARPS_N = Cta_tile::WARPS_N };
-    enum { WARPS_K = Cta_tile::WARPS_K };
-    enum { BYTES_PER_ELEMENT = sizeof(float) };
+    using Base = Tile_o_normalizer_fp32<Ada_qmma_e4m3_fp32_traits, Cta_tile>;
+    using Fragment_accu = typename Base::Fragment_accu;
+    using Mma_tile     = typename Base::Mma_tile;
+    using Base::MMAS_M; using Base::MMAS_N; using Base::ROWS_PER_THREAD;
+    using Base::REGS_PER_THREAD; using Base::WARPS_M; using Base::WARPS_N;
+    using Base::WARPS_K; using Base::BYTES_PER_ELEMENT;

1822-1824: Typo: “diviser” → “divisor”.

cpp/kernels/fmha_v2/fmha_test.py (2)

164-169: Align epsilon for FP8 MLA context tests to avoid flakiness.

E4M3 typically needs a looser tolerance. Mirror the top-level tests by setting epsilon for FP8 here as well.

     epsilon = ''
-    if dtype == "-bf16" and s == 4096:
+    if dtype == "-bf16" and s == 4096:
         epsilon += ' -epsilon 0.03'
+    elif dtype in ["-e4m3", "-e4m3 -bf16-output"]:
+        epsilon += ' -epsilon 0.2'

179-191: Consider covering “-e4m3 -bf16-output” in save-softmax runs if supported.

The PR claims FP8 MLA chunked prefill support; if save-softmax is valid with BF16 outputs, include it to prevent regressions. If not supported, add an explicit skip.

-    if dtype in ["-bf16", "-e4m3"]:
+    if dtype in ["-bf16", "-e4m3", "-e4m3 -bf16-output"]:

If this combination is intentionally unsupported, please add a skip with a clear reason.

cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp (2)

489-497: Unify Hopper context-MLA detection; avoid duplicate, inconsistent predicates.

Here isHopperContextMLA also checks headSize == headSizeV + 64 and dtype in {BF16,E4M3}. Earlier (Line 456) a different predicate (sm90 && headSizeV == 128) is used. Merge into one helper so all decisions (tiling, WS, and softmax-stats) agree.

Minimal inline refactor within setupLaunchParams:

-    // Now we have SM90 context and FP8 generation MLA kernels
-    bool isHopperContextMLA = isSm90 && mFixedParams.headSizeV == 128;
+    auto const isContextMLAShape = (mFixedParams.headSize == mFixedParams.headSizeV + 64);
+    auto const isHopperContextMLA = isSm90 && isContextMLAShape
+        && (mFixedParams.dataType == DATA_TYPE_BF16 || mFixedParams.dataType == DATA_TYPE_E4M3)
+        && mFixedParams.headSizeV == 128;

Then reuse the same isHopperContextMLA at Lines 489–497 instead of recomputing it.


469-471: Comment typo.

“hooper style” → “Hopper style”.

cpp/tensorrt_llm/thop/attentionOp.h (1)

57-58: Prevent positional-arg break; keep param safe and keyword-only in bindings

Inserting chunk_prefill_buffer_batch_size before q_lora_rank can break existing positional callers. Ensure all Python bindings make args from attention_input_type onward keyword-only, and clamp the value to >= 1 when mapping into mChunkPrefillBufferBatchSize.

Would you like me to add kw-only in both pybind/nanobind bindings and add a clamp where the member is set?

tensorrt_llm/_torch/attention_backend/interface.py (2)

28-29: Validate and document runtime chunk buffer sizing

Add a quick validation that chunk_buffer_batch_size >= 1 and clarify units in the class docstring to avoid oversized/undersized allocations at runtime.


634-635: Keep MLA param default aligned and enforce bounds

Default chunk_prefill_buffer_batch_size=1 matches C++ expectations; add a bounds check (>=1) in the code path that forwards this into the ThOP call to prevent negative/zero from configs.

cpp/tensorrt_llm/thop/attentionOp.cpp (1)

631-633: SM gating uses magic numbers; clarify intent and ensure SM100 exclusion is intentional

Current condition enables FP8 Context MLA on SM90 and SM120 only. If SM100 is intentionally excluded, add a short comment. Also avoid magic numbers by capturing SM once.

Apply a small readability tweak:

-        op->mFP8ContextMLA = (tensorrt_llm::common::getSMVersion() == 120 || tensorrt_llm::common::getSMVersion() == 90)
-            && op->mKVCacheQuantMode.hasFp8KvCache();
+        // Enable FP8 Context MLA on Hopper (SM90) and Blackwell-next (SM120) only; SM100 intentionally excluded.
+        auto const sm = tensorrt_llm::common::getSMVersion();
+        op->mFP8ContextMLA = ((sm == 120) || (sm == 90)) && op->mKVCacheQuantMode.hasFp8KvCache();
tensorrt_llm/_torch/modules/attention.py (4)

1159-1165: Fix minor comment typo

“toal_token_q” -> “total_token_q”.

-        # [toal_token_q, num_heads, 2] -> [toal_token_q, num_heads] float2
+        # [total_token_q, num_heads, 2] -> [total_token_q, num_heads] float2

1186-1195: Assert device/dtype for chunked_ld_global_offset to match CUDA kernel expectations

Add lightweight asserts to prevent host/device mismatch or dtype errors at runtime.

Apply:

-            chunked_ld_global_offset = attn_metadata.chunked_ld_global_offset[loop_idx]
-            chunked_max_seq_len = attn_metadata.max_chunk_len_per_loop[loop_idx]
+            chunked_ld_global_offset = attn_metadata.chunked_ld_global_offset[loop_idx]
+            chunked_max_seq_len = attn_metadata.max_chunk_len_per_loop[loop_idx]
+            # Sanity checks for kernel inputs
+            assert temp_cu_chunked_seq_len.is_cuda, "cu_chunked_seq_len must be CUDA tensor"
+            assert chunked_ld_global_offset.is_cuda, "chunked_ld_global_offset must be CUDA tensor"
+            assert chunked_ld_global_offset.dtype == torch.int64, "chunked_ld_global_offset must be int64"

1243-1245: Graceful fallback for chunk_buffer_batch_size to keep backward compatibility

Use getattr to avoid attribute errors with older metadata.

Apply:

-                softmax_stats_tensor=self.temp_softmax_stats_tensor,
-                chunked_prefill_buffer_batch_size=attn_metadata.
-                runtime_features.chunk_buffer_batch_size,
+                softmax_stats_tensor=self.temp_softmax_stats_tensor,
+                chunked_prefill_buffer_batch_size=getattr(
+                    attn_metadata.runtime_features, "chunk_buffer_batch_size", None),

1295-1297: Repeat: fallback for final prefill call

Mirror the same defensive fallback here.

Apply:

-            softmax_stats_tensor=self.temp_softmax_stats_tensor,
-            chunked_prefill_buffer_batch_size=attn_metadata.runtime_features.
-            chunk_buffer_batch_size,
+            softmax_stats_tensor=self.temp_softmax_stats_tensor,
+            chunked_prefill_buffer_batch_size=getattr(
+                attn_metadata.runtime_features, "chunk_buffer_batch_size", None),
cpp/kernels/fmha_v2/setup.py (1)

3818-3821: Combination tuple order change is OK, but add a short comment for maintainers.

The new ordering (alibi, input_layout, enable_attn_logit_softcapping, return_softmax) differs from the hgmma path and can be easy to misread. A one-liner here would prevent future mistakes.

-    combinations = product([False, True], \
-        [InputLayout.PACKED_QKV, InputLayout.CONTIGUOUS_Q_KV,
-         InputLayout.Q_PAGED_KV, InputLayout.SEPARATE_Q_K_V],
-        [False, True], [False, True])
+    # combinations: (alibi, input_layout, enable_attn_logit_softcapping, return_softmax)
+    combinations = product([False, True],
+        [InputLayout.PACKED_QKV, InputLayout.CONTIGUOUS_Q_KV,
+         InputLayout.Q_PAGED_KV, InputLayout.SEPARATE_Q_K_V],
+        [False, True], [False, True])
cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu (3)

150-161: Remove duplicate early-exit check.

merge_op_val==0 is checked twice back-to-back; keep only one.

-    int64_t merge_op_val = merge_op[batch_idx];
-    if (merge_op_val == 0)
-    {
-        return; // skip this batch
-    }
+    int64_t merge_op_val = merge_op[batch_idx];
@@
-    if (merge_op_val == 0)
-    {
-        return; // skip this batch
-    }

274-286: K-PE write indexing is correct given head_dim gating; consider a comment.

head_dim_idx >= kLoraSize in this branch; subtracting kLoraSize is safe. A clarifying comment helps.

-            int const global_st_idx
-                = global_st_offset * KT::kRopeSize + local_token_idx * KT::kRopeSize + (head_dim_idx - KT::kLoraSize);
+            // head_dim_vec_idx >= kKVThreadPerHead => head_dim_idx >= kLoraSize; map to rope sub-dim.
+            int const global_st_idx = global_st_offset * KT::kRopeSize
+                + local_token_idx * KT::kRopeSize + (head_dim_idx - KT::kLoraSize);

337-349: Instantiation set expanded to bf16; good coverage.

Consider adding a unit test that mixes T=float with TCache=__nv_fp8_e4m3 to validate dequant path.

I can add a minimal test to cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu if helpful.

tensorrt_llm/_torch/attention_backend/trtllm.py (2)

191-192: Document and validate the new plan() parameter.

Add an Args entry for chunk_prefill_buffer_batch_size and (optionally) assert it’s >= 1 to prevent silent misconfigurations.

@@ def plan(
-        chunk_prefill_buffer_batch_size: int = 1,
+        chunk_prefill_buffer_batch_size: int = 1,
@@
-        """
+        """
@@
+            chunk_prefill_buffer_batch_size (int): Number of chunk buffers processed per loop (host-side staging/batching for chunked prefill). Default: 1.
@@
         self.is_spec_decoding_enabled = is_spec_decoding_enabled
@@
         self.spec_decoding_generation_lengths = spec_decoding_generation_lengths
         self.chunk_prefill_buffer_batch_size = chunk_prefill_buffer_batch_size
+        assert self.chunk_prefill_buffer_batch_size >= 1, "chunk_prefill_buffer_batch_size must be >= 1"
         self.kwargs.update(kwargs)

Also applies to: 276-277


896-959: Allocate per-context shapes to reduce memory.

chunked_seq_len is allocated with second dim = num_seqs, but only :num_contexts are used. Consider allocating [chunked_loop_num, num_contexts] for both device and host tensors to cut memory and H2D traffic.

cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu (2)

349-359: Debug helper overload is fine.

Named tensor printer improves diagnostics without impacting tests.


663-783: Metadata preparation mirrors runtime logic; consider replacing macros.

The algorithm matches the Python path. For readability, prefer inline lambdas over #define macros for 2D indexing.

-#define chunked_seq_len(loop,b) chunked_seq_len_vec[(loop) * (this->mBatchSize) + (b)]
-#define cu_chunked_seq_len(loop,b) h_cu_chunk_lens_ptr[(loop) * (this->mBatchSize + 1) + (b)]
-#define chunked_ld_global_offset(loop,b) h_chunked_ld_global_offset_ptr[(loop) * (this->mBatchSize) + (b)]
+auto chunked_seq_len = [&](int loop, int b){ return chunked_seq_len_vec[loop * this->mBatchSize + b]; };
+auto& cu_chunked_seq_len = [&](int loop, int b){ return h_cu_chunk_lens_ptr[loop * (this->mBatchSize + 1) + b]; };
+auto& chunked_ld_global_offset = [&](int loop, int b){ return h_chunked_ld_global_offset_ptr[loop * this->mBatchSize + b]; };
-#undef chunked_seq_len
-#undef cu_chunked_seq_len
-#undef chunked_ld_global_offset
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between f90375f and 5272e59.

📒 Files selected for processing (22)
  • cpp/kernels/fmha_v2/fmha_test.py (1 hunks)
  • cpp/kernels/fmha_v2/setup.py (3 hunks)
  • cpp/kernels/fmha_v2/src/fmha/fragment.h (1 hunks)
  • cpp/kernels/fmha_v2/src/fmha/warpspec/epilogue.h (1 hunks)
  • cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp (0 hunks)
  • cpp/tensorrt_llm/common/attentionOp.cpp (1 hunks)
  • cpp/tensorrt_llm/common/attentionOp.h (1 hunks)
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp (1 hunks)
  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu (5 hunks)
  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh (1 hunks)
  • cpp/tensorrt_llm/nanobind/thop/bindings.cpp (1 hunks)
  • cpp/tensorrt_llm/pybind/thop/bindings.cpp (1 hunks)
  • cpp/tensorrt_llm/thop/attentionOp.cpp (3 hunks)
  • cpp/tensorrt_llm/thop/attentionOp.h (1 hunks)
  • cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp (5 hunks)
  • cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu (18 hunks)
  • tensorrt_llm/_torch/attention_backend/interface.py (2 hunks)
  • tensorrt_llm/_torch/attention_backend/trtllm.py (14 hunks)
  • tensorrt_llm/_torch/modules/attention.py (3 hunks)
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1 hunks)
  • tests/integration/defs/accuracy/test_llm_api_pytorch.py (1 hunks)
  • tests/integration/test_lists/test-db/l0_h100.yml (1 hunks)
💤 Files with no reviewable changes (1)
  • cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp
🧰 Additional context used
📓 Path-based instructions (7)
**/*.{h,hpp,hh,hxx,cc,cpp,cxx,cu,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{h,hpp,hh,hxx,cc,cpp,cxx,cu,cuh}: Closing braces of C++ namespaces must include a comment naming the namespace (e.g., } // namespace foo)
Avoid using literals (except 0, nullptr, true, false) directly in logic; use named constants for comparisons
Use Allman brace style in C++
Place semicolon of empty for/while loop on its own line
Use brace-delimited statements for bodies of switch/while/do/for and always brace if/else bodies
C++ type names use UpperCamelCase
Local variables, methods, and namespaces use lowerCamelCase
Non-static, externally visible globals use g prefix with lowerCamelCase (e.g., gDontUseGlobalFoos)
Static or anonymous-namespace globals use s prefix with lowerCamelCase (e.g., sMutableStaticGlobal)
Locally visible static variables use s prefix (e.g., static std::once_flag sFlag)
Member variables use m prefix with CamelCase (public may omit but encouraged)
Constants (enums, globals, static consts, function-scope magic numbers) use k prefix with UPPER_SNAKE (e.g., kDIGIT_NUM)
Function-scope non-literal, non-magic constants use normal non-const naming (e.g., const bool pass)
If macros are necessary, name them in UPPER_SNAKE_CASE
Avoid Hungarian notation except allowed app’s hungarian like nb for counts
Constructor parameters conflicting with member names get a trailing underscore (e.g., foo_)
Use uppercase literal suffixes (e.g., 1234L not 1234l)
Format C++ with clang-format (LLVM style), max line length 120; justify any exceptions with clang-format off/on blocks
Use C++-style comments; C comments not allowed except special inline cases; single-line comments use //
Use inline parameter comments in calls when arguments aren’t obvious (e.g., /* checkForErrors = / false)
Disable code with #if/#endif (optionally mnemonic conditions or no-op macros); do not comment out code; avoid dead code
Use the least forceful C++ cast; avoid removing const/volatile; avoid C-style and functional casts (except explicit constructors); cast void
to T* with static_cas...

Files:

  • cpp/tensorrt_llm/common/attentionOp.cpp
  • cpp/tensorrt_llm/common/attentionOp.h
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
  • cpp/tensorrt_llm/thop/attentionOp.h
  • cpp/tensorrt_llm/nanobind/thop/bindings.cpp
  • cpp/kernels/fmha_v2/src/fmha/warpspec/epilogue.h
  • cpp/tensorrt_llm/pybind/thop/bindings.cpp
  • cpp/kernels/fmha_v2/src/fmha/fragment.h
  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh
  • cpp/tensorrt_llm/thop/attentionOp.cpp
  • cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu
  • cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp
**/*.{cc,cpp,cxx,cu}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{cc,cpp,cxx,cu}: Prefer const or constexpr variables over #define for constants in C++
Declare variables const if not modified after initialization
Use smart pointers for heap allocation; prefer unique_ptr for sole ownership, shared_ptr for shared; weak_ptr only exceptionally; avoid deprecated smart pointers
Avoid declaring large functions inline unless there’s a quantifiable benefit; remember in-class definitions are implicitly inline
Every defined function must be referenced at least once; avoid unused methods

Files:

  • cpp/tensorrt_llm/common/attentionOp.cpp
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
  • cpp/tensorrt_llm/nanobind/thop/bindings.cpp
  • cpp/tensorrt_llm/pybind/thop/bindings.cpp
  • cpp/tensorrt_llm/thop/attentionOp.cpp
  • cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu
  • cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp
**/*

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Filenames compiled into a target must be case-insensitively unique

Files:

  • cpp/tensorrt_llm/common/attentionOp.cpp
  • cpp/tensorrt_llm/common/attentionOp.h
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
  • tensorrt_llm/_torch/attention_backend/interface.py
  • cpp/tensorrt_llm/thop/attentionOp.h
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
  • tests/integration/test_lists/test-db/l0_h100.yml
  • cpp/tensorrt_llm/nanobind/thop/bindings.cpp
  • cpp/kernels/fmha_v2/src/fmha/warpspec/epilogue.h
  • cpp/kernels/fmha_v2/fmha_test.py
  • cpp/tensorrt_llm/pybind/thop/bindings.cpp
  • cpp/kernels/fmha_v2/src/fmha/fragment.h
  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh
  • cpp/kernels/fmha_v2/setup.py
  • tests/integration/defs/accuracy/test_llm_api_pytorch.py
  • tensorrt_llm/_torch/attention_backend/trtllm.py
  • cpp/tensorrt_llm/thop/attentionOp.cpp
  • cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu
  • tensorrt_llm/_torch/modules/attention.py
  • cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp
**/*.{h,hpp,hh,hxx,cc,cpp,cxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use spaces, not tabs; indent 4 spaces

Files:

  • cpp/tensorrt_llm/common/attentionOp.cpp
  • cpp/tensorrt_llm/common/attentionOp.h
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
  • tensorrt_llm/_torch/attention_backend/interface.py
  • cpp/tensorrt_llm/thop/attentionOp.h
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
  • cpp/tensorrt_llm/nanobind/thop/bindings.cpp
  • cpp/kernels/fmha_v2/src/fmha/warpspec/epilogue.h
  • cpp/kernels/fmha_v2/fmha_test.py
  • cpp/tensorrt_llm/pybind/thop/bindings.cpp
  • cpp/kernels/fmha_v2/src/fmha/fragment.h
  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh
  • cpp/kernels/fmha_v2/setup.py
  • tests/integration/defs/accuracy/test_llm_api_pytorch.py
  • tensorrt_llm/_torch/attention_backend/trtllm.py
  • cpp/tensorrt_llm/thop/attentionOp.cpp
  • cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu
  • tensorrt_llm/_torch/modules/attention.py
  • cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp
**/*.{cpp,cc,cxx,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • cpp/tensorrt_llm/common/attentionOp.cpp
  • cpp/tensorrt_llm/common/attentionOp.h
  • cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp
  • tensorrt_llm/_torch/attention_backend/interface.py
  • cpp/tensorrt_llm/thop/attentionOp.h
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
  • cpp/tensorrt_llm/nanobind/thop/bindings.cpp
  • cpp/kernels/fmha_v2/src/fmha/warpspec/epilogue.h
  • cpp/kernels/fmha_v2/fmha_test.py
  • cpp/tensorrt_llm/pybind/thop/bindings.cpp
  • cpp/kernels/fmha_v2/src/fmha/fragment.h
  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh
  • cpp/kernels/fmha_v2/setup.py
  • tests/integration/defs/accuracy/test_llm_api_pytorch.py
  • tensorrt_llm/_torch/attention_backend/trtllm.py
  • cpp/tensorrt_llm/thop/attentionOp.cpp
  • cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu
  • tensorrt_llm/_torch/modules/attention.py
  • cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp
**/*.{h,hpp,hh,hxx}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{h,hpp,hh,hxx}: Prefer const or constexpr over #define for constants in C++ headers
Use Doxygen for documenting interfaces; use //! for comments and //!< for member annotations in C++
Use include guards in headers with symbol format TRTLLM__H (no underscores prefix/suffix; filename only)

Files:

  • cpp/tensorrt_llm/common/attentionOp.h
  • cpp/tensorrt_llm/thop/attentionOp.h
  • cpp/kernels/fmha_v2/src/fmha/warpspec/epilogue.h
  • cpp/kernels/fmha_v2/src/fmha/fragment.h
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code must target Python 3.8+
Indent with 4 spaces; do not use tabs (Python)
Maintain module namespace on import: prefer from package.subpackage import foo; use foo.Symbol()
Python filenames use snake_case
Python class names use PascalCase
Python functions and methods use snake_case
Python local variables use snake_case; if starting with a number concept, prefix with k (e.g., k_99th_percentile)
Python global variables use G_ prefix with UPPER_SNAKE_CASE
Python constants use UPPER_SNAKE_CASE
Avoid shadowing variables from outer scopes
Initialize all externally visible class members in init
For public interfaces, prefer docstrings over comments; comments should be for in-function or file-local interfaces
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes and variables inline with docstrings immediately after assignment
Avoid reflection when a non-reflective approach suffices
Limit except clauses to specific exceptions where possible
When using try/except for duck-typing, keep try body minimal and move logic to else

Files:

  • tensorrt_llm/_torch/attention_backend/interface.py
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
  • cpp/kernels/fmha_v2/fmha_test.py
  • cpp/kernels/fmha_v2/setup.py
  • tests/integration/defs/accuracy/test_llm_api_pytorch.py
  • tensorrt_llm/_torch/attention_backend/trtllm.py
  • tensorrt_llm/_torch/modules/attention.py
🧠 Learnings (5)
📚 Learning: 2025-08-14T21:04:50.248Z
Learnt from: thorjohnsen
PR: NVIDIA/TensorRT-LLM#6910
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-14T21:04:50.248Z
Learning: In KV cache onboarding logic during prefill in cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, when calculating which blocks fall within the attention window, use getTokensPerBlock() to advance token indices rather than block->getUniqueTokens().size(), because the calculation needs to consider the post-prefill state where blocks will be filled to capacity, not their current token count.

Applied to files:

  • cpp/tensorrt_llm/common/attentionOp.cpp
  • cpp/tensorrt_llm/common/attentionOp.h
  • cpp/tensorrt_llm/thop/attentionOp.h
  • cpp/tensorrt_llm/nanobind/thop/bindings.cpp
  • cpp/tensorrt_llm/pybind/thop/bindings.cpp
  • tensorrt_llm/_torch/attention_backend/trtllm.py
  • cpp/tensorrt_llm/thop/attentionOp.cpp
  • cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu
  • tensorrt_llm/_torch/modules/attention.py
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
PR: NVIDIA/TensorRT-LLM#6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.

Applied to files:

  • tests/integration/test_lists/test-db/l0_h100.yml
📚 Learning: 2025-08-15T06:46:54.897Z
Learnt from: eopXD
PR: NVIDIA/TensorRT-LLM#6767
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-15T06:46:54.897Z
Learning: In cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp addToken function, newly allocated blocks are unshared by design. The beam search path in addToken (when sequence.getNumTokens() > windowSize) is currently broken/non-functional with SWA, so the block allocation doesn't follow a shared-then-unshared pattern.

Applied to files:

  • cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
  • cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu
📚 Learning: 2025-08-21T09:41:49.347Z
Learnt from: eopXD
PR: NVIDIA/TensorRT-LLM#6768
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:2010-2045
Timestamp: 2025-08-21T09:41:49.347Z
Learning: In cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, updateSequenceCacheBlockOffsets is specifically for updating bookkeeping when blocks are added during the context phase, not for refreshing offsets after detach operations. During detach operations, GenerationRequest::removeFrontBlock handles the necessary cache block bookkeeping internally.

Applied to files:

  • cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
📚 Learning: 2025-08-20T06:56:02.889Z
Learnt from: eopXD
PR: NVIDIA/TensorRT-LLM#6768
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:577-579
Timestamp: 2025-08-20T06:56:02.889Z
Learning: In cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, maxSequenceLength is now enforced as a non-optional argument in the BlockManager constructor, so concerns about std::nullopt defaulting to 0 are not applicable. When windowSize > maxSequenceLength, a warning should be added instead of handling optional parameter cases.

Applied to files:

  • cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
🧬 Code graph analysis (6)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)
tensorrt_llm/logger.py (1)
  • warning (131-132)
cpp/kernels/fmha_v2/src/fmha/fragment.h (2)
cpp/kernels/fmha_v2/src/fmha/hopper/fragment.h (3)
  • void (113-119)
  • void (145-151)
  • void (481-505)
cpp/kernels/fmha_v2/src/softmax_impl.h (7)
  • void (116-119)
  • void (139-142)
  • float (56-59)
  • float (87-90)
  • float (94-105)
  • float (146-149)
  • float (884-897)
tests/integration/defs/accuracy/test_llm_api_pytorch.py (1)
tests/integration/defs/conftest.py (1)
  • parametrize_with_ids (1786-1811)
cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu (2)
cpp/tensorrt_llm/kernels/mlaKernels.cu (1)
  • dequantCopy (156-157)
tensorrt_llm/_torch/attention_backend/trtllm.py (2)
  • max_seq_len (565-575)
  • max_seq_len (578-582)
tensorrt_llm/_torch/modules/attention.py (2)
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (1)
  • attn_metadata (68-69)
tensorrt_llm/_torch/attention_backend/trtllm.py (1)
  • load_chunked_kv_cache_for_mla (1344-1390)
cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp (1)
cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu (2)
  • invokeMLALoadChunkedKV (323-335)
  • invokeMLALoadChunkedKV (323-325)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (27)
cpp/tensorrt_llm/common/attentionOp.h (1)

54-56: Verify the new batch-size knob actually feeds workspace-sizing paths.

getWorkspaceSizeForContext() should reflect mChunkPrefillBufferBatchSize for FP8 Context MLA with separate Q/KV inputs. Please confirm usage downstream (enqueue/alloc) and add a brief note in toString() for debug.

Would you like a repo-scanning script to locate all uses of mChunkPrefillBufferBatchSize and check serialization/clone sites?

tests/integration/test_lists/test-db/l0_h100.yml (1)

225-225: Ensure FP8 chunked-prefill test fits pre-merge CI budget

  • Verify the average runtime for
    test_chunked_prefill[quant_dtype=fp8-kv_cache_reuse=True-fp8kv=True-overlap_scheduler=True]
    on H100 stays within our CI time budget.
  • For example, compute it from recent CI logs:
    # Extract durations for this exact test from the last 20 H100 runs
    grep -E "test_chunked_prefill\[\s*quant_dtype=fp8-kv_cache_reuse=True" ci/logs/h100/durations.log \
      | tail -n 20 \
      | awk '{sum+=$NF; count++} END{print "avg:", sum/count, "s"}'
cpp/kernels/fmha_v2/src/fmha/warpspec/epilogue.h (1)

1321-1325: Good: dequantized softmax-sum persisted for stats.

Multiplying global_sum by SOFTMAX_FP_DEQUANT_SCALE under RETURN_SOFTMAX_STATS preserves output scaling while exposing dequant stats. Please verify the consumer reads global_sum (not global_sum_mi) post-epilogue.

tests/integration/defs/accuracy/test_llm_api_pytorch.py (1)

1709-1712: Expanding quant_dtype to include fp8 for chunked prefill looks good.

Params align with the test signature and H100/BW gating.

cpp/tensorrt_llm/common/attentionOp.cpp (1)

768-770: Clamp chunk prefill batch size and cast operands to size_t
Clamping mChunkPrefillBufferBatchSize to std::max<size_t>(1, ...) prevents zero‐sized FP8 buffers (it’s a size_t and may be set to 0) and ensures all multiplications stay in size_t. Ensure enqueue-time workspace pointer arithmetic also uses this clamped chunkBatch for offset calculations to avoid buffer overlap.

cpp/tensorrt_llm/pybind/thop/bindings.cpp (1)

51-53: Mark tail args keyword-only to preserve backward compatibility

Add py::kw_only() before the newly added arg to prevent positional-API breaks:

-        py::arg("rotary_embedding_max_position_info"), py::arg("use_paged_context_fmha"),
-        py::arg("attention_input_type") = std::nullopt, py::arg("is_mla_enable"),
+        py::arg("rotary_embedding_max_position_info"), py::arg("use_paged_context_fmha"),
+        py::kw_only(),
+        py::arg("attention_input_type") = std::nullopt, py::arg("is_mla_enable"),

py::kw_only() is already used elsewhere in the codebase, confirming the current pybind11 version supports it.

cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cuh (1)

38-40: API change: document new params & verify binding

  • Document chunked_ld_global_offset (global leading-dimension offset per chunk) and max_seq_len (tokens) in the invokeMLALoadChunkedKV declaration (.cuh).
  • Declarations (.cuh) and definitions (.cu) already match. Confirm the Python binding in cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp (around the m.def("load_chunked_kv_cache_for_mla", …)) includes chunked_ld_global_offset and max_seq_len.
cpp/kernels/fmha_v2/setup.py (4)

3825-3830: Gating looks correct and prevents invalid fp8 variants.

Skipping normal D-size specs when return_softmax or output_dtype is set matches the intent; MLA-only softmax-stats and BF16 output are generated below.


3866-3899: Inconsistent KV step: comment says 128, code uses 256 for D in (80, 96, 104, 128). Verify which is intended.

If 256 is correct, fix the comment. If 128 is intended (to reduce regs/smem), update the code.

-                    kv_loop_step=256,
+                    kv_loop_step=128,

Or change the preceding comment to “64 < D <= 128: KV_STEP = 256”.


3921-3935: Good: KV_STEP=128 for D >= 160 aligns with ‘avoid register spilling’.

No action; noting consistency with the rationale in the comment.


6219-6223: BF16-output fp8 path enabled; confirm downstream name/matchers include “_output_bf16”.

encode_name appends “output{output_dtype}”, and metadata parsing strips “output”. Please confirm tests enumerate/run the BF16-output variants on SM90.

cpp/tensorrt_llm/kernels/mlaChunkedPrefill.cu (3)

223-232: Template contract is clear. Nice use of static_assert.

The T/TCache relationship and compile-time branching make the fp8 and non-fp8 paths explicit.


261-271: Dequantization path looks right; scale read once per kernel is fine.

No action.


321-335: Grid uses max_seq_len; ensure it covers batches with zero-length chunks.

Kernel already early-returns for real_chunked_size <= 0, so this is safe. LGTM.

tensorrt_llm/_torch/attention_backend/trtllm.py (5)

70-70: New wrapper field wiring looks correct.

chunk_prefill_buffer_batch_size is threaded through the wrapper and used in the native call. No issues spotted here.


821-857: Chunked metadata computation LGTM.

The rolling fill across contexts with per-loop buffer budget and the global-offset prefix sums look correct.


864-871: Merge-op toggling logic is sound.

“Copy-only” on first non-empty loop or after a zero-length loop prevents redundant accumulation. Good.


476-477: Confirm native attention signature matches new parameter order
Verify that the C++/CUDA binding for thop.attention includes the chunk_prefill_buffer_batch_size argument in the same position as in the Python call to prevent ABI mismatches.


1349-1389: Per-loop 1D tensors correctly passed to load_chunked_kv_cache_for_mla: in tensorrt_llm/_torch/modules/attention.py (lines 1191–1193) you slice attn_metadata.cu_chunked_seq_len[loop_idx], attn_metadata.chunked_ld_global_offset[loop_idx], and attn_metadata.max_chunk_len_per_loop[loop_idx] before passing them, so each argument is a 1D view as expected.

cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp (3)

52-63: Helper signature update is consistent with kernel changes.

Passing both cu_ctx_chunked_len and chunked_ld_global_offset plus max_seq_len aligns with the kernel API. Looks good.


296-339: Type branches correctly forward new params.

All dtype/quant branches pass chunked_ld_global_offset and max_seq_len. Good.


502-526: Torch schema matches C++ signature.

The new chunked_ld_global_offset and max_seq_len are reflected in the schema. LGTM.

cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu (5)

25-64: Reference kernel update is correct.

Indexing KV via chunked_ld_global_offset[b] + s matches the new per-loop global-offset model.


169-185: Host KV copy respects per-loop offsets.

copyRelatedChunkedKV now accounts for chunked_ld_global_offset. Correct.


511-523: Chunked-loop bookkeeping is consistent.

Shapes for h_cu_chunk_lens and h_chunked_ld_global_offset reflect [loops+1, B+1] and [loops+1, B], respectively. Looks correct.


822-853: Merged-attention main loop pointer bumps are correct.

Advancing merge_op, cu_chunk_lens, and chunked_ld_global_offset per loop ensures each iteration consumes the right slices.

Also applies to: 856-871


1029-1075: Chunked load test validates per-loop max length.

Looping chunked_loop_num - 1 and comparing against the reference is appropriate. Nice coverage.

@jmydurant
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17356 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17356 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13045 completed with status: 'FAILURE'

@jmydurant
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17466 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17466 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13125 completed with status: 'SUCCESS'

@jmydurant jmydurant force-pushed the user/mingyangj/opt_chunked_prefill branch from d056760 to 4da0181 Compare September 4, 2025 04:12
@jmydurant
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17616 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17616 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13242 completed with status: 'FAILURE'

@kaiyux
Copy link
Member

kaiyux commented Sep 6, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17834 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17834 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13350 completed with status: 'FAILURE'

@jmydurant
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17854 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #17854 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13366 completed with status: 'SUCCESS'

@jmydurant jmydurant requested a review from yuxianq September 8, 2025 05:18
@tensorrt-cicd
Copy link
Collaborator

PR_Github #18390 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18343 [ run ] completed with state ABORTED
LLM/main/L0_MergeRequest_PR #13757 (Blue Ocean) completed with status: ABORTED

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18390 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13795 completed with status: 'FAILURE'

@kaiyux
Copy link
Member

kaiyux commented Sep 11, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18426 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18426 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #13821 completed with status: 'FAILURE'

@kaiyux
Copy link
Member

kaiyux commented Sep 11, 2025

/bot run --disable-fail-fast

auto-merge was automatically disabled September 15, 2025 04:28

Head branch was pushed to by a user without write access

@jmydurant jmydurant force-pushed the user/mingyangj/opt_chunked_prefill branch from 327f0b5 to 1d0327d Compare September 15, 2025 04:28
@jmydurant
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18563 [ run ] triggered by Bot

@kaiyux kaiyux enabled auto-merge (squash) September 15, 2025 11:48
@tensorrt-cicd
Copy link
Collaborator

PR_Github #18563 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13936 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@kaiyux kaiyux merged commit 7deefb3 into NVIDIA:main Sep 15, 2025
5 checks passed
Wong4j pushed a commit to Wong4j/TensorRT-LLM that referenced this pull request Sep 20, 2025
MrGeva pushed a commit to nv-auto-deploy/TensorRT-LLM that referenced this pull request Sep 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants