Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 41 additions & 38 deletions distributed/batched.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import asyncio
import logging
from collections import deque

from tornado import gen, locks
from tornado.ioloop import IOLoop

import dask
from dask.utils import parse_timedelta

Expand Down Expand Up @@ -36,17 +34,13 @@ class BatchedSend:
['Hello,', 'world!']
"""

# XXX why doesn't BatchedSend follow either the IOStream or Comm API?

def __init__(self, interval, loop=None, serializers=None):
# XXX is the loop arg useful?
self.loop = loop or IOLoop.current()
def __init__(self, interval, serializers=None, name=None):
self.interval = parse_timedelta(interval, default="ms")
self.waker = locks.Event()
self.stopped = locks.Event()
self.waker = asyncio.Event()
self.please_stop = False
self.buffer = []
self.comm = None
self.name = name
self.message_count = 0
self.batch_count = 0
self.byte_count = 0
Expand All @@ -55,11 +49,20 @@ def __init__(self, interval, loop=None, serializers=None):
maxlen=dask.config.get("distributed.comm.recent-messages-log-length")
)
self.serializers = serializers
self._consecutive_failures = 0
self._background_task = None

def start(self, comm):
if self._background_task and not self._background_task.done():
raise RuntimeError("Background task still running")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise RuntimeError("Background task still running")
raise RuntimeError("Background task still running for {self!r}")

self.please_stop = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the BatchedSend isn't fully stopped, it's possible for restarting to create a second _background_send coroutine racing against the first. I would imagine this is a problem, not certain though.

  • If you call BatchedSend.abort() and then BatchedSend.start() without a (long enough) await in between
  • If you call BatchedSend.start() and the BatchedSend has just aborted, but the _background_send coroutine is currently waiting on self.waker.wait or self.comm.write.
  • If you just (accidentally) call BatchedSend.start() multiple times in a row when it hasn't been stopped already.

Again, this is the sort of thing I'd like to be checking for and validating against. Even if we don't currently have code that can trigger this error condition, it's all too easy for future maintainers to use the API in ways we don't currently expect, and then subtly break things in ways that are hard to debug and cause significant pain for users (just like this issue was). See #5481 (comment).

I don't think we can validate this properly without having a handle on the _background_send coroutine to check whether it's still running. We also need to decide what the desired behavior is if you call start() while the coroutine is still running and BatchedSend isn't closed yet, or if it's currently trying to close. Do we just error? Or do we need start() to be idempotent as long as it's passed the same comm each time?

self.waker.set()
self.next_deadline = None
self.comm = comm
self.loop.add_callback(self._background_send)

self._background_task = asyncio.create_task(
self._background_send(),
name=f"background-send-{self.name}",
)

def closed(self):
return self.comm and self.comm.closed()
Expand All @@ -72,13 +75,15 @@ def __repr__(self):

__str__ = __repr__
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update __repr__ to include self.name?


@gen.coroutine
def _background_send(self):
async def _background_send(self):
while not self.please_stop:
try:
yield self.waker.wait(self.next_deadline)
timeout = None
if self.next_deadline:
timeout = self.next_deadline - time()
await asyncio.wait_for(self.waker.wait(), timeout=timeout)
self.waker.clear()
except gen.TimeoutError:
except asyncio.TimeoutError:
pass
if not self.buffer:
# Nothing to send
Expand All @@ -90,8 +95,9 @@ def _background_send(self):
payload, self.buffer = self.buffer, []
self.batch_count += 1
self.next_deadline = time() + self.interval

try:
nbytes = yield self.comm.write(
nbytes = await self.comm.write(
payload, serializers=self.serializers, on_error="raise"
)
if nbytes < 1e6:
Expand All @@ -100,7 +106,12 @@ def _background_send(self):
self.recent_message_log.append("large-message")
self.byte_count += nbytes
except CommClosedError:
logger.info("Batched Comm Closed %r", self.comm, exc_info=True)
logger.info(
"Batched Comm Closed %r. Lost %s messages.",
self.comm,
len(payload),
exc_info=True,
)
break
except Exception:
# We cannot safely retry self.comm.write, as we have no idea
Expand All @@ -115,7 +126,6 @@ def _background_send(self):
payload = None # lose ref
else:
# nobreak. We've been gracefully closed.
self.stopped.set()
return

# If we've reached here, it means `break` was hit above and
Expand All @@ -125,50 +135,43 @@ def _background_send(self):
# This means that any messages in our buffer our lost.
# To propagate exceptions, we rely on subsequent `BatchedSend.send`
# calls to raise CommClosedErrors.
Comment on lines 136 to 137
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# To propagate exceptions, we rely on subsequent `BatchedSend.send`
# calls to raise CommClosedErrors.
# The exception will not be propagated. Instead, users of `BatchedSend` are expected
# to be implementing explicit reconnection logic when the comm closes. Reconnection often
# involves application logic reconciling state (because messages buffered on the
# `BatchedSend` may have been lost), then calling `start` again with a new comm object.

This is no longer true.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, are we confident nothing was relying on this behavior? Some places where we might be:

Probably important:

def worker_send(self, worker, msg):
"""Send message to worker
This also handles connection failures by adding a callback to remove
the worker on the next cycle.
"""
stream_comms: dict = self.stream_comms
try:
stream_comms[worker].send(msg)
except (CommClosedError, AttributeError):
self.loop.add_callback(
self.remove_worker,
address=worker,
stimulus_id=f"worker-send-comm-fail-{time()}",
)

for worker, msgs in worker_msgs.items():
try:
w = stream_comms[worker]
w.send(*msgs)
except KeyError:
# worker already gone
pass
except (CommClosedError, AttributeError):
self.loop.add_callback(
self.remove_worker,
address=worker,
stimulus_id=f"send-all-comm-fail-{time()}",
)

try:
self.scheduler.client_comms[c].send(
{"op": "pubsub-msg", "name": name, "msg": msg}
)
except (KeyError, CommClosedError):
self.remove_subscriber(name=name, client=c)

Less important:

try:
c.send(msg)
# logger.debug("Scheduler sends message to client %s", msg)
except CommClosedError:
if self.status == Status.running:
logger.critical(
"Closed comm %r while trying to write %s", c, msg, exc_info=True
)

def client_send(self, client, msg):
"""Send message to client"""
client_comms: dict = self.client_comms
c = client_comms.get(client)
if c is None:
return
try:
c.send(msg)
except CommClosedError:
if self.status == Status.running:
logger.critical(
"Closed comm %r while trying to write %s", c, msg, exc_info=True
)

try:
c.send(*msgs)
except CommClosedError:
if self.status == Status.running:
logger.critical(
"Closed comm %r while trying to write %s",
c,
msgs,
exc_info=True,
)

def add_worker(self, worker=None, **kwargs):
ident = self.scheduler.workers[worker].identity()
del ident["metrics"]
del ident["last_seen"]
try:
self.bcomm.send(["add", {"workers": {worker: ident}}])
except CommClosedError:
self.scheduler.remove_plugin(name=self.name)

except CommClosedError:
logger.info("Worker comm %r closed while stealing: %r", victim, ts)
return "comm-closed"

I have a feeling that removing these "probably important" code paths is actually a good thing (besides PubSub). It centralizes the logic for handling worker disconnects into handle_worker. If any of these places were previously raising CommClosedError, handle_worker should already have been calling remove_worker in its finally statement. Now we're not duplicating that.

self.stopped.set()
self.abort()

def send(self, *msgs: dict) -> None:
"""Schedule a message for sending to the other side

This completes quickly and synchronously
"""
if self.comm is not None and self.comm.closed():
raise CommClosedError(f"Comm {self.comm!r} already closed.")

self.message_count += len(msgs)
self.buffer.extend(msgs)
# Avoid spurious wakeups if possible
if self.next_deadline is None:
self.waker.set()

@gen.coroutine
def close(self, timeout=None):
"""Flush existing messages and then close comm
if self.comm and not self.comm.closed() and self.next_deadline is None:
self.waker.set()

If set, raises `tornado.util.TimeoutError` after a timeout.
"""
if self.comm is None:
return
async def close(self):
"""Flush existing messages and then close comm"""
self.please_stop = True
self.waker.set()
yield self.stopped.wait(timeout=timeout)
if not self.comm.closed():

if self._background_task:
await self._background_task
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the background task failed, we'll raise an exception here instead of trying to flush the buffer. I assume that's okay, but should be documented.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

background task doesn't raise any exceptions


if self.comm and not self.comm.closed():
try:
if self.buffer:
self.buffer, payload = [], self.buffer
yield self.comm.write(
await self.comm.write(
payload, serializers=self.serializers, on_error="raise"
)
except CommClosedError:
pass
yield self.comm.close()
await self.comm.close()

def abort(self):
if self.comm is None:
return
self.please_stop = True
self.buffer = []
self.waker.set()
if not self.comm.closed():
if self.comm and not self.comm.closed():
self.comm.abort()
8 changes: 2 additions & 6 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,7 @@ async def _ensure_connected(self, timeout=None):
if msg[0].get("warning"):
warnings.warn(version_module.VersionMismatchWarning(msg[0]["warning"]))

bcomm = BatchedSend(interval="10ms", loop=self.loop)
bcomm = BatchedSend(interval="10ms", name="Client")
bcomm.start(comm)
self.scheduler_comm = bcomm
if self._set_as_default:
Expand Down Expand Up @@ -1533,11 +1533,7 @@ async def _close(self, fast=False):
with suppress(asyncio.CancelledError, TimeoutError):
await asyncio.wait_for(asyncio.shield(handle_report_task), 0.1)

if (
self.scheduler_comm
and self.scheduler_comm.comm
and not self.scheduler_comm.comm.closed()
):
if self.scheduler_comm:
await self.scheduler_comm.close()

for key in list(self.futures):
Expand Down
46 changes: 28 additions & 18 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3501,7 +3501,7 @@ async def close(self, fast=False, close_workers=False):
await future

for comm in self.client_comms.values():
comm.abort()
await comm.close()
Comment on lines 3503 to +3504
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be an asyncio.gather instead of serial?


await self.rpc.close()

Expand Down Expand Up @@ -3641,7 +3641,7 @@ def heartbeat_worker(
@log_errors
async def add_worker(
self,
comm=None,
comm,
*,
address: str,
status: str,
Expand Down Expand Up @@ -3733,7 +3733,10 @@ async def add_worker(
# for key in keys: # TODO
# self.mark_key_in_memory(key, [address])

self.stream_comms[address] = BatchedSend(interval="5ms", loop=self.loop)
self.stream_comms[address] = BatchedSend(
interval="5ms",
name=f"Scheduler->Worker-{address}",
)

if ws.nthreads > len(ws.processing):
self.idle[ws.address] = ws
Expand Down Expand Up @@ -3766,7 +3769,19 @@ async def add_worker(
nbytes=nbytes[key],
typename=types[key],
)
recommendations, client_msgs, worker_msgs = t
recommendations, new_cmsgs, new_wmsgs = t
for c, new_msgs in new_cmsgs.items():
msgs = client_msgs.get(c) # type: ignore
if msgs is not None:
msgs.extend(new_msgs)
else:
client_msgs[c] = new_msgs
for w, new_msgs in new_wmsgs.items():
msgs = worker_msgs.get(w) # type: ignore
if msgs is not None:
msgs.extend(new_msgs)
else:
worker_msgs[w] = new_msgs
self._transitions(
recommendations, client_msgs, worker_msgs, stimulus_id
)
Expand Down Expand Up @@ -4305,7 +4320,7 @@ async def remove_worker(self, address, stimulus_id, safe=False, close=True):

logger.info("Remove worker %s", ws)
if close:
with suppress(AttributeError, CommClosedError):
with suppress(AttributeError):
self.stream_comms[address].send({"op": "close", "report": False})

self.remove_resources(address)
Expand All @@ -4319,6 +4334,7 @@ async def remove_worker(self, address, stimulus_id, safe=False, close=True):
del self.host_info[host]

self.rpc.remove(address)
await self.stream_comms[address].close()
del self.stream_comms[address]
Comment on lines +4337 to 4338
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
await self.stream_comms[address].close()
del self.stream_comms[address]
bcomm = self.stream_comms.pop(address)
await bcomm.close()

I think what you had exposes us to a race condition where:

  1. We start closing the current BatchedSend instance
  2. Same worker reconnects. The event loop runs add_worker before the await close() here finishes
  3. add_worker overwrites self.stream_comms[address] with a different (fresh) BatchedSend instance
  4. await here comes back and we delete the new instance (without even closing it), which other code is expecting to exist in the stream_comms dict

Additionally, this block should probably be moved to right before the plugins section, since that's the only other await in this function, and the plugins section should be moved to L4396. I don't like an await mixed into all this state updating. It means we're giving up control in a half-removed state, where for example stream_comms[address] doesn't exist, but workers[address] does. Remember, many other things can run whenever there's an await, not just the one thing you called.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, similar fix should be made to

await self.client_comms[client].close()
del self.client_comms[client]

del self.aliases[ws.name]
self.idle.pop(ws.address, None)
Expand Down Expand Up @@ -4673,7 +4689,7 @@ async def add_client(
logger.exception(e)

try:
bcomm = BatchedSend(interval="2ms", loop=self.loop)
bcomm = BatchedSend(interval="2ms", name="Scheduler->Client")
bcomm.start(comm)
self.client_comms[client] = bcomm
msg = {"op": "stream-start"}
Expand Down Expand Up @@ -4914,7 +4930,7 @@ async def handle_worker(self, comm=None, worker=None, stimulus_id=None):
await self.handle_stream(comm=comm, extra={"worker": worker})
finally:
if worker in self.stream_comms:
worker_comm.abort()
await worker_comm.close()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
await worker_comm.close()

remove_worker is already going to do this for us, and more cleanly

await self.remove_worker(address=worker, stimulus_id=stimulus_id)

def add_plugin(
Expand Down Expand Up @@ -5021,13 +5037,7 @@ def client_send(self, client, msg):
c = client_comms.get(client)
if c is None:
return
try:
c.send(msg)
except CommClosedError:
if self.status == Status.running:
logger.critical(
"Closed comm %r while trying to write %s", c, msg, exc_info=True
)
c.send(msg)

def send_all(self, client_msgs: dict, worker_msgs: dict):
"""Send messages to client and workers"""
Expand Down Expand Up @@ -5057,7 +5067,7 @@ def send_all(self, client_msgs: dict, worker_msgs: dict):
except KeyError:
# worker already gone
pass
except (CommClosedError, AttributeError):
except AttributeError:
self.loop.add_callback(
self.remove_worker,
address=worker,
Expand Down Expand Up @@ -7701,7 +7711,7 @@ class WorkerStatusPlugin(SchedulerPlugin):
name = "worker-status"

def __init__(self, scheduler, comm):
self.bcomm = BatchedSend(interval="5ms")
self.bcomm = BatchedSend(interval="5ms", name="WorkerStatus")
self.bcomm.start(comm)

self.scheduler = scheduler
Expand All @@ -7722,8 +7732,8 @@ def remove_worker(self, worker=None, **kwargs):
except CommClosedError:
self.scheduler.remove_plugin(name=self.name)

def teardown(self):
self.bcomm.close()
async def close(self) -> None:
await self.bcomm.close()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe teardown was never called before (it's not part of the SchedulerPlugin interface).

bcomm.close() is going to close the underlying comm object. I believe that's a comm to a client? Is that something WorkerStatusPlugin should actually be closing unilaterally, or should add_client for example be responsible for that here

finally:
if not comm.closed():
self.client_comms[client].send({"op": "stream-closed"})
try:
if not sys.is_finalizing():
await self.client_comms[client].close()
del self.client_comms[client]
if self.status == Status.running:
logger.info("Close client connection: %s", client)

Just want to point out that you're exposing a previously-dead codepath, which has some questionable behavior regardless of your changes in this PR.

I think it'll be okay as things stand right now (only because close won't be called in practice until the cluster is shutting down), but it feels a little weird. I don't get the sense that multiple BatchedSends are meant to be multiplexed onto a single comm object, which is what WorkerStatusPlugin is doing. Multiplexing the send side is probably okay, it's just odd to multiplex the close side.



class CollectTaskMetaDataPlugin(SchedulerPlugin):
Expand Down
41 changes: 38 additions & 3 deletions distributed/tests/test_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,6 @@ async def test_send_before_close():
await asyncio.sleep(0.01)
assert time() < start + 5

with pytest.raises(CommClosedError):
b.send("123")


@gen_test()
async def test_close_closed():
Expand Down Expand Up @@ -252,3 +249,41 @@ async def test_serializers():
assert "function" in value

assert comm.closed()


@gen_test()
async def test_restart():
async with EchoServer() as e:
comm = await connect(e.address)

b = BatchedSend(interval="2ms")
b.start(comm)
b.send(123)
assert await comm.read() == (123,)
await b.close()
assert b.closed()

# We can buffer stuff even while it is closed
b.send(345)

new_comm = await connect(e.address)
b.start(new_comm)

assert await new_comm.read() == (345,)
await b.close()
assert new_comm.closed()


@gen_test()
async def test_restart_fails_if_still_running():
async with EchoServer() as e:
comm = await connect(e.address)

b = BatchedSend(interval="2ms")
b.start(comm)
with pytest.raises(RuntimeError):
b.start(comm)

b.send(123)
assert await comm.read() == (123,)
await b.close()
Loading