Skip to content

Commit c68a459

Browse files
committed
bind implementation
1 parent 6589c18 commit c68a459

File tree

1 file changed

+38
-28
lines changed

1 file changed

+38
-28
lines changed

dask/graph_manipulation.py

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -199,32 +199,43 @@ def bind(
199199

200200
omit, _ = unpack_collections(omit)
201201
if assume_layers:
202-
bind_one = _bind_one_by_layers
203-
omit_set = {layer for coll in omit for layer in coll.__dask_layers__()}
202+
omit_layers = {layer for coll in omit for layer in coll.__dask_layers__()}
203+
omit_keys = set()
204204
else:
205205
# One or more collections do not use HighLevelGraph; downgrade all graphs to
206206
# plain dicts
207-
bind_one = _bind_one_by_keys
208-
omit_set = {key for coll in omit for key in ensure_dict(coll.__dask_graph__())}
207+
omit_layers = set()
208+
omit_keys = {key for coll in omit for key in ensure_dict(coll.__dask_graph__())}
209209

210210
return repack(
211-
[bind_one(child, blocker, omit_set, seed) for child in unpacked_children]
211+
[
212+
_bind_one(child, blocker, omit_layers, omit_keys, seed)
213+
for child in unpacked_children
214+
]
212215
)
213216

214217

215-
def _bind_one_by_layers(
216-
child: T, blocker: Optional[Delayed], omit_layers: Set[str], seed: Hashable
218+
def _bind_one(
219+
child: T,
220+
blocker: Optional[Delayed],
221+
omit_layers: Set[str],
222+
omit_keys: Set[str],
223+
seed: Hashable,
217224
) -> T:
218225
try:
219-
old_name = get_collection_name(child)
226+
name = get_collection_name(child)
220227
except KeyError:
221228
return child # Collection with no keys; e.g. Array of size 0
222229

223230
dsk = child.__dask_graph__()
224-
assert isinstance(dsk, HighLevelGraph)
231+
if isinstance(dsk, HighLevelGraph):
232+
new_layers = dict(dsk.layers)
233+
new_deps = dict(dsk.dependencies)
234+
else:
235+
# Squash layers, if any
236+
new_layers = {name: ensure_dict(dsk)}
237+
new_deps = {name: set()}
225238

226-
new_layers = dict(dsk.layers)
227-
new_deps = dict(dsk.dependencies)
228239
if blocker is not None:
229240
blocker_key = blocker.key
230241
blocker_dsk = blocker.__dask_graph__()
@@ -234,7 +245,10 @@ def _bind_one_by_layers(
234245
else:
235246
blocker_key = None
236247

237-
to_visit = set(child.__dask_layers__())
248+
try:
249+
to_visit = set(child.__dask_layers__())
250+
except AttributeError:
251+
to_visit = {name}
238252

239253
while to_visit:
240254
name = to_visit.pop()
@@ -246,33 +260,29 @@ def _bind_one_by_layers(
246260
deps_to_visit = deps - omit_layers
247261
to_visit |= deps_to_visit
248262

249-
bind_before_keys = {
263+
# If assume_layers=True (the default), the stop_before_keys set will be
264+
# non-empty only in the layers immediately above the top layers of the
265+
# collections passed through the 'omit' parameter.
266+
# If assume_layers=False, it will instead always be populated with the output
267+
# of __dask_keys__ of the collections passed through the 'omit' parameter.
268+
stop_before_keys = {
250269
key for dep in deps_to_omit for key in new_layers[dep].get_external_keys()
251-
}
252-
new_layers[new_name] = layer.bind(
253-
to_key=blocker_key, before_keys=bind_before_keys, seed=seed
270+
} | omit_keys
271+
new_layers[new_name] = layer.clone(
272+
stop_before=stop_before_keys, bind_to=blocker_key, seed=seed
254273
)
255274
new_deps[new_name] = {
256275
clone_key(dep, seed=seed) for dep in deps_to_visit
257276
} | deps_to_omit
277+
if deps_to_omit and blocker_key:
278+
new_deps[new_name].add(blocker_key)
258279

259280
rebuild, args = child.__dask_postpersist__()
260281
return rebuild(
261-
HighLevelGraph(new_layers, new_deps), *args, name=clone_key(old_name, seed)
282+
HighLevelGraph(new_layers, new_deps), *args, name=clone_key(name, seed)
262283
)
263284

264285

265-
def _bind_one_by_keys(
266-
child: T, blocker: Optional[Delayed], omit_keys: set, seed: Hashable
267-
) -> T:
268-
dsk = ensure_dict(child.__dask_graph__())
269-
270-
# TODO
271-
272-
rebuild, args = child.__dask_postpersist__()
273-
return rebuild(dsk, *args, name="bind-" + tokenize(child, seed))
274-
275-
276286
def clone(*collections, omit=None, seed: Hashable = None, assume_layers: bool = True):
277287
"""Clone dask collections, returning equivalent collections that are generated from
278288
independent calculations.

0 commit comments

Comments
 (0)