Skip to content

Conversation

@yibinl-nvidia
Copy link
Collaborator

@yibinl-nvidia yibinl-nvidia commented Nov 4, 2025

Summary by CodeRabbit

  • New Features

    • StarCoder2 model now available with support for 3B, 7B, and 15B parameter variants
    • Includes optimized sliding window attention mechanism for efficient inference
    • Full compatibility with CUDA graph acceleration for improved performance
  • Tests

    • Added comprehensive validation and correctness testing for the new model

Description

This PR implements PyTorch backend support for Starcoder2 3B, 7B, and 15B checkpoint, as well as the FP8 quantized checkpoint. Several tests are added to check network raw outputs or full e2e accuracy tests on GSM8K.

Token level output comparision against HF implementation:

Prompt: 'def fibonacci(n):'                                            
                                                                       
TRT-LLM generated: '                                                                                                                          
    if n == 0:                                                                                                                                        
        return 0
    elif n == 1:                                                       
        return'
HF generated:      '
    if n == 0:
        return 0
    elif n == 1:
        return'

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@yibinl-nvidia yibinl-nvidia self-assigned this Nov 11, 2025
@yibinl-nvidia yibinl-nvidia marked this pull request as ready for review November 11, 2025 03:24
@yibinl-nvidia yibinl-nvidia requested a review from a team as a code owner November 11, 2025 03:24
@yibinl-nvidia
Copy link
Collaborator Author

/bot run

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 11, 2025

📝 Walkthrough

Walkthrough

Adds a complete StarCoder2 model implementation for TensorRT-LLM, including custom LayerNorm, grouped-query attention with sliding window support, decoder layers, a transformer model, and a causal language model wrapper with weight loading logic for GPT-2 style MLP naming conventions. Includes comprehensive unit tests covering sanity checks, HuggingFace reference comparison, and token generation correctness.

Changes

Cohort / File(s) Summary
Model exports
tensorrt_llm/_torch/models/__init__.py
Added Starcoder2ForCausalLM to public API exports.
Model implementation
tensorrt_llm/_torch/models/modeling_starcoder2.py
Introduced complete StarCoder2 model family: Starcoder2LayerNorm (meta-tensor compatible), Starcoder2Attention (grouped-query with sliding window), Starcoder2DecoderLayer (attention + MLP + residual), Starcoder2Model (transformer backbone), and Starcoder2ForCausalLM (public wrapper with weight loading and GPT-2 to internal MLP name mapping).
Model tests
tests/unittest/_torch/modeling/test_modeling_starcoder2.py
Added comprehensive test suite with base configurations (3B/7B/15B), Scenario dataclass, KV cache manager factory, sanity tests, HuggingFace reference comparison tests, and token generation correctness tests with optional CUDA graph acceleration.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant Starcoder2ForCausalLM
    participant Starcoder2Model
    participant DecoderLayer as Starcoder2DecoderLayer
    participant Attention as Starcoder2Attention
    participant MLP

    User->>Starcoder2ForCausalLM: forward(input_ids)
    Starcoder2ForCausalLM->>Starcoder2Model: forward(input_ids)
    
    Starcoder2Model->>Starcoder2Model: embed tokens & init position_ids
    
    loop for each decoder layer
        Starcoder2Model->>DecoderLayer: forward(hidden_states, position_ids)
        DecoderLayer->>DecoderLayer: input normalization
        DecoderLayer->>Attention: forward(hidden_states, position_ids)
        rect rgba(200, 220, 255, 0.3)
            Note over Attention: Grouped-query attention<br/>with sliding window
        end
        Attention-->>DecoderLayer: attn_output
        DecoderLayer->>MLP: forward(attn_output)
        MLP-->>DecoderLayer: mlp_output
        DecoderLayer->>DecoderLayer: residual connections
        DecoderLayer-->>Starcoder2Model: layer_output
    end
    
    Starcoder2Model->>Starcoder2Model: final layer normalization
    Starcoder2Model-->>Starcoder2ForCausalLM: hidden_states
    
    rect rgba(220, 240, 220, 0.3)
        Note over Starcoder2ForCausalLM: Output projection<br/>(logits)
    end
    Starcoder2ForCausalLM-->>User: logits
Loading
sequenceDiagram
    participant User
    participant Starcoder2ForCausalLM
    participant Loader as Weight Loader
    participant HFWeights as HF Model Weights
    participant InternalModules as Internal Modules

    User->>Starcoder2ForCausalLM: load_weights(weights, weight_mapper)
    
    alt with weight_mapper
        Starcoder2ForCausalLM->>Loader: load_weights(weight_mapper path)
    else without weight_mapper
        Starcoder2ForCausalLM->>Loader: load_weights(default path)
    end
    
    Loader->>HFWeights: read c_fc, c_proj (GPT-2 MLP naming)
    rect rgba(255, 220, 200, 0.3)
        Note over Loader: Map GPT-2 names to<br/>internal up_proj/down_proj
    end
    HFWeights-->>Loader: weight tensors
    Loader->>InternalModules: set_parameter(mapped_name)
    InternalModules-->>Starcoder2ForCausalLM: weights loaded
    Starcoder2ForCausalLM-->>User: ready for inference
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Weight mapping logic: Custom GPT-2 to internal naming convention mapping (c_fc/c_proj to up_proj/down_proj) requires careful verification.
  • Attention implementation: Grouped-query attention with sliding window support needs scrutiny for correctness, particularly boundary handling and window size configuration.
  • Meta-tensor compatibility: Starcoder2LayerNorm's reset_parameters override must be verified to properly support meta tensor initialization without breaking normal initialization paths.
  • Test coverage complexity: Multiple test scenarios across different backends, configurations (3B/7B/15B), and CUDA graph paths require tracing through parameterized test logic.
  • Integration points: Verify proper integration with KVCacheManager, AttentionMetadata, and SpecMetadata across forward passes.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 46.67% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description check ❓ Inconclusive The pull request description provides a clear summary of the implementation (StarCoder2 PyTorch backend support for 3B/7B/15B models with FP8 quantization) and includes a concrete example comparing token-level output between TRT-LLM and HuggingFace implementations, demonstrating functional correctness. However, the Test Coverage section is empty and the PR Checklist items lack specific details about which tests validate the changes. Please specify which tests safeguard these changes (e.g., test_starcoder2_sanity, test_starcoder2_allclose_to_hf, test_starcoder2_generated_tokens_match_hf) in the Test Coverage section to clarify test coverage for the implementation.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: adding StarCoder2 model support to the PyTorch backend, which aligns with all file changes (modeling_starcoder2.py implementation, exports, and comprehensive tests).
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

📝 Customizable high-level summaries are now available in beta!

You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.

  • Provide your own instructions using the high_level_summary_instructions setting.
  • Format the summary however you like (bullet lists, tables, multi-section layouts, contributor stats, etc.).
  • Use high_level_summary_in_walkthrough to move the summary from the description to the walkthrough section.

Example instruction:

"Divide the high-level summary into five sections:

  1. 📝 Description — Summarize the main change in 50–60 words, explaining what was done.
  2. 📓 References — List relevant issues, discussions, documentation, or related PRs.
  3. 📦 Dependencies & Requirements — Mention any new/updated dependencies, environment variable changes, or configuration updates.
  4. 📊 Contributor Summary — Include a Markdown table showing contributions:
    | Contributor | Lines Added | Lines Removed | Files Changed |
  5. ✔️ Additional Notes — Add any extra reviewer context.
    Keep each section concise (under 200 words) and use bullet or numbered lists for clarity."

Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0938a3a and 9129c6e.

📒 Files selected for processing (3)
  • tensorrt_llm/_torch/models/__init__.py (2 hunks)
  • tensorrt_llm/_torch/models/modeling_starcoder2.py (1 hunks)
  • tests/unittest/_torch/modeling/test_modeling_starcoder2.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/models/__init__.py
  • tensorrt_llm/_torch/models/modeling_starcoder2.py
  • tests/unittest/_torch/modeling/test_modeling_starcoder2.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/models/__init__.py
  • tensorrt_llm/_torch/models/modeling_starcoder2.py
  • tests/unittest/_torch/modeling/test_modeling_starcoder2.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/models/__init__.py
  • tensorrt_llm/_torch/models/modeling_starcoder2.py
  • tests/unittest/_torch/modeling/test_modeling_starcoder2.py
🧠 Learnings (5)
📓 Common learnings
Learnt from: moraxu
Repo: NVIDIA/TensorRT-LLM PR: 6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.
Learnt from: galagam
Repo: NVIDIA/TensorRT-LLM PR: 6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.
Learnt from: fredricz-20070104
Repo: NVIDIA/TensorRT-LLM PR: 7645
File: tests/integration/test_lists/qa/llm_function_core.txt:648-648
Timestamp: 2025-09-09T09:40:45.658Z
Learning: In TensorRT-LLM test lists, it's common and intentional for the same test to appear in multiple test list files when they serve different purposes (e.g., llm_function_core.txt for comprehensive core functionality testing and llm_function_core_sanity.txt for quick sanity checks). This duplication allows tests to be run in different testing contexts.
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
Repo: NVIDIA/TensorRT-LLM PR: 6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.

Applied to files:

  • tests/unittest/_torch/modeling/test_modeling_starcoder2.py
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
Repo: NVIDIA/TensorRT-LLM PR: 6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.

Applied to files:

  • tests/unittest/_torch/modeling/test_modeling_starcoder2.py
📚 Learning: 2025-09-09T09:40:45.658Z
Learnt from: fredricz-20070104
Repo: NVIDIA/TensorRT-LLM PR: 7645
File: tests/integration/test_lists/qa/llm_function_core.txt:648-648
Timestamp: 2025-09-09T09:40:45.658Z
Learning: In TensorRT-LLM test lists, it's common and intentional for the same test to appear in multiple test list files when they serve different purposes (e.g., llm_function_core.txt for comprehensive core functionality testing and llm_function_core_sanity.txt for quick sanity checks). This duplication allows tests to be run in different testing contexts.

Applied to files:

  • tests/unittest/_torch/modeling/test_modeling_starcoder2.py
📚 Learning: 2025-08-26T09:49:04.956Z
Learnt from: pengbowang-nv
Repo: NVIDIA/TensorRT-LLM PR: 7192
File: tests/integration/test_lists/test-db/l0_dgx_b200.yml:56-72
Timestamp: 2025-08-26T09:49:04.956Z
Learning: In TensorRT-LLM test configuration files, the test scheduling system handles wildcard matching with special rules that prevent duplicate test execution even when the same tests appear in multiple yaml files with overlapping GPU wildcards (e.g., "*b200*" and "*gb200*").

Applied to files:

  • tests/unittest/_torch/modeling/test_modeling_starcoder2.py
🧬 Code graph analysis (3)
tensorrt_llm/_torch/models/__init__.py (1)
tensorrt_llm/_torch/models/modeling_starcoder2.py (1)
  • Starcoder2ForCausalLM (249-304)
tensorrt_llm/_torch/models/modeling_starcoder2.py (6)
tensorrt_llm/_torch/attention_backend/interface.py (3)
  • AttentionMetadata (44-394)
  • PositionalEmbeddingParams (564-582)
  • RopeParams (408-560)
tensorrt_llm/_torch/models/modeling_utils.py (3)
  • register_auto_model (617-623)
  • _load_weights_impl (816-937)
  • _load_weights_impl_v2 (940-1016)
tensorrt_llm/_torch/modules/embedding.py (1)
  • Embedding (180-264)
tensorrt_llm/_torch/modules/linear.py (1)
  • TensorParallelMode (50-62)
tensorrt_llm/_torch/speculative/interface.py (1)
  • SpecMetadata (152-240)
tensorrt_llm/_torch/model_config.py (1)
  • torch_dtype (206-211)
tests/unittest/_torch/modeling/test_modeling_starcoder2.py (6)
tensorrt_llm/_torch/models/modeling_starcoder2.py (5)
  • Starcoder2ForCausalLM (249-304)
  • forward (70-86)
  • forward (139-174)
  • forward (213-245)
  • load_weights (267-304)
tests/unittest/utils/util.py (1)
  • default_dtype (406-410)
tensorrt_llm/_torch/attention_backend/utils.py (1)
  • get_attention_backend (15-37)
tensorrt_llm/_torch/metadata.py (1)
  • KVCacheParams (9-31)
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (2)
  • CUDAGraphRunner (25-413)
  • attn_metadata (131-132)
tensorrt_llm/mapping.py (1)
  • Mapping (336-493)
🪛 Ruff (0.14.4)
tensorrt_llm/_torch/models/modeling_starcoder2.py

124-124: Avoid specifying long messages outside the exception class

(TRY003)


223-226: Avoid specifying long messages outside the exception class

(TRY003)


267-267: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)

tests/unittest/_torch/modeling/test_modeling_starcoder2.py

117-117: Avoid specifying long messages outside the exception class

(TRY003)


146-146: Value being cast to int is already an integer

Remove unnecessary int call

(RUF046)


168-168: Consider [*context_sequence_lengths, 1, 1] instead of concatenation

Replace with [*context_sequence_lengths, 1, 1]

(RUF005)


245-245: Unused lambda argument: param_num

(ARG005)


426-426: Unused lambda argument: param_num

(ARG005)


542-542: Loop control variable step not used within loop body

Rename unused step to _step

(B007)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24096 [ run ] triggered by Bot. Commit: 9129c6e

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24096 [ run ] completed with state SUCCESS. Commit: 9129c6e
/LLM/main/L0_MergeRequest_PR pipeline #18161 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@yibinl-nvidia
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24515 [ run ] triggered by Bot. Commit: 7d135ba

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24515 [ run ] completed with state SUCCESS. Commit: 7d135ba
/LLM/main/L0_MergeRequest_PR pipeline #18503 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@yibinl-nvidia yibinl-nvidia force-pushed the dev-yibinl-starcoder2 branch 2 times, most recently from 180a560 to c1ece3e Compare November 20, 2025 15:04
@yibinl-nvidia yibinl-nvidia changed the title [TRTLLM-7967][feat] Adding Starcoder2 PyTorch Flow Support [TRTLLM-7967][feat] Adding Starcoder2 PyTorch Backend Support Nov 20, 2025
@tensorrt-cicd
Copy link
Collaborator

PR_Github #25212 [ run ] completed with state SUCCESS. Commit: c1ece3e
/LLM/main/L0_MergeRequest_PR pipeline #19067 completed with status: 'FAILURE'

@yibinl-nvidia
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #25256 [ run ] triggered by Bot. Commit: c1ece3e

@tensorrt-cicd
Copy link
Collaborator

PR_Github #25256 [ run ] completed with state SUCCESS. Commit: c1ece3e
/LLM/main/L0_MergeRequest_PR pipeline #19103 completed with status: 'FAILURE'

@yibinl-nvidia
Copy link
Collaborator Author

/bot run --disable-fail-fast

@yibinl-nvidia
Copy link
Collaborator Author

@2ez4bz @Wanli-Jiang all comments are addressed, could you review again? Thank you!

@tensorrt-cicd
Copy link
Collaborator

PR_Github #25394 [ run ] triggered by Bot. Commit: cf69b40

@tensorrt-cicd
Copy link
Collaborator

PR_Github #25394 [ run ] completed with state SUCCESS. Commit: cf69b40
/LLM/main/L0_MergeRequest_PR pipeline #19212 completed with status: 'FAILURE'

@yibinl-nvidia
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #25407 [ run ] triggered by Bot. Commit: cf69b40

@tensorrt-cicd
Copy link
Collaborator

PR_Github #25407 [ run ] completed with state SUCCESS. Commit: cf69b40
/LLM/main/L0_MergeRequest_PR pipeline #19223 completed with status: 'SUCCESS'

@yibinl-nvidia yibinl-nvidia merged commit 1ce483c into NVIDIA:main Nov 24, 2025
5 checks passed
@yibinl-nvidia yibinl-nvidia deleted the dev-yibinl-starcoder2 branch November 24, 2025 19:23
codego7250 pushed a commit to codego7250/TensorRT-LLM that referenced this pull request Dec 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants