Skip to content

(Worker) State Machine determinism and replayability #5736

@fjetter

Description

@fjetter

Connected tasks


The dask scheduling logic on scheduler and worker side are using the model of a
finite state machine to calculate decisions. A few definitions about a finite
state machine first

  • A deterministic finite state machine is a mathematical model of a
    computation describing an abstract machine that has a finite number of
    distinct states. Given a stimulus S_i and a state W_i, there is a function
    F such that a new state W_i+1 can be calculated as F(W_i, S_i) -> W_i+1
  • The only way to change the state W_i is to apply a transformation F with a
    stimulus S_i
  • Given an initial state W_0 and the sequence of all stimuli S_i, it
    is possible to calculate state W_i by applying the transformation F
    sequentially for all i

How does this model apply to us and where do we violate it?

Most of these arguments can be made for the scheduler as well but we'll restrict
ourselves to the Worker to keep the scope contained.

The worker state is defined primarily by the Worker.tasks dictionary including
TaskState objects with various task specific attributes. On top of this, there
are a few Worker attributes which hold global or remote-worker specific
attributes. A few examples include Worker.data_needed, Worker.has_what,
Worker.pending_data_per_worker, Worker.ready but the list goes on. We
currently do not properly distinguish the state machine attributes from the
server / networking code / other code.

The function F is a bit more difficult to define. Naively one would expect
this to be Worker.transitions but this is not the case since it does not
accept stimuli. Worker.transitions, in short T, accepts a set of already
made decisions we call recommendations. The recommendations are generated by
stimuli handler H, like Worker.handle_task_compute,
Worker.handle_free_keys. Therefore, to define the state transition function we
need a combination of H and T, M ~ T * H, such that W_i+1 = M(W_i, S_i) = T(W_i, H(W_i, S_i)). Our implementation of handlers introduces a certain ambiguity
since it is not entirely clear whether a piece of logic should reside on side of the
handler or the transition function.
However, every decision should be the result of the stimulus and the current state
such that, given all stimuli in order and the initial state, we can reconstruct every iteration.

There are three (actually four) places where this pattern is violated
and the stimulus generation is not only tightly coupled to the handling and transition itself
but also coupled to asynchronous actors.

Specifically dependency gathering (Worker.gather_dep) but also to a softer
extend task execution (Worker.execute) breaks the above pattern since they
simultaneously are generating stimuli and are handling them while interacting
with an async actor (i.e. remote worker or threadpool). There is no way to
inspect, assert or pause the state machine naturally. This prohibits writing
effective tests, increases instability and renders deterministic replayability
impossible.

Worse even are the ensure_communicating and ensure_computing methods which
are triggered in various places of our code to work off the queue of
read-to-compute / ready-to-be-fetched tasks. This pattern effectively delays
state transitions and performs a set of these transitions in bulk which is
benefitial to dependency fetching. However, they are called recursively
(actually rather something like pseudo recursively in the context of
ensure_communicating -> async gather_dep -> ensure_communicating).

Pseudo code below

def stimuls_XY(...):
    # First decision: *What* to fetch

    transition(ts, "released->fetch")
    data_needed.push(ts)

def ensure_communicating():
    # Second decision: Is there *capacity* to fetch?
    while is__gather_channel_available():
        ts = data_needed.pop()
        transition(ts, "fetch->flight")
        loop.add_callback(gather_dep, ts)


async def gather_dep(ts, ...):
    try:
        if not assumptions_still_valid(ts):
            # Another decision might have cancelled this fetch already
            return
        data = await fetch_data()
    finally:
        match response:
            case response.get('busy'):
                await sleep(a_while)
                retry_fetch(ts)
            case response.get('error'):
                flag_remote_dead(worker)
                reschedule_fetch(ts)
            case response.get('data'):
                transition(ts, "memory")
            case not response.get('data'):
                stale_information(ts, worker) # remove who_has / find-missing

        # Recursively loop into ensure_communicating
        ensure_communicating()

Problems

  • The knowledge about capacity to fetch is not encoded in the state machine. This
    requires us to have the infinite callback ensure_* to check periodically if
    something changed that would allow us to fetch a new dependency. Instead,
    this should be a stimulus to the state machine as well since this change in
    something is always connected to a stimulus somehow
  • There is no way to stop the state machine and freeze it in a specific
    iteration i since it is always moving. We frequently need to deal with an
    intermediate state (one of the motivations for the states resumed and
    cancelled)
  • Dealing with intermediate state exposes us to many edge cases to consider
    ultimately resulting in an unstable implementation
  • There is currently no way to know exactly what the stimuli were. While we're
    logging their outcome as part of the transition log we can sometimes only
    guess what the trigger was. For some stimulus handlers we do log the
    incoming stimulus as part of the transition log, e.g. the only stimulus body
    of handle_free_keys are a list of keys which is what we append to our log.
    For more complicated handlers like handle_compute_task we do not do this. If
    we start doing this, we should ensure not to log too much information and only
    restrict the logged info to what is relevant to the state machine, e.g.
    TaskState.runspec is not relevant to the state machine and we should
    therefore not remember it to reduce the memory footprint.
  • By not being able to define the state at a given time i it is impossible to
    write proper deterministic tests for the state machine. Instead, we rely on
    sophisticated mocking to construct a very timing sensitive intermediate state.

Given the state machine is extended sufficiently with the information to make
the capacity decision, the ensure_* loops can be removed such that the entire
state machine can be calculated deterministically and synchronously. Every interaction
with an asynchronous actor will be mapped as a set of out-/inbound stimuli.
This will allow us to

  • Log every stimulus [1] and reconstruct the state at any given time i
  • write specific fakes which would interact with the state machine using
    signals, e.g.
async def gather_dep(worker: str, keys: Collection[str]) -> list[dict]:
    try:
        response = await fetch_data(worker, keys)
    except:
        return [{"stimulus": "gather_dep_failed", "keys": keys}]
    finally:
        stimuli = []
        if busy:
            return [{"stimulus": "gather_dep_remote_busy", "keys": keys}]
        else:
            return [
                {
                    "stimulus": "gather_dep_response",
                    "key": key,
                    "data": data,
                    "nbyte": sizeof(data)
                } for key, data in response.get("data")
            ]

async def gather_dep_fake(worker: str, keys: Collection[str]) -> list[dict]:
    """This fake will emulate a failure with an exception when connecting to worker A and will
    return an empty result when fetching `missing` from B.
    Otherwise it will return data as expected"""

    # No need to do *anything* asynchronous here. This is only a
    # coroutine function to keep interfaces and signatures clean
    # but ideally there is no IO
    if worker == "A":
        return [{"stimulus": "gather_dep_failed", "keys": keys}]
    elif worker == "B":
        res = []
        for k in keys:
            if k == "missing":
                res.append({
                    "stimulus": "gather-dep-missing",
                    "key": "key"
                })
            else:
                res.append({
                    "stimulus": "gather_dep_response",
                    "key": key,
                    "data": DummyData(),
                    "nbyte": 42
                })
        return res
    else:
        return [{
            "stimulus": "gather_dep_response",
            "key": key,
            "data": DummyData(),
            "nbyte": 42
        } for key in keys]

[1] Logging all stimuli is actually cheaper than logging all transitions since one stimulus usually triggers many, many recommendations and transitions

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions