Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 105 additions & 46 deletions dask/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from dask.core import literal, quote
from dask.hashing import hash_buffer_hex
from dask.system import CPU_COUNT
from dask.typing import Key, SchedulerGetCallable
from dask.typing import DaskCollection2, Key, SchedulerGetCallable, TaskGraphFactory
from dask.utils import (
Dispatch,
apply,
Expand Down Expand Up @@ -231,6 +231,8 @@ def is_dask_collection(x) -> bool:
implementation of the protocol.

"""
if isinstance(x, DaskCollection2):
return True
if (
isinstance(x, type)
or not hasattr(x, "__dask_graph__")
Expand Down Expand Up @@ -414,36 +416,69 @@ def optimization_function(x):
return getattr(x, "__dask_optimize__", dont_optimize)


def collections_to_dsk(collections, optimize_graph=True, optimizations=(), **kwargs):
"""
Convert many collections into a single dask graph, after optimization
"""
from dask.highlevelgraph import HighLevelGraph

optimizations = tuple(optimizations) + tuple(config.get("optimizations", ()))
def newstyle_collections(collections):
from dask.delayed import Delayed

if optimize_graph:
groups = groupby(optimization_function, collections)
is_newstyle = [
isinstance(c, DaskCollection2) and not isinstance(c, Delayed)
for c in collections
]
if any(is_newstyle):
if not all(is_newstyle):
raise RuntimeError("Provided new- and old-style collections.")
if not all(
isinstance(c.__dask_graph_factory__(), TaskGraphFactory)
for c in collections
):
raise TypeError("Newstyle collections must have a TaskGraphFactory graph.")

graphs = []
for opt, val in groups.items():
dsk, keys = _extract_graph_and_keys(val)
dsk = opt(dsk, keys, **kwargs)
return True
return False

for opt_inner in optimizations:
dsk = opt_inner(dsk, keys, **kwargs)

graphs.append(dsk)
def collections_to_dsk(
collections, optimize_graph=True, optimizations=(), **kwargs
) -> TaskGraphFactory:
Comment on lines +439 to +441
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code here is still pretty ugly since it requires boilerplate all over for compat purposes.

It probably doesn't make sense to have a switch inside of this function. I'd likely define a function that only works for newstyle and oldstyle collections since we'll need another switch outside as well.

"""
Convert many collections into a single dask graph, after optimization
"""
from dask.highlevelgraph import HighLevelGraph, TaskFactoryHLGWrapper

# Merge all graphs
if any(isinstance(graph, HighLevelGraph) for graph in graphs):
dsk = HighLevelGraph.merge(*graphs)
if newstyle_collections(collections):
graph_factories = [c.__dask_graph_factory__() for c in collections]
if len(graph_factories) > 1:
expr = type(graph_factories[0]).combine_factories(*graph_factories)
else:
dsk = merge(*map(ensure_dict, graphs))
expr = collections[0]
return expr.optimize()
else:
dsk, _ = _extract_graph_and_keys(collections)

return dsk
optimizations = tuple(optimizations) + tuple(config.get("optimizations", ()))
if optimize_graph:
ext_keys = []
groups = groupby(optimization_function, collections)

graphs = []
for opt, val in groups.items():
dsk, keys = _extract_graph_and_keys(val)
ext_keys.extend(keys)
dsk = opt(dsk, keys, **kwargs)

for opt_inner in optimizations:
dsk = opt_inner(dsk, keys, **kwargs)

graphs.append(dsk)

# Merge all graphs
if any(isinstance(graph, HighLevelGraph) for graph in graphs):
dsk = HighLevelGraph.merge(*graphs)
return TaskFactoryHLGWrapper(dsk, ext_keys)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I defined this wrapper just to streamline stuff in distributed. As pointed out in dask/distributed#7942 (comment) we need more than just the graph, we also need the output keys to actually know what to compute

else:
dsk = merge(*map(ensure_dict, graphs))
return TaskFactoryHLGWrapper.from_low_level(dsk, ext_keys)
else:
return TaskFactoryHLGWrapper.from_low_level(
*_extract_graph_and_keys(collections)
)


def _extract_graph_and_keys(vals):
Expand Down Expand Up @@ -654,17 +689,27 @@ def compute(
collections=collections,
get=get,
)
if newstyle_collections(collections):
collections = [c.finalize_compute() for c in collections]
graph_factory = collections_to_dsk(collections, optimize_graph, **kwargs)
dsk = graph_factory.materialize()
keys = list(graph_factory.__dask_output_keys__())
return schedule(dsk, keys, **kwargs)
else:
graph_factory = collections_to_dsk(collections, optimize_graph, **kwargs)
from dask.highlevelgraph import TaskFactoryHLGWrapper

dsk = collections_to_dsk(collections, optimize_graph, **kwargs)
keys, postcomputes = [], []
for x in collections:
keys.append(x.__dask_keys__())
postcomputes.append(x.__dask_postcompute__())
assert isinstance(graph_factory, TaskFactoryHLGWrapper)
dsk = graph_factory.materialize()
keys, postcomputes = [], []
for x in collections:
keys.append(x.__dask_keys__())
postcomputes.append(x.__dask_postcompute__())

with shorten_traceback():
results = schedule(dsk, keys, **kwargs)
with shorten_traceback():
results = schedule(dsk, keys, **kwargs)

return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])


def visualize(
Expand Down Expand Up @@ -757,7 +802,8 @@ def visualize(
"""
args, _ = unpack_collections(*args, traverse=traverse)

dsk = dict(collections_to_dsk(args, optimize_graph=optimize_graph))
graph_fact = collections_to_dsk(args, optimize_graph=optimize_graph)
dsk = graph_fact.materialize()

return visualize_dsk(
dsk=dsk,
Expand Down Expand Up @@ -989,21 +1035,32 @@ def persist(*args, traverse=True, optimize_graph=True, scheduler=None, **kwargs)
collections, optimize_graph=optimize_graph, **kwargs
)
return repack(results)
# FIXME: Both paths needed
if newstyle_collections(collections):
graph_fact = collections_to_dsk(collections, optimize_graph, **kwargs)
dsk = graph_fact.materialize()
keys = graph_fact.__dask_output_keys__()
with shorten_traceback():
results = schedule(dsk, keys, **kwargs)

d = dict(zip(keys, results))
return [c.postpersist(d) for c in collections]
else:
dsk = collections_to_dsk(collections, optimize_graph, **kwargs)
dsk = dsk._hlg
keys, postpersists = [], []
for a in collections:
a_keys = list(flatten(a.__dask_keys__()))
rebuild, state = a.__dask_postpersist__()
keys.extend(a_keys)
postpersists.append((rebuild, a_keys, state))

dsk = collections_to_dsk(collections, optimize_graph, **kwargs)
keys, postpersists = [], []
for a in collections:
a_keys = list(flatten(a.__dask_keys__()))
rebuild, state = a.__dask_postpersist__()
keys.extend(a_keys)
postpersists.append((rebuild, a_keys, state))

with shorten_traceback():
results = schedule(dsk, keys, **kwargs)
with shorten_traceback():
results = schedule(dsk, keys, **kwargs)

d = dict(zip(keys, results))
results2 = [r({k: d[k] for k in ks}, *s) for r, ks, s in postpersists]
return repack(results2)
d = dict(zip(keys, results))
results2 = [r({k: d[k] for k in ks}, *s) for r, ks, s in postpersists]
return repack(results2)


############
Expand Down Expand Up @@ -1283,6 +1340,8 @@ def register_numpy():
def normalize_array(x):
if not x.shape:
return (x.item(), x.dtype)
if x.size == 0:
return (None, x.dtype, x.shape)
if hasattr(x, "mode") and getattr(x, "filename", None):
if hasattr(x.base, "ctypes"):
offset = (
Expand Down
6 changes: 3 additions & 3 deletions dask/dataframe/tests/test_merge_column_and_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_merge_known_to_known(
# Assertions
assert_eq(result, expected)
assert_eq(result.divisions, tuple(range(12)))
assert len(result.__dask_graph__()) < 80
assert len(result.optimize().materialize()) < 80


@pytest.mark.parametrize("how", ["inner", "left"])
Expand All @@ -121,7 +121,7 @@ def test_merge_known_to_single(
# Assertions
assert_eq(result, expected)
assert result.divisions == ddf_left.divisions
assert len(result.__dask_graph__()) < 30
assert len(result.optimize().materialize()) < 30


@pytest.mark.parametrize("how", ["inner", "right"])
Expand All @@ -139,7 +139,7 @@ def test_merge_single_to_known(
# Assertions
assert_eq(result, expected)
assert result.divisions == ddf_right.divisions
assert len(result.__dask_graph__()) < 30
assert len(result.optimize().materialize()) < 30


def test_merge_known_to_unknown(
Expand Down
35 changes: 27 additions & 8 deletions dask/dataframe/tests/test_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1880,12 +1880,21 @@ def test_cheap_single_partition_merge_divisions():
actual = aa.merge(bb, on="x", how="inner")
if not dd._dask_expr_enabled():
assert not hlg_layer_topological(actual.dask, -1).is_materialized()

assert not actual.known_divisions
assert not actual.known_divisions
else:
# Right side is a single partition join, i.e. we'll broadcast and preserve
# divisions of left
assert actual.known_divisions
assert actual.divisions == aa.divisions
assert_divisions(actual)

actual = bb.merge(aa, on="x", how="inner")
assert not actual.known_divisions

if not dd._dask_expr_enabled():
assert not actual.known_divisions
else:
assert actual.known_divisions
assert actual.divisions == aa.divisions
assert_divisions(actual)


Expand Down Expand Up @@ -1935,7 +1944,9 @@ def test_cheap_single_partition_merge_on_index():

if not dd._dask_expr_enabled():
assert not hlg_layer_topological(actual.dask, -1).is_materialized()
assert not actual.known_divisions
assert not actual.known_divisions
else:
assert actual.known_divisions
assert_eq(actual, expected)

actual = bb.merge(aa, right_index=True, left_on="x", how="inner")
Expand All @@ -1944,7 +1955,9 @@ def test_cheap_single_partition_merge_on_index():

if not dd._dask_expr_enabled():
assert not hlg_layer_topological(actual.dask, -1).is_materialized()
assert not actual.known_divisions
assert not actual.known_divisions
else:
assert actual.known_divisions
assert_eq(actual, expected)


Expand Down Expand Up @@ -2014,6 +2027,7 @@ def test_concat_one_series():
assert isinstance(c, dd.DataFrame)


@pytest.mark.skip("dask_expr.from_pandas always knows divisions")
def test_concat_unknown_divisions():
a = pd.Series([1, 2, 3, 4])
b = pd.Series([4, 3, 2, 1])
Expand Down Expand Up @@ -2362,9 +2376,14 @@ def check_and_return(ddfs, dfs, join):
res = dd.concat(ddfs, join=join, interleave_partitions=divisions)
assert_eq(res, sol)
if known:
parts = compute_as_if_collection(
dd.DataFrame, res.dask, res.__dask_keys__()
)
try:
graph_fact = res.__dask_graph_factory__().optimize()
dsk = graph_fact.materialize()
keys = graph_fact.__dask_output_keys__()
except AttributeError:
dsk = res.dask
keys = res.__dask_keys__()
parts = compute_as_if_collection(dd.DataFrame, dsk, keys)
for p in [i.iloc[:0] for i in parts]:
check_meta(res._meta, p) # will error if schemas don't align
assert not cat_index or has_known_categories(res.index) == known
Expand Down
43 changes: 28 additions & 15 deletions dask/dataframe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
meta_nonempty,
)
from dask.dataframe.extensions import make_scalar
from dask.typing import NoDefault, no_default
from dask.typing import DaskCollection2, NoDefault, no_default
from dask.utils import (
asciitable,
is_dataframe_like,
Expand Down Expand Up @@ -465,11 +465,17 @@ def index_summary(idx, name=None):

def _check_dask(dsk, check_names=True, check_dtypes=True, result=None, scheduler=None):
import dask.dataframe as dd
from dask.typing import DaskCollection2

if is_dask_collection(dsk):
if isinstance(dsk, DaskCollection2):
assert not hasattr(dsk, "__dask_graph__")
graph = dsk.__dask_graph_factory__().optimize()
elif hasattr(dsk, "__dask_graph__"):
graph = dsk.__dask_graph__()
if hasattr(graph, "validate"):
graph.validate()

if hasattr(dsk, "__dask_graph__"):
graph = dsk.__dask_graph__()
if hasattr(graph, "validate"):
graph.validate()
if result is None:
result = dsk.compute(scheduler=scheduler)
if isinstance(dsk, dd.Index) or is_index_like(dsk._meta):
Expand Down Expand Up @@ -649,7 +655,11 @@ def index(x):
return x.index

get = get_scheduler(scheduler=scheduler, collections=[type(ddf)])
results = get(ddf.dask, ddf.__dask_keys__())
return
from dask.base import collections_to_dsk

graph_fact = collections_to_dsk([ddf])
results = get(graph_fact.materialize(), ddf.__dask_output_keys__())
for i, df in enumerate(results[:-1]):
if len(df):
assert index(df).min() >= ddf.divisions[i]
Expand All @@ -661,15 +671,18 @@ def index(x):


def assert_sane_keynames(ddf):
if not hasattr(ddf, "dask"):
return
for k in ddf.dask.keys():
while isinstance(k, tuple):
k = k[0]
assert isinstance(k, (str, bytes))
assert len(k) < 100
assert " " not in k
assert k.split("-")[0].isidentifier(), k
if is_dask_collection(ddf):
if isinstance(ddf, DaskCollection2):
dsk = ddf.__dask_graph_factory__().lower_completely().materialize()
else:
dsk = ddf.dask
for k in dsk:
while isinstance(k, tuple):
k = k[0]
assert isinstance(k, (str, bytes))
assert len(k) < 100
assert " " not in k
assert k.split("-")[0].isidentifier(), k


def assert_dask_dtypes(ddf, res, numeric_equal=True):
Expand Down
Loading