@@ -199,32 +199,43 @@ def bind(
199199
200200 omit , _ = unpack_collections (omit )
201201 if assume_layers :
202- bind_one = _bind_one_by_layers
203- omit_set = { layer for coll in omit for layer in coll . __dask_layers__ ()}
202+ omit_layers = { layer for coll in omit for layer in coll . __dask_layers__ ()}
203+ omit_keys = set ()
204204 else :
205205 # One or more collections do not use HighLevelGraph; downgrade all graphs to
206206 # plain dicts
207- bind_one = _bind_one_by_keys
208- omit_set = {key for coll in omit for key in ensure_dict (coll .__dask_graph__ ())}
207+ omit_layers = set ()
208+ omit_keys = {key for coll in omit for key in ensure_dict (coll .__dask_graph__ ())}
209209
210210 return repack (
211- [bind_one (child , blocker , omit_set , seed ) for child in unpacked_children ]
211+ [
212+ _bind_one (child , blocker , omit_layers , omit_keys , seed )
213+ for child in unpacked_children
214+ ]
212215 )
213216
214217
215- def _bind_one_by_layers (
216- child : T , blocker : Optional [Delayed ], omit_layers : Set [str ], seed : Hashable
218+ def _bind_one (
219+ child : T ,
220+ blocker : Optional [Delayed ],
221+ omit_layers : Set [str ],
222+ omit_keys : Set [str ],
223+ seed : Hashable ,
217224) -> T :
218225 try :
219- old_name = get_collection_name (child )
226+ name = get_collection_name (child )
220227 except KeyError :
221228 return child # Collection with no keys; e.g. Array of size 0
222229
223230 dsk = child .__dask_graph__ ()
224- assert isinstance (dsk , HighLevelGraph )
231+ if isinstance (dsk , HighLevelGraph ):
232+ new_layers = dict (dsk .layers )
233+ new_deps = dict (dsk .dependencies )
234+ else :
235+ # Squash layers, if any
236+ new_layers = {name : ensure_dict (dsk )}
237+ new_deps = {name : set ()}
225238
226- new_layers = dict (dsk .layers )
227- new_deps = dict (dsk .dependencies )
228239 if blocker is not None :
229240 blocker_key = blocker .key
230241 blocker_dsk = blocker .__dask_graph__ ()
@@ -234,7 +245,10 @@ def _bind_one_by_layers(
234245 else :
235246 blocker_key = None
236247
237- to_visit = set (child .__dask_layers__ ())
248+ try :
249+ to_visit = set (child .__dask_layers__ ())
250+ except AttributeError :
251+ to_visit = {name }
238252
239253 while to_visit :
240254 name = to_visit .pop ()
@@ -246,33 +260,29 @@ def _bind_one_by_layers(
246260 deps_to_visit = deps - omit_layers
247261 to_visit |= deps_to_visit
248262
249- bind_before_keys = {
263+ # If assume_layers=True (the default), the stop_before_keys set will be
264+ # non-empty only in the layers immediately above the top layers of the
265+ # collections passed through the 'omit' parameter.
266+ # If assume_layers=False, it will instead always be populated with the output
267+ # of __dask_keys__ of the collections passed through the 'omit' parameter.
268+ stop_before_keys = {
250269 key for dep in deps_to_omit for key in new_layers [dep ].get_external_keys ()
251- }
252- new_layers [new_name ] = layer .bind (
253- to_key = blocker_key , before_keys = bind_before_keys , seed = seed
270+ } | omit_keys
271+ new_layers [new_name ] = layer .clone (
272+ stop_before = stop_before_keys , bind_to = blocker_key , seed = seed
254273 )
255274 new_deps [new_name ] = {
256275 clone_key (dep , seed = seed ) for dep in deps_to_visit
257276 } | deps_to_omit
277+ if deps_to_omit and blocker_key :
278+ new_deps [new_name ].add (blocker_key )
258279
259280 rebuild , args = child .__dask_postpersist__ ()
260281 return rebuild (
261- HighLevelGraph (new_layers , new_deps ), * args , name = clone_key (old_name , seed )
282+ HighLevelGraph (new_layers , new_deps ), * args , name = clone_key (name , seed )
262283 )
263284
264285
265- def _bind_one_by_keys (
266- child : T , blocker : Optional [Delayed ], omit_keys : set , seed : Hashable
267- ) -> T :
268- dsk = ensure_dict (child .__dask_graph__ ())
269-
270- # TODO
271-
272- rebuild , args = child .__dask_postpersist__ ()
273- return rebuild (dsk , * args , name = "bind-" + tokenize (child , seed ))
274-
275-
276286def clone (* collections , omit = None , seed : Hashable = None , assume_layers : bool = True ):
277287 """Clone dask collections, returning equivalent collections that are generated from
278288 independent calculations.
0 commit comments