[None][fix] Reuse batch_indices_cuda across CUDA graph captures in EAGLE3#14381
Conversation
📝 WalkthroughWalkthroughThis PR threads a new CUDA-allocated ChangesEAGLE3 batch_indices_cuda Resource Management
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tensorrt_llm/_torch/speculative/eagle3.py (1)
67-71: 💤 Low valueOptional: Consider removing trailing commas for consistency.
The new
torch.empty(...)allocations use trailing commas (afterdevice='cuda',), but existing allocations in this file omit them (e.g., lines 66, 221-226, 401). For consistency with the established style in this file, consider removing the trailing commas.Minor style adjustment
self.batch_indices_cuda = torch.empty( [max_num_requests], dtype=torch.int, - device='cuda', + device='cuda' )Apply the same change to the other two allocations at lines 143-147 and 411-415.
Also applies to: 143-147, 411-415
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/speculative/eagle3.py` around lines 67 - 71, Remove the trailing commas from the torch.empty(...) allocation calls for consistency with the file's style: update the call that assigns self.batch_indices_cuda (currently torch.empty([max_num_requests], dtype=torch.int, device='cuda',)) to remove the comma after device='cuda', and make the same change to the two other torch.empty allocations referenced in the comment (the other torch.empty calls in this file) so none of them end with a trailing comma after the final argument.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@tensorrt_llm/_torch/speculative/eagle3.py`:
- Around line 67-71: Remove the trailing commas from the torch.empty(...)
allocation calls for consistency with the file's style: update the call that
assigns self.batch_indices_cuda (currently torch.empty([max_num_requests],
dtype=torch.int, device='cuda',)) to remove the comma after device='cuda', and
make the same change to the two other torch.empty allocations referenced in the
comment (the other torch.empty calls in this file) so none of them end with a
trailing comma after the final argument.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 988561ba-9730-405c-b150-7b6aabd73ad2
📒 Files selected for processing (1)
tensorrt_llm/_torch/speculative/eagle3.py
|
/bot run |
|
PR_Github #49579 [ run ] triggered by Bot. Commit: |
|
PR_Github #49579 [ run ] completed with state
|
|
/bot run |
|
PR_Github #49770 [ run ] triggered by Bot. Commit: |
|
PR_Github #49770 [ run ] completed with state
|
|
/bot run |
|
PR_Github #49989 [ run ] triggered by Bot. Commit: |
|
PR_Github #49989 [ run ] completed with state
|
0099d6c to
13f3798
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #50358 [ run ] triggered by Bot. Commit: |
|
PR_Github #50358 [ run ] completed with state |
… in Eagle3 Share a single batch_indices_cuda buffer from the resource manager across all CUDA graph metadata copies, mirroring the existing hidden_states dedup pattern (PR NVIDIA#13920). Previously each of the 34 CUDA graph variants allocated its own torch.empty([max_num_requests]) tensor in __post_init__. Since only one graph executes at a time and the buffer is overwritten via [:num_seqs].copy_() before each use, sharing is safe. Validated on H100 with LLaMA-3.1-8B + Eagle3: - Baseline: 34 unique batch_indices_cuda tensors - Fixed: 1 shared tensor, identical inference outputs Signed-off-by: Aurelien Chartier <[email protected]>
13f3798 to
542e27d
Compare
|
/bot run |
|
PR_Github #50840 [ run ] triggered by Bot. Commit: |
|
PR_Github #50840 [ run ] completed with state
|
|
/bot run |
|
PR_Github #50853 [ run ] triggered by Bot. Commit: |
|
PR_Github #50853 [ run ] completed with state
|
|
/bot run |
|
PR_Github #50862 [ run ] triggered by Bot. Commit: |
|
PR_Github #50862 [ run ] completed with state
|
|
/bot run |
|
PR_Github #50881 [ run ] triggered by Bot. Commit: |
|
PR_Github #50881 [ run ] completed with state |
Summary by CodeRabbit
Description
Reuse batch_indices_cuda across CUDA graph captures in EAGLE3
This applies the same fix as #13920 to batch_indices_cuda
Test Coverage
Manually validates
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
If PR introduces API changes, an appropriate PR label is added - either
api-compatibleorapi-breaking. Forapi-breaking, includeBREAKINGin the PR title.Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.