Skip to content

Commit a8ac78d

Browse files
committed
ManualEvictProto (missing unit tests+docs)
1 parent 9e6fdfb commit a8ac78d

File tree

2 files changed

+48
-10
lines changed

2 files changed

+48
-10
lines changed

distributed/spill.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
import logging
44
import time
5-
from collections.abc import Mapping
5+
from collections.abc import Mapping, Sized
66
from contextlib import contextmanager
77
from functools import partial
8-
from typing import Any, Literal, NamedTuple
8+
from typing import Any, Literal, NamedTuple, Protocol
99

1010
from packaging.version import parse as parse_version
1111

@@ -33,6 +33,36 @@ def __sub__(self, other: SpilledSize) -> SpilledSize: # type: ignore
3333
return SpilledSize(self.memory - other.memory, self.disk - other.disk)
3434

3535

36+
class ManualEvictProto(Protocol):
37+
"""Duck-type API that a third-party alternative to SpillBuffer must respect (in
38+
addition to MutableMapping) if it wishes to support spilling when the
39+
``distributed.worker.memory.spill`` threshold is surpassed.
40+
41+
This is public API. At the moment of writing, Dask-CUDA implements this protocol in
42+
the ProxifyHostFile class.
43+
"""
44+
45+
@property
46+
def fast(self) -> Sized | bool:
47+
"""Access to fast memory. This is normally a MutableMapping, but for the purpose
48+
of the manual eviction API it is just tested for emptiness to know if there is
49+
anything to evict.
50+
"""
51+
... # pragma: nocover
52+
53+
def evict(self) -> int:
54+
"""Manually evict a key/value pair from fast to slow memory.
55+
Return size of the evicted value in fast memory.
56+
57+
If the eviction failed for whatever reason, return -1. This method must
58+
guarantee that the key/value pair that caused the issue has been retained in
59+
fast memory and that the problem has been logged internally.
60+
61+
This method never raises.
62+
"""
63+
... # pragma: nocover
64+
65+
3666
class SpillBuffer(zict.Buffer):
3767
"""MutableMapping that automatically spills out dask key/value pairs to disk when
3868
the total size of the stored data exceeds the target. If max_spill is provided the
@@ -163,11 +193,14 @@ def __setitem__(self, key: str, value: Any) -> None:
163193
assert key not in self.slow
164194

165195
def evict(self) -> int:
166-
"""Manually evict the oldest key/value pair, even if target has not been reached.
167-
Returns sizeof(value).
196+
"""Implementation of :meth:`ManualEvictProto.evict`.
197+
198+
Manually evict the oldest key/value pair, even if target has not been
199+
reached. Returns sizeof(value).
168200
If the eviction failed (value failed to pickle, disk full, or max_spill
169201
exceeded), return -1; the key/value pair that caused the issue will remain in
170-
fast. This method never raises.
202+
fast. The exception has been logged internally.
203+
This method never raises.
171204
"""
172205
try:
173206
with self.handle_errors(None):

distributed/worker_memory.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from collections.abc import Callable, MutableMapping
2929
from contextlib import suppress
3030
from functools import partial
31-
from typing import TYPE_CHECKING, Any, Container, Literal
31+
from typing import TYPE_CHECKING, Any, Container, Literal, cast
3232

3333
import psutil
3434
from tornado.ioloop import PeriodicCallback
@@ -39,7 +39,7 @@
3939

4040
from . import system
4141
from .core import Status
42-
from .spill import SpillBuffer
42+
from .spill import ManualEvictProto, SpillBuffer
4343
from .utils import log_errors
4444
from .utils_perf import ThrottledGC
4545

@@ -208,8 +208,12 @@ def _maybe_pause_or_unpause(self, worker: Worker, memory: int) -> None:
208208
async def _maybe_spill(self, worker: Worker, memory: int) -> None:
209209
if self.memory_spill_fraction is False:
210210
return
211-
if not isinstance(self.data, SpillBuffer):
211+
212+
# SpillBuffer or a duct-type compatible MutableMapping which offers the
213+
# fast property and evict() methods. Dask-CUDA uses this.
214+
if not hasattr(self.data, "fast") or not hasattr(self.data, "evict"):
212215
return
216+
data = cast(ManualEvictProto, self.data)
213217

214218
assert self.memory_limit
215219
frac = memory / self.memory_limit
@@ -231,7 +235,7 @@ async def _maybe_spill(self, worker: Worker, memory: int) -> None:
231235
count = 0
232236
need = memory - target
233237
while memory > target:
234-
if not self.data.fast:
238+
if not data.fast:
235239
logger.warning(
236240
"Unmanaged memory use is high. This may indicate a memory leak "
237241
"or the memory may not be released to the OS; see "
@@ -242,7 +246,8 @@ async def _maybe_spill(self, worker: Worker, memory: int) -> None:
242246
format_bytes(self.memory_limit),
243247
)
244248
break
245-
weight = self.data.evict()
249+
250+
weight = data.evict()
246251
if weight == -1:
247252
# Failed to evict:
248253
# disk full, spill size limit exceeded, or pickle error

0 commit comments

Comments
 (0)