Skip to content

Commit c2c1ece

Browse files
authored
Pickle da.argwhere and da.count_nonzero (#10885)
1 parent c472a8a commit c2c1ece

File tree

2 files changed

+26
-15
lines changed

2 files changed

+26
-15
lines changed

dask/array/routines.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2056,24 +2056,19 @@ def _isnonzero_vec(v):
20562056
_isnonzero_vec = np.vectorize(_isnonzero_vec, otypes=[bool])
20572057

20582058

2059+
def _isnonzero(a):
2060+
# Output of np.vectorize can't be pickled
2061+
return _isnonzero_vec(a)
2062+
2063+
20592064
def isnonzero(a):
2060-
if a.dtype.kind in {"U", "S"}:
2061-
# NumPy treats all-whitespace strings as falsy (like in `np.nonzero`).
2062-
# but not in `.astype(bool)`. To match the behavior of numpy at least until
2063-
# 1.19, we use `_isnonzero_vec`. When NumPy changes behavior, we should just
2064-
# use the try block below.
2065-
# https://github.com/numpy/numpy/issues/9875
2066-
return a.map_blocks(_isnonzero_vec, dtype=bool)
2065+
"""Handle special cases where conversion to bool does not work correctly.
2066+
xref: https://github.com/numpy/numpy/issues/9479
2067+
"""
20672068
try:
2068-
np.zeros(tuple(), dtype=a.dtype).astype(bool)
2069+
np.zeros([], dtype=a.dtype).astype(bool)
20692070
except ValueError:
2070-
######################################################
2071-
# Handle special cases where conversion to bool does #
2072-
# not work correctly. #
2073-
# #
2074-
# xref: https://github.com/numpy/numpy/issues/9479 #
2075-
######################################################
2076-
return a.map_blocks(_isnonzero_vec, dtype=bool)
2071+
return a.map_blocks(_isnonzero, dtype=bool)
20772072
else:
20782073
return a.astype(bool)
20792074

dask/array/tests/test_routines.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import contextlib
44
import itertools
5+
import pickle
56
import sys
67
import warnings
78
from numbers import Number
@@ -2708,3 +2709,18 @@ def test_tril_triu_indices(n, k, m, chunks):
27082709
)
27092710
else:
27102711
assert_eq(actual, expected)
2712+
2713+
2714+
def test_pickle_vectorized_routines():
2715+
"""Test that graphs that internally use np.vectorize can be pickled"""
2716+
a = da.from_array(["foo", "bar", ""])
2717+
2718+
b = da.count_nonzero(a)
2719+
assert_eq(b, 2, check_dtype=False)
2720+
b2 = pickle.loads(pickle.dumps(b))
2721+
assert_eq(b2, 2, check_dtype=False)
2722+
2723+
c = da.argwhere(a)
2724+
assert_eq(c, [[0], [1]], check_dtype=False)
2725+
c2 = pickle.loads(pickle.dumps(c))
2726+
assert_eq(c2, [[0], [1]], check_dtype=False)

0 commit comments

Comments
 (0)