Skip to content

Commit c437e63

Browse files
authored
Fix meta_from_array to support Xarray test suite (dask#4938)
Fixes pydata/xarray#3009
1 parent d8ff4c4 commit c437e63

File tree

6 files changed

+77
-30
lines changed

6 files changed

+77
-30
lines changed

dask/array/core.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -888,8 +888,7 @@ def __new__(cls, dask, name, chunks, dtype=None, meta=None, shape=None):
888888
if dtype:
889889
self._meta = np.empty((0,) * self.ndim, dtype=dtype)
890890
else:
891-
from .utils import meta_from_array
892-
self._meta = meta_from_array(meta, meta.ndim)
891+
self._meta = meta_from_array(meta)
893892

894893
for plugin in config.get('array_plugins', ()):
895894
result = plugin(self)
@@ -2984,8 +2983,7 @@ def concatenate(seq, axis=0, allow_unknown_chunksizes=False):
29842983
if not seq:
29852984
raise ValueError("Need array(s) to concatenate")
29862985

2987-
from .utils import meta_from_array
2988-
meta = np.concatenate([meta_from_array(s, s.ndim) for s in seq])
2986+
meta = np.concatenate([meta_from_array(s) for s in seq])
29892987

29902988
# Promote types to match meta
29912989
seq = [a.astype(meta.dtype) for a in seq]
@@ -4167,3 +4165,6 @@ def from_npy_stack(dirname, mmap_mode='r'):
41674165
dsk = dict(zip(keys, values))
41684166

41694167
return Array(dsk, name, chunks, dtype)
4168+
4169+
4170+
from .utils import meta_from_array

dask/array/gufunc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from toolz import concat, merge, unique
1111

1212
from .core import Array, asarray, blockwise, getitem, apply_infer_dtype
13-
from .utils import normalize_meta
13+
from .utils import meta_from_array
1414
from ..highlevelgraph import HighLevelGraph
1515
from ..core import flatten
1616

@@ -398,7 +398,7 @@ def apply_gufunc(func, signature, *args, **kwargs):
398398
leaf_name = "%s_%d-%s" % (name, i, token)
399399
leaf_dsk = {(leaf_name,) + key[1:] + core_chunkinds: ((getitem, key, i) if nout else key) for key in keys}
400400
graph = HighLevelGraph.from_collections(leaf_name, leaf_dsk, dependencies=[tmp])
401-
meta = normalize_meta(tmp._meta, len(output_shape), dtype=odt)
401+
meta = meta_from_array(tmp._meta, len(output_shape), dtype=odt)
402402
leaf_arr = Array(graph,
403403
leaf_name,
404404
chunks=output_chunks,

dask/array/linalg.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -798,9 +798,9 @@ def lu(a):
798798
# l_permuted is not referred in upper triangulars
799799

800800
pp, ll, uu = scipy.linalg.lu(np.ones(shape=(1, 1), dtype=a.dtype))
801-
pp_meta = meta_from_array(a, a.ndim, dtype=pp.dtype)
802-
ll_meta = meta_from_array(a, a.ndim, dtype=ll.dtype)
803-
uu_meta = meta_from_array(a, a.ndim, dtype=uu.dtype)
801+
pp_meta = meta_from_array(a, dtype=pp.dtype)
802+
ll_meta = meta_from_array(a, dtype=ll.dtype)
803+
uu_meta = meta_from_array(a, dtype=uu.dtype)
804804

805805
graph = HighLevelGraph.from_collections(name_p, dsk, dependencies=[a])
806806
p = Array(graph, name_p, shape=a.shape, chunks=a.chunks, meta=pp_meta)
@@ -1048,7 +1048,7 @@ def _cholesky(a):
10481048
graph_upper = HighLevelGraph.from_collections(name_upper, dsk, dependencies=[a])
10491049
graph_lower = HighLevelGraph.from_collections(name, dsk, dependencies=[a])
10501050
cho = scipy.linalg.cholesky(np.array([[1, 2], [2, 5]], dtype=a.dtype))
1051-
meta = meta_from_array(a, a.ndim, dtype=cho.dtype)
1051+
meta = meta_from_array(a, dtype=cho.dtype)
10521052

10531053
lower = Array(graph_lower, name, shape=a.shape, chunks=a.chunks, meta=meta)
10541054
# do not use .T, because part of transposed blocks are already calculated
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import numpy as np
2+
import pytest
3+
4+
from dask.array.utils import meta_from_array
5+
6+
asarrays = [np.asarray]
7+
8+
try:
9+
import sparse
10+
asarrays.append(sparse.COO.from_numpy)
11+
except ImportError:
12+
pass
13+
14+
try:
15+
import cupy
16+
asarrays.append(cupy.asarray)
17+
except ImportError:
18+
pass
19+
20+
21+
@pytest.mark.parametrize("asarray", asarrays)
22+
def test_meta_from_array(asarray):
23+
x = np.ones((1, 2, 3), dtype='float32')
24+
x = asarray(x)
25+
26+
assert meta_from_array(x).shape == (0, 0, 0)
27+
assert meta_from_array(x).dtype == 'float32'
28+
assert type(meta_from_array(x)) is type(x)
29+
30+
assert meta_from_array(x, ndim=2).shape == (0, 0)
31+
assert meta_from_array(x, ndim=4).shape == (0, 0, 0, 0)
32+
assert meta_from_array(x, dtype="float64").dtype == "float64"

dask/array/utils.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,19 @@ def normalize_to_array(x):
2828
return x
2929

3030

31-
def normalize_meta(x, ndim, dtype=None):
32-
if ndim > x.ndim:
33-
meta = x[(Ellipsis, ) + tuple(None for _ in range(ndim - x.ndim))]
34-
meta = meta[tuple(slice(0, 0, None) for _ in range(meta.ndim))]
35-
elif ndim < x.ndim:
36-
meta = np.sum(x, axis=tuple(d for d in range((x.ndim - ndim))))
37-
else:
38-
meta = x
39-
40-
if dtype:
41-
meta = meta.astype(dtype)
42-
43-
return meta
44-
45-
46-
def meta_from_array(x, ndim, dtype=None):
31+
def meta_from_array(x, ndim=None, dtype=None):
32+
""" Normalize an array to appropriate meta object
33+
34+
Parameters
35+
----------
36+
x: array-like
37+
ndim: int
38+
dtype: dtype
39+
40+
Returns
41+
-------
42+
array-like
43+
"""
4744
if isinstance(x, list) or isinstance(x, tuple):
4845
ndims = [0 if isinstance(a, numbers.Number)
4946
else a.ndim if hasattr(a, 'ndim') else len(a) for a in x]
@@ -53,11 +50,28 @@ def meta_from_array(x, ndim, dtype=None):
5350
# x._meta must be a Dask Array, some libraries (e.g. zarr) implement a
5451
# _meta attribute that are incompatible with Dask Array._meta
5552
if hasattr(x, '_meta') and isinstance(x, Array):
56-
meta = x._meta
57-
else:
53+
x = x._meta
54+
55+
if ndim is None:
56+
ndim = x.ndim
57+
58+
try:
5859
meta = x[tuple(slice(0, 0, None) for _ in range(x.ndim))]
60+
if meta.ndim != ndim:
61+
if ndim > x.ndim:
62+
meta = meta[(Ellipsis, ) + tuple(None for _ in range(ndim - meta.ndim))]
63+
meta = meta[tuple(slice(0, 0, None) for _ in range(meta.ndim))]
64+
elif ndim == 0:
65+
meta = meta.sum()
66+
else:
67+
meta = meta.reshape((0,) * ndim)
68+
except Exception:
69+
meta = np.empty((0,) * ndim, dtype=dtype or x.dtype)
70+
71+
if dtype and meta.dtype != dtype:
72+
meta = meta.astype(dtype)
5973

60-
return normalize_meta(meta, ndim, dtype)
74+
return meta
6175

6276

6377
def allclose(a, b, equal_nan=False, **kwargs):

dask/array/wrap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def wrap_func_like(func, *args, **kwargs):
7373
Transform np creation function into blocked version
7474
"""
7575
x = args[0]
76-
meta = meta_from_array(x, x.ndim)
76+
meta = meta_from_array(x)
7777
shape = kwargs.get('shape', x.shape)
7878

7979
parsed = _parse_wrap_args(func, args, kwargs, shape)

0 commit comments

Comments
 (0)