Skip to content

Commit 48e3a86

Browse files
authored
Simplify and lazify broadcast_to_shape (#5307)
* working for all except masked lazy * use moveaxis * handle lazy masked case * add tests for is_lazy_masked_data * whatsnew * check compute isn't called * update docstring
1 parent 299b335 commit 48e3a86

File tree

5 files changed

+90
-31
lines changed

5 files changed

+90
-31
lines changed

docs/src/whatsnew/latest.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ This document explains the changes made to Iris for this release
3030
✨ Features
3131
===========
3232

33-
#. N/A
33+
#. `@rcomer`_ rewrote :func:`~iris.util.broadcast_to_shape` so it now handles
34+
lazy data. (:pull:`5307`)
3435

3536

3637
🐛 Bugs Fixed

lib/iris/_lazy_data.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,15 @@ def is_lazy_data(data):
4747
return result
4848

4949

50+
def is_lazy_masked_data(data):
51+
"""
52+
Return True if the argument is both an Iris 'lazy' data array and the
53+
underlying array is of masked type. Otherwise return False.
54+
55+
"""
56+
return is_lazy_data(data) and ma.isMA(da.utils.meta_from_array(data))
57+
58+
5059
@lru_cache
5160
def _optimum_chunksize_internals(
5261
chunks,
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright Iris contributors
2+
#
3+
# This file is part of Iris and is released under the LGPL license.
4+
# See COPYING and COPYING.LESSER in the root of the repository for full
5+
# licensing details.
6+
"""Test function :func:`iris._lazy data.is_lazy_masked_data`."""
7+
8+
import dask.array as da
9+
import numpy as np
10+
import pytest
11+
12+
from iris._lazy_data import is_lazy_masked_data
13+
14+
real_arrays = [
15+
np.arange(3),
16+
np.ma.array(range(3)),
17+
np.ma.array(range(3), mask=[0, 1, 1]),
18+
]
19+
lazy_arrays = [da.from_array(arr) for arr in real_arrays]
20+
21+
22+
@pytest.mark.parametrize(
23+
"arr, expected", zip(real_arrays + lazy_arrays, [False] * 4 + [True] * 2)
24+
)
25+
def test_is_lazy_masked_data(arr, expected):
26+
result = is_lazy_masked_data(arr)
27+
assert result is expected

lib/iris/tests/unit/util/test_broadcast_to_shape.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
# importing anything else
1010
import iris.tests as tests # isort:skip
1111

12+
from unittest import mock
13+
14+
import dask
15+
import dask.array as da
1216
import numpy as np
1317
import numpy.ma as ma
1418

@@ -40,6 +44,17 @@ def test_added_dimensions_transpose(self):
4044
for j in range(4):
4145
self.assertArrayEqual(b[i, :, j, :].T, a)
4246

47+
@mock.patch.object(dask.base, "compute", wraps=dask.base.compute)
48+
def test_lazy_added_dimensions_transpose(self, mocked_compute):
49+
# adding dimensions and having the dimensions of the input
50+
# transposed
51+
a = da.random.random([2, 3])
52+
b = broadcast_to_shape(a, (5, 3, 4, 2), (3, 1))
53+
mocked_compute.assert_not_called()
54+
for i in range(5):
55+
for j in range(4):
56+
self.assertArrayEqual(b[i, :, j, :].T.compute(), a.compute())
57+
4358
def test_masked(self):
4459
# masked arrays are also accepted
4560
a = np.random.random([2, 3])
@@ -49,6 +64,19 @@ def test_masked(self):
4964
for j in range(4):
5065
self.assertMaskedArrayEqual(b[i, :, j, :].T, m)
5166

67+
@mock.patch.object(dask.base, "compute", wraps=dask.base.compute)
68+
def test_lazy_masked(self, mocked_compute):
69+
# masked arrays are also accepted
70+
a = np.random.random([2, 3])
71+
m = da.ma.masked_array(a, mask=[[0, 1, 0], [0, 1, 1]])
72+
b = broadcast_to_shape(m, (5, 3, 4, 2), (3, 1))
73+
mocked_compute.assert_not_called()
74+
for i in range(5):
75+
for j in range(4):
76+
self.assertMaskedArrayEqual(
77+
b[i, :, j, :].compute().T, m.compute()
78+
)
79+
5280
def test_masked_degenerate(self):
5381
# masked arrays can have degenerate masks too
5482
a = np.random.random([2, 3])

lib/iris/util.py

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import numpy.ma as ma
2424

2525
from iris._deprecation import warn_deprecated
26-
from iris._lazy_data import as_concrete_data, is_lazy_data
26+
from iris._lazy_data import as_concrete_data, is_lazy_data, is_lazy_masked_data
2727
from iris.common import SERVICES
2828
from iris.common.lenient import _lenient_client
2929
import iris.exceptions
@@ -34,8 +34,7 @@ def broadcast_to_shape(array, shape, dim_map):
3434
Broadcast an array to a given shape.
3535
3636
Each dimension of the array must correspond to a dimension in the
37-
given shape. Striding is used to repeat the array until it matches
38-
the desired shape, returning repeated views on the original array.
37+
given shape. The result is a read-only view (see :func:`numpy.broadcast_to`).
3938
If you need to write to the resulting array, make a copy first.
4039
4140
Args:
@@ -76,35 +75,30 @@ def broadcast_to_shape(array, shape, dim_map):
7675
See more at :doc:`/userguide/real_and_lazy_data`.
7776
7877
"""
79-
if len(dim_map) != array.ndim:
80-
# We must check for this condition here because we cannot rely on
81-
# getting an error from numpy if the dim_map argument is not the
82-
# correct length, we might just get a segfault.
83-
raise ValueError(
84-
"dim_map must have an entry for every "
85-
"dimension of the input array"
86-
)
78+
n_orig_dims = len(array.shape)
79+
n_new_dims = len(shape) - n_orig_dims
80+
array = array.reshape(array.shape + (1,) * n_new_dims)
81+
82+
# Get dims in required order.
83+
array = np.moveaxis(array, range(n_orig_dims), dim_map)
84+
new_array = np.broadcast_to(array, shape)
8785

88-
def _broadcast_helper(a):
89-
strides = [0] * len(shape)
90-
for idim, dim in enumerate(dim_map):
91-
if shape[dim] != a.shape[idim]:
92-
# We'll get garbage values if the dimensions of array are not
93-
# those indicated by shape.
94-
raise ValueError("shape and array are not compatible")
95-
strides[dim] = a.strides[idim]
96-
return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
97-
98-
array_view = _broadcast_helper(array)
99-
if ma.isMaskedArray(array):
100-
if array.mask is ma.nomask:
101-
# Degenerate masks can be applied as-is.
102-
mask_view = array.mask
86+
if ma.isMA(array):
87+
# broadcast_to strips masks so we need to handle them explicitly.
88+
mask = ma.getmask(array)
89+
if mask is ma.nomask:
90+
new_mask = ma.nomask
10391
else:
104-
# Mask arrays need to be handled in the same way as the data array.
105-
mask_view = _broadcast_helper(array.mask)
106-
array_view = ma.array(array_view, mask=mask_view)
107-
return array_view
92+
new_mask = np.broadcast_to(mask, shape)
93+
new_array = ma.array(new_array, mask=new_mask)
94+
95+
elif is_lazy_masked_data(array):
96+
# broadcast_to strips masks so we need to handle them explicitly.
97+
mask = da.ma.getmaskarray(array)
98+
new_mask = da.broadcast_to(mask, shape)
99+
new_array = da.ma.masked_array(new_array, new_mask)
100+
101+
return new_array
108102

109103

110104
def delta(ndarray, dimension, circular=False):

0 commit comments

Comments
 (0)