-
Notifications
You must be signed in to change notification settings - Fork 2k
[None][feat] Support kv_cahce_reuse for HyperCLOVAX-Vision model #7789
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[None][feat] Support kv_cahce_reuse for HyperCLOVAX-Vision model #7789
Conversation
📝 WalkthroughWalkthroughAdds GPT-OSS to docs and restructures feature matrices. Reuses pre-initialized attention metadata in CLIP and SigLIP vision models. Refactors HyperCLOVA-X multimodal input processor and embedding flow. Pads LLaVA-Next image patches for batching. Refactors multimodal hashing/positioning pipeline. Adjusts a log level in multimodal token position lookup. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor Caller
participant Registry as inputs.registry
participant Proc as InputProcessor
participant Utils as multimodal utils
participant MM as MultimodalInput
Caller->>Registry: multimodal_hashing_process(inputs)
Note over Registry: Early compute
Registry->>Proc: get_multimodal_hashes(...)
Registry->>Proc: get_prompt_token_ids(...)
alt No multimodal tokens
Registry-->>Caller: (prompt_token_ids, None)
else Has multimodal tokens
Registry->>Proc: get_vocab_size(), get_mm_token_ids(), get_mm_special_token_ids()
Registry->>Utils: find_mm_token_positions(prompt_token_ids, vocab_size/mm_ids/...)
Note over Registry: Flatten + int32 cast
Registry->>MM: from_components(int32_hashes, start_positions, counts, offsets)
Registry-->>Caller: (prompt_token_ids, { "multimodal_input": MM })
end
sequenceDiagram
autonumber
actor User
participant HCX as HCXVisionForCausalLM
participant Enc as MM Encoder
participant Utils as modeling_multimodal_utils
User->>HCX: forward(input_ids, multimodal_params)
alt DISAGG enabled
HCX-->>User: NotImplementedError
else Non-DISAGG
HCX->>Utils: get_multimodal_embeddings(encoder_forward_fn=Enc.forward, multimodal_params[:N])
Utils->>Enc: forward(pixel_values,...)
Utils-->>HCX: mm_embeds
HCX->>Utils: find_input_mm_embeds(mm_embeds, multimodal_params[:N])
Utils-->>HCX: mm_embeds (aligned)
HCX-->>User: logits
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (11)
tensorrt_llm/inputs/multimodal.py (3)
1-1: Add NVIDIA Apache-2.0 header (2025).Per repo guidelines, prepend the current-year NVIDIA Apache-2.0 header to all Python sources.
Apply:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.
626-629: Always return a tuple from find_mm_token_positions.Returning a single list breaks callers expecting a 2‑tuple; will raise unpack errors in registry.py.
Apply:
- if not torch.any(mm_mask): - return [] + if not torch.any(mm_mask): + return [], []
150-193: UnboundLocalError risk for remainder in MultimodalRuntimeData.remainder is defined only when (num_unseen_mm_tokens or num_mm_tokens_in_chunk) is None; later used unconditionally.
Apply:
- if self.num_unseen_mm_tokens is None or self.num_mm_tokens_in_chunk is None: + remainder = 0 + if self.num_unseen_mm_tokens is None or self.num_mm_tokens_in_chunk is None: # Compute cached multimodal tokens based on positions and cached tokens self.num_unseen_mm_tokens = 0 self.num_mm_tokens_in_chunk = 0 - remainder = 0 for pos, length in zip(self.mm_token_positions, self.mm_token_lengths): ... + else: + # Ensure defined when provided by caller + remainder = 0tensorrt_llm/_torch/models/modeling_llava_next.py (2)
1-1: Add NVIDIA Apache-2.0 header (2025).Add the standard header per guidelines.
Apply:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.
5-11: Import torch.nn.functional.Undefined name F in _pad_for_batching.
Apply:
import torch import torch.nn as nn +import torch.nn.functional as Ftensorrt_llm/inputs/registry.py (1)
1-1: Add NVIDIA Apache-2.0 header (2025).Add the standard header per guidelines.
Apply:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.tensorrt_llm/_torch/models/modeling_hyperclovax.py (2)
1-1: Add NVIDIA Apache-2.0 header (2025).Add the standard header per guidelines.
Apply:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.
605-608: Use typing.Dict/Any for Python 3.8 compatibility.Replace
dict[str, any]withDict[str, Any].Apply:
- def _post_process(self, - input_ids: torch.Tensor, - preprocessed_image: dict[str, any] = None): + def _post_process(self, + input_ids: torch.Tensor, + preprocessed_image: Dict[str, Any] = None): @@ - def _preprocess(self, text_prompt: dict[str, any], images: List[Any], + def _preprocess(self, text_prompt: Dict[str, Any], images: List[Any], mm_processor_kwargs: Dict[str, Any]):Also applies to: 667-669
tensorrt_llm/_torch/models/modeling_clip.py (1)
1-1: Add NVIDIA Apache-2.0 header (2025) at top of file.Required by repo guidelines.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.tensorrt_llm/_torch/models/modeling_siglip.py (2)
1-1: Add NVIDIA Apache-2.0 header (2025) at top of file.Required by repo guidelines.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.
7-9: Duplicate/conflicting import of SiglipVisionConfig.
SiglipVisionConfigis imported from both configuration_siglip and modeling_siglip; the latter likely shadows the config class. Keep the config import only from configuration module.-from transformers.models.siglip.configuration_siglip import SiglipVisionConfig -from transformers.models.siglip.modeling_siglip import (SiglipVisionConfig, - SiglipVisionEmbeddings) +from transformers.models.siglip.configuration_siglip import SiglipVisionConfig +from transformers.models.siglip.modeling_siglip import SiglipVisionEmbeddings
🧹 Nitpick comments (7)
tensorrt_llm/inputs/multimodal.py (1)
648-656: Clarify semantics of special token positions.Returned start_special_token_positions are offsets within the flattened MM‑token union, not absolute prompt indices. Consider renaming to special_token_offsets or documenting explicitly to avoid misuse.
docs/source/models/supported-models.md (1)
47-58: Table consistency check.Ensure the modality/feature values reflect actual runtime support (e.g., KV cache reuse states) for each class; discrepancies here create user confusion.
tensorrt_llm/_torch/models/modeling_hyperclovax.py (2)
595-601: Silence linter for unused parameters or use underscores.Minor cleanup.
Apply:
- def get_num_tokens_per_image( - self, - image: Image.Image, - **kwargs, - ): + def get_num_tokens_per_image( + self, + _image: Image.Image, + **_kwargs, + ): return self.vision_query_lengths[0].pop(0)
1059-1062: Remove extraneous f-prefix.No interpolation; keep plain string.
Apply:
- raise NotImplementedError( - "HCXVisionForCausalLM does not support disaggregated inference yet. Please unset " - f"the TLLM_MULTIMODAL_DISAGGREGATED environment variable, or set it to '0'." - ) + raise NotImplementedError( + "HCXVisionForCausalLM does not support disaggregated inference yet. Please unset " + "the TLLM_MULTIMODAL_DISAGGREGATED environment variable, or set it to '0'." + )tensorrt_llm/_torch/models/modeling_clip.py (2)
187-193: Pre‑allocating a mutable attn_metadata on the model can be racy across concurrent forwards.If this model is shared across threads/requests, mutating a single instance is not thread‑safe. Either document single‑use/serialized access or scope attn_metadata per runner/engine instance.
118-121: Derive view shape from num_contexts/max_seq_len, not seq_lens buffer length.Preallocated buffers may have capacity > batch size. Use explicit metadata fields.
- encoder_states = encoder_states + (hidden_states.view( - attn_metadata.seq_lens.shape[0], attn_metadata.seq_lens[0], - -1), ) + encoder_states = encoder_states + (hidden_states.view( + attn_metadata.num_contexts, attn_metadata.max_seq_len, -1), ) @@ - encoder_states = encoder_states + (hidden_states.view( - attn_metadata.seq_lens.shape[0], attn_metadata.seq_lens[0], -1), ) + encoder_states = encoder_states + (hidden_states.view( + attn_metadata.num_contexts, attn_metadata.max_seq_len, -1), )Also applies to: 130-132
tensorrt_llm/_torch/models/modeling_siglip.py (1)
80-86: Same shared-state caveat as CLIP: attn_metadata on the model is mutable.Not thread‑safe if the model is shared across requests. Clarify usage or scope per runner/engine.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
docs/source/models/supported-models.md(2 hunks)tensorrt_llm/_torch/models/modeling_clip.py(2 hunks)tensorrt_llm/_torch/models/modeling_hyperclovax.py(7 hunks)tensorrt_llm/_torch/models/modeling_llava_next.py(2 hunks)tensorrt_llm/_torch/models/modeling_siglip.py(2 hunks)tensorrt_llm/inputs/multimodal.py(1 hunks)tensorrt_llm/inputs/registry.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Use only spaces, no tabs; indent with 4 spaces.
Files:
tensorrt_llm/inputs/multimodal.pytensorrt_llm/_torch/models/modeling_clip.pytensorrt_llm/_torch/models/modeling_siglip.pytensorrt_llm/inputs/registry.pytensorrt_llm/_torch/models/modeling_hyperclovax.pytensorrt_llm/_torch/models/modeling_llava_next.py
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.
Files:
tensorrt_llm/inputs/multimodal.pytensorrt_llm/_torch/models/modeling_clip.pytensorrt_llm/_torch/models/modeling_siglip.pytensorrt_llm/inputs/registry.pytensorrt_llm/_torch/models/modeling_hyperclovax.pytensorrt_llm/_torch/models/modeling_llava_next.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).
Files:
tensorrt_llm/inputs/multimodal.pytensorrt_llm/_torch/models/modeling_clip.pytensorrt_llm/_torch/models/modeling_siglip.pytensorrt_llm/inputs/registry.pytensorrt_llm/_torch/models/modeling_hyperclovax.pytensorrt_llm/_torch/models/modeling_llava_next.py
🧬 Code graph analysis (5)
tensorrt_llm/inputs/multimodal.py (1)
tensorrt_llm/logger.py (1)
debug(144-145)
tensorrt_llm/_torch/models/modeling_clip.py (3)
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (1)
attn_metadata(96-97)tensorrt_llm/_torch/speculative/eagle3.py (2)
prepare(123-166)prepare(244-254)tensorrt_llm/_torch/speculative/mtp.py (1)
prepare(161-209)
tensorrt_llm/_torch/models/modeling_siglip.py (4)
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (1)
attn_metadata(96-97)tensorrt_llm/_torch/speculative/interface.py (1)
prepare(172-175)tensorrt_llm/_torch/speculative/eagle3.py (2)
prepare(123-166)prepare(244-254)tensorrt_llm/_torch/speculative/mtp.py (1)
prepare(161-209)
tensorrt_llm/inputs/registry.py (5)
tensorrt_llm/inputs/multimodal.py (7)
apply_mm_hashes(478-510)find_mm_token_lengths(529-572)find_mm_token_positions(575-656)validate_mm_inputs(659-702)hexdigest_to_int32(513-526)MultimodalInput(19-85)from_components(72-77)tensorrt_llm/runtime/model_runner_cpp.py (1)
vocab_size(489-490)tensorrt_llm/_torch/models/modeling_hyperclovax.py (2)
get_vocab_size(592-593)get_mm_token_ids(602-603)tensorrt_llm/_torch/models/modeling_gemma3vl.py (1)
mm_token_ids(291-292)tensorrt_llm/_torch/models/modeling_mistral.py (1)
mm_token_ids(506-507)
tensorrt_llm/_torch/models/modeling_hyperclovax.py (3)
tensorrt_llm/inputs/multimodal.py (1)
MultimodalParams(196-455)tensorrt_llm/inputs/registry.py (5)
BaseMultimodalInputProcessor(45-178)InputProcessor(22-42)get_vocab_size(64-84)get_num_tokens_per_image(130-148)get_mm_token_ids(86-100)tensorrt_llm/_torch/models/modeling_multimodal_utils.py (3)
find_input_mm_embeds(158-233)fuse_input_embeds(276-330)get_multimodal_embeddings(99-155)
🪛 Ruff (0.12.2)
tensorrt_llm/inputs/registry.py
480-482: Avoid specifying long messages outside the exception class
(TRY003)
tensorrt_llm/_torch/models/modeling_hyperclovax.py
597-597: Unused method argument: image
(ARG002)
598-598: Unused method argument: kwargs
(ARG002)
1061-1061: f-string without any placeholders
Remove extraneous f prefix
(F541)
tensorrt_llm/_torch/models/modeling_llava_next.py
375-375: Undefined name F
(F821)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (5)
tensorrt_llm/inputs/multimodal.py (2)
606-608: Log level downgrade LGTM.Demoting to debug reduces noise; in this repo, logger.debug maps to VERBOSE. No action needed.
529-538: Fix return annotation and stale comment in find_mm_token_lengths.Function returns a dict, not List[int]; comment is misleading.
[suggest_minor_issue]
Apply:-def find_mm_token_lengths(mm_data: Dict[str, Any], - input_processor: Any) -> List[int]: +def find_mm_token_lengths(mm_data: Dict[str, Any], + input_processor: Any) -> Dict[str, List[int]]: @@ - return num_mm_tokens # flatten all mm instances to a single list + return num_mm_tokens # mapping: modality -> list of lengthsAlso applies to: 572-573
docs/source/models/supported-models.md (1)
13-13: Verify GPT‑OSS listing.Please confirm that
GptOssForCausalLMexists in the codebase and has a working loader; otherwise gate the row behind a release flag or remove until available.tensorrt_llm/inputs/registry.py (1)
503-506: Attach MultimodalInput robustly.This block assumes extra_processed_inputs is a dict (after guards above). Looks good once guards are in.
tensorrt_llm/_torch/models/modeling_siglip.py (1)
93-105: Confirm max_seq_len semantics.You switched to
max_seq_len = seq_len(per‑request maximum). Verify the backend expects per‑request max rather than total flattened tokens; mismatch can under‑allocate.
9e58dec to
40f1830
Compare
|
/bot run |
|
PR_Github #20235 [ run ] triggered by Bot |
|
PR_Github #20235 [ run ] completed with state |
|
/bot run |
|
PR_Github #20331 [ run ] triggered by Bot |
|
PR_Github #20331 [ run ] completed with state |
chang-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
Thank you for the MR, @yechank-nvidia. Is there a good way to test this change? If there's an existing test harness, we should probably parameterize it to test the models being touched in this MR? |
Signed-off-by: yechank <[email protected]>
Signed-off-by: yechank <[email protected]>
Signed-off-by: yechank <[email protected]>
Signed-off-by: yechank <[email protected]>
47a8c65 to
c0e9ed1
Compare
|
/bot run |
|
PR_Github #21847 [ run ] triggered by Bot. Commit: |
|
PR_Github #21847 [ run ] completed with state |
brb-nv
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
…DIA#7789) Signed-off-by: yechank <[email protected]>
…DIA#7789) Signed-off-by: yechank <[email protected]> Signed-off-by: yufeiwu-nv <[email protected]>
…DIA#7789) Signed-off-by: yechank <[email protected]>
…DIA#7789) Signed-off-by: yechank <[email protected]>
…DIA#7789) Signed-off-by: yechank <[email protected]>
…DIA#7789) Signed-off-by: yechank <[email protected]>
Summary by CodeRabbit
New Features
Bug Fixes
Documentation
Refactor
Chores