22
33import os
44import json
5+ import contextlib
56
67from threading import Lock
78from functools import partial
@@ -156,6 +157,7 @@ async def get_event_publisher(
156157 request : Request ,
157158 inner_send_chan : MemoryObjectSendStream ,
158159 iterator : Iterator ,
160+ on_complete = None ,
159161):
160162 async with inner_send_chan :
161163 try :
@@ -175,6 +177,9 @@ async def get_event_publisher(
175177 with anyio .move_on_after (1 , shield = True ):
176178 print (f"Disconnected from client (via refresh/close) { request .client } " )
177179 raise e
180+ finally :
181+ if on_complete :
182+ on_complete ()
178183
179184
180185def _logit_bias_tokens_to_input_ids (
@@ -258,8 +263,11 @@ async def authenticate(
258263async def create_completion (
259264 request : Request ,
260265 body : CreateCompletionRequest ,
261- llama_proxy : LlamaProxy = Depends (get_llama_proxy ),
262266) -> llama_cpp .Completion :
267+ exit_stack = contextlib .ExitStack ()
268+ llama_proxy = await run_in_threadpool (
269+ lambda : exit_stack .enter_context (contextlib .contextmanager (get_llama_proxy )())
270+ )
263271 if isinstance (body .prompt , list ):
264272 assert len (body .prompt ) <= 1
265273 body .prompt = body .prompt [0 ] if len (body .prompt ) > 0 else ""
@@ -312,6 +320,7 @@ async def create_completion(
312320 def iterator () -> Iterator [llama_cpp .CreateCompletionStreamResponse ]:
313321 yield first_response
314322 yield from iterator_or_completion
323+ exit_stack .close ()
315324
316325 send_chan , recv_chan = anyio .create_memory_object_stream (10 )
317326 return EventSourceResponse (
@@ -321,6 +330,7 @@ def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]:
321330 request = request ,
322331 inner_send_chan = send_chan ,
323332 iterator = iterator (),
333+ on_complete = exit_stack .close ,
324334 ),
325335 sep = "\n " ,
326336 ping_message_factory = _ping_message_factory ,
@@ -449,8 +459,15 @@ async def create_chat_completion(
449459 },
450460 }
451461 ),
452- llama_proxy : LlamaProxy = Depends (get_llama_proxy ),
453462) -> llama_cpp .ChatCompletion :
463+ # This is a workaround for an issue in FastAPI dependencies
464+ # where the dependency is cleaned up before a StreamingResponse
465+ # is complete.
466+ # https://github.com/tiangolo/fastapi/issues/11143
467+ exit_stack = contextlib .ExitStack ()
468+ llama_proxy = await run_in_threadpool (
469+ lambda : exit_stack .enter_context (contextlib .contextmanager (get_llama_proxy )())
470+ )
454471 exclude = {
455472 "n" ,
456473 "logit_bias_type" ,
@@ -491,6 +508,7 @@ async def create_chat_completion(
491508 def iterator () -> Iterator [llama_cpp .ChatCompletionChunk ]:
492509 yield first_response
493510 yield from iterator_or_completion
511+ exit_stack .close ()
494512
495513 send_chan , recv_chan = anyio .create_memory_object_stream (10 )
496514 return EventSourceResponse (
@@ -500,11 +518,13 @@ def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
500518 request = request ,
501519 inner_send_chan = send_chan ,
502520 iterator = iterator (),
521+ on_complete = exit_stack .close ,
503522 ),
504523 sep = "\n " ,
505524 ping_message_factory = _ping_message_factory ,
506525 )
507526 else :
527+ exit_stack .close ()
508528 return iterator_or_completion
509529
510530
0 commit comments