-
-
Notifications
You must be signed in to change notification settings - Fork 12.2k
ENH: Implement take_along_axis as described in #8708 #8714
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,10 +16,219 @@ | |
| __all__ = [ | ||
| 'column_stack', 'row_stack', 'dstack', 'array_split', 'split', | ||
| 'hsplit', 'vsplit', 'dsplit', 'apply_over_axes', 'expand_dims', | ||
| 'apply_along_axis', 'kron', 'tile', 'get_array_wrap' | ||
| 'apply_along_axis', 'kron', 'tile', 'get_array_wrap', 'take_along_axis', | ||
| 'put_along_axis' | ||
| ] | ||
|
|
||
|
|
||
| def _make_along_axis_idx(arr, indices, axis): | ||
| # compute dimensions to iterate over | ||
| shape_ones = (1,) * indices.ndim | ||
| ins_ndim = indices.ndim - (arr.ndim - 1) # inserted dimensions | ||
| if ins_ndim < 0: | ||
| raise ValueError("`indices` must have ndim >= arr.ndim - 1") | ||
| dest_dims = list(range(axis)) + [None] + list(range(axis+ins_ndim, indices.ndim)) | ||
|
|
||
| # build a fancy index, consisting of orthogonal aranges, with the | ||
| # requested index inserted at the right location | ||
| fancy_index = [] | ||
| for dim, n in zip(dest_dims, arr.shape): | ||
| if dim is None: | ||
| fancy_index.append(indices) | ||
| else: | ||
| ind_shape = shape_ones[:dim] + (-1,) + shape_ones[dim+1:] | ||
| fancy_index.append(_nx.arange(n).reshape(ind_shape)) | ||
|
|
||
| return tuple(fancy_index) | ||
|
|
||
|
|
||
| def take_along_axis(arr, indices, axis): | ||
| """ | ||
| Take the elements described by `indices` along each 1-D slice of the given | ||
| `axis`, matching up subspaces of arr and indices. | ||
|
|
||
| This function can be used to index with the result of `argsort`, `argmax`, | ||
| and other `arg` functions. | ||
|
|
||
| This is equivalent to (but faster than) the following use of `ndindex` and | ||
| `s_`, which sets each of ``ii``, ``jj``, and ``kk`` to a tuple of indices:: | ||
|
|
||
| # extract the subshapes as labelled in the docs below | ||
| Ni, Nk = a.shape[:axis], a.shape[axis+1:] | ||
| Nj = indices.shape[len(Ni):indices.ndim - len(Nk)] | ||
|
|
||
| for ii in ndindex(Ni): | ||
| for jj in ndindex(Nj): | ||
| for kk in ndindex(Nk): | ||
| a_1d = a[ii + s_[:,] + kk] | ||
| out[ii + jj + kk] = a_1d[indices[ii + jj + kk]] | ||
|
|
||
| Equivalently, eliminating the inner loop, this can be expressed as:: | ||
|
|
||
| Ni, Nk = a.shape[:axis], a.shape[axis+1:] | ||
| for ii in ndindex(Ni): | ||
| for kk in ndindex(Nk): | ||
| a_1d = a[ii + s_[:,] + kk] | ||
| out[ii + s_[...,] + kk] = a_1d[indices[ii + s_[...,] + kk]] | ||
|
|
||
| .. versionadded:: 1.15.0 | ||
|
|
||
| Parameters | ||
| ---------- | ||
| arr: array_like (Ni..., M, Nk...) | ||
| source array | ||
| indices: array_like (Ni..., Nj..., Nk...) | ||
| indices to take along each 1d slice of `arr` | ||
| axis: int | ||
| the axis to take 1d slices along | ||
|
|
||
| Returns | ||
| ------- | ||
| out: ndarray (Ni..., Nj..., Nk...) | ||
| The indexed result, as described above. | ||
|
|
||
| See Also | ||
| -------- | ||
| take : Take along an axis without matching up subspaces | ||
|
|
||
| Examples | ||
| -------- | ||
|
|
||
| For this sample array | ||
|
|
||
| >>> a = np.array([[10, 30, 20], [60, 40, 50]]) | ||
|
|
||
| We can sort either by using sort directly, or argsort and this function | ||
|
|
||
| >>> np.sort(a, axis=1) | ||
| array([[10, 20, 30], | ||
| [40, 50, 60]]) | ||
| >>> ai = np.argsort(a, axis=1) | ||
| >>> ai | ||
| array([[0, 2, 1], | ||
| [1, 2, 0]], dtype=int64) | ||
| >>> np.take_along_axis(a, ai, axis=1) | ||
| array([[10, 20, 30], | ||
| [40, 50, 60]]) | ||
|
|
||
| The same works for max and min: | ||
|
|
||
| >>> np.max(a, axis=1) | ||
| array([30, 60]) | ||
| >>> ai = np.argmax(a, axis=1) | ||
| >>> ai | ||
| array([1, 0], dtype=int64) | ||
| >>> np.take_along_axis(a, ai, axis=1) | ||
| array([30, 60]) | ||
|
|
||
| If we want to get the max and min at the same time, we can stack the | ||
| indices first | ||
|
|
||
| >>> ai_min = np.argmin(a, axis=1) | ||
| >>> ai_max = np.argmax(a, axis=1) | ||
| >>> ai = np.stack([ai_min, ai_max], axis=axis) | ||
| >>> ai | ||
| array([[0, 1], | ||
| [1, 0]], dtype=int64) | ||
| >>> np.take_along_axis(a, ai, axis=1) | ||
| array([[10, 30], | ||
| [40, 60]]) | ||
| """ | ||
| # normalize inputs | ||
| arr = asanyarray(arr) | ||
| indices = asanyarray(indices) | ||
| if axis is None: | ||
|
||
| arr = arr.ravel() | ||
| axis = 0 | ||
| else: | ||
| axis = normalize_axis_index(axis, arr.ndim) | ||
| if not _nx.issubdtype(indices.dtype, _nx.integer): | ||
| raise IndexError('arrays used as indices must be of integer type') | ||
|
||
|
|
||
| # use the fancy index | ||
| return arr[_make_along_axis_idx(arr, indices, axis)] | ||
|
|
||
|
|
||
| def put_along_axis(arr, indices, values, axis): | ||
|
||
| """ | ||
| Put `values` at the elements described by `indices` along each 1-D slice of | ||
| the given `axis` of `arr`, matching up subspaces of arr and indices. | ||
|
|
||
| This function can be used to index with the result of `argsort`, `argmax`, | ||
| and other `arg` functions. | ||
|
|
||
| This is equivalent to (but faster than) the following use of `ndindex` and | ||
| `s_`, which sets each of ``ii``, ``jj``, and ``kk`` to a tuple of indices:: | ||
|
|
||
| # extract the subshapes as labelled in the docs below | ||
| Ni, Nk = a.shape[:axis], a.shape[axis+1:] | ||
| Nj = indices.shape[len(Ni):indices.ndim - len(Nk)] | ||
|
|
||
| for ii in ndindex(Ni): | ||
| for jj in ndindex(Nj): | ||
| for kk in ndindex(Nk): | ||
| a_1d = a[ii + s_[:,] + kk] | ||
| a_1d[indices[ii + jj + kk]] = values[ii + jj + kk] | ||
|
|
||
| Equivalently, eliminating the inner loop, this can be expressed as:: | ||
|
|
||
| Ni, Nk = a.shape[:axis], a.shape[axis+1:] | ||
| for ii in ndindex(Ni): | ||
| for kk in ndindex(Nk): | ||
| a_1d = a[ii + s_[:,] + kk] | ||
| a_1d[indices[ii + s_[...,] + kk]] = values[ii + s_[...,] + kk] | ||
|
|
||
| .. versionadded:: 1.15.0 | ||
|
|
||
| Parameters | ||
| ---------- | ||
| arr: array_like (Ni..., M, Nk...) | ||
| source array | ||
| indices: array_like (Ni..., Nj..., Nk...) | ||
| indices to change along each 1d slice of `arr` | ||
| values: array_like (Ni..., Nj..., Nk...) | ||
| values to insert at those indices. Its shape and dimension are | ||
| broadcast to match that of `indices`. | ||
| axis: int | ||
| the axis to take 1d slices along | ||
|
|
||
| See Also | ||
| -------- | ||
| take_along_axis : Take along an axis without matching up subspaces | ||
|
|
||
| Examples | ||
| -------- | ||
|
|
||
| For this sample array | ||
|
|
||
| >>> a = np.array([[10, 30, 20], [60, 40, 50]]) | ||
|
|
||
| We can replace the maximum values with: | ||
|
|
||
| >>> ai = np.expand_dims(np.argmax(a, axis=1), axis=1) | ||
| >>> ai | ||
| array([[1], | ||
| [0]], dtype=int64) | ||
| >>> np.put_along_axis(a, ai, 99, axis=1) | ||
| >>> a | ||
| array([[10, 99, 20], | ||
| [99, 40, 50]]) | ||
|
|
||
| """ | ||
| # normalize inputs | ||
| indices = asanyarray(indices) | ||
| if axis is None: | ||
| arr = arr.ravel() | ||
| axis = 0 | ||
| else: | ||
| axis = normalize_axis_index(axis, arr.ndim) | ||
| if not _nx.issubdtype(indices.dtype, _nx.integer): | ||
| raise IndexError('arrays used as indices must be of integer type') | ||
|
|
||
| # use the fancy index | ||
| arr[_make_along_axis_idx(arr, indices, axis)] = values | ||
|
|
||
|
|
||
| def apply_along_axis(func1d, axis, arr, *args, **kwargs): | ||
| """ | ||
| Apply a function to 1-D slices along the given axis. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
default to axis=-1 like 99% of other numpy functions?