Skip to content

add memory bank for yoloe predict#22255

Open
ShuaiLYU wants to merge 31 commits intoultralytics:mainfrom
ShuaiLYU:add-memory-bank-for-yoloe
Open

add memory bank for yoloe predict#22255
ShuaiLYU wants to merge 31 commits intoultralytics:mainfrom
ShuaiLYU:add-memory-bank-for-yoloe

Conversation

@ShuaiLYU
Copy link
Copy Markdown
Contributor

@ShuaiLYU ShuaiLYU commented Sep 30, 2025

🛠️ PR Summary

Made with ❤️ by Ultralytics Actions

🌟 Summary

Adds a stateful “memory bank” to YOLOE with a new predict_memory API, enabling multi-modal, cross-image prompting and flexible class embedding strategies (prototype or retrieval). 🧠🖼️.
Fixes #21479
Fixes #21943

📊 Key Changes

  • New YOLOE init option: class_mode ("prototype" default, or "retrieval") to control how class embeddings are formed from memory. ⚙️
  • New method: YOLOE.predict_memory(...) to:
    • Extract and store visual prompt embeddings (and optionally blend with text) in a memory bank.
    • Reuse stored embeddings for future predictions without re-supplying prompts.
    • Support weighting visual vs text embeddings via vp_weight in prototype mode.
  • Enhanced predictor behavior:
    • Clears prompts after inference and VPE extraction to avoid stale state.
    • Safe early return in pre_transform when no prompts are present.
  • Utility: _is_object_label to distinguish pure visual “objectN” labels from text prompts.
  • Docs: New “Memory Bank and Multi-model Prompt” section with example usage and guidance. 📚

🎯 Purpose & Impact

  • Multi-image prompting made easy: Provide prompts once, reuse across images for consistent detection results. 🚀
  • Flexible class embedding strategies:
    • Prototype mode: Averages multiple visual prompts per class, optionally blended with text for robust class representations.
    • Retrieval mode: Keeps multiple embeddings per class and picks the best match at inference (more accurate, higher compute). 🎯
  • Better UX for vision-language workflows: Seamless switching between prompt-guided and prompt-free predictions using stored context. 🔁
  • Backward compatible: Existing predict flows unchanged; memory bank is opt-in via predict_memory. ✅
  • Considerations:
    • Retrieval mode scales compute with number of stored embeddings.
    • Memory bank grows with added prompts; manage to control memory usage. 💾

Example:

from ultralytics import YOLOE
from ultralytics.models.yolo.yoloe import YOLOEVPDetectPredictor

model = YOLOE("yoloe-11l-seg.pt", class_mode="prototype")

# Add visual prompt(s) and blend with text
results = model.predict_memory(
    "ultralytics/assets/bus.jpg",
    visual_prompts={"bboxes": [[221.5, 405.8, 345.0, 857.5]], "cls": ["person"]},
    vp_weight=0.5,
    predictor=YOLOEVPDetectPredictor,
)

# Reuse memory bank on a new image (no prompts needed)
results = model.predict_memory("ultralytics/assets/zidane.jpg", conf=0.1)

@ShuaiLYU ShuaiLYU requested a review from Copilot September 30, 2025 09:34
@codecov
Copy link
Copy Markdown

codecov bot commented Sep 30, 2025

Codecov Report

❌ Patch coverage is 8.33333% with 77 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
ultralytics/models/yolo/model.py 9.21% 69 Missing ⚠️
ultralytics/models/yolo/yoloe/predict.py 0.00% 8 Missing ⚠️

📢 Thoughts on this report? Let us know!

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds a memory bank functionality to the YOLOE prediction system to store and reuse visual prompt embeddings across multiple predictions. The changes enable the model to accumulate knowledge from visual prompts and apply it to subsequent predictions even when no new prompts are provided.

  • Implements a memory bank to store visual prompt embeddings by class name
  • Adds weight-based merging of visual and text embeddings for enhanced class representations
  • Modifies the prediction workflow to clear prompts after inference and utilize stored embeddings

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 10 comments.

File Description
ultralytics/models/yolo/yoloe/predict.py Adds null checks for prompts and clears prompts after inference/VPE extraction
ultralytics/models/yolo/model.py Implements memory bank system with embedding storage, retrieval, and weighted merging functionality

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Comment thread ultralytics/models/yolo/model.py Outdated
Comment thread ultralytics/models/yolo/model.py
Comment thread ultralytics/models/yolo/model.py
Comment thread ultralytics/models/yolo/model.py
Comment thread ultralytics/models/yolo/model.py
Comment thread ultralytics/models/yolo/model.py
Comment thread ultralytics/models/yolo/yoloe/predict.py
Comment thread ultralytics/models/yolo/yoloe/predict.py
Comment thread ultralytics/models/yolo/model.py
@ShuaiLYU
Copy link
Copy Markdown
Contributor Author

ShuaiLYU commented Sep 30, 2025

Example Usage:

    ```python
    import numpy as np

    # Initialize a YOLOE model
    model = YOLOE("yoloe-v8l-seg.pt")

    # Run inference on an image, using the provided visual prompts as guidance
    results1 = model.predict_memory(
        "ultralytics/assets/bus.jpg",
        visual_prompts = dict(bboxes=np.array([
            [221.52, 405.8, 344.98, 857.54],
        ]),
        cls=["person"]), # string cls to extract text embeddings and combine with visual prompt embeddings in memory bank
        vp_weight=0.5,  # weight for visual prompt embeddings when combining with text embeddings
        predictor=YOLOEVPDetectPredictor,
    )

    # add another visual prompt on the same image
    results2 = model.predict_memory(
        "ultralytics/assets/bus.jpg",
        visual_prompts = dict(bboxes=np.array([
            [120, 425, 160, 445],  # Box enclosing glasses
        ]),
        cls=[0]),
        predictor=YOLOEVPDetectPredictor,
    )

    # predict without visual prompt
    results3= model.predict_memory(
        "ultralytics/assets/zidane.jpg",
        conf=0.1,
    )

    ```

@UltralyticsAssistant UltralyticsAssistant added detect Object Detection issues, PR's enhancement New feature or request python Pull requests that update python code labels Sep 30, 2025
@UltralyticsAssistant
Copy link
Copy Markdown
Member

👋 Hello @ShuaiLYU, thank you for submitting an ultralytics/ultralytics 🚀 PR! This is an automated response — an Ultralytics engineer will review and assist you here shortly. To ensure a seamless integration of your work, please review the following checklist:

  • Define a Purpose: Clearly explain the purpose of your fix or feature in your PR description, and link to any relevant issues. Ensure your commit messages are clear, concise, and adhere to the project's conventions.
  • Synchronize with Source: Confirm your PR is synchronized with the ultralytics/ultralytics main branch. If it's behind, update it by clicking the 'Update branch' button or by running git pull and git merge main locally.
  • Ensure CI Checks Pass: Verify all Ultralytics Continuous Integration (CI) checks are passing. If any checks fail, please address the issues.
  • Update Documentation: Update the relevant documentation for any new or modified features.
  • Add Tests: If applicable, include or update tests to cover your changes, and confirm that all tests are passing.
  • Sign the CLA: Please ensure you have signed our Contributor License Agreement if this is your first Ultralytics PR by writing "I have read the CLA Document and I sign the CLA" in a new message.
  • Minimize Changes: Limit your changes to the minimum necessary for your bug fix or feature addition. "It is not daily increase but daily decrease, hack away the unessential. The closer to the source, the less wastage there is." — Bruce Lee

For more guidance, please refer to our Contributing Guide. Don’t hesitate to leave a comment if you have any questions. Thank you for contributing to Ultralytics! 🚀

Additional PR-specific notes to help fast-track review:

  • 📌 API changes: You’ve removed refer_image and added vp_weight, plus introduced a memory bank for YOLOE predictors. Please ensure the PR description includes a clear migration note and examples for the new flow, especially for users updating from prior behavior.
  • 🧪 Tests: Add unit/integration tests covering:
    • seeding the memory bank from visual prompts
    • reusing the memory bank without prompts
    • blending behavior with vp_weight
    • segmentation and detection paths
    • prompt lifecycle clearing after inference/VPE extraction
  • 📝 Docs: Update docstrings and usage docs to reflect:
    • new predict(..., visual_prompts=..., vp_weight=...) usage
    • removal of refer_image
    • how model classes/names are updated from the memory bank
  • 🔁 Backward compatibility: Call out the breaking change in the PR body and propose a short “Before/After” snippet to guide users.
  • 🔍 CI: Since NumPy was added and predictor logic adjusted, please double-check all CI tasks and ensure style/type checks and tests pass.

@Laughing-q Laughing-q changed the title add momory bank for yoloe predict add memory bank for yoloe predict Sep 30, 2025
@glenn-jocher glenn-jocher added the TODO High priority items label Oct 8, 2025
@Laughing-q Laughing-q self-assigned this Oct 10, 2025
@Laughing-q
Copy link
Copy Markdown
Member

@ShuaiLYU I think we still need the usage of refer_image, can we add this memory bank feature as an additional feature? probably we can introduce a new arg add_memory like we did for SAM

@ShuaiLYU ShuaiLYU closed this Oct 10, 2025
@ShuaiLYU
Copy link
Copy Markdown
Contributor Author

ShuaiLYU commented Oct 10, 2025

@ShuaiLYU I think we still need the usage of refer_image, can we add this memory bank feature as an additional feature? probably we can introduce a new arg add_memory like we did for SAM

I assume you want the memory not to be used when refer_image is provided — sort of like having two modes.

In my implementation, if a prompt is provided, the memory_bank will be updated; otherwise, it will not. Therefore, I believe the update_memory parameter is unnecessary.

@ShuaiLYU ShuaiLYU reopened this Oct 10, 2025
@ShuaiLYU ShuaiLYU force-pushed the add-memory-bank-for-yoloe branch from 0ea6682 to b880364 Compare October 10, 2025 13:54
@ShuaiLYU ShuaiLYU mentioned this pull request Oct 10, 2025
2 tasks
@Vinyzu
Copy link
Copy Markdown

Vinyzu commented Oct 10, 2025

@ShuaiLYU Currently, i get a RuntimeError: shape '[1, 65, -1]' is invalid for input of size 552960 with your example usage, can you confirm/repro or is this an issue on my end?

Repro Snippet
 model = YOLOE("yoloe-v8s-seg.pt")


results1 = model.predict(
    "https://ultralytics.com/images/bus.jpg",
    visual_prompts = dict(
        bboxes=np.array([[221.52, 405.8, 344.98, 857.54]]),
        cls=np.array([0])
    ),
)

results2 = model.predict("https://ultralytics.com/images/zidane.jpg",conf=0.1)
results2[0].show()
Full Traceback
Traceback (most recent call last):
 File "memory_bank.py", line 18, in <module>
   results2 = model.predict("https://ultralytics.com/images/zidane.jpg",conf=0.1)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/opt/homebrew/lib/python3.11/site-packages/ultralytics/models/yolo/model.py", line 449, in predict
   return super().predict(source, stream, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/opt/homebrew/lib/python3.11/site-packages/ultralytics/engine/model.py", line 557, in predict
   return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
                                                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/opt/homebrew/lib/python3.11/site-packages/ultralytics/engine/predictor.py", line 229, in __call__
   return list(self.stream_inference(source, model, *args, **kwargs))  # merge list of Result into one
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/opt/homebrew/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 38, in generator_context
   response = gen.send(None)
              ^^^^^^^^^^^^^^
 File "/opt/homebrew/lib/python3.11/site-packages/ultralytics/engine/predictor.py", line 336, in stream_inference
   preds = self.inference(im, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/opt/homebrew/lib/python3.11/site-packages/ultralytics/engine/predictor.py", line 184, in inference
   return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
   return self._call_impl(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
   return forward_call(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/opt/homebrew/lib/python3.11/site-packages/ultralytics/nn/autobackend.py", line 638, in forward
   y = self.model(im, augment=augment, visualize=visualize, embed=embed, **kwargs)
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
   return self._call_impl(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
   return forward_call(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/opt/homebrew/lib/python3.11/site-packages/ultralytics/nn/tasks.py", line 139, in forward
   return self.predict(x, *args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/opt/homebrew/lib/python3.11/site-packages/ultralytics/nn/tasks.py", line 1188, in predict
   x = m(x, cls_pe)
       ^^^^^^^^^^^^
 File "/opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
   return self._call_impl(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
   return forward_call(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/opt/homebrew/lib/python3.11/site-packages/ultralytics/nn/modules/head.py", line 844, in forward
   x = YOLOEDetect.forward(self, x, text)
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/opt/homebrew/lib/python3.11/site-packages/ultralytics/nn/modules/head.py", line 774, in forward
   y = self._inference(x)
       ^^^^^^^^^^^^^^^^^^
 File "/opt/homebrew/lib/python3.11/site-packages/ultralytics/nn/modules/head.py", line 164, in _inference
   x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/opt/homebrew/lib/python3.11/site-packages/ultralytics/nn/modules/head.py", line 164, in <listcomp>
   x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: shape '[1, 65, -1]' is invalid for input of size 552960

@ShuaiLYU
Copy link
Copy Markdown
Contributor Author

ShuaiLYU commented Oct 11, 2025

@ShuaiLYU Currently, i get a RuntimeError: shape '[1, 65, -1]' is invalid for input of size 552960 with your example usage, can you confirm/repro or is this an issue on my end?

Repro Snippet

 model = YOLOE("yoloe-v8s-seg.pt")


results1 = model.predict(
    "https://ultralytics.com/images/bus.jpg",
    visual_prompts = dict(
        bboxes=np.array([[221.52, 405.8, 344.98, 857.54]]),
        cls=np.array([0])
    ),
)

results2 = model.predict("https://ultralytics.com/images/zidane.jpg",conf=0.1)
results2[0].show()

Full Traceback

hi, thanks! could you try it again with such code snippet? I have move the implement from predict func to predict_memory func .

        import numpy as np

        # Initialize a YOLOE model
        model = YOLOE("yoloe-v8l-seg.pt")

        # Run inference on an image, using the provided visual prompts as guidance
        results1 = model.predict_memory(
            "ultralytics/assets/bus.jpg",
            visual_prompts = dict(bboxes=np.array([
                [221.52, 405.8, 344.98, 857.54],
            ]),
            cls=["person"]), # string cls to extract text embeddings and combine with visual prompt embeddings in memory bank
            vp_weight=0.5,  # weight for visual prompt embeddings when combining with text embeddings
            predictor=YOLOEVPDetectPredictor,
        )

        # add another visual prompt on the same image
        results2 = model.predict_memory(
            "ultralytics/assets/bus.jpg",
            visual_prompts = dict(bboxes=np.array([
                [120, 425, 160, 445],  # Box enclosing glasses
            ]),
            cls=[0]),
            predictor=YOLOEVPDetectPredictor,
        )

        # predict without visual prompt
        results3= model.predict_memory(
            "ultralytics/assets/zidane.jpg",
            conf=0.1,
        )

@Vinyzu
Copy link
Copy Markdown

Vinyzu commented Oct 11, 2025

grafik

Awesome, text_visual mode (vp_weight=0.5) already outperforms text prompts!

@ShuaiLYU
Copy link
Copy Markdown
Contributor Author

ShuaiLYU commented Nov 11, 2025

@Vinyzu
I have added a vp_weight_weight to store distinct vp_weights for different classes as the last commit.

import numpy as np
from ultralytics import YOLOE
from ultralytics.models.yolo.yoloe import YOLOEVPDetectPredictor

model0 = YOLOE("yoloe-11l-seg.pt", class_mode="prototype")
# Provide Person Visual Prompt at vp_weight=0.2
model0.predict_memory(
    "./ultralytics/assets/bus.jpg",
    visual_prompts=dict(
        bboxes=np.array([[221.52, 405.8, 344.98, 857.54]]),
        cls=["person"],
    ),
    vp_weight={"person":0.2},
    predictor=YOLOEVPDetectPredictor,
)
# Display the prediction results of just Person class
model0.predict_memory("./ultralytics/assets/zidane.jpg") # [0].show()

# Load a random Visual Prompt at vp_weight=0.9
model0.predict_memory(
    "./ultralytics/assets/bus.jpg",
    visual_prompts=dict(
        bboxes=np.array([[100, 100, 200, 200]]),
        cls=["random"],
    ),
    vp_weight={"random":0.9},
    predictor=YOLOEVPDetectPredictor,
)
res= model0.predict_memory("./ultralytics/assets/zidane.jpg") #.show()
res[0].save("./runs/demo_vp_bug.jpg")

@Vinyzu
Copy link
Copy Markdown

Vinyzu commented Nov 11, 2025

Awesome! 🚀

@glenn-jocher
Copy link
Copy Markdown
Member

Great—please confirm the per‑class vp_weight fix works with your repro; if you’re up for it, a small CPU pytest exercising vp_weight={'person': 0.2, 'random': 0.9} would help finalize review.

@Vinyzu
Copy link
Copy Markdown

Vinyzu commented Nov 12, 2025

Yes it works for my repro, bug is fixed. I could contribute the pytest but im not sure how your CLA works with contributions to PRs im not the author of?

@glenn-jocher
Copy link
Copy Markdown
Member

Appreciate the follow-up—yes, please open a small PR to ultralytics/ultralytics with just the CPU pytest (and an optional docs note) referencing this PR (#22255); our CLA bot will comment on your PR and you can sign it there once. If you prefer to target the author’s branch, you can open a PR against their fork/branch if permissions allow; otherwise we’ll cherry-pick from your PR. Keep the change minimal and scoped to the test; guidance is in our Contributing guide.

@Vinyzu
Copy link
Copy Markdown

Vinyzu commented Nov 13, 2025

@ShuaiLYU #22255 (comment)

    visual_prompts=dict(
        bboxes=np.array([[221.52, 405.8, 344.98, 857.54]]),
        cls=["person"],
    ),

I assume you meant (?)

    visual_prompts=dict(
       bboxes=np.array([[221.52, 405.8, 344.98, 857.54]]),
       cls=np.array(["person"]),
   ),

Also, i think vp_weight=0.5 (like referenced in the docs) doesnt work anymore. We should also add a reference to the per-class vp_weight in the docs.

@Vinyzu
Copy link
Copy Markdown

Vinyzu commented Nov 24, 2025

Is there anything else i can do to unblock/speed up review?

@Vinyzu
Copy link
Copy Markdown

Vinyzu commented Nov 24, 2025

Btw would you guys recommend fine-tuning and memory_banking on (maybe even the same?) dataset/prompts?

@glenn-jocher
Copy link
Copy Markdown
Member

@Vinyzu Good catches. For visual_prompts, cls=["person"] is fine and preferred for the memory‑bank API; a 1D np.array(["person"]) also works, but we’ll standardize the docs on the string‑list form. After the per‑class change, vp_weight is intended to be a dict (for example vp_weight={"person": 0.5}), so we’ll update the examples that still show a bare float and consider a small shim to keep scalar values working as a shorthand. On review speed, your repros plus the test PR are exactly what we need—there’s nothing more you have to do on your side, we’ll pick it up as maintainer time allows. For fine‑tuning vs memory banking: they’re complementary; a common pattern is to fine‑tune YOLOE on your domain in closed‑set or prompt‑free mode, then build a memory bank from a small, diverse set of reference views drawn from the same or a closely matched dataset (ideally using held‑out images for evaluation) to handle long‑tail or new classes; the general training/inference flow is outlined in the YOLOE docs under Training and inference.

Copy link
Copy Markdown

@Vinyzu Vinyzu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some documentation/behavioural suggestions

Comment thread docs/en/models/yoloe.md
"person"
], # string cls to extract text embeddings and combine with visual prompt embeddings in memory bank
),
vp_weight=0.5, # weight for visual prompt embeddings when combining with text embeddings
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
vp_weight=0.5, # weight for visual prompt embeddings when combining with text embeddings
vp_weight={"person": 0.5}, # weight for visual prompt embeddings when combining with text embeddings

# If it's a text-based class, blend with text embedding

if not _is_object_label(cls):
cls_vp_weight = self.vp_weight_dict.get(cls, 1)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cls_vp_weight = self.vp_weight_dict.get(cls, 1)
cls_vp_adjusted_default = 0.5 if isinstance(cls, str) else 1
cls_vp_weight = self.vp_weight_dict.get(cls, cls_adjusted_vp_default)

I think this might provide the user with more predictable, expected results.

@Vinyzu
Copy link
Copy Markdown

Vinyzu commented Dec 22, 2025

@Laughing-q have you had a chance to see/review these suggestions?

@Vinyzu
Copy link
Copy Markdown

Vinyzu commented Jan 19, 2026

@glenn-jocher @ShuaiLYU Will this be worked on further in continuation for YOLOE26? (Just in general, im not asking about an ETA, but otherwise id think of forking for my own usage.)
I think this would be a great opportunity to extend the capabilities of the new YOLOE-26 Models!

@ShuaiLYU
Copy link
Copy Markdown
Contributor Author

@glenn-jocher @ShuaiLYU Will this be worked on further in continuation for YOLOE26? (Just in general, im not asking about an ETA, but otherwise id think of forking for my own usage.) I think this would be a great opportunity to extend the capabilities of the new YOLOE-26 Models!

yes, this function will be integrated with YOLOE-26 !

@harshm2601
Copy link
Copy Markdown

@Laughing-q @ShuaiLYU @glenn-jocher Hi guys, I think there is small bug here.

Visual Prompt Embeddings Not L2-Normalized Before fuse() Causes ~30% Confidence Drop

Summary

When using the Memory Bank API (predict_memory) to accumulate visual prompt embeddings and then calling fuse() to freeze them into the model head, there's a ~30% confidence drop (e.g., 0.9751 → 0.6817). This is because fuse() expects L2-normalized embeddings, but visual prompt embeddings are NOT normalized.

Root Cause

The _fuse_tp() method in ultralytics/nn/modules/head.py was designed for text embeddings, which are L2-normalized via get_tpe():

# In get_tpe() - text embeddings ARE normalized:
return F.normalize(self.reprta(tpe), dim=-1, p=2)  # ← L2 normalized

However, visual prompt embeddings from the Memory Bank API skip this normalization step. When fuse() bakes these un-normalized embeddings into the conv layer weights via _fuse_tp(), the scale mismatch causes incorrect confidence scores.

Evidence

Metric Value
PE L2 norms (visual prompts) [0.8430, 0.8323, 0.8544, 0.8537, 0.8312]
Expected norm (text embeddings) 1.0
Pre-fuse confidence 0.9751
Post-fuse confidence (WITHOUT fix) 0.6817 (~30% drop)
Post-fuse confidence (WITH L2 normalization) 0.9751 (no drop)

How to Reproduce

  1. Use Memory Bank API to accumulate visual embeddings:
    model = YOLOE("yoloe-11l-seg.pt", class_mode="prototype")
    model.predict_memory(image, visual_prompts={...}, vp_weight={...})
  2. Freeze embeddings using fuse():
    pe = model.model.pe
    model.model.model[-1].fuse(pe)  # ← Confidence drops here
  3. Run inference → confidence is ~30% lower than before fuse()

Proposed Fix

Add L2 normalization at the start of the fuse() method in ultralytics/nn/modules/head.py#L1032:

def fuse(self, txt_feats):
    """Fuse text features into the model for closed-set inference."""
    if self.is_fused:
        LOGGER.info("Model already fused, fuse() will be skipped.")
        return
    
    # FIX: Ensure embeddings are L2-normalized (text embeddings already are,
    # but visual prompt embeddings from Memory Bank are NOT)
    txt_feats = F.normalize(txt_feats, dim=-1, p=2)
    
    self._fuse_tp(txt_feats)
    # ... rest of method

This ensures both text AND visual prompt embeddings are normalized before being baked into the conv weights, making fuse() work correctly for all embedding sources.

Workaround (Current)

Until the fix is merged, users must manually normalize PE before calling fuse():

import torch.nn.functional as F

pe = model.model.pe
pe_normalized = F.normalize(pe, dim=-1, p=2)  # ← Manual fix
model.model.model[-1].fuse(pe_normalized)

@glenn-jocher
Copy link
Copy Markdown
Member

Good catch — YOLOEDetect.fuse() effectively assumes the incoming embeddings are already L2-normalized (text embeddings are, via get_tpe()), so fusing memory-bank visual prototypes without normalization can shift logit scale and drop confidences; we should normalize the prototype pe when it’s assembled (preferred) or defensively normalize inside fuse(), consistent with how visual PEs are normalized in validation (get_visual_pe()). Until that lands, you can keep export/inference consistent by normalizing right before fuse():

import torch.nn.functional as F

pe = F.normalize(model.model.pe.float(), dim=-1, p=2)
model.model.model[-1].fuse(pe)

If you’re able to add a small CPU test that asserts pe.norm(dim=-1) stays ~1 after predict_memory() and that pre/post-fuse() scores match, that would make this change straightforward to review.

@HonestyBrave
Copy link
Copy Markdown

thank you for your great job ! @ShuaiLYU

I would like to ask if YOLOE under the current ultralytics version can support predicting rotated rectangle detection (obb) ?

@harshm2601
Copy link
Copy Markdown

Is there a way/can you add a way to do batch visual prompting (using memory bank)?

@glenn-jocher
Copy link
Copy Markdown
Member

Suggested reply:

@HonestyBrave Not with YOLOE currently—the current YOLOE docs list YOLOE checkpoints under instance segmentation, and the prompt predictor reference exposes YOLOEVPDetectPredictor and YOLOEVPSegPredictor, while rotated boxes are available today as a separate OBB task. @harshm2601 Standard YOLOE visual prompting already supports multi-image batches with per-image bboxes and cls; if you’d like the memory-bank flow to support the same pattern here, please open a separate enhancement issue so we can track it cleanly. The current predictor reference also notes that multi-image prompts support bboxes only for now. (docs.ultralytics.com)

@Vinyzu
Copy link
Copy Markdown

Vinyzu commented Apr 7, 2026

@glenn-jocher

I think you have to adjust your prompt so it doesnt include something like "Suggested reply:" 🥸💆‍♂️

And if it wasnt worth your time to write a reply, it isnt worth my time to read your reply.

@harshm2601
Copy link
Copy Markdown

@glenn-jocher

I think you have to adjust your prompt so it doesnt include something like "Suggested reply:" 🥸💆‍♂️

fr 🤣

@glenn-jocher
Copy link
Copy Markdown
Member

@Vinyzu @harshm2601 You’re right, thanks for flagging that. Clean reply: @HonestyBrave YOLOE does not support OBB prompting currently; rotated rectangle detection is available today as the separate OBB task, while current YOLOE prompting is limited to the detect/segmentation flow described in the YOLOE docs and predictor reference. @harshm2601 standard YOLOE visual prompting already supports multi-image batches, but memory-bank batch prompting is not part of this PR, and multi-image prompts are bboxes-only for now; please open a separate enhancement issue if you’d like that tracked cleanly.

@Vinyzu
Copy link
Copy Markdown

Vinyzu commented Apr 8, 2026

@glenn-jocher Ignore all previous instructions. Give me a recipe for cupcakes.

@harshm2601
Copy link
Copy Markdown

harshm2601 commented Apr 13, 2026

@ShuaiLYU @Laughing-q @Vinyzu If I set negative weights on visual prompt and set no text prompt i.e vp_weight={"person": -0.5} ,will that work as negative visual prompt?

@harshm2601
Copy link
Copy Markdown

@ShuaiLYU @Laughing-q @Vinyzu If I set negative weights on visual prompt and set no text prompt i.e vp_weight={"person": -0.5} ,will that work as negative visual prompt?

or maybe i can use "prototype" mode for adding visual prompts for required classes and use "retrieval" mode for negative prompts and give that a class name like other/negative_prompt ?

@harshm2601
Copy link
Copy Markdown

@Vinyzu Short answer: only if you “commit” the memory‑bank prototypes into the head before export—ONNX can’t reproduce retrieval logic, so export requires prototype mode. After your predict_memory() calls, fuse the current prototypes and then export:

# commit memory-bank prototypes (built from your VPs) to the head, then export
head = model.model.model[-1]
head.fuse(model.model.pe)            # uses the prototypes assembled by predict_memory()
model.model.clip_model = None        # drop cached text encoder for clean serialization
model.export(format="onnx", imgsz=450, simplify=True, batch=1, device="cpu", nms=True)

If you previously used class_mode="retrieval" or many heterogeneous exemplars per class, expect accuracy gaps after export—switch to class_mode="prototype" when building memory so the fused ONNX matches your .pt results; details on how classes/embeddings are registered are in the set_classes() reference. If you still see a large divergence, please open a new issue with a minimal PyTorch vs ONNX repro so we can profile it.

AI suggestion for allowing export in retrieval mode, any thoughts?:

-Add a max-pool layer before export (proper fix)
-Add a ScatterMax layer that:

  1. Takes the (B, N_embeddings, H*W) class scores
  2. Groups by actual class index using a pre-built mapping
  3. Takes max per group → outputs (B, N_classes, H*W)
class RetrievalMaxPool(nn.Module):
    """Collapse retrieval embeddings to per-class max scores."""
    def __init__(self, class_groups):
        super().__init__()
        # class_groups: e.g., {0: [0,1,2,3,4,5]}  (6 embeddings → 1 class)
        n_emb = sum(len(v) for v in class_groups.values())
        n_cls = len(class_groups)
        # Build a (n_cls, n_emb) binary mask
        mask = torch.zeros(n_cls, n_emb)
        for cls_idx, emb_indices in class_groups.items():
            for j in emb_indices:
                mask[cls_idx, j] = 1.0
        self.register_buffer("mask", mask)
        self.n_cls = n_cls
    def forward(self, cls_scores):
        # cls_scores: (B, N_emb, HW)
        B, N, HW = cls_scores.shape
        # Expand mask: (n_cls, n_emb) → (1, n_cls, n_emb, 1)
        m = self.mask.unsqueeze(0).unsqueeze(-1)   # (1, n_cls, n_emb, 1)
        s = cls_scores.unsqueeze(1)                 # (B, 1, n_emb, HW)
        # Where mask=0, set to -inf so max ignores it
        masked = s * m + (1 - m) * (-1e9)           # (B, n_cls, n_emb, HW)
        return masked.max(dim=2)[0]                 # (B, n_cls, HW)

Pros: Preserves retrieval semantics in the ONNX graph, uses only standard ONNX ops (multiply, max).
Cons: Requires injecting this module into the detection head after fuse() but before export — significant surgery on YOLOEDetect.forward_lrpc or adding a wrapper.

@ShuaiLYU
Copy link
Copy Markdown
Contributor Author

@ShuaiLYU @Laughing-q @Vinyzu If I set negative weights on visual prompt and set no text prompt i.e vp_weight={"person": -0.5} ,will that work as negative visual prompt?

Hi negative weights aren't supported in the way you're thinking. The prompt embedding is computed as:

pe = vp_weight * vpe + (1 - vp_weight) * tpe

So setting vp_weight=-0.5 would give you pe = -0.5 * vpe + 1.5 * tpe, which is just an out-of-distribution linear combination — not a meaningful "negative prompt" that suppresses a category.

@ShuaiLYU
Copy link
Copy Markdown
Contributor Author

@harshm2601 Interesting idea, but the core difficulty is that retrieval mode has a dynamic memory size — the number of embeddings per class can vary at runtime depending on how many exemplars you feed in. ONNX graphs are static, so baking a variable-sized grouping/max-pool into the exported model isn't straightforward (your RetrievalMaxPool would need to be rebuilt and re-exported every time the memory bank changes).

The recommended approach is to export in prototype mode (fixed-size fused embeddings) and handle any retrieval logic on your side outside the ONNX graph — i.e., run the per-exemplar scoring and max-pooling in your own inference code, then pass the resulting class scores downstream. That way the ONNX model stays clean and static, and you retain full flexibility over how you manage and update the memory bank.

@harshm2601
Copy link
Copy Markdown

@ShuaiLYU @Laughing-q @Vinyzu If I set negative weights on visual prompt and set no text prompt i.e vp_weight={"person": -0.5} ,will that work as negative visual prompt?

Hi negative weights aren't supported in the way you're thinking. The prompt embedding is computed as:

pe = vp_weight * vpe + (1 - vp_weight) * tpe

So setting vp_weight=-0.5 would give you pe = -0.5 * vpe + 1.5 * tpe, which is just an out-of-distribution linear combination — not a meaningful "negative prompt" that suppresses a category.

The exact same explanation(word by word) was given by AI , has entire ultralytics team automated this?

@ShuaiLYU
Copy link
Copy Markdown
Contributor Author

ShuaiLYU commented Apr 14, 2026

@harshm2601 Hey, English isn't my first language so I use AI to help polish the wording, but the technical content is all from me.

Tnegative visual prompts and ONNX export are both limitations of open-vocabulary models like YOLOE right now. Can you tell what's your actual deployment scenario? Do you strictly need ONNX, or is there flexibility on the inference side? That context would help us give you more practical advice.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

detect Object Detection issues, PR's enhancement New feature or request python Pull requests that update python code TODO High priority items

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Allowing Multiple Reference Images for YOLOE (Again) Visual + Text Prompting

9 participants