[None][fix] Fix chunked prefill API contract for nemotron nano VL#13025
Conversation
f4313d5 to
35608fe
Compare
|
Caution Review failedAn error occurred during the review process. Please try again later. 📝 WalkthroughWalkthroughThese changes refactor multimodal embedding handling in Nemotron Nano models. The Changes
Sequence Diagram(s)sequenceDiagram
participant Forward as Forward()
participant Encoder as _encode_multimodal()
participant Cache as MultimodalParams<br/>(Cache Storage)
participant Merge as merge_evs_mm_embeds()
participant Context as Context Params
Forward->>Encoder: Call with uncached params
Encoder->>Encoder: Vision/Audio encoding
Encoder->>Cache: Store embeddings in<br/>multimodal_data
Encoder->>Cache: Store num_tokens_in_video
Encoder-->>Forward: Return List[Tensor]
Forward->>Merge: Pass merged embeddings<br/>+ context params
Merge-->>Forward: Merged result
Forward->>Context: Reconstruct num_tokens<br/>from cached values
Forward-->>Forward: Continue with<br/>merged embeddings
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
35608fe to
e7868dd
Compare
|
/bot run |
|
PR_Github #43265 [ run ] triggered by Bot. Commit: |
|
PR_Github #43265 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #43372 [ run ] triggered by Bot. Commit: |
|
PR_Github #43372 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #43443 [ run ] triggered by Bot. Commit: |
|
PR_Github #43443 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #43486 [ run ] triggered by Bot. Commit: |
|
PR_Github #43486 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #43540 [ run ] triggered by Bot. Commit: |
|
PR_Github #43540 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #43556 [ run ] triggered by Bot. Commit: |
|
PR_Github #43556 [ run ] completed with state
|
* Why? In order to opt into the caching functionality for chunked prefix, there are certain assumptions on the return type of the encoder's forward function. These assumptions did not hold for nemotron nano VL prior to this commit. * What? This commit fixes this issue, and adds tests to catch regressions. Signed-off-by: William Zhang <[email protected]>
e7868dd to
1caf386
Compare
|
/bot run --disable-fail-fast |
|
/bot run |
1 similar comment
|
/bot run |
|
@coderabbitai what's the api contract described here? |
|
Tip For best results, initiate chat on the files or code changes. 🧩 Analysis chain🏁 Script executed: Length of output: 10999 🏁 Script executed: Length of output: 11619 🏁 Script executed: Length of output: 3401
The Contract:
|
| Requirement | Detail |
|---|---|
| Return type | List[torch.Tensor] |
| Length | Always 1 (all per-request embeddings concatenated into a single tensor), or [] when there is nothing to encode |
| Content | result[0] is a 2-D torch.Tensor of shape [total_tokens, hidden_dim] across all requests in the batch |
| Auxiliary outputs | Must not be returned as extra list elements; must be stored as side-channel data in each param.multimodal_data dict instead |
The test helper _assert_compatible_with_chunked_prefill captures this precisely:
assert len(multimodal_embeddings) == 1
assert isinstance(multimodal_embeddings[0], torch.Tensor)Why this matters for caching
get_multimodal_embeddings has an explicit guard:
if len(encoder_embeddings) > 1:
logger.warning("Multiple modalities caching is not supported yet. ...")
return encoder_embeddings # ← caching is SKIPPED entirelySo if encoder_forward_fn returns anything other than a single-element list, the embedding is never cached, and the encoder re-runs on every chunk during chunked prefill — defeating the whole purpose.
What was broken before this PR
Before the fix, NemotronH_Nano_VL_V2._encode_multimodal was returning a tuple (mm_embeddings, mm_num_tokens):
mm_embeddingswas a list of per-request tensors (length > 1 for batches)mm_num_tokenswas a separate auxiliary output for EVS video pruning
This violated both requirements: the length was > 1 and the elements weren't plain tensors. As a result, get_multimodal_embeddings fell into the warning/skip path and never cached any embeddings.
The fix
_encode_multimodal now:
- Concatenates all per-request embeddings into one tensor and returns
[torch.cat(mm_embeddings, dim=0)] - Stashes
num_tokens_in_video(needed by EVS) as a side-channel intoparam.multimodal_data["num_tokens_in_video"]instead of returning it
The forward method then reads those stashed counts back when EVS pruning is needed:
# Retrieve per-video count stashed by `_encode_multimodal`.|
PR_Github #43673 [ run ] triggered by Bot. Commit: |
|
PR_Github #43673 [ run ] completed with state
|
|
/bot run |
|
PR_Github #43680 [ run ] triggered by Bot. Commit: |
|
PR_Github #43680 [ run ] completed with state
|
|
/bot run |
|
PR_Github #43728 [ run ] triggered by Bot. Commit: |
|
PR_Github #43728 [ run ] completed with state
|
|
/bot run |
|
PR_Github #43791 [ run ] triggered by Bot. Commit: |
|
PR_Github #43791 [ run ] completed with state |
|
/bot run |
|
PR_Github #43810 [ run ] triggered by Bot. Commit: |
|
PR_Github #43810 [ run ] completed with state
|
|
/bot run |
|
PR_Github #43909 [ run ] triggered by Bot. Commit: |
|
PR_Github #43909 [ run ] completed with state
|
|
/bot run |
|
PR_Github #43928 [ run ] triggered by Bot. Commit: |
|
PR_Github #43928 [ run ] completed with state
|
|
/bot run |
|
PR_Github #43945 [ run ] triggered by Bot. Commit: |
|
PR_Github #43945 [ run ] completed with state |

Summary by CodeRabbit
Improvements
Tests
Description
In order to opt into the caching functionality for chunked prefix, there are certain assumptions on the return type of the encoder's forward function. These assumptions did not hold for nemotron nano VL prior to this commit.
This commit fixes this issue, and adds tests to catch regressions.
Test Coverage
See above.
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.