Skip to content

Commit 16fbb10

Browse files
committed
Implement name=
1 parent d0f0348 commit 16fbb10

File tree

4 files changed

+37
-7
lines changed

4 files changed

+37
-7
lines changed

dask/bag/core.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,14 @@
4040
from ..base import tokenize, dont_optimize, DaskMethodsMixin
4141
from ..bytes import open_files
4242
from ..context import globalmethod
43-
from ..core import quote, istask, get_dependencies, reverse_dict, flatten
43+
from ..core import (
44+
quote,
45+
istask,
46+
get_dependencies,
47+
replace_name_in_key,
48+
reverse_dict,
49+
flatten,
50+
)
4451
from ..sizeof import sizeof
4552
from ..delayed import Delayed, unpack_collections
4653
from ..highlevelgraph import HighLevelGraph
@@ -365,7 +372,9 @@ def __dask_postpersist__(self):
365372
return self._rebuild, ()
366373

367374
def _rebuild(self, dsk, name=None):
368-
key = change_name_in_key(self.key, name) if name else self.key
375+
key = self.key
376+
if name:
377+
key = replace_name_in_key(key, name)
369378
return Item(dsk, key)
370379

371380
@staticmethod

dask/core.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,3 +503,21 @@ def quote(x):
503503
if istask(x) or type(x) is list or type(x) is dict:
504504
return (literal(x),)
505505
return x
506+
507+
508+
def replace_name_in_key(key, name: str):
509+
"""Given a dask key, which must be either a single string or a tuple whose first
510+
element is a string (commonly referred to as 'name'), replace the name with a new
511+
one.
512+
513+
e.g.::
514+
515+
>>> replace_name_in_key("foo", "bar")
516+
"bar"
517+
>>> replace_name_in_key(("foo", 1, 2), "bar") -> ("bar", 1, 2)
518+
"""
519+
if isinstance(key, str):
520+
return name
521+
if isinstance(key, tuple) and key and isinstance(key[0], str):
522+
return (name,) + key[1:]
523+
raise TypeError(f"Expected str or tuple[str, Hashable, ...]; got {key}")

dask/delayed.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from .base import tokenize as _tokenize
1212
from .compatibility import is_dataclass, dataclass_fields
1313

14-
from .core import quote
14+
from .core import replace_name_in_key, quote
1515
from .context import globalmethod
1616
from .optimization import cull
1717
from .utils import funcname, methodcaller, OperatorMethodMixin, ensure_dict, apply
@@ -512,7 +512,9 @@ def __dask_postpersist__(self):
512512
return self._rebuild, ()
513513

514514
def _rebuild(self, dsk, name=None):
515-
key = change_name_in_key(self.key, name) if name else self.key
515+
key = self.key
516+
if name:
517+
key = replace_name_in_key(key, name)
516518
return Delayed(key, dsk, getattr(self, "_length", None))
517519

518520
def __getstate__(self):

dask/graphsurgery.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
from .array import Array
66
from .bag import Bag
77
from .base import tokenize, unpack_collections
8+
from .blockwise import blockwise
89
from .core import flatten
910
from .dataframe import DataFrame, Series
1011
from .delayed import Delayed, delayed
11-
from .highlevelgraph import BasicLayer, Layer, HighLevelGraph
12-
from .blockwise import blockwise
12+
from .highlevelgraph import BasicLayer, HighLevelGraph, Layer
13+
from .utils import ensure_dict
1314

1415
__all__ = ("bind", "choke", "clone", "drop")
1516

@@ -206,7 +207,7 @@ def _bind_one_by_layers(
206207
def _bind_one_by_keys(
207208
child: T, blocker: Optional[Delayed], omit_keys: set, seed: Hashable
208209
) -> T:
209-
dsk = dict(child.__dask_graph__())
210+
dsk = ensure_dict(child.__dask_graph__())
210211

211212
# TODO
212213

0 commit comments

Comments
 (0)