Skip to content

Commit 3eea7b8

Browse files
Modify _drain_helper() to handle parallel calls without race-condition (#6028)
* Modify _drain_helper() to handle parallel calls of _send_frame() without race-condition. * Update CHANGES/2934.bugfix * Update 2934.bugfix Co-authored-by: Andrew Svetlov <[email protected]>
1 parent 11b46df commit 3eea7b8

File tree

4 files changed

+24
-4
lines changed

4 files changed

+24
-4
lines changed

CHANGES/2934.bugfix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Modify _drain_helper() to handle concurrent `await resp.write(...)` or `ws.send_json(...)` calls without race-condition.

CONTRIBUTORS.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ Navid Sheikhol
230230
Nicolas Braem
231231
Nikolay Kim
232232
Nikolay Novik
233+
Nándor Mátravölgyi
233234
Oisin Aylward
234235
Olaf Conradi
235236
Oleg Höfling

aiohttp/base_protocol.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ async def _drain_helper(self) -> None:
8181
if not self._paused:
8282
return
8383
waiter = self._drain_waiter
84-
assert waiter is None or waiter.cancelled()
85-
waiter = self._loop.create_future()
86-
self._drain_waiter = waiter
87-
await waiter
84+
if waiter is None:
85+
waiter = self._loop.create_future()
86+
self._drain_waiter = waiter
87+
await asyncio.shield(waiter)

tests/test_base_protocol.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,21 @@ async def wait() -> None:
180180
with suppress(asyncio.CancelledError):
181181
await t
182182
assert pr._drain_waiter is None
183+
184+
185+
async def test_parallel_drain_race_condition() -> None:
186+
loop = asyncio.get_event_loop()
187+
pr = BaseProtocol(loop=loop)
188+
tr = mock.Mock()
189+
pr.connection_made(tr)
190+
pr.pause_writing()
191+
192+
ts = [loop.create_task(pr._drain_helper()) for _ in range(5)]
193+
assert not (await asyncio.wait(ts, timeout=0.5))[
194+
0
195+
], "All draining tasks must be pending"
196+
197+
assert pr._drain_waiter is not None
198+
pr.resume_writing()
199+
await asyncio.gather(*ts)
200+
assert pr._drain_waiter is None

0 commit comments

Comments
 (0)