77from itertools import product
88from numbers import Integral , Number
99from operator import itemgetter
10+ from typing import TYPE_CHECKING
1011
1112import numpy as np
1213from tlz import concat , memoize , merge , pluck
1718from dask .highlevelgraph import HighLevelGraph
1819from dask .utils import cached_cumsum , is_arraylike
1920
21+ if TYPE_CHECKING :
22+ from dask .array import Array
23+
2024colon = slice (None , None , None )
2125
2226
@@ -2164,7 +2168,23 @@ def setitem(x, v, indices):
21642168 return x
21652169
21662170
2167- def take_along_axis (arr , indices , axis ):
2171+ def take_along_axis (arr : Array , indices : Array , axis : int ):
2172+ """Slice a dask ndarray according to dask ndarray of indices along an axis.
2173+
2174+ Parameters
2175+ ----------
2176+ arr: dask.array.Array, dtype=Any
2177+ Data array.
2178+ indices: dask.array.Array, dtype=int64
2179+ Indices of interest.
2180+ axis:int
2181+ The axis along which the indices are from.
2182+
2183+ Returns
2184+ -------
2185+ out: dask.array.Array
2186+ The indexed arr.
2187+ """
21682188 from dask .array .core import Array , blockwise , from_array
21692189
21702190 if axis < 0 :
@@ -2179,29 +2199,30 @@ def take_along_axis(arr, indices, axis):
21792199 # e.g. chunks=(..., (5, 3, 4), ...) -> offset=[0, 5, 8]
21802200 offset = np .roll (np .cumsum (arr .chunks [axis ]), 1 )
21812201 offset [0 ] = 0
2182- offset = from_array (offset , chunks = 1 )
2202+ da_offset = from_array (offset , chunks = 1 )
21832203 # Tamper with the declared chunks of offset to make blockwise align it with
2184- # x[axis]
2185- offset = Array (offset .dask , offset .name , (arr .chunks [axis ],), offset .dtype )
2204+ # arr[axis]
2205+ da_offset = Array (
2206+ da_offset .dask , da_offset .name , (arr .chunks [axis ],), da_offset .dtype
2207+ )
21862208 # Define axis labels for blockwise
2187- x_axes = tuple (range (arr .ndim ))
2209+ arr_axes = tuple (range (arr .ndim ))
21882210 idx_label = (arr .ndim ,) # arbitrary unused
2189- index_axes = x_axes [:axis ] + idx_label + x_axes [axis + 1 :]
2211+ index_axes = arr_axes [:axis ] + idx_label + arr_axes [axis + 1 :]
21902212 offset_axes = (axis ,)
2191- p_axes = x_axes [: axis + 1 ] + idx_label + x_axes [axis + 1 :]
2192- # Calculate the cartesian product of every chunk of x vs
2193- # every chunk of index
2213+ p_axes = arr_axes [: axis + 1 ] + idx_label + arr_axes [axis + 1 :]
2214+ # Compute take_along_axis for each chunk
2215+ # TODO: Add meta argument for blockwise ?
21942216 p = blockwise (
21952217 take_along_axis_chunk ,
21962218 p_axes ,
21972219 arr ,
2198- x_axes ,
2220+ arr_axes ,
21992221 indices ,
22002222 index_axes ,
2201- offset ,
2223+ da_offset ,
22022224 offset_axes ,
2203- # align_arrays=False,
2204- x_size = arr .shape [axis ],
2225+ arr_size = arr .shape [axis ],
22052226 axis = axis ,
22062227 dtype = arr .dtype ,
22072228 )
0 commit comments