Skip to content

Commit 37d9fe6

Browse files
authored
Refactor connection waiters to be cancellation safe (#9671)
1 parent 5b654d5 commit 37d9fe6

File tree

4 files changed

+331
-151
lines changed

4 files changed

+331
-151
lines changed

CHANGES/9670.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
9671.bugfix.rst

CHANGES/9671.bugfix.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Fixed a deadlock that could occur while attempting to get a new connection slot after a timeout -- by :user:`bdraco`.
2+
3+
The connector was not cancellation-safe.

aiohttp/connector.py

Lines changed: 99 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import sys
77
import traceback
88
import warnings
9-
from collections import defaultdict, deque
9+
from collections import OrderedDict, defaultdict
1010
from contextlib import suppress
1111
from http import HTTPStatus
1212
from itertools import chain, cycle, islice
@@ -252,9 +252,11 @@ def __init__(
252252
self._force_close = force_close
253253

254254
# {host_key: FIFO list of waiters}
255-
self._waiters: DefaultDict[ConnectionKey, deque[asyncio.Future[None]]] = (
256-
defaultdict(deque)
257-
)
255+
# The FIFO is implemented with an OrderedDict with None keys because
256+
# python does not have an ordered set.
257+
self._waiters: DefaultDict[
258+
ConnectionKey, OrderedDict[asyncio.Future[None], None]
259+
] = defaultdict(OrderedDict)
258260

259261
self._loop = loop
260262
self._factory = functools.partial(ResponseHandler, loop=loop)
@@ -335,7 +337,7 @@ def _cleanup(self) -> None:
335337
# recreate it ever!
336338
self._cleanup_handle = None
337339

338-
now = self._loop.time()
340+
now = monotonic()
339341
timeout = self._keepalive_timeout
340342

341343
if self._conns:
@@ -366,14 +368,6 @@ def _cleanup(self) -> None:
366368
timeout_ceil_threshold=self._timeout_ceil_threshold,
367369
)
368370

369-
def _drop_acquired_per_host(
370-
self, key: "ConnectionKey", val: ResponseHandler
371-
) -> None:
372-
if conns := self._acquired_per_host.get(key):
373-
conns.remove(val)
374-
if not conns:
375-
del self._acquired_per_host[key]
376-
377371
def _cleanup_closed(self) -> None:
378372
"""Double confirmation for transport close.
379373
@@ -446,6 +440,9 @@ def _close_immediately(self) -> List[Awaitable[object]]:
446440
finally:
447441
self._conns.clear()
448442
self._acquired.clear()
443+
for keyed_waiters in self._waiters.values():
444+
for keyed_waiter in keyed_waiters:
445+
keyed_waiter.cancel()
449446
self._waiters.clear()
450447
self._cleanup_handle = None
451448
self._cleanup_closed_transports.clear()
@@ -489,117 +486,109 @@ async def connect(
489486
) -> Connection:
490487
"""Get from pool or create new connection."""
491488
key = req.connection_key
492-
available = self._available_connections(key)
493-
wait_for_conn = available <= 0 or key in self._waiters
494-
if not wait_for_conn and (proto := self._get(key)) is not None:
489+
if (conn := await self._get(key, traces)) is not None:
495490
# If we do not have to wait and we can get a connection from the pool
496491
# we can avoid the timeout ceil logic and directly return the connection
497-
return await self._reused_connection(key, proto, traces)
492+
return conn
498493

499494
async with ceil_timeout(timeout.connect, timeout.ceil_threshold):
500-
# Wait if there are no available connections or if there are/were
501-
# waiters (i.e. don't steal connection from a waiter about to wake up)
502-
if wait_for_conn:
495+
if self._available_connections(key) <= 0:
503496
await self._wait_for_available_connection(key, traces)
504-
if (proto := self._get(key)) is not None:
505-
return await self._reused_connection(key, proto, traces)
497+
if (conn := await self._get(key, traces)) is not None:
498+
return conn
506499

507500
placeholder = cast(
508501
ResponseHandler, _TransportPlaceholder(self._placeholder_future)
509502
)
510503
self._acquired.add(placeholder)
511504
self._acquired_per_host[key].add(placeholder)
512505

513-
if traces:
514-
for trace in traces:
515-
await trace.send_connection_create_start()
516-
517506
try:
507+
# Traces are done inside the try block to ensure that the
508+
# that the placeholder is still cleaned up if an exception
509+
# is raised.
510+
if traces:
511+
for trace in traces:
512+
await trace.send_connection_create_start()
518513
proto = await self._create_connection(req, traces, timeout)
519-
if self._closed:
520-
proto.close()
521-
raise ClientConnectionError("Connector is closed.")
514+
if traces:
515+
for trace in traces:
516+
await trace.send_connection_create_end()
522517
except BaseException:
523-
if not self._closed:
524-
self._acquired.remove(placeholder)
525-
self._drop_acquired_per_host(key, placeholder)
526-
self._release_waiter()
518+
self._release_acquired(key, placeholder)
527519
raise
528520
else:
529-
if not self._closed:
530-
self._acquired.remove(placeholder)
531-
self._drop_acquired_per_host(key, placeholder)
532-
533-
if traces:
534-
for trace in traces:
535-
await trace.send_connection_create_end()
536-
537-
return self._acquired_connection(proto, key)
538-
539-
async def _reused_connection(
540-
self, key: "ConnectionKey", proto: ResponseHandler, traces: List["Trace"]
541-
) -> Connection:
542-
if traces:
543-
# Acquire the connection to prevent race conditions with limits
544-
placeholder = cast(
545-
ResponseHandler, _TransportPlaceholder(self._placeholder_future)
546-
)
547-
self._acquired.add(placeholder)
548-
self._acquired_per_host[key].add(placeholder)
549-
for trace in traces:
550-
await trace.send_connection_reuseconn()
551-
self._acquired.remove(placeholder)
552-
self._drop_acquired_per_host(key, placeholder)
553-
return self._acquired_connection(proto, key)
521+
if self._closed:
522+
proto.close()
523+
raise ClientConnectionError("Connector is closed.")
554524

555-
def _acquired_connection(
556-
self, proto: ResponseHandler, key: "ConnectionKey"
557-
) -> Connection:
558-
"""Mark proto as acquired and wrap it in a Connection object."""
525+
# The connection was successfully created, drop the placeholder
526+
# and add the real connection to the acquired set. There should
527+
# be no awaits after the proto is added to the acquired set
528+
# to ensure that the connection is not left in the acquired set
529+
# on cancellation.
530+
acquired_per_host = self._acquired_per_host[key]
531+
self._acquired.remove(placeholder)
532+
acquired_per_host.remove(placeholder)
559533
self._acquired.add(proto)
560-
self._acquired_per_host[key].add(proto)
534+
acquired_per_host.add(proto)
561535
return Connection(self, key, proto, self._loop)
562536

563537
async def _wait_for_available_connection(
564538
self, key: "ConnectionKey", traces: List["Trace"]
565539
) -> None:
566-
"""Wait until there is an available connection."""
567-
fut: asyncio.Future[None] = self._loop.create_future()
568-
569-
# This connection will now count towards the limit.
570-
self._waiters[key].append(fut)
540+
"""Wait for an available connection slot."""
541+
# We loop here because there is a race between
542+
# the connection limit check and the connection
543+
# being acquired. If the connection is acquired
544+
# between the check and the await statement, we
545+
# need to loop again to check if the connection
546+
# slot is still available.
547+
attempts = 0
548+
while True:
549+
fut: asyncio.Future[None] = self._loop.create_future()
550+
keyed_waiters = self._waiters[key]
551+
keyed_waiters[fut] = None
552+
if attempts:
553+
# If we have waited before, we need to move the waiter
554+
# to the front of the queue as otherwise we might get
555+
# starved and hit the timeout.
556+
keyed_waiters.move_to_end(fut, last=False)
571557

572-
if traces:
573-
for trace in traces:
574-
await trace.send_connection_queued_start()
558+
try:
559+
# Traces happen in the try block to ensure that the
560+
# the waiter is still cleaned up if an exception is raised.
561+
if traces:
562+
for trace in traces:
563+
await trace.send_connection_queued_start()
564+
await fut
565+
if traces:
566+
for trace in traces:
567+
await trace.send_connection_queued_end()
568+
finally:
569+
# pop the waiter from the queue if its still
570+
# there and not already removed by _release_waiter
571+
keyed_waiters.pop(fut, None)
572+
if not self._waiters.get(key, True):
573+
del self._waiters[key]
575574

576-
try:
577-
await fut
578-
except BaseException as e:
579-
if key in self._waiters:
580-
# remove a waiter even if it was cancelled, normally it's
581-
# removed when it's notified
582-
with suppress(ValueError):
583-
# fut may no longer be in list
584-
self._waiters[key].remove(fut)
585-
586-
raise e
587-
finally:
588-
if key in self._waiters and not self._waiters[key]:
589-
del self._waiters[key]
575+
if self._available_connections(key) > 0:
576+
break
577+
attempts += 1
590578

591-
if traces:
592-
for trace in traces:
593-
await trace.send_connection_queued_end()
579+
async def _get(
580+
self, key: "ConnectionKey", traces: List["Trace"]
581+
) -> Optional[Connection]:
582+
"""Get next reusable connection for the key or None.
594583
595-
def _get(self, key: "ConnectionKey") -> Optional[ResponseHandler]:
596-
"""Get next reusable connection for the key or None."""
584+
The connection will be marked as acquired.
585+
"""
597586
try:
598587
conns = self._conns[key]
599588
except KeyError:
600589
return None
601590

602-
t1 = self._loop.time()
591+
t1 = monotonic()
603592
while conns:
604593
proto, t0 = conns.pop()
605594
# We will we reuse the connection if its connected and
@@ -608,7 +597,16 @@ def _get(self, key: "ConnectionKey") -> Optional[ResponseHandler]:
608597
if not conns:
609598
# The very last connection was reclaimed: drop the key
610599
del self._conns[key]
611-
return proto
600+
self._acquired.add(proto)
601+
self._acquired_per_host[key].add(proto)
602+
if traces:
603+
for trace in traces:
604+
try:
605+
await trace.send_connection_reuseconn()
606+
except BaseException:
607+
self._release_acquired(key, proto)
608+
raise
609+
return Connection(self, key, proto, self._loop)
612610

613611
# Connection cannot be reused, close it
614612
transport = proto.transport
@@ -642,25 +640,23 @@ def _release_waiter(self) -> None:
642640

643641
waiters = self._waiters[key]
644642
while waiters:
645-
waiter = waiters.popleft()
643+
waiter, _ = waiters.popitem(last=False)
646644
if not waiter.done():
647645
waiter.set_result(None)
648646
return
649647

650648
def _release_acquired(self, key: "ConnectionKey", proto: ResponseHandler) -> None:
649+
"""Release acquired connection."""
651650
if self._closed:
652651
# acquired connection is already released on connector closing
653652
return
654653

655-
try:
656-
self._acquired.remove(proto)
657-
self._drop_acquired_per_host(key, proto)
658-
except KeyError: # pragma: no cover
659-
# this may be result of undetermenistic order of objects
660-
# finalization due garbage collection.
661-
pass
662-
else:
663-
self._release_waiter()
654+
self._acquired.discard(proto)
655+
if conns := self._acquired_per_host.get(key):
656+
conns.discard(proto)
657+
if not conns:
658+
del self._acquired_per_host[key]
659+
self._release_waiter()
664660

665661
def _release(
666662
self,
@@ -692,7 +688,7 @@ def _release(
692688
conns = self._conns.get(key)
693689
if conns is None:
694690
conns = self._conns[key] = []
695-
conns.append((protocol, self._loop.time()))
691+
conns.append((protocol, monotonic()))
696692

697693
if self._cleanup_handle is None:
698694
self._cleanup_handle = helpers.weakref_handle(

0 commit comments

Comments
 (0)