Skip to content

Commit 02e1b43

Browse files
committed
Fix for integer-valued indices - allows, e.g. d[:, 0] = d[:, 2]
1 parent 4440264 commit 02e1b43

File tree

2 files changed

+91
-102
lines changed

2 files changed

+91
-102
lines changed

dask/array/core.py

Lines changed: 81 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,48 +1582,52 @@ def __setitem__(self, key, value):
15821582
def parse_indices(shape, indices):
15831583
"""Reformat the indices.
15841584
1585-
The aim of this is is convert the indices to a
1586-
standardised form so that it is easier a) to check them
1587-
for validity, and b) to ascertain which chunks are touched
1588-
by the indices.
1589-
1590-
Note that indices which are decreasing (such as as
1591-
``slice(7, 3, -1)`` and ``[6, 2, 1]``) are recast as
1592-
increasing indices (``slice(4, 8, 1)`` and ``[1, 2, 6]``
1593-
respectively) to facilitate finding which chunks are
1594-
touched by the indices. The make sure that the correct
1595-
values are still assigned, the value is (effectively)
1596-
reversed along the appropriate dimensions at compute time.
1597-
1598-
Parameters
1599-
----------
1600-
shape : sequence of ints
1601-
The shape of the global array.
1602-
indices : tuple
1603-
The given indices for assignment.
1604-
1605-
Returns
1606-
-------
1607-
parsed_indices : list
1608-
The reformated indices that are equivalent to the
1609-
input indices.
1610-
mirror : list
1611-
The dimensions that need to be reversed in the
1612-
assigment value, prior to assignment.
1613-
1614-
Examples
1615-
--------
1616-
>>> parse_indices((8,), (slice(1, -1),))
1617-
(slice(1, 7, 1)] [])
1618-
>>> parse_indices((8,), ([1, 2, 4, 6],))
1619-
(array([1, 2, 4, 6]), [])
1620-
>>> parse_indices((8,), (slice(-1, 2, -1),))
1621-
(slice(3, 8, 1)] [0])
1622-
>>> parse_indices((8,), ([6, 4, 2, 1],))
1623-
(array([1, 2, 4, 6]), [0])
1585+
The aim of this is is convert the indices to a
1586+
standardised form so that it is easier a) to check them
1587+
for validity, and b) to ascertain which chunks are touched
1588+
by the indices.
1589+
1590+
Note that indices which are decreasing (such as as
1591+
``slice(7, 3, -1)`` and ``[6, 2, 1]``) are recast as
1592+
increasing indices (``slice(4, 8, 1)`` and ``[1, 2, 6]``
1593+
respectively) to facilitate finding which chunks are
1594+
touched by the indices. The make sure that the correct
1595+
values are still assigned, the value is (effectively)
1596+
reversed along the appropriate dimensions at compute time.
1597+
1598+
Parameters
1599+
----------
1600+
shape : sequence of ints
1601+
The shape of the global array.
1602+
indices : tuple
1603+
The given indices for assignment.
1604+
1605+
Returns
1606+
-------
1607+
parsed_indices : list
1608+
The reformated indices that are equivalent to the
1609+
input indices.
1610+
indices_shape : list
1611+
The shape of the parsed indices. E.g. indices of
1612+
``(slice(0,2), 5, [4,2,1)`` will have shape ``[2,3]``.
1613+
mirror : list
1614+
The dimensions that need to be reversed in the
1615+
assigment value, prior to assignment.
1616+
1617+
Examples
1618+
--------
1619+
>>> parse_indices((8,), (slice(1, -1),))
1620+
(slice(1, 7, 1)] [6], [])
1621+
>>> parse_indices((8,), ([1, 2, 4, 6],))
1622+
(array([1, 2, 4, 6]), [4], [])
1623+
>>> parse_indices((8,), (slice(-1, 2, -1),))
1624+
(slice(3, 8, 1), [5], [0])
1625+
>>> parse_indices((8,), ([6, 4, 2, 1],))
1626+
(array([1, 2, 4, 6]), [4], [0])
16241627
16251628
"""
16261629
parsed_indices = []
1630+
indices_shape = []
16271631
mirror = []
16281632

16291633
if not isinstance(indices, tuple):
@@ -1656,7 +1660,7 @@ def parse_indices(shape, indices):
16561660
parsed_indices.extend([slice(None)] * (ndim - len_parsed_indices))
16571661

16581662
if not ndim and parsed_indices:
1659-
raise IndexError("Scalar array can only be indexed with () or Ellipsis")
1663+
raise IndexError("too many indices for array")
16601664

16611665
n_lists = 0
16621666

@@ -1672,11 +1676,7 @@ def parse_indices(shape, indices):
16721676

16731677
elif isinstance(index, (int, np.integer)):
16741678
# Index is an integer
1675-
if index < 0:
1676-
index += size
1677-
1678-
index = slice(index, index + 1, 1)
1679-
is_slice = True
1679+
index = int(index)
16801680
else:
16811681
n_lists += 1
16821682
if n_lists > 1:
@@ -1799,53 +1799,26 @@ def parse_indices(shape, indices):
17991799

18001800
parsed_indices[i] = index
18011801

1802-
return parsed_indices, mirror
1803-
1804-
def size_of_index(index, size):
1805-
"""Return the number of elements resulting in applying an index to a dimension of given size.
1806-
1807-
Parameters
1808-
----------
1809-
index : slice or sequence of int
1810-
The index being applied to the sequence.
1811-
size : int, optional
1812-
The size of the dimension being indexed (ignored if
1813-
index is sequence of int).
1814-
1815-
Returns
1816-
-------
1817-
size : int
1818-
The length of the sequence resulting from applying the
1819-
index. May be zero.
1820-
1821-
Examples
1822-
--------
1823-
>>> size_of_index(slice(None, None, -2), 10)
1824-
5
1825-
>>> size_of_index([1, 4, 9], 10)
1826-
3
1827-
>>> size_of_index(slice(2, 2), 10)
1828-
0
1829-
>>> size_of_index(slice(4, 2), 10)
1830-
0
1831-
>>> size_of_index(slice(2, 4, -1), 10)
1832-
0
1833-
1834-
"""
1835-
if isinstance(index, slice):
1836-
# Index is a slice object
1837-
start, stop, step = index.indices(size)
1838-
div, mod = divmod(stop - start, step)
1839-
if div <= 0:
1840-
return 0
1841-
1842-
if mod != 0:
1843-
div += 1
1802+
# Find the shape of the indices. E.g. indices of
1803+
# (slice(0,2), 5, [4,2,1) will have shape [2, 3]. Note
1804+
# that integer indices are not inclded in the shape.
1805+
for index in parsed_indices:
1806+
if isinstance(index, slice):
1807+
# Index is a slice object
1808+
start, stop, step = index.indices(size)
1809+
div, mod = divmod(stop - start, step)
1810+
if div <= 0:
1811+
indices_shape.append(0)
1812+
else:
1813+
if mod != 0:
1814+
div += 1
18441815

1845-
return div
1816+
indices_shape.append(div)
1817+
elif not isinstance(index, int):
1818+
# Index is a sequence of integers
1819+
indices_shape.append(len(index))
18461820

1847-
# Index is a sequence of integers
1848-
return len(index)
1821+
return parsed_indices, indices_shape, mirror
18491822

18501823
def setitem(
18511824
array,
@@ -1909,6 +1882,7 @@ def setitem(
19091882
for index, (loc0, loc1), size in zip(
19101883
indices, array_location, block_info[None]["chunk-shape"]
19111884
):
1885+
integer_index = isinstance(index, int)
19121886
if isinstance(index, slice):
19131887
# Index is a slice
19141888
stop = size
@@ -1941,6 +1915,14 @@ def setitem(
19411915
n_preceeding, rem = divmod(pre[1] - pre[0], step)
19421916
if rem:
19431917
n_preceeding += 1
1918+
elif integer_index:
1919+
# Index is an integer
1920+
if not loc0 <= index < loc1:
1921+
# This block does not overlap the index
1922+
overlaps = False
1923+
break
1924+
1925+
block_index = index - loc0
19441926
else:
19451927
# Index is a list of integers
19461928
block_index = [i - loc0 for i in index if loc0 <= i < loc1]
@@ -1977,8 +1959,9 @@ def setitem(
19771959
n_preceeding = sum(1 for i in index if i < loc0)
19781960

19791961
block_indices.append(block_index)
1980-
subset_shape.append(block_index_size)
1981-
preceeding_size.append(n_preceeding)
1962+
if not integer_index:
1963+
preceeding_size.append(n_preceeding)
1964+
subset_shape.append(block_index_size)
19821965

19831966
if not overlaps:
19841967
# This block does not overlap the indices, so return
@@ -2063,13 +2046,9 @@ def setitem(
20632046

20642047
# Still here? Then parse the indices from 'key' and apply the
20652048
# assignment via map_blocks
2066-
self_shape = self.shape
20672049

20682050
# Reformat input indices
2069-
indices, mirror = parse_indices(self_shape, key)
2070-
2071-
# Find the shape implied by the indices
2072-
indices_shape = list(map(size_of_index, indices, self_shape))
2051+
indices, indices_shape, mirror = parse_indices(self.shape, key)
20732052

20742053
# Cast 'value' as a dask array
20752054
if value is np.ma.masked:
@@ -2118,7 +2097,7 @@ def setitem(
21182097
# Note that self_common_shape and value_common_shape may be
21192098
# different if there are any size 1 dimensions are being
21202099
# brodacast.
2121-
offset = self.ndim - value.ndim
2100+
offset = len(indices_shape) - value.ndim
21222101
if offset >= 0:
21232102
# self has the same number or more dimensions than 'value'
21242103
self_common_shape = indices_shape[offset:]
@@ -2130,11 +2109,12 @@ def setitem(
21302109
# 'value' has more dimensions than self
21312110
value_offset = -offset
21322111
if value_shape[:value_offset] != [1] * value_offset:
2133-
# Can only 'allow' value to have more dimensions then
2134-
# self if the extra trailing dimensions all have size
2112+
# Can only allow 'value' to have more dimensions then
2113+
# self if the extra leading dimensions all have size
21352114
# 1.
21362115
raise ValueError(
2137-
f"Can't broadcast shape {value_shape} across shape {self_shape}"
2116+
"could not broadcast input array from shape"
2117+
f"{value_shape} into shape {tuple(indices_shape)}"
21382118
)
21392119

21402120
offset = 0

dask/array/tests/test_array_core.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3615,6 +3615,8 @@ def test_setitem_extended_API():
36153615
x = np.ma.arange(60).reshape((6, 10))
36163616
dx = da.from_array(x.copy(), chunks=(2, 2))
36173617

3618+
x[:, 2] = range(6)
3619+
x[3, :] = range(10)
36183620
x[::2, ::-1] = -1
36193621
x[1::2] = -2
36203622
x[:, [3, 5, 6]] = -3
@@ -3624,14 +3626,19 @@ def test_setitem_extended_API():
36243626
x[x % 2 == 0] = -8
36253627
x[[4, 3, 1]] = -9
36263628
x[5, ...] = -10
3629+
x[..., 4] = -11
36273630
x[2:4, 5:1:-2] = -x[:2, 4:1:-2]
36283631
x[:2, :3] = [[1, 2, 3]]
36293632
x[1, 1:7:2] = np.ma.masked
36303633
x[0, 1:3] = -x[0, 4:2:-1]
36313634
x[...] = x
36323635
x[...] = x[...]
36333636
x[0] = x[-1]
3637+
x[0, :] = x[-2, :]
3638+
x[:, 1] = x[:, -3]
36343639

3640+
dx[:, 2] = range(6)
3641+
dx[3, :] = range(10)
36353642
dx[::2, ::-1] = -1
36363643
dx[1::2] = -2
36373644
dx[:, [3, 5, 6]] = -3
@@ -3641,13 +3648,15 @@ def test_setitem_extended_API():
36413648
dx[dx % 2 == 0] = -8
36423649
dx[[4, 3, 1]] = -9
36433650
dx[5, ...] = -10
3651+
dx[..., 4] = -11
36443652
dx[2:4, 5:1:-2] = -dx[:2, 4:1:-2]
36453653
dx[:2, :3] = [[1, 2, 3]]
36463654
dx[1, 1:7:2] = np.ma.masked
36473655
dx[0, 1:3] = -dx[0, 4:2:-1]
36483656
dx[...] = dx
36493657
dx[...] = dx[...]
3650-
dx[0] = dx[-1]
3658+
dx[0, :] = dx[-2, :]
3659+
dx[:, 1] = dx[:, -3]
36513660

36523661
assert_eq(x, dx.compute())
36533662
assert_eq(x.mask, da.ma.getmaskarray(dx))

0 commit comments

Comments
 (0)