Skip to content

Commit cecf2c0

Browse files
committed
Merge branch 'main' into WSMR/wait_for_state
2 parents b1ba2b0 + 33c5cb2 commit cecf2c0

File tree

5 files changed

+172
-13
lines changed

5 files changed

+172
-13
lines changed

distributed/scheduler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,8 +1065,13 @@ class TaskState:
10651065
#: Cached hash of :attr:`~TaskState.client_key`
10661066
_hash: int
10671067

1068+
# Support for weakrefs to a class with __slots__
1069+
__weakref__: Any = None
10681070
__slots__ = tuple(__annotations__) # type: ignore
10691071

1072+
# Instances not part of slots since class variable
1073+
_instances: ClassVar[weakref.WeakSet[TaskState]] = weakref.WeakSet()
1074+
10701075
def __init__(self, key: str, run_spec: object):
10711076
self.key = key
10721077
self._hash = hash(key)
@@ -1101,6 +1106,7 @@ def __init__(self, key: str, run_spec: object):
11011106
self.metadata = {}
11021107
self.annotations = {}
11031108
self.erred_on = set()
1109+
TaskState._instances.add(self)
11041110

11051111
def __hash__(self) -> int:
11061112
return self._hash

distributed/tests/test_active_memory_manager.py

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,20 @@
99

1010
import pytest
1111

12-
from distributed import Nanny, wait
12+
from distributed import Event, Nanny, Scheduler, Worker, wait
1313
from distributed.active_memory_manager import (
1414
ActiveMemoryManagerExtension,
1515
ActiveMemoryManagerPolicy,
16+
RetireWorker,
1617
)
1718
from distributed.core import Status
18-
from distributed.utils_test import captured_logger, gen_cluster, inc, slowinc
19+
from distributed.utils_test import (
20+
assert_story,
21+
captured_logger,
22+
gen_cluster,
23+
inc,
24+
slowinc,
25+
)
1926

2027
NO_AMM_START = {"distributed.scheduler.active-memory-manager.start": False}
2128

@@ -903,6 +910,101 @@ async def test_RetireWorker_all_recipients_are_paused(c, s, a, b):
903910
assert await c.submit(inc, 1) == 2
904911

905912

913+
@gen_cluster(
914+
client=True,
915+
config={
916+
"distributed.scheduler.active-memory-manager.start": True, # to avoid one-off AMM instance
917+
"distributed.scheduler.active-memory-manager.policies": [],
918+
},
919+
timeout=15,
920+
)
921+
async def test_RetireWorker_new_keys_arrive_after_all_keys_moved_away(
922+
c, s: Scheduler, a: Worker, b: Worker
923+
):
924+
"""
925+
If all keys have been moved off a worker, but then new keys arrive (due to task completion or `gather_dep`)
926+
before the worker has actually closed, make sure we still retire it (instead of hanging forever).
927+
928+
This test is timing-sensitive. If it runs too slowly, it *should* `pytest.skip` itself.
929+
930+
See https://github.com/dask/distributed/issues/6223 for motivation.
931+
"""
932+
ws_a = s.workers[a.address]
933+
ws_b = s.workers[b.address]
934+
event = Event()
935+
936+
# Put 200 keys on the worker, so `_track_retire_worker` will sleep for 0.5s
937+
xs = c.map(lambda x: x, range(200), workers=[a.address])
938+
await wait(xs)
939+
940+
# Put an extra task on the worker, which we will allow to complete once the `xs`
941+
# have been replicated.
942+
extra = c.submit(
943+
lambda: event.wait("2s"),
944+
workers=[a.address],
945+
allow_other_workers=True,
946+
key="extra",
947+
)
948+
949+
while (
950+
extra.key not in a.state.tasks or a.state.tasks[extra.key].state != "executing"
951+
):
952+
await asyncio.sleep(0.01)
953+
954+
t = asyncio.create_task(c.retire_workers([a.address]))
955+
956+
# Wait for all `xs` to be replicated.
957+
while not len(ws_b.has_what) == len(xs):
958+
await asyncio.sleep(0)
959+
960+
# `_track_retire_worker` _should_ now be sleeping for 0.5s, because there were >=200 keys on A.
961+
# In this test, everything from the beginning of the transfers needs to happen within 0.5s.
962+
963+
# Simulate the policy running again. Because the default 2s AMM interval is longer
964+
# than the 0.5s wait, what we're about to trigger is unlikely, but still possible
965+
# for the times to line up. (Especially with a custom AMM interval.)
966+
amm: ActiveMemoryManagerExtension = s.extensions["amm"]
967+
assert len(amm.policies) == 1
968+
policy = next(iter(amm.policies))
969+
assert isinstance(policy, RetireWorker)
970+
971+
amm.run_once()
972+
973+
# The policy has removed itself, because all `xs` have been replicated.
974+
assert not amm.policies
975+
assert policy.done(), {ts.key: ts.who_has for ts in ws_a.has_what}
976+
977+
# But what if a new key arrives now while `_track_retire_worker` is still (maybe)
978+
# sleeping? Let `extra` complete and wait for it to hit the scheduler.
979+
await event.set()
980+
await wait(extra)
981+
982+
if a.address not in s.workers:
983+
# It took more than 0.5s to get here, and the scheduler closed our worker. Dang.
984+
pytest.skip(
985+
"Timing didn't work out: `_track_retire_worker` finished before `extra` completed."
986+
)
987+
988+
# `retire_workers` doesn't hang
989+
await t
990+
assert a.address not in s.workers
991+
assert not amm.policies
992+
993+
# `extra` was not transferred from `a` to `b`. Instead, it was recomputed on `b`.
994+
story = b.state.story(extra.key)
995+
assert_story(
996+
story,
997+
[
998+
(extra.key, "compute-task", "released"),
999+
(extra.key, "released", "waiting", "waiting", {"extra": "ready"}),
1000+
(extra.key, "waiting", "ready", "ready", {"extra": "executing"}),
1001+
],
1002+
)
1003+
1004+
# `extra` completes successfully and is fetched from the other worker.
1005+
await extra.result()
1006+
1007+
9061008
# FIXME can't drop runtime of this test below 10s; see distributed#5585
9071009
@pytest.mark.slow
9081010
@gen_cluster(

distributed/tests/test_worker_state_machine.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import gc
45
from collections.abc import Iterator
56

67
import pytest
8+
from tlz import first
79

8-
from distributed import Worker, wait
10+
import distributed.profile as profile
11+
from distributed import Nanny, Worker, wait
912
from distributed.protocol.serialize import Serialize
13+
from distributed.scheduler import TaskState as SchedulerTaskState
1014
from distributed.utils import recursive_to_dict
1115
from distributed.utils_test import (
1216
BlockedGetData,
@@ -37,6 +41,15 @@
3741
)
3842

3943

44+
def test_TaskState_tracking(cleanup):
45+
gc.collect()
46+
x = TaskState("x")
47+
assert len(TaskState._instances) == 1
48+
assert first(TaskState._instances) == x
49+
del x
50+
assert len(TaskState._instances) == 0
51+
52+
4053
def test_TaskState_get_nbytes():
4154
assert TaskState("x", nbytes=123).get_nbytes() == 123
4255
# Default to distributed.scheduler.default-data-size
@@ -670,6 +683,34 @@ async def test_missing_to_waiting(c, s, w1, w2, w3):
670683
await f1
671684

672685

686+
@gen_cluster(client=True, Worker=Nanny)
687+
async def test_task_state_instance_are_garbage_collected(c, s, a, b):
688+
futs = c.map(inc, range(10))
689+
red = c.submit(sum, futs)
690+
f1 = c.submit(inc, red, pure=False)
691+
f2 = c.submit(inc, red, pure=False)
692+
693+
async def check(dask_worker):
694+
while dask_worker.tasks:
695+
await asyncio.sleep(0.01)
696+
with profile.lock:
697+
gc.collect()
698+
assert not TaskState._instances
699+
700+
await c.gather([f2, f1])
701+
del futs, red, f1, f2
702+
await c.run(check)
703+
704+
async def check(dask_scheduler):
705+
while dask_scheduler.tasks:
706+
await asyncio.sleep(0.01)
707+
with profile.lock:
708+
gc.collect()
709+
assert not SchedulerTaskState._instances
710+
711+
await c.run_on_scheduler(check)
712+
713+
673714
@gen_cluster(client=True, nthreads=[("", 1)] * 3)
674715
async def test_fetch_to_missing_on_refresh_who_has(c, s, w1, w2, w3):
675716
"""

distributed/utils_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
from distributed.node import ServerNode
5959
from distributed.proctitle import enable_proctitle_on_children
6060
from distributed.protocol import deserialize
61+
from distributed.scheduler import TaskState as SchedulerTaskState
6162
from distributed.security import Security
6263
from distributed.utils import (
6364
DequeHandler,
@@ -72,6 +73,7 @@
7273
)
7374
from distributed.worker import WORKER_ANY_RUNNING, Worker
7475
from distributed.worker_state_machine import InvalidTransition
76+
from distributed.worker_state_machine import TaskState as WorkerTaskState
7577

7678
try:
7779
import ssl
@@ -1839,9 +1841,8 @@ def check_instances():
18391841
Scheduler._instances.clear()
18401842
SpecCluster._instances.clear()
18411843
Worker._initialized_clients.clear()
1842-
# assert all(n.status == "closed" for n in Nanny._instances), {
1843-
# n: n.status for n in Nanny._instances
1844-
# }
1844+
SchedulerTaskState._instances.clear()
1845+
WorkerTaskState._instances.clear()
18451846
Nanny._instances.clear()
18461847
_global_clients.clear()
18471848
Comm._instances.clear()

distributed/worker_state_machine.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import operator
99
import random
1010
import sys
11+
import weakref
1112
from collections import defaultdict, deque
1213
from collections.abc import (
1314
Callable,
@@ -262,20 +263,23 @@ class TaskState:
262263
#: True if the task is in memory or erred; False otherwise
263264
done: bool = False
264265

266+
_instances: ClassVar[weakref.WeakSet[TaskState]] = weakref.WeakSet()
267+
265268
# Support for weakrefs to a class with __slots__
266269
__weakref__: Any = field(init=False)
267270

271+
def __post_init__(self):
272+
TaskState._instances.add(self)
273+
268274
def __repr__(self) -> str:
269275
return f"<TaskState {self.key!r} {self.state}>"
270276

271277
def __eq__(self, other: object) -> bool:
272-
if not isinstance(other, TaskState) or other.key != self.key:
273-
return False
274-
# When a task transitions to forgotten and exits Worker.tasks, it should be
275-
# immediately dereferenced. If the same task is recreated later on on the
276-
# worker, we should not have to deal with its previous incarnation lingering.
277-
assert other is self
278-
return True
278+
# A task may be forgotten and a new TaskState object with the same key may be created in
279+
# its place later on. In the Worker state, you should never have multiple TaskState objects with
280+
# the same key. We can't assert it here however, as this comparison is also used in WeakSets
281+
# for instance tracking purposes.
282+
return other is self
279283

280284
def __hash__(self) -> int:
281285
return hash(self.key)
@@ -3002,6 +3006,11 @@ def validate_state(self) -> None:
30023006
if self.transition_counter_max:
30033007
assert self.transition_counter < self.transition_counter_max
30043008

3009+
# Test that there aren't multiple TaskState objects with the same key in data_needed
3010+
assert len({ts.key for ts in self.data_needed}) == len(self.data_needed)
3011+
for tss in self.data_needed_per_worker.values():
3012+
assert len({ts.key for ts in tss}) == len(tss)
3013+
30053014

30063015
class BaseWorker(abc.ABC):
30073016
"""Wrapper around the :class:`WorkerState` that implements instructions handling.

0 commit comments

Comments
 (0)