Skip to content

Commit 2953090

Browse files
authored
Fix test_restarting_does_not_deadlock (#8849)
1 parent 09ed8af commit 2953090

File tree

1 file changed

+69
-42
lines changed

1 file changed

+69
-42
lines changed

distributed/shuffle/tests/test_shuffle.py

Lines changed: 69 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -326,23 +326,42 @@ async def test_bad_disk(c, s, a, b):
326326
await assert_scheduler_cleanup(s)
327327

328328

329-
async def wait_until_worker_has_tasks(
330-
prefix: str, worker: str, count: int, scheduler: Scheduler, interval: float = 0.01
331-
) -> None:
332-
ws = scheduler.workers[worker]
333-
while (
334-
len(
335-
[
336-
key
337-
for key, ts in scheduler.tasks.items()
338-
if prefix in key_split(key)
339-
and ts.state == "memory"
340-
and {ws} == ts.who_has
341-
]
342-
)
343-
< count
344-
):
345-
await asyncio.sleep(interval)
329+
from distributed.diagnostics.plugin import SchedulerPlugin
330+
331+
332+
class ObserveTasksPlugin(SchedulerPlugin):
333+
def __init__(self, prefixes, count, worker):
334+
self.prefixes = prefixes
335+
self.count = count
336+
self.worker = worker
337+
self.counter = defaultdict(int)
338+
self.event = asyncio.Event()
339+
340+
async def start(self, scheduler):
341+
self.scheduler = scheduler
342+
343+
def transition(self, key, start, finish, *args, **kwargs):
344+
if (
345+
finish == "processing"
346+
and key_split(key) in self.prefixes
347+
and self.scheduler.tasks[key].processing_on
348+
and self.scheduler.tasks[key].processing_on.address == self.worker
349+
):
350+
self.counter[key_split(key)] += 1
351+
if self.counter[key_split(key)] == self.count:
352+
self.event.set()
353+
return key, start, finish
354+
355+
356+
@contextlib.asynccontextmanager
357+
async def wait_until_worker_has_tasks(prefix, worker, count, scheduler):
358+
plugin = ObserveTasksPlugin([prefix], count, worker)
359+
scheduler.add_plugin(plugin, name="observe-tasks")
360+
await plugin.start(scheduler)
361+
try:
362+
yield plugin.event
363+
finally:
364+
scheduler.remove_plugin("observe-tasks")
346365

347366

348367
async def wait_for_tasks_in_state(
@@ -562,8 +581,12 @@ async def test_get_or_create_from_dangling_transfer(c, s, a, b):
562581
@pytest.mark.slow
563582
@gen_cluster(client=True, nthreads=[("", 1)])
564583
async def test_crashed_worker_during_transfer(c, s, a):
565-
async with Nanny(s.address, nthreads=1) as n:
566-
killed_worker_address = n.worker_address
584+
async with (
585+
Nanny(s.address, nthreads=1) as n,
586+
wait_until_worker_has_tasks(
587+
"shuffle-transfer", n.worker_address, 1, s
588+
) as event,
589+
):
567590
df = dask.datasets.timeseries(
568591
start="2000-01-01",
569592
end="2000-03-01",
@@ -573,9 +596,7 @@ async def test_crashed_worker_during_transfer(c, s, a):
573596
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
574597
shuffled = df.shuffle("x")
575598
fut = c.compute([shuffled, df], sync=True)
576-
await wait_until_worker_has_tasks(
577-
"shuffle-transfer", killed_worker_address, 1, s
578-
)
599+
await event.wait()
579600
await n.process.process.kill()
580601

581602
result, expected = await fut
@@ -605,20 +626,16 @@ async def test_restarting_does_not_deadlock(c, s):
605626
)
606627
df = await c.persist(df)
607628
expected = await c.compute(df)
608-
609-
async with Nanny(s.address) as b:
629+
async with Worker(s.address) as b:
610630
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
611631
out = df.shuffle("x")
612632
assert not s.workers[b.worker_address].has_what
613633
result = c.compute(out)
614-
await wait_until_worker_has_tasks(
615-
"shuffle-transfer", b.worker_address, 1, s
616-
)
634+
while not s.extensions["shuffle"].active_shuffles:
635+
await asyncio.sleep(0)
617636
a.status = Status.paused
618637
await async_poll_for(lambda: len(s.running) == 1, timeout=5)
619-
b.close_gracefully()
620-
await b.process.process.kill()
621-
638+
b.batched_stream.close()
622639
await async_poll_for(lambda: not s.running, timeout=5)
623640

624641
a.status = Status.running
@@ -672,8 +689,12 @@ def mock_mock_get_worker_for_range_sharding(
672689
"distributed.shuffle._shuffle._get_worker_for_range_sharding",
673690
mock_mock_get_worker_for_range_sharding,
674691
):
675-
async with Nanny(s.address, nthreads=1) as n:
676-
killed_worker_address = n.worker_address
692+
async with (
693+
Nanny(s.address, nthreads=1) as n,
694+
wait_until_worker_has_tasks(
695+
"shuffle-transfer", n.worker_address, 1, s
696+
) as event,
697+
):
677698
df = dask.datasets.timeseries(
678699
start="2000-01-01",
679700
end="2000-03-01",
@@ -683,9 +704,7 @@ def mock_mock_get_worker_for_range_sharding(
683704
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
684705
shuffled = df.shuffle("x")
685706
fut = c.compute([shuffled, df], sync=True)
686-
await wait_until_worker_has_tasks(
687-
"shuffle-transfer", n.worker_address, 1, s
688-
)
707+
await event.wait()
689708
await n.process.process.kill()
690709

691710
result, expected = await fut
@@ -1033,8 +1052,10 @@ async def test_restarting_during_unpack_raises_killed_worker(c, s, a, b):
10331052
@pytest.mark.slow
10341053
@gen_cluster(client=True, nthreads=[("", 1)])
10351054
async def test_crashed_worker_during_unpack(c, s, a):
1036-
async with Nanny(s.address, nthreads=2) as n:
1037-
killed_worker_address = n.worker_address
1055+
async with (
1056+
Nanny(s.address, nthreads=2) as n,
1057+
wait_until_worker_has_tasks(UNPACK_PREFIX, n.worker_address, 1, s) as event,
1058+
):
10381059
df = dask.datasets.timeseries(
10391060
start="2000-01-01",
10401061
end="2000-03-01",
@@ -1046,7 +1067,7 @@ async def test_crashed_worker_during_unpack(c, s, a):
10461067
shuffled = df.shuffle("x")
10471068
result = c.compute(shuffled)
10481069

1049-
await wait_until_worker_has_tasks(UNPACK_PREFIX, killed_worker_address, 1, s)
1070+
await event.wait()
10501071
await n.process.process.kill()
10511072

10521073
result = await result
@@ -1486,7 +1507,10 @@ def block(df, in_event, block_event):
14861507
block_event.wait()
14871508
return df
14881509

1489-
async with Nanny(s.address, nthreads=1) as n:
1510+
async with (
1511+
Nanny(s.address, nthreads=1) as n,
1512+
wait_until_worker_has_tasks(UNPACK_PREFIX, n.worker_address, 1, s) as event,
1513+
):
14901514
df = dask.datasets.timeseries(
14911515
start="2000-01-01",
14921516
end="2000-03-01",
@@ -1507,7 +1531,7 @@ def block(df, in_event, block_event):
15071531
allow_other_workers=True,
15081532
)
15091533

1510-
await wait_until_worker_has_tasks(UNPACK_PREFIX, n.worker_address, 1, s)
1534+
await event.wait()
15111535
await in_event.wait()
15121536
await n.process.process.kill()
15131537
await block_event.set()
@@ -1524,7 +1548,10 @@ def block(df, in_event, block_event):
15241548

15251549
@gen_cluster(client=True, nthreads=[("", 1)])
15261550
async def test_crashed_worker_after_shuffle_persisted(c, s, a):
1527-
async with Nanny(s.address, nthreads=1) as n:
1551+
async with (
1552+
Nanny(s.address, nthreads=1) as n,
1553+
wait_until_worker_has_tasks(UNPACK_PREFIX, n.worker_address, 1, s) as event,
1554+
):
15281555
df = df = dask.datasets.timeseries(
15291556
start="2000-01-01",
15301557
end="2000-01-10",
@@ -1536,7 +1563,7 @@ async def test_crashed_worker_after_shuffle_persisted(c, s, a):
15361563
out = df.shuffle("x")
15371564
out = out.persist()
15381565

1539-
await wait_until_worker_has_tasks(UNPACK_PREFIX, n.worker_address, 1, s)
1566+
await event.wait()
15401567
await out
15411568

15421569
await n.process.process.kill()

0 commit comments

Comments
 (0)