-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Description
What happened:
Defining the kwargs used in xr.apply_ufunc in a separate dictionary and using it multiple times in different call of the method, while using dask="paralellized", ends in an error since the dimension names in ouput_sizes (inside dask_gufunc_kwargs) are modified internally.
What you expected to happen:
Keep the same dictionary of kwargs unmodified
Minimal Complete Verifiable Example:
import numpy as np
import xarray as xr
def dummy1(data, nfft):
return data[..., (nfft // 2) + 1 :] * 2
def dummy2(data, nfft):
return data[..., (nfft // 2) + 1 :] / 2
def xoperations(xarr, **kwargs):
ufunc_kwargs = dict(
kwargs=kwargs,
input_core_dims=[["time"]],
output_core_dims=[["freq"]],
dask="parallelized",
output_dtypes=[np.float],
dask_gufunc_kwargs=dict(output_sizes={"freq": int(kwargs["nfft"] / 2) + 1}),
)
ans1 = xr.apply_ufunc(dummy1, xarr, **ufunc_kwargs)
ans2 = xr.apply_ufunc(dummy2, xarr, **ufunc_kwargs)
return ans1, ans2
test = xr.DataArray(
4, coords=[("time", np.arange(1000)), ("lon", np.arange(160, 300, 10))]
).chunk({"time": -1, "lon": 10})
xoperations(test, nfft=1024)This returns
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-1-822bd3b2d4da> in <module>
32 ).chunk({"time": -1, "lon": 10})
33
---> 34 xoperations(test, nfft=1024)
<ipython-input-1-822bd3b2d4da> in xoperations(xarr, **kwargs)
23
24 ans1 = xr.apply_ufunc(dummy1, xarr, **ufunc_kwargs)
---> 25 ans2 = xr.apply_ufunc(dummy2, xarr, **ufunc_kwargs)
26
27 return ans1, ans2
~/GitLab/xarray_test/xarray/xarray/core/computation.py in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, *args)
1086 join=join,
1087 exclude_dims=exclude_dims,
-> 1088 keep_attrs=keep_attrs,
1089 )
1090 # feed Variables directly through apply_variable_ufunc
~/GitLab/xarray_test/xarray/xarray/core/computation.py in apply_dataarray_vfunc(func, signature, join, exclude_dims, keep_attrs, *args)
260
261 data_vars = [getattr(a, "variable", a) for a in args]
--> 262 result_var = func(*data_vars)
263
264 if signature.num_outputs > 1:
~/GitLab/xarray_test/xarray/xarray/core/computation.py in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, vectorize, keep_attrs, dask_gufunc_kwargs, *args)
632 if key not in signature.all_output_core_dims:
633 raise ValueError(
--> 634 f"dimension '{key}' in 'output_sizes' must correspond to output_core_dims"
635 )
636 output_sizes_renamed[signature.dims_map[key]] = value
ValueError: dimension 'dim0' in 'output_sizes' must correspond to output_core_dims
It is easily verifiable by sneaking a print statement before and after calling the first apply_ufunc. Everything is the same but the dimension names in output_sizes
{'kwargs': {'nfft': 1024}, 'input_core_dims': [['time']], 'output_core_dims': [['freq']], 'dask': 'parallelized', 'output_dtypes': [<class 'float'>], 'dask_gufunc_kwargs': {'output_sizes': {'freq': 513}}}
{'kwargs': {'nfft': 1024}, 'input_core_dims': [['time']], 'output_core_dims': [['freq']], 'dask': 'parallelized', 'output_dtypes': [<class 'float'>], 'dask_gufunc_kwargs': {'output_sizes': {'dim0': 513}}}Anything else we need to know?:
I have a fork with a fix ready to be sent as a PR. I just imported the copy module and used deepcopy like this
dask_gufunc_kwargs = copy.deepcopy(dask_gufunc_kwargs)around here
xarray/xarray/core/computation.py
Lines 1013 to 1020 in 2acd0fc
| if dask == "parallelized": | |
| if dask_gufunc_kwargs is None: | |
| dask_gufunc_kwargs = {} | |
| # todo: remove warnings after deprecation cycle | |
| if meta is not None: | |
| warnings.warn( | |
| "``meta`` should be given in the ``dask_gufunc_kwargs`` parameter." | |
| " It will be removed as direct parameter in a future version.", |
If it's good enough then I can send the PR.
Environment:
Output of xr.show_versions()
INSTALLED VERSIONS
commit: 2acd0fc
python: 3.7.8 | packaged by conda-forge | (default, Jul 31 2020, 02:25:08)
[GCC 7.5.0]
python-bits: 64
OS: Linux
OS-release: 3.12.74-60.64.40-default
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: None
LANG: en_US.UTF-8
LOCALE: en_US.UTF-8
libhdf5: 1.10.5
libnetcdf: 4.7.4
xarray: 0.14.2.dev337+g2acd0fc6
pandas: 1.1.1
numpy: 1.19.1
scipy: 1.5.2
netCDF4: 1.5.3
pydap: installed
h5netcdf: 0.8.1
h5py: 2.10.0
Nio: 1.5.5
zarr: 2.4.0
cftime: 1.2.1
nc_time_axis: 1.2.0
PseudoNetCDF: installed
rasterio: 1.1.5
cfgrib: 0.9.8.4
iris: 2.4.0
bottleneck: 1.3.2
dask: 2.25.0
distributed: 2.25.0
matplotlib: 3.3.1
cartopy: 0.18.0
seaborn: 0.10.1
numbagg: installed
pint: 0.15
setuptools: 49.6.0.post20200814
pip: 20.2.2
conda: None
pytest: 6.0.1
IPython: 7.18.1
sphinx: None