Skip to content

Commit 2bbe896

Browse files
SubtleSparkclaude
andcommitted
fix(memory): serialize local embedding initialization to avoid duplicate model loads
Concurrent calls to ensureContext() during file-level parallel indexing (EMBEDDING_INDEX_CONCURRENCY=4) could each pass the `if (!llama)` check before the first await resolved, causing the model to be loaded multiple times into VRAM. This exhausted GPU memory and made local embeddings unusable for users with 2+ memory files. Guard initialization with a cached Promise so all concurrent callers share a single init sequence. Co-Authored-By: Claude Opus 4.6 <[email protected]>
1 parent 417509c commit 2bbe896

File tree

2 files changed

+92
-10
lines changed

2 files changed

+92
-10
lines changed

src/memory/embeddings.test.ts

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,3 +480,75 @@ describe("local embedding normalization", () => {
480480
}
481481
});
482482
});
483+
484+
describe("local embedding ensureContext concurrency", () => {
485+
afterEach(() => {
486+
vi.resetAllMocks();
487+
vi.resetModules();
488+
vi.unstubAllGlobals();
489+
vi.doUnmock("./node-llama.js");
490+
});
491+
492+
it("loads the model only once when embedBatch is called concurrently", async () => {
493+
const getLlamaSpy = vi.fn();
494+
const loadModelSpy = vi.fn();
495+
const createContextSpy = vi.fn();
496+
497+
vi.doMock("./node-llama.js", () => ({
498+
importNodeLlamaCpp: async () => ({
499+
getLlama: async (...args: unknown[]) => {
500+
getLlamaSpy(...args);
501+
// Simulate real async delay so concurrent callers can interleave
502+
await new Promise((r) => setTimeout(r, 50));
503+
return {
504+
loadModel: async (...modelArgs: unknown[]) => {
505+
loadModelSpy(...modelArgs);
506+
await new Promise((r) => setTimeout(r, 50));
507+
return {
508+
createEmbeddingContext: async () => {
509+
createContextSpy();
510+
return {
511+
getEmbeddingFor: vi.fn().mockResolvedValue({
512+
vector: new Float32Array([1, 0, 0, 0]),
513+
}),
514+
};
515+
},
516+
};
517+
},
518+
};
519+
},
520+
resolveModelFile: async () => "/fake/model.gguf",
521+
LlamaLogLevel: { error: 0 },
522+
}),
523+
}));
524+
525+
const { createEmbeddingProvider } = await import("./embeddings.js");
526+
527+
const result = await createEmbeddingProvider({
528+
config: {} as never,
529+
provider: "local",
530+
model: "",
531+
fallback: "none",
532+
});
533+
534+
// Launch 4 concurrent embedBatch calls (simulates EMBEDDING_INDEX_CONCURRENCY = 4)
535+
const results = await Promise.all([
536+
result.provider.embedBatch(["text1"]),
537+
result.provider.embedBatch(["text2"]),
538+
result.provider.embedBatch(["text3"]),
539+
result.provider.embedBatch(["text4"]),
540+
]);
541+
542+
// All calls should return valid embeddings
543+
expect(results).toHaveLength(4);
544+
for (const embeddings of results) {
545+
expect(embeddings).toHaveLength(1);
546+
expect(embeddings[0]).toHaveLength(4);
547+
}
548+
549+
// The model should only be loaded once despite 4 concurrent calls
550+
expect(getLlamaSpy).toHaveBeenCalledTimes(1);
551+
expect(loadModelSpy).toHaveBeenCalledTimes(1);
552+
expect(createContextSpy).toHaveBeenCalledTimes(1);
553+
});
554+
});

src/memory/embeddings.ts

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -91,19 +91,29 @@ async function createLocalEmbeddingProvider(
9191
let llama: Llama | null = null;
9292
let embeddingModel: LlamaModel | null = null;
9393
let embeddingContext: LlamaEmbeddingContext | null = null;
94+
let initPromise: Promise<LlamaEmbeddingContext> | null = null;
9495

95-
const ensureContext = async () => {
96-
if (!llama) {
97-
llama = await getLlama({ logLevel: LlamaLogLevel.error });
96+
const ensureContext = async (): Promise<LlamaEmbeddingContext> => {
97+
if (embeddingContext) {
98+
return embeddingContext;
9899
}
99-
if (!embeddingModel) {
100-
const resolved = await resolveModelFile(modelPath, modelCacheDir || undefined);
101-
embeddingModel = await llama.loadModel({ modelPath: resolved });
100+
if (initPromise) {
101+
return initPromise;
102102
}
103-
if (!embeddingContext) {
104-
embeddingContext = await embeddingModel.createEmbeddingContext();
105-
}
106-
return embeddingContext;
103+
initPromise = (async () => {
104+
if (!llama) {
105+
llama = await getLlama({ logLevel: LlamaLogLevel.error });
106+
}
107+
if (!embeddingModel) {
108+
const resolved = await resolveModelFile(modelPath, modelCacheDir || undefined);
109+
embeddingModel = await llama.loadModel({ modelPath: resolved });
110+
}
111+
if (!embeddingContext) {
112+
embeddingContext = await embeddingModel.createEmbeddingContext();
113+
}
114+
return embeddingContext;
115+
})();
116+
return initPromise;
107117
};
108118

109119
return {

0 commit comments

Comments
 (0)