Skip to content

Conversation

@ZhangGe6
Copy link
Contributor

@ZhangGe6 ZhangGe6 commented Aug 14, 2025

The split/concat kernels expect KV cache in [max_num_pages, 2, num_kv_heads, page_size, head_dim] layout. However, the actually used KV cache layout exposed by KVCacheManager.get_buffers() is [max_num_pages, 2, page_size, num_kv_heads, head_dim]. This layout mismatch causes wrong indexing for split/concat kernels, leading to incorrect transferred prefill KV cache. This patch is a quick fix for flashinfer attn_backend.

Summary by CodeRabbit

  • New Features

    • Option to select KV cache memory layout (NHD or HND) for attention caching.
  • Refactor

    • Default KV cache layout changed to HND.
    • Resource manager API updated to accept and validate layout choices.
  • Chores

    • KV buffer retrieval and reshaping adjusted to honor the chosen layout while preserving compatibility.
  • Tests

    • Unit tests tweaked to reflect KV cache construction specifics.

Description

Test Coverage

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@ZhangGe6 ZhangGe6 requested review from a team as code owners August 14, 2025 23:36
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Aug 14, 2025

📝 Walkthrough

Walkthrough

Default KV layout for FlashInfer attention changed from "NHD" to "HND". KV cache retrieval now accepts and uses a kv_layout parameter; callers (FlashInfer and Star attention) pass metadata.kv_layout when requesting buffers. KVCacheManager.get_buffers validates kv_layout and applies layout-dependent reshape logic for "NHD" and "HND".

Changes

Cohort / File(s) Summary
Attention backend (FlashInfer)
tensorrt_llm/_torch/attention_backend/flashinfer.py
FlashInferAttentionMetadata.kv_layout default switched to "HND"; forward_impl now calls KV cache retrieval with kv_layout=metadata.kv_layout; added comments describing the intended HND buffer layout.
Attention backend (Star)
tensorrt_llm/_torch/attention_backend/star_flashinfer.py
StarAttention.forward updated to fetch KV buffers via metadata.kv_cache_manager.get_buffers(..., kv_layout=metadata.kv_layout) (passes layout through).
Resource manager (KV cache buffers)
tensorrt_llm/_torch/pyexecutor/resource_manager.py
KVCacheManager.get_buffers signature extended to get_buffers(self, layer_idx: int, kv_layout: str = "NHD") -> Optional[torch.Tensor]; asserts valid layout; implements two reshape pathways: "NHD" -> [..., kv_factor, tokens_per_block, num_kv_heads, head_dim] and "HND" -> [..., kv_factor, num_kv_heads, tokens_per_block, head_dim]; docs updated.
Tests
tests/unittest/_torch/test_attention.py
Only whitespace adjustments around the page_size line in two tensor shape constructions; no behavioral or API changes.

Sequence Diagram(s)

sequenceDiagram
  participant Fwd as FlashInferAttention.forward_impl
  participant Star as StarAttention.forward
  participant Meta as FlashInferAttentionMetadata
  participant KV as KVCacheManager

  Note over Fwd,Star: read metadata.kv_layout and request KV buffers with layout
  Fwd->>Meta: read kv_layout
  Fwd->>KV: get_buffers(layer_idx, kv_layout)
  Star->>Meta: read kv_layout
  Star->>KV: get_buffers(layer_idx, kv_layout)

  alt kv_layout == "NHD"
    KV-->>Fwd: buffers reshaped to [max_pages, kv_factor, tokens_per_block, num_kv_heads, head_dim]
    KV-->>Star: buffers reshaped to [max_pages, kv_factor, tokens_per_block, num_kv_heads, head_dim]
  else kv_layout == "HND"
    KV-->>Fwd: buffers reshaped to [max_pages, kv_factor, num_kv_heads, tokens_per_block, head_dim]
    KV-->>Star: buffers reshaped to [max_pages, kv_factor, num_kv_heads, tokens_per_block, head_dim]
  end

  Fwd->>Fwd: continue attention compute with returned buffers
  Star->>Star: continue attention compute with returned buffers
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

  • Pay attention to:
    • Correctness of reshape dimension ordering and indexing in get_buffers.
    • All callsites of get_buffers to ensure kv_layout is passed or default is acceptable.
    • Default change of FlashInferAttentionMetadata.kv_layout to "HND" and any runtime assumptions about previous "NHD" ordering.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 20.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly summarizes the main change: fixing a KV layout mismatch issue affecting split/concat kernels. It is specific, concise, and directly related to the primary objective.
Description check ✅ Passed The PR description explains the issue (KV layout mismatch between expected and actual formats), the impact (incorrect indexing and precision errors), and the solution. However, the description is minimal and the Test Coverage section is empty.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c135edb and 0828526.

📒 Files selected for processing (1)
  • tests/unittest/_torch/test_attention.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/unittest/_torch/test_attention.py

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🔭 Outside diff range comments (2)
tensorrt_llm/_torch/attention_backend/flashinfer.py (1)

1-3: Add missing 2025 NVIDIA copyright header

Per repo guidelines, Python sources must include the NVIDIA copyright header.

Apply this diff at the top of the file:

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
 import math
 import os
 import weakref
tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)

1-4: Add missing 2025 NVIDIA copyright header

Per repo guidelines, Python sources must include the NVIDIA copyright header.

Apply this diff at the top of the file:

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
 import copy
 import enum
 import math
 from abc import ABC, abstractmethod
🧹 Nitpick comments (1)
tensorrt_llm/_torch/attention_backend/flashinfer.py (1)

59-61: Default kv_layout to HND aligns with split/concat kernels; consider making layout a planning key

Setting kv_layout="HND" matches the kernel expectation [max_num_pages, 2, num_kv_heads, page_size, head_dim]. One caveat: kv_layout currently isn’t part of the planning key (PlanParams), but affects wrapper construction. If kv_layout changes over the lifetime of the metadata, cached wrappers could be reused with the wrong layout. Either ensure kv_layout is immutable per metadata instance or include it in the plan key.

If you opt to make the plan cache layout-aware, I can draft a small patch to add kv_layout to PlanParams and its equality key.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between b13a5a9 and 8082cf9.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/attention_backend/flashinfer.py (2 hunks)
  • tensorrt_llm/_torch/pyexecutor/resource_manager.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tensorrt_llm/_torch/pyexecutor/resource_manager.py
  • tensorrt_llm/_torch/attention_backend/flashinfer.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/_torch/pyexecutor/resource_manager.py
  • tensorrt_llm/_torch/attention_backend/flashinfer.py
🧠 Learnings (1)
📓 Common learnings
Learnt from: thorjohnsen
PR: NVIDIA/TensorRT-LLM#6910
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-14T21:04:50.208Z
Learning: In KV cache onboarding logic during prefill in cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, when calculating which blocks fall within the attention window, use getTokensPerBlock() to advance token indices rather than block->getUniqueTokens().size(), because the calculation needs to consider the post-prefill state where blocks will be filled to capacity, not their current token count.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (1)
tensorrt_llm/_torch/attention_backend/flashinfer.py (1)

498-500: KV buffers now requested with kv_layout — LGTM

Fetching KV cache with kv_layout=metadata.kv_layout is correct and resolves the previous layout mismatch against split/concat kernels.

Note: KVCacheManager.get_buffers is annotated to return Optional[torch.Tensor] but this site assumes a tensor (uses .dtype). Either:

  • tighten the return annotation to torch.Tensor (preferred; see suggested change in resource_manager.py), or
  • add a guard here:
 kv_cache = metadata.kv_cache_manager.get_buffers(
     self.layer_idx, kv_layout=metadata.kv_layout)
+assert kv_cache is not None, "KV buffers are not allocated"

@brb-nv
Copy link
Collaborator

brb-nv commented Sep 30, 2025

Requesting @chuangz0'z review as expert on split/concat kernels.

@liji-nv
Copy link
Collaborator

liji-nv commented Nov 4, 2025

@yuxianq to review

@ZhangGe6 Please use "git commit --amend -s" to sign off the commit. Thanks.

@ZhangGe6
Copy link
Contributor Author

ZhangGe6 commented Nov 4, 2025

@yuxianq I am working on polishing this PR and will update it tomorrow. Thanks for your suggestions!

@ZhangGe6
Copy link
Contributor Author

ZhangGe6 commented Nov 5, 2025

@liji-nv I used "git commit --amend -s" to sign off the commit. Thanks for reminding.

@yuxianq I updated the PR according to review suggestions, please take a review again, thanks.

@ZhangGe6 ZhangGe6 requested a review from yuxianq November 5, 2025 15:38
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f14d110 and a69b2b5.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/pyexecutor/resource_manager.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/pyexecutor/resource_manager.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/pyexecutor/resource_manager.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/pyexecutor/resource_manager.py
🧠 Learnings (2)
📓 Common learnings
Learnt from: thorjohnsen
Repo: NVIDIA/TensorRT-LLM PR: 6910
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-14T21:04:50.248Z
Learning: In KV cache onboarding logic during prefill in cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, when calculating which blocks fall within the attention window, use getTokensPerBlock() to advance token indices rather than block->getUniqueTokens().size(), because the calculation needs to consider the post-prefill state where blocks will be filled to capacity, not their current token count.
📚 Learning: 2025-08-14T21:04:50.248Z
Learnt from: thorjohnsen
Repo: NVIDIA/TensorRT-LLM PR: 6910
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-14T21:04:50.248Z
Learning: In KV cache onboarding logic during prefill in cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, when calculating which blocks fall within the attention window, use getTokensPerBlock() to advance token indices rather than block->getUniqueTokens().size(), because the calculation needs to consider the post-prefill state where blocks will be filled to capacity, not their current token count.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/resource_manager.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (2)
cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp (4)
  • layer_idx (221-224)
  • layer_idx (221-221)
  • layer_idx (226-229)
  • layer_idx (226-226)
cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp (4)
  • layer_idx (217-221)
  • layer_idx (217-217)
  • layer_idx (223-226)
  • layer_idx (223-223)

@yuxianq yuxianq changed the title [Fix][Disaggregated]: Fix precision issue due to KV layout mismatch for split/concat kernels [#6507][fix]: Fix precision issue due to KV layout mismatch for split/concat kernels Nov 7, 2025
@yuxianq
Copy link
Collaborator

yuxianq commented Nov 7, 2025

/bot run --disable-fail-fast

@yuxianq
Copy link
Collaborator

yuxianq commented Nov 7, 2025

@ZhangGe6 It seems that the last 3 commits have not been signed off, please sign off all commits to pass the DCO check, thanks~

@tensorrt-cicd
Copy link
Collaborator

PR_Github #23783 [ run ] triggered by Bot. Commit: a69b2b5

@tensorrt-cicd
Copy link
Collaborator

PR_Github #23783 [ run ] completed with state SUCCESS. Commit: a69b2b5
/LLM/main/L0_MergeRequest_PR pipeline #17903 completed with status: 'FAILURE'

@ZhangGe6 ZhangGe6 changed the title [#6507][fix]: Fix precision issue due to KV layout mismatch for split/concat kernels [#6507][fix] Fix precision issue due to KV layout mismatch for split/concat kernels Nov 8, 2025
…split/concat kernels

The split/concat kernels expect KV cache in
"[max_num_pages, 2, num_kv_heads, page_size, head_dim]" layout. However,
the actually used KV cache layout exposed by"KVCacheManager.get_buffers()" is
"[max_num_pages, 2, page_size, num_kv_heads, head_dim]".
This layout mismatch causes wrong indexing for split/concat kernels,
leading to incorrect transferred prefill KV cache. This patch is a quick fix
for flashinfer attn_backend.

Signed-off-by: ZhangGe6 <[email protected]>
@ZhangGe6 ZhangGe6 reopened this Nov 8, 2025
@ZhangGe6
Copy link
Contributor Author

ZhangGe6 commented Nov 8, 2025

@yuxianq Hi, I trimmed my commits and signed off them, thanks for reminding. In addition, I modified the flashinfer_kv_cache layout in test_attention.py for correctness issue. Please take a review.

@yuxianq
Copy link
Collaborator

yuxianq commented Nov 10, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #23938 [ run ] triggered by Bot. Commit: 5cbcfcb

@tensorrt-cicd
Copy link
Collaborator

PR_Github #23938 [ run ] completed with state SUCCESS. Commit: 5cbcfcb
/LLM/main/L0_MergeRequest_PR pipeline #18025 completed with status: 'FAILURE'

Signed-off-by: Yuxian Qiu <[email protected]>
@yuxianq
Copy link
Collaborator

yuxianq commented Nov 11, 2025

@ZhangGe6 I have fixed some CI errors for this PR, please allow me to push to your branch or cherry-pick bugfix commit from https://github.com/yuxianq/TensorRT-LLM/commits/fix_pd_diff_tp by yourself, I will rerun the CI. Thanks~

@ZhangGe6
Copy link
Contributor Author

@ZhangGe6 I have fixed some CI errors for this PR, please allow me to push to your branch or cherry-pick bugfix commit from https://github.com/yuxianq/TensorRT-LLM/commits/fix_pd_diff_tp by yourself, I will rerun the CI. Thanks~

@yuxianq OK, you can push to my branch directly. Feel free to remind me if there is something I can/should do.

@yuxianq
Copy link
Collaborator

yuxianq commented Nov 12, 2025

@ZhangGe6 I get Authentication error: Authentication required: You must have push access to verify locks when I push to your branch, please give me push permission of your repo, thanks~

@ZhangGe6
Copy link
Contributor Author

@yuxianq Hi, I have sent you an invitation to collaborate on my forked TensorRT-LLM repo (via "Setting -> Access -> Collaborator -> Add people"). Please accept it and try again.

I'm not yet familiar with GitHub operations. Let me know if I missed anything, or I can cherry-pick your bugfix commit later (maybe tonight).

@yuxianq
Copy link
Collaborator

yuxianq commented Nov 12, 2025

/bot run --disable-fail-fast

@yuxianq
Copy link
Collaborator

yuxianq commented Nov 12, 2025

@ZhangGe6 It works. I have started to rerun CI. Thanks~

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24236 [ run ] triggered by Bot. Commit: be3350b

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24236 [ run ] completed with state SUCCESS. Commit: be3350b
/LLM/main/L0_MergeRequest_PR pipeline #18282 completed with status: 'FAILURE'

@yuxianq
Copy link
Collaborator

yuxianq commented Nov 12, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24301 [ run ] triggered by Bot. Commit: 3c949b5

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24301 [ run ] completed with state SUCCESS. Commit: 3c949b5
/LLM/main/L0_MergeRequest_PR pipeline #18333 completed with status: 'FAILURE'

@yuxianq
Copy link
Collaborator

yuxianq commented Nov 13, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24354 [ run ] triggered by Bot. Commit: 3c949b5

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24354 [ run ] completed with state SUCCESS. Commit: 3c949b5
/LLM/main/L0_MergeRequest_PR pipeline #18380 completed with status: 'SUCCESS'

@yuxianq yuxianq merged commit 49df731 into NVIDIA:main Nov 13, 2025
5 checks passed
zheyuf pushed a commit to zheyuf/TensorRT-LLM that referenced this pull request Nov 19, 2025
…split/concat kernels (NVIDIA#6917)

Signed-off-by: ZhangGe6 <[email protected]>
Co-authored-by: Yuxian Qiu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Community want to contribute PRs initiated from Community

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants