-
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Closed
Labels
Description
Running into numpy array related issues when trying to use cupy arrays with einsum. Looks like it's perhaps due to still getting to the asarray branching logic which appears to be set to True somewhere internally, despite the cupy inputs. Perhaps it's being reset or not fully passed down from the original arrays to an intermediate array.
import cupy as cp
import dask.array as da
A = cp.array([0., 1, 2])
B = cp.array([[ 0, 1., 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
dA = da.from_array(A, asarray=False)
dB = da.from_array(B, asarray=False)
res = da.einsum('i,ij->i', dA, dB).compute()
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-147-06b849502129> in <module>
11 dB = da.from_array(B, asarray=False)
12
---> 13 res = da.einsum('i,ij->i', dA, dB).compute()
/conda/envs/rapids/lib/python3.7/site-packages/dask/base.py in compute(self, **kwargs)
154 dask.base.compute
155 """
--> 156 (result,) = compute(self, traverse=False, **kwargs)
157 return result
158
/conda/envs/rapids/lib/python3.7/site-packages/dask/base.py in compute(*args, **kwargs)
396 keys = [x.__dask_keys__() for x in collections]
397 postcomputes = [x.__dask_postcompute__() for x in collections]
--> 398 results = schedule(dsk, keys, **kwargs)
399 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
400
/conda/envs/rapids/lib/python3.7/site-packages/dask/threaded.py in get(dsk, result, cache, num_workers, pool, **kwargs)
74 results = get_async(pool.apply_async, len(pool._pool), dsk, result,
75 cache=cache, get_id=_thread_get_id,
---> 76 pack_exception=pack_exception, **kwargs)
77
78 # Cleanup pools associated to dead threads
/conda/envs/rapids/lib/python3.7/site-packages/dask/local.py in get_async(apply_async, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, **kwargs)
460 _execute_task(task, data) # Re-execute locally
461 else:
--> 462 raise_exception(exc, tb)
463 res, worker_id = loads(res_info)
464 state['cache'][key] = res
/conda/envs/rapids/lib/python3.7/site-packages/dask/compatibility.py in reraise(exc, tb)
110 if exc.__traceback__ is not tb:
111 raise exc.with_traceback(tb)
--> 112 raise exc
113
114 import pickle as cPickle
/conda/envs/rapids/lib/python3.7/site-packages/dask/local.py in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
228 try:
229 task, data = loads(task_info)
--> 230 result = _execute_task(task, data)
231 id = get_id()
232 result = dumps((result, id))
/conda/envs/rapids/lib/python3.7/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
116 elif istask(arg):
117 func, args = arg[0], arg[1:]
--> 118 args2 = [_execute_task(a, cache) for a in args]
119 return func(*args2)
120 elif not ishashable(arg):
/conda/envs/rapids/lib/python3.7/site-packages/dask/core.py in <listcomp>(.0)
116 elif istask(arg):
117 func, args = arg[0], arg[1:]
--> 118 args2 = [_execute_task(a, cache) for a in args]
119 return func(*args2)
120 elif not ishashable(arg):
/conda/envs/rapids/lib/python3.7/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
117 func, args = arg[0], arg[1:]
118 args2 = [_execute_task(a, cache) for a in args]
--> 119 return func(*args2)
120 elif not ishashable(arg):
121 return arg
/conda/envs/rapids/lib/python3.7/site-packages/dask/optimization.py in __call__(self, *args)
940 % (len(self.inkeys), len(args)))
941 return core.get(self.dsk, self.outkey,
--> 942 dict(zip(self.inkeys, args)))
943
944 def __reduce__(self):
/conda/envs/rapids/lib/python3.7/site-packages/dask/core.py in get(dsk, out, cache)
147 for key in toposort(dsk):
148 task = dsk[key]
--> 149 result = _execute_task(task, cache)
150 cache[key] = result
151 result = _execute_task(out, cache)
/conda/envs/rapids/lib/python3.7/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
117 func, args = arg[0], arg[1:]
118 args2 = [_execute_task(a, cache) for a in args]
--> 119 return func(*args2)
120 elif not ishashable(arg):
121 return arg
/conda/envs/rapids/lib/python3.7/site-packages/dask/compatibility.py in apply(func, args, kwargs)
91 def apply(func, args, kwargs=None):
92 if kwargs:
---> 93 return func(*args, **kwargs)
94 else:
95 return func(*args)
/conda/envs/rapids/lib/python3.7/site-packages/dask/array/einsumfuncs.py in chunk_einsum(*operands, **kwargs)
17 dtype = kwargs.pop('kernel_dtype')
18 einsum = einsum_lookup.dispatch(type(operands[0]))
---> 19 chunk = einsum(subscripts, *operands, dtype=dtype, **kwargs)
20
21 # Avoid concatenate=True in blockwise by adding 1's
/conda/envs/rapids/lib/python3.7/site-packages/numpy/core/einsumfunc.py in einsum(*operands, **kwargs)
1344 # If no optimization, run pure einsum
1345 if optimize_arg is False:
-> 1346 return c_einsum(*operands, **kwargs)
1347
1348 valid_einsum_kwargs = ['out', 'dtype', 'order', 'casting']
ValueError: object __array__ method not producing an arraySome initial debugging
> /conda/envs/rapids/lib/python3.7/site-packages/numpy/core/einsumfunc.py(1346)einsum()
1344 # If no optimization, run pure einsum
1345 if optimize_arg is False:
-> 1346 return c_einsum(*operands, **kwargs)
1347
1348 valid_einsum_kwargs = ['out', 'dtype', 'order', 'casting']
ipdb> type(operands[1])
<class 'cupy.core.core.ndarray'>
ipdb> type(operands[2])
<class 'cupy.core.core.ndarray'>