Skip to content

Commit 105ef10

Browse files
author
Abel Aoun
committed
DOC: Add documentation for take_along_axis
1 parent 06ae5c0 commit 105ef10

File tree

3 files changed

+63
-19
lines changed

3 files changed

+63
-19
lines changed

dask/array/chunk.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -433,16 +433,38 @@ def getitem(obj, index):
433433
return result
434434

435435

436-
def take_along_axis_chunk(arr, indices, offset, x_size, axis):
437-
# Needed when idx is unsigned
436+
def take_along_axis_chunk(
437+
arr: np.ndarray, indices: np.ndarray, offset: np.ndarray, arr_size: int, axis: int
438+
):
439+
"""Slice an ndarray according to ndarray indices along an axis.
440+
441+
Parameters
442+
----------
443+
arr: np.ndarray, dtype=Any
444+
The data array.
445+
indices: np.ndarray, dtype=int64
446+
The indices of interest.
447+
offset: np.ndarray, shape=(1, ), dtype=int64
448+
Index of the first element along axis of the current chunk of arr
449+
arr_size: int
450+
Total size of the arr da.Array along axis
451+
axis: int
452+
The axis along which the indices are from.
453+
454+
Returns
455+
-------
456+
out: np.ndarray
457+
The indexed arr.
458+
"""
459+
# Needed when indices is unsigned
438460
indices = indices.astype(np.int64)
439461
# Normalize negative indices
440-
indices = np.where(indices < 0, indices + x_size, indices)
462+
indices = np.where(indices < 0, indices + arr_size, indices)
441463
# A chunk of the offset dask Array is a numpy array with shape (1, ).
442464
# It indicates the index of the first element along axis of the current
443-
# chunk of x.
465+
# chunk of arr.
444466
indices = indices - offset
445-
# Drop elements of idx that do not fall inside the current chunk of x
467+
# Drop elements of idx that do not fall inside the current chunk of arr.
446468
idx_filter = (indices >= 0) & (indices < arr.shape[axis])
447469
indices[~idx_filter] = 0
448470
res = np.take_along_axis(arr, indices, axis=axis)

dask/array/slicing.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from itertools import product
88
from numbers import Integral, Number
99
from operator import itemgetter
10+
from typing import TYPE_CHECKING
1011

1112
import numpy as np
1213
from tlz import concat, memoize, merge, pluck
@@ -17,6 +18,9 @@
1718
from dask.highlevelgraph import HighLevelGraph
1819
from dask.utils import cached_cumsum, is_arraylike
1920

21+
if TYPE_CHECKING:
22+
from dask.array import Array
23+
2024
colon = slice(None, None, None)
2125

2226

@@ -2164,7 +2168,23 @@ def setitem(x, v, indices):
21642168
return x
21652169

21662170

2167-
def take_along_axis(arr, indices, axis):
2171+
def take_along_axis(arr: Array, indices: Array, axis: int):
2172+
"""Slice a dask ndarray according to dask ndarray of indices along an axis.
2173+
2174+
Parameters
2175+
----------
2176+
arr: dask.array.Array, dtype=Any
2177+
Data array.
2178+
indices: dask.array.Array, dtype=int64
2179+
Indices of interest.
2180+
axis:int
2181+
The axis along which the indices are from.
2182+
2183+
Returns
2184+
-------
2185+
out: dask.array.Array
2186+
The indexed arr.
2187+
"""
21682188
from dask.array.core import Array, blockwise, from_array
21692189

21702190
if axis < 0:
@@ -2179,29 +2199,30 @@ def take_along_axis(arr, indices, axis):
21792199
# e.g. chunks=(..., (5, 3, 4), ...) -> offset=[0, 5, 8]
21802200
offset = np.roll(np.cumsum(arr.chunks[axis]), 1)
21812201
offset[0] = 0
2182-
offset = from_array(offset, chunks=1)
2202+
da_offset = from_array(offset, chunks=1)
21832203
# Tamper with the declared chunks of offset to make blockwise align it with
2184-
# x[axis]
2185-
offset = Array(offset.dask, offset.name, (arr.chunks[axis],), offset.dtype)
2204+
# arr[axis]
2205+
da_offset = Array(
2206+
da_offset.dask, da_offset.name, (arr.chunks[axis],), da_offset.dtype
2207+
)
21862208
# Define axis labels for blockwise
2187-
x_axes = tuple(range(arr.ndim))
2209+
arr_axes = tuple(range(arr.ndim))
21882210
idx_label = (arr.ndim,) # arbitrary unused
2189-
index_axes = x_axes[:axis] + idx_label + x_axes[axis + 1 :]
2211+
index_axes = arr_axes[:axis] + idx_label + arr_axes[axis + 1 :]
21902212
offset_axes = (axis,)
2191-
p_axes = x_axes[: axis + 1] + idx_label + x_axes[axis + 1 :]
2192-
# Calculate the cartesian product of every chunk of x vs
2193-
# every chunk of index
2213+
p_axes = arr_axes[: axis + 1] + idx_label + arr_axes[axis + 1 :]
2214+
# Compute take_along_axis for each chunk
2215+
# TODO: Add meta argument for blockwise ?
21942216
p = blockwise(
21952217
take_along_axis_chunk,
21962218
p_axes,
21972219
arr,
2198-
x_axes,
2220+
arr_axes,
21992221
indices,
22002222
index_axes,
2201-
offset,
2223+
da_offset,
22022224
offset_axes,
2203-
# align_arrays=False,
2204-
x_size=arr.shape[axis],
2225+
arr_size=arr.shape[axis],
22052226
axis=axis,
22062227
dtype=arr.dtype,
22072228
)

docs/source/array-slicing.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ supports the following:
1010
* Slicing one :class:`~dask.array.Array` with an :class:`~dask.array.Array` of bools: ``x[x > 0]``
1111
* Slicing one :class:`~dask.array.Array` with a zero or one-dimensional :class:`~dask.array.Array`
1212
of ints: ``a[b.argtopk(5)]``
13+
* Slicing one :class:`~dask.array.Array` with a multi-dimensional :class:`~dask.array.Array` of ints.
14+
This can be done using ``dask.array.slicing.take_along_axis``.
1315

1416
However, it does not currently support the following:
1517

@@ -19,7 +21,6 @@ However, it does not currently support the following:
1921
issue. Also, users interested in this should take a look at
2022
:attr:`~dask.array.Array.vindex`.
2123

22-
* Slicing one :class:`~dask.array.Array` with a multi-dimensional :class:`~dask.array.Array` of ints
2324

2425
.. _array.slicing.efficiency:
2526

0 commit comments

Comments
 (0)