Skip to content

feat(embedding): add native loading for BERT/XLMRoBERTa embedding models#330

Merged
jundot merged 2 commits intojundot:mainfrom
yes999zc:feat/native-embedding-clean
Mar 21, 2026
Merged

feat(embedding): add native loading for BERT/XLMRoBERTa embedding models#330
jundot merged 2 commits intojundot:mainfrom
yes999zc:feat/native-embedding-clean

Conversation

@yes999zc
Copy link
Copy Markdown
Contributor

What

Add native embedding support for BERT/XLMRoBERTa-family MLX models in omlx/models/embedding.py, with mlx-embeddings kept as the fallback for unsupported architectures.

Why

Some embedding models (for example BAAI/bge-m3 MLX variants) are not supported by mlx-embeddings, even though oMLX already has a native XLM-RoBERTa implementation that can produce normalized text embeddings.

Changes

  • add _load_native() to detect and load local BertModel / BertForMaskedLM / XLMRobertaModel embedding models natively
  • keep mlx-embeddings path as the fallback for other architectures
  • support native tokenization + forward pass in embed()
  • preserve existing compiled eager/fallback behavior for the mlx-embeddings path
  • add tests for native BERT/XLMRoBERTa embedding loading

Validation

Local smoke-tested with real models:

  • bge-small-en-v1.5
  • mxbai-embed-large-v1
  • mlx-community/bge-m3-mlx-fp16

All loaded successfully and returned embeddings end-to-end.

Notes

This PR is intentionally kept narrow to embedding only. Jina reranker support is excluded and can be proposed separately after dedicated verification.

FocusFlow Dev added 2 commits March 21, 2026 08:45
- Add _load_native() method for loading embedding models via omlx's
  native xlm_roberta.py implementation (BERT, XLMRoBERTa architectures)
- Fall back to mlx-embeddings for other embedding architectures
- Fix tokenizer handling for both PreTrainedTokenizer (callable) and
  tokenizers.Tokenizer (encode method) in embed()
- Fix config ModelArgs property setter issue for embedding mode

Supports: BERT, XLMRoBERTa embedding models without mlx-embeddings dependency
- Test _load_native() for BERT and XLMRoBERTa architectures
- Test fallback for unknown architectures
- Test that embed produces L2-normalized vectors
@jundot jundot force-pushed the main branch 7 times, most recently from f6faf2f to c2beead Compare March 21, 2026 05:58
Copy link
Copy Markdown
Owner

@jundot jundot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed the code. Clean scope, no issues found.

Native loading reuses the existing xlm_roberta.py implementation nicely, and the mlx-embeddings fallback keeps backward compatibility intact. Attention mask handling for padding is also correct.

LGTM, good to merge.

@jundot jundot merged commit 1e41845 into jundot:main Mar 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants