Skip to content

Dask gufunc kwarg "output_sizes" is not deep copied #4399

@griverat

Description

@griverat

What happened:

Defining the kwargs used in xr.apply_ufunc in a separate dictionary and using it multiple times in different call of the method, while using dask="paralellized", ends in an error since the dimension names in ouput_sizes (inside dask_gufunc_kwargs) are modified internally.

What you expected to happen:

Keep the same dictionary of kwargs unmodified

Minimal Complete Verifiable Example:

import numpy as np

import xarray as xr


def dummy1(data, nfft):
    return data[..., (nfft // 2) + 1 :] * 2


def dummy2(data, nfft):
    return data[..., (nfft // 2) + 1 :] / 2


def xoperations(xarr, **kwargs):
    ufunc_kwargs = dict(
        kwargs=kwargs,
        input_core_dims=[["time"]],
        output_core_dims=[["freq"]],
        dask="parallelized",
        output_dtypes=[np.float],
        dask_gufunc_kwargs=dict(output_sizes={"freq": int(kwargs["nfft"] / 2) + 1}),
    )

    ans1 = xr.apply_ufunc(dummy1, xarr, **ufunc_kwargs)
    ans2 = xr.apply_ufunc(dummy2, xarr, **ufunc_kwargs)

    return ans1, ans2


test = xr.DataArray(
    4, coords=[("time", np.arange(1000)), ("lon", np.arange(160, 300, 10))]
).chunk({"time": -1, "lon": 10})

xoperations(test, nfft=1024)

This returns

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-1-822bd3b2d4da> in <module>
     32 ).chunk({"time": -1, "lon": 10})
     33 
---> 34 xoperations(test, nfft=1024)

<ipython-input-1-822bd3b2d4da> in xoperations(xarr, **kwargs)
     23 
     24     ans1 = xr.apply_ufunc(dummy1, xarr, **ufunc_kwargs)
---> 25     ans2 = xr.apply_ufunc(dummy2, xarr, **ufunc_kwargs)
     26 
     27     return ans1, ans2

~/GitLab/xarray_test/xarray/xarray/core/computation.py in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, *args)
   1086             join=join,
   1087             exclude_dims=exclude_dims,
-> 1088             keep_attrs=keep_attrs,
   1089         )
   1090     # feed Variables directly through apply_variable_ufunc

~/GitLab/xarray_test/xarray/xarray/core/computation.py in apply_dataarray_vfunc(func, signature, join, exclude_dims, keep_attrs, *args)
    260 
    261     data_vars = [getattr(a, "variable", a) for a in args]
--> 262     result_var = func(*data_vars)
    263 
    264     if signature.num_outputs > 1:

~/GitLab/xarray_test/xarray/xarray/core/computation.py in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, vectorize, keep_attrs, dask_gufunc_kwargs, *args)
    632                     if key not in signature.all_output_core_dims:
    633                         raise ValueError(
--> 634                             f"dimension '{key}' in 'output_sizes' must correspond to output_core_dims"
    635                         )
    636                     output_sizes_renamed[signature.dims_map[key]] = value

ValueError: dimension 'dim0' in 'output_sizes' must correspond to output_core_dims

It is easily verifiable by sneaking a print statement before and after calling the first apply_ufunc. Everything is the same but the dimension names in output_sizes

{'kwargs': {'nfft': 1024}, 'input_core_dims': [['time']], 'output_core_dims': [['freq']], 'dask': 'parallelized', 'output_dtypes': [<class 'float'>], 'dask_gufunc_kwargs': {'output_sizes': {'freq': 513}}}
{'kwargs': {'nfft': 1024}, 'input_core_dims': [['time']], 'output_core_dims': [['freq']], 'dask': 'parallelized', 'output_dtypes': [<class 'float'>], 'dask_gufunc_kwargs': {'output_sizes': {'dim0': 513}}}

Anything else we need to know?:

I have a fork with a fix ready to be sent as a PR. I just imported the copy module and used deepcopy like this

dask_gufunc_kwargs = copy.deepcopy(dask_gufunc_kwargs)

around here

if dask == "parallelized":
if dask_gufunc_kwargs is None:
dask_gufunc_kwargs = {}
# todo: remove warnings after deprecation cycle
if meta is not None:
warnings.warn(
"``meta`` should be given in the ``dask_gufunc_kwargs`` parameter."
" It will be removed as direct parameter in a future version.",

If it's good enough then I can send the PR.

Environment:

Output of xr.show_versions()

INSTALLED VERSIONS

commit: 2acd0fc
python: 3.7.8 | packaged by conda-forge | (default, Jul 31 2020, 02:25:08)
[GCC 7.5.0]
python-bits: 64
OS: Linux
OS-release: 3.12.74-60.64.40-default
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: None
LANG: en_US.UTF-8
LOCALE: en_US.UTF-8
libhdf5: 1.10.5
libnetcdf: 4.7.4

xarray: 0.14.2.dev337+g2acd0fc6
pandas: 1.1.1
numpy: 1.19.1
scipy: 1.5.2
netCDF4: 1.5.3
pydap: installed
h5netcdf: 0.8.1
h5py: 2.10.0
Nio: 1.5.5
zarr: 2.4.0
cftime: 1.2.1
nc_time_axis: 1.2.0
PseudoNetCDF: installed
rasterio: 1.1.5
cfgrib: 0.9.8.4
iris: 2.4.0
bottleneck: 1.3.2
dask: 2.25.0
distributed: 2.25.0
matplotlib: 3.3.1
cartopy: 0.18.0
seaborn: 0.10.1
numbagg: installed
pint: 0.15
setuptools: 49.6.0.post20200814
pip: 20.2.2
conda: None
pytest: 6.0.1
IPython: 7.18.1
sphinx: None

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions