@@ -326,23 +326,42 @@ async def test_bad_disk(c, s, a, b):
326326 await assert_scheduler_cleanup (s )
327327
328328
329- async def wait_until_worker_has_tasks (
330- prefix : str , worker : str , count : int , scheduler : Scheduler , interval : float = 0.01
331- ) -> None :
332- ws = scheduler .workers [worker ]
333- while (
334- len (
335- [
336- key
337- for key , ts in scheduler .tasks .items ()
338- if prefix in key_split (key )
339- and ts .state == "memory"
340- and {ws } == ts .who_has
341- ]
342- )
343- < count
344- ):
345- await asyncio .sleep (interval )
329+ from distributed .diagnostics .plugin import SchedulerPlugin
330+
331+
332+ class ObserveTasksPlugin (SchedulerPlugin ):
333+ def __init__ (self , prefixes , count , worker ):
334+ self .prefixes = prefixes
335+ self .count = count
336+ self .worker = worker
337+ self .counter = defaultdict (int )
338+ self .event = asyncio .Event ()
339+
340+ async def start (self , scheduler ):
341+ self .scheduler = scheduler
342+
343+ def transition (self , key , start , finish , * args , ** kwargs ):
344+ if (
345+ finish == "processing"
346+ and key_split (key ) in self .prefixes
347+ and self .scheduler .tasks [key ].processing_on
348+ and self .scheduler .tasks [key ].processing_on .address == self .worker
349+ ):
350+ self .counter [key_split (key )] += 1
351+ if self .counter [key_split (key )] == self .count :
352+ self .event .set ()
353+ return key , start , finish
354+
355+
356+ @contextlib .asynccontextmanager
357+ async def wait_until_worker_has_tasks (prefix , worker , count , scheduler ):
358+ plugin = ObserveTasksPlugin ([prefix ], count , worker )
359+ scheduler .add_plugin (plugin , name = "observe-tasks" )
360+ await plugin .start (scheduler )
361+ try :
362+ yield plugin .event
363+ finally :
364+ scheduler .remove_plugin ("observe-tasks" )
346365
347366
348367async def wait_for_tasks_in_state (
@@ -562,8 +581,12 @@ async def test_get_or_create_from_dangling_transfer(c, s, a, b):
562581@pytest .mark .slow
563582@gen_cluster (client = True , nthreads = [("" , 1 )])
564583async def test_crashed_worker_during_transfer (c , s , a ):
565- async with Nanny (s .address , nthreads = 1 ) as n :
566- killed_worker_address = n .worker_address
584+ async with (
585+ Nanny (s .address , nthreads = 1 ) as n ,
586+ wait_until_worker_has_tasks (
587+ "shuffle-transfer" , n .worker_address , 1 , s
588+ ) as event ,
589+ ):
567590 df = dask .datasets .timeseries (
568591 start = "2000-01-01" ,
569592 end = "2000-03-01" ,
@@ -573,9 +596,7 @@ async def test_crashed_worker_during_transfer(c, s, a):
573596 with dask .config .set ({"dataframe.shuffle.method" : "p2p" }):
574597 shuffled = df .shuffle ("x" )
575598 fut = c .compute ([shuffled , df ], sync = True )
576- await wait_until_worker_has_tasks (
577- "shuffle-transfer" , killed_worker_address , 1 , s
578- )
599+ await event .wait ()
579600 await n .process .process .kill ()
580601
581602 result , expected = await fut
@@ -605,20 +626,16 @@ async def test_restarting_does_not_deadlock(c, s):
605626 )
606627 df = await c .persist (df )
607628 expected = await c .compute (df )
608-
609- async with Nanny (s .address ) as b :
629+ async with Worker (s .address ) as b :
610630 with dask .config .set ({"dataframe.shuffle.method" : "p2p" }):
611631 out = df .shuffle ("x" )
612632 assert not s .workers [b .worker_address ].has_what
613633 result = c .compute (out )
614- await wait_until_worker_has_tasks (
615- "shuffle-transfer" , b .worker_address , 1 , s
616- )
634+ while not s .extensions ["shuffle" ].active_shuffles :
635+ await asyncio .sleep (0 )
617636 a .status = Status .paused
618637 await async_poll_for (lambda : len (s .running ) == 1 , timeout = 5 )
619- b .close_gracefully ()
620- await b .process .process .kill ()
621-
638+ b .batched_stream .close ()
622639 await async_poll_for (lambda : not s .running , timeout = 5 )
623640
624641 a .status = Status .running
@@ -672,8 +689,12 @@ def mock_mock_get_worker_for_range_sharding(
672689 "distributed.shuffle._shuffle._get_worker_for_range_sharding" ,
673690 mock_mock_get_worker_for_range_sharding ,
674691 ):
675- async with Nanny (s .address , nthreads = 1 ) as n :
676- killed_worker_address = n .worker_address
692+ async with (
693+ Nanny (s .address , nthreads = 1 ) as n ,
694+ wait_until_worker_has_tasks (
695+ "shuffle-transfer" , n .worker_address , 1 , s
696+ ) as event ,
697+ ):
677698 df = dask .datasets .timeseries (
678699 start = "2000-01-01" ,
679700 end = "2000-03-01" ,
@@ -683,9 +704,7 @@ def mock_mock_get_worker_for_range_sharding(
683704 with dask .config .set ({"dataframe.shuffle.method" : "p2p" }):
684705 shuffled = df .shuffle ("x" )
685706 fut = c .compute ([shuffled , df ], sync = True )
686- await wait_until_worker_has_tasks (
687- "shuffle-transfer" , n .worker_address , 1 , s
688- )
707+ await event .wait ()
689708 await n .process .process .kill ()
690709
691710 result , expected = await fut
@@ -1033,8 +1052,10 @@ async def test_restarting_during_unpack_raises_killed_worker(c, s, a, b):
10331052@pytest .mark .slow
10341053@gen_cluster (client = True , nthreads = [("" , 1 )])
10351054async def test_crashed_worker_during_unpack (c , s , a ):
1036- async with Nanny (s .address , nthreads = 2 ) as n :
1037- killed_worker_address = n .worker_address
1055+ async with (
1056+ Nanny (s .address , nthreads = 2 ) as n ,
1057+ wait_until_worker_has_tasks (UNPACK_PREFIX , n .worker_address , 1 , s ) as event ,
1058+ ):
10381059 df = dask .datasets .timeseries (
10391060 start = "2000-01-01" ,
10401061 end = "2000-03-01" ,
@@ -1046,7 +1067,7 @@ async def test_crashed_worker_during_unpack(c, s, a):
10461067 shuffled = df .shuffle ("x" )
10471068 result = c .compute (shuffled )
10481069
1049- await wait_until_worker_has_tasks ( UNPACK_PREFIX , killed_worker_address , 1 , s )
1070+ await event . wait ( )
10501071 await n .process .process .kill ()
10511072
10521073 result = await result
@@ -1486,7 +1507,10 @@ def block(df, in_event, block_event):
14861507 block_event .wait ()
14871508 return df
14881509
1489- async with Nanny (s .address , nthreads = 1 ) as n :
1510+ async with (
1511+ Nanny (s .address , nthreads = 1 ) as n ,
1512+ wait_until_worker_has_tasks (UNPACK_PREFIX , n .worker_address , 1 , s ) as event ,
1513+ ):
14901514 df = dask .datasets .timeseries (
14911515 start = "2000-01-01" ,
14921516 end = "2000-03-01" ,
@@ -1507,7 +1531,7 @@ def block(df, in_event, block_event):
15071531 allow_other_workers = True ,
15081532 )
15091533
1510- await wait_until_worker_has_tasks ( UNPACK_PREFIX , n . worker_address , 1 , s )
1534+ await event . wait ( )
15111535 await in_event .wait ()
15121536 await n .process .process .kill ()
15131537 await block_event .set ()
@@ -1524,7 +1548,10 @@ def block(df, in_event, block_event):
15241548
15251549@gen_cluster (client = True , nthreads = [("" , 1 )])
15261550async def test_crashed_worker_after_shuffle_persisted (c , s , a ):
1527- async with Nanny (s .address , nthreads = 1 ) as n :
1551+ async with (
1552+ Nanny (s .address , nthreads = 1 ) as n ,
1553+ wait_until_worker_has_tasks (UNPACK_PREFIX , n .worker_address , 1 , s ) as event ,
1554+ ):
15281555 df = df = dask .datasets .timeseries (
15291556 start = "2000-01-01" ,
15301557 end = "2000-01-10" ,
@@ -1536,7 +1563,7 @@ async def test_crashed_worker_after_shuffle_persisted(c, s, a):
15361563 out = df .shuffle ("x" )
15371564 out = out .persist ()
15381565
1539- await wait_until_worker_has_tasks ( UNPACK_PREFIX , n . worker_address , 1 , s )
1566+ await event . wait ( )
15401567 await out
15411568
15421569 await n .process .process .kill ()
0 commit comments