Skip to content

[BUG] dask.array.einsum doesn't support CuPy backend #4898

@beckernick

Description

@beckernick

Running into numpy array related issues when trying to use cupy arrays with einsum. Looks like it's perhaps due to still getting to the asarray branching logic which appears to be set to True somewhere internally, despite the cupy inputs. Perhaps it's being reset or not fully passed down from the original arrays to an intermediate array.

import cupy as cp
import dask.array as daA = cp.array([0., 1, 2])
​
B = cp.array([[ 0,  1.,  2,  3],
              [ 4,  5,  6,  7],
              [ 8,  9, 10, 11]])
​
dA = da.from_array(A, asarray=False)
dB = da.from_array(B, asarray=False)
​
res = da.einsum('i,ij->i', dA, dB).compute()
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-147-06b849502129> in <module>
     11 dB = da.from_array(B, asarray=False)
     12 
---> 13 res = da.einsum('i,ij->i', dA, dB).compute()

/conda/envs/rapids/lib/python3.7/site-packages/dask/base.py in compute(self, **kwargs)
    154         dask.base.compute
    155         """
--> 156         (result,) = compute(self, traverse=False, **kwargs)
    157         return result
    158 

/conda/envs/rapids/lib/python3.7/site-packages/dask/base.py in compute(*args, **kwargs)
    396     keys = [x.__dask_keys__() for x in collections]
    397     postcomputes = [x.__dask_postcompute__() for x in collections]
--> 398     results = schedule(dsk, keys, **kwargs)
    399     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
    400 

/conda/envs/rapids/lib/python3.7/site-packages/dask/threaded.py in get(dsk, result, cache, num_workers, pool, **kwargs)
     74     results = get_async(pool.apply_async, len(pool._pool), dsk, result,
     75                         cache=cache, get_id=_thread_get_id,
---> 76                         pack_exception=pack_exception, **kwargs)
     77 
     78     # Cleanup pools associated to dead threads

/conda/envs/rapids/lib/python3.7/site-packages/dask/local.py in get_async(apply_async, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, **kwargs)
    460                         _execute_task(task, data)  # Re-execute locally
    461                     else:
--> 462                         raise_exception(exc, tb)
    463                 res, worker_id = loads(res_info)
    464                 state['cache'][key] = res

/conda/envs/rapids/lib/python3.7/site-packages/dask/compatibility.py in reraise(exc, tb)
    110         if exc.__traceback__ is not tb:
    111             raise exc.with_traceback(tb)
--> 112         raise exc
    113 
    114     import pickle as cPickle

/conda/envs/rapids/lib/python3.7/site-packages/dask/local.py in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    228     try:
    229         task, data = loads(task_info)
--> 230         result = _execute_task(task, data)
    231         id = get_id()
    232         result = dumps((result, id))

/conda/envs/rapids/lib/python3.7/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
    116     elif istask(arg):
    117         func, args = arg[0], arg[1:]
--> 118         args2 = [_execute_task(a, cache) for a in args]
    119         return func(*args2)
    120     elif not ishashable(arg):

/conda/envs/rapids/lib/python3.7/site-packages/dask/core.py in <listcomp>(.0)
    116     elif istask(arg):
    117         func, args = arg[0], arg[1:]
--> 118         args2 = [_execute_task(a, cache) for a in args]
    119         return func(*args2)
    120     elif not ishashable(arg):

/conda/envs/rapids/lib/python3.7/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
    117         func, args = arg[0], arg[1:]
    118         args2 = [_execute_task(a, cache) for a in args]
--> 119         return func(*args2)
    120     elif not ishashable(arg):
    121         return arg

/conda/envs/rapids/lib/python3.7/site-packages/dask/optimization.py in __call__(self, *args)
    940                              % (len(self.inkeys), len(args)))
    941         return core.get(self.dsk, self.outkey,
--> 942                         dict(zip(self.inkeys, args)))
    943 
    944     def __reduce__(self):

/conda/envs/rapids/lib/python3.7/site-packages/dask/core.py in get(dsk, out, cache)
    147     for key in toposort(dsk):
    148         task = dsk[key]
--> 149         result = _execute_task(task, cache)
    150         cache[key] = result
    151     result = _execute_task(out, cache)

/conda/envs/rapids/lib/python3.7/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
    117         func, args = arg[0], arg[1:]
    118         args2 = [_execute_task(a, cache) for a in args]
--> 119         return func(*args2)
    120     elif not ishashable(arg):
    121         return arg

/conda/envs/rapids/lib/python3.7/site-packages/dask/compatibility.py in apply(func, args, kwargs)
     91     def apply(func, args, kwargs=None):
     92         if kwargs:
---> 93             return func(*args, **kwargs)
     94         else:
     95             return func(*args)

/conda/envs/rapids/lib/python3.7/site-packages/dask/array/einsumfuncs.py in chunk_einsum(*operands, **kwargs)
     17     dtype = kwargs.pop('kernel_dtype')
     18     einsum = einsum_lookup.dispatch(type(operands[0]))
---> 19     chunk = einsum(subscripts, *operands, dtype=dtype, **kwargs)
     20 
     21     # Avoid concatenate=True in blockwise by adding 1's

/conda/envs/rapids/lib/python3.7/site-packages/numpy/core/einsumfunc.py in einsum(*operands, **kwargs)
   1344     # If no optimization, run pure einsum
   1345     if optimize_arg is False:
-> 1346         return c_einsum(*operands, **kwargs)
   1347 
   1348     valid_einsum_kwargs = ['out', 'dtype', 'order', 'casting']

ValueError: object __array__ method not producing an array

Some initial debugging

> /conda/envs/rapids/lib/python3.7/site-packages/numpy/core/einsumfunc.py(1346)einsum()
   1344     # If no optimization, run pure einsum
   1345     if optimize_arg is False:
-> 1346         return c_einsum(*operands, **kwargs)
   1347 
   1348     valid_einsum_kwargs = ['out', 'dtype', 'order', 'casting']

ipdb>  type(operands[1])
<class 'cupy.core.core.ndarray'>
ipdb>  type(operands[2])
<class 'cupy.core.core.ndarray'>

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions