Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 08528c5

Browse files
authored
[numpy] add op median (#17084)
* part * wrapper * sanity
1 parent 56e7985 commit 08528c5

File tree

7 files changed

+189
-4
lines changed

7 files changed

+189
-4
lines changed

3rdparty/mkldnn

python/mxnet/ndarray/numpy/_op.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'matmul',
3737
'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'histogram',
3838
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'all', 'any', 'sort',
39-
'tensordot', 'eye', 'linspace',
39+
'tensordot', 'eye', 'linspace', 'median',
4040
'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'hsplit', 'vsplit', 'dsplit',
4141
'concatenate', 'append', 'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack',
4242
'average', 'mean', 'maximum', 'fmax', 'minimum', 'fmin', 'around', 'round', 'round_', 'flatnonzero',
@@ -6923,6 +6923,55 @@ def percentile(a, q, axis=None, out=None, overwrite_input=None, interpolation='l
69236923
return _api_internal.percentile(a, q, axis, interpolation, keepdims, out)
69246924

69256925

6926+
@set_module('mxnet.ndarray.numpy')
6927+
def median(a, axis=None, out=None, overwrite_input=None, keepdims=False):
6928+
r"""
6929+
Compute the median along the specified axis.
6930+
Returns the median of the array elements.
6931+
Parameters
6932+
----------
6933+
a : array_like
6934+
Input array or object that can be converted to an array.
6935+
axis : {int, sequence of int, None}, optional
6936+
Axis or axes along which the medians are computed. The default
6937+
is to compute the median along a flattened version of the array.
6938+
A sequence of axes is supported since version 1.9.0.
6939+
out : ndarray, optional
6940+
Alternative output array in which to place the result. It must
6941+
have the same shape and buffer length as the expected output,
6942+
but the type (of the output) will be cast if necessary.
6943+
keepdims : bool, optional
6944+
If this is set to True, the axes which are reduced are left
6945+
in the result as dimensions with size one. With this option,
6946+
the result will broadcast correctly against the original `arr`.
6947+
Returns
6948+
-------
6949+
median : ndarray
6950+
A new array holding the result. If the input contains integers
6951+
or floats smaller than ``float32``, then the output data-type is
6952+
``np.float32``. Otherwise, the data-type of the output is the
6953+
same as that of the input. If `out` is specified, that array is
6954+
returned instead.
6955+
See Also
6956+
--------
6957+
mean, percentile
6958+
Examples
6959+
--------
6960+
>>> a = np.array([[10, 7, 4], [3, 2, 1]])
6961+
>>> a
6962+
array([[10, 7, 4],
6963+
[ 3, 2, 1]])
6964+
>>> np.median(a)
6965+
3.5
6966+
>>> np.median(a, axis=0)
6967+
array([6.5, 4.5, 2.5])
6968+
>>> np.median(a, axis=1)
6969+
array([7., 2.])
6970+
"""
6971+
return quantile(a=a, q=0.5, axis=axis, out=out, overwrite_input=overwrite_input,
6972+
interpolation='midpoint', keepdims=keepdims)
6973+
6974+
69266975
@set_module('mxnet.ndarray.numpy')
69276976
def quantile(a, q, axis=None, out=None, overwrite_input=None, interpolation='linear', keepdims=False): # pylint: disable=too-many-arguments
69286977
"""

python/mxnet/numpy/multiarray.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
from . import fallback
5353

5454

55-
__all__ = ['ndarray', 'empty', 'empty_like', 'array', 'shape',
55+
__all__ = ['ndarray', 'empty', 'empty_like', 'array', 'shape', 'median',
5656
'zeros', 'zeros_like', 'ones', 'ones_like', 'full', 'full_like', 'all', 'any', 'broadcast_to',
5757
'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'fmod', 'power', 'bitwise_not',
5858
'delete',
@@ -8981,6 +8981,55 @@ def percentile(a, q, axis=None, out=None, overwrite_input=None, interpolation='l
89818981
interpolation=interpolation, keepdims=keepdims)
89828982

89838983

8984+
@set_module('mxnet.numpy')
8985+
def median(a, axis=None, out=None, overwrite_input=None, keepdims=False):
8986+
r"""
8987+
Compute the median along the specified axis.
8988+
Returns the median of the array elements.
8989+
Parameters
8990+
----------
8991+
a : array_like
8992+
Input array or object that can be converted to an array.
8993+
axis : {int, sequence of int, None}, optional
8994+
Axis or axes along which the medians are computed. The default
8995+
is to compute the median along a flattened version of the array.
8996+
A sequence of axes is supported since version 1.9.0.
8997+
out : ndarray, optional
8998+
Alternative output array in which to place the result. It must
8999+
have the same shape and buffer length as the expected output,
9000+
but the type (of the output) will be cast if necessary.
9001+
keepdims : bool, optional
9002+
If this is set to True, the axes which are reduced are left
9003+
in the result as dimensions with size one. With this option,
9004+
the result will broadcast correctly against the original `arr`.
9005+
Returns
9006+
-------
9007+
median : ndarray
9008+
A new array holding the result. If the input contains integers
9009+
or floats smaller than ``float32``, then the output data-type is
9010+
``np.float32``. Otherwise, the data-type of the output is the
9011+
same as that of the input. If `out` is specified, that array is
9012+
returned instead.
9013+
See Also
9014+
--------
9015+
mean, percentile
9016+
Examples
9017+
--------
9018+
>>> a = np.array([[10, 7, 4], [3, 2, 1]])
9019+
>>> a
9020+
array([[10, 7, 4],
9021+
[ 3, 2, 1]])
9022+
>>> np.median(a)
9023+
3.5
9024+
>>> np.median(a, axis=0)
9025+
array([6.5, 4.5, 2.5])
9026+
>>> np.median(a, axis=1)
9027+
array([7., 2.])
9028+
"""
9029+
return _mx_nd_np.median(a, axis=axis, overwrite_input=overwrite_input,
9030+
keepdims=keepdims, out=out)
9031+
9032+
89849033
@set_module('mxnet.numpy')
89859034
def quantile(a, q, axis=None, out=None, overwrite_input=None, interpolation='linear', keepdims=False): # pylint: disable=too-many-arguments
89869035
"""

python/mxnet/numpy_dispatch_protocol.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
174174
'shares_memory',
175175
'may_share_memory',
176176
'quantile',
177+
'median',
177178
'percentile',
178179
'diff',
179180
'ediff1d',

python/mxnet/symbol/numpy/_symbol.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
'delete', 'add', 'broadcast_to', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'fmod',
4040
'power', 'arctan2',
4141
'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'fabs', 'exp',
42-
'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'matmul',
42+
'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'matmul', 'median',
4343
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'histogram', 'insert',
4444
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'sort', 'tensordot', 'eye', 'linspace',
4545
'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'hsplit', 'vsplit', 'dsplit',
@@ -6163,6 +6163,43 @@ def percentile(a, q, axis=None, out=None, overwrite_input=None, interpolation='l
61636163
keepdims=keepdims, q_scalar=None, out=out)
61646164

61656165

6166+
@set_module('mxnet.symbol.numpy')
6167+
def median(a, axis=None, out=None, overwrite_input=None, keepdims=False):
6168+
r"""
6169+
Compute the median along the specified axis.
6170+
Returns the median of the array elements.
6171+
Parameters
6172+
----------
6173+
a : _Symbol
6174+
Input array or object that can be converted to an array.
6175+
axis : {int, sequence of int, None}, optional
6176+
Axis or axes along which the medians are computed. The default
6177+
is to compute the median along a flattened version of the array.
6178+
A sequence of axes is supported since version 1.9.0.
6179+
out : _Symbol, optional
6180+
Alternative output array in which to place the result. It must
6181+
have the same shape and buffer length as the expected output,
6182+
but the type (of the output) will be cast if necessary.
6183+
keepdims : bool, optional
6184+
If this is set to True, the axes which are reduced are left
6185+
in the result as dimensions with size one. With this option,
6186+
the result will broadcast correctly against the original `arr`.
6187+
Returns
6188+
-------
6189+
median : _Symbol
6190+
A new array holding the result. If the input contains integers
6191+
or floats smaller than ``float32``, then the output data-type is
6192+
``np.float32``. Otherwise, the data-type of the output is the
6193+
same as that of the input. If `out` is specified, that array is
6194+
returned instead.
6195+
See Also
6196+
--------
6197+
mean, percentile
6198+
"""
6199+
return quantile(a=a, q=0.5, axis=axis, out=out, overwrite_input=overwrite_input,
6200+
interpolation='midpoint', keepdims=keepdims)
6201+
6202+
61666203
@set_module('mxnet.symbol.numpy')
61676204
def quantile(a, q, axis=None, out=None, overwrite_input=None, interpolation='linear', keepdims=False): # pylint: disable=too-many-arguments
61686205
"""

tests/python/unittest/test_numpy_interoperability.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,14 @@ def _add_workload_diagonal():
178178
OpArgMngr.add_workload('diagonal', B, 0, 2, 1)
179179

180180

181+
def _add_workload_median(array_pool):
182+
OpArgMngr.add_workload('median', array_pool['4x1'])
183+
OpArgMngr.add_workload('median', array_pool['4x1'], axis=0, keepdims=True)
184+
OpArgMngr.add_workload('median', np.array([[1, 2, 3], [4, 5, 6]]))
185+
OpArgMngr.add_workload('median', np.array([[1, 2, 3], [4, 5, 6]]), axis=0)
186+
OpArgMngr.add_workload('median', np.array([[1, 2, 3], [4, 5, 6]]), axis=1)
187+
188+
181189
def _add_workload_quantile():
182190
x1 = np.arange(8) * 0.5
183191
x2 = np.arange(100.)
@@ -2915,6 +2923,7 @@ def _prepare_workloads():
29152923
_add_workload_diff()
29162924
_add_workload_ediff1d()
29172925
_add_workload_quantile()
2926+
_add_workload_median(array_pool)
29182927
_add_workload_percentile()
29192928
_add_workload_resize()
29202929
_add_workload_full_like(array_pool)

tests/python/unittest/test_numpy_op.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7388,6 +7388,46 @@ def test_np_share_memory():
73887388
assert not op(np.ones((5, 0), dtype=dt), np.ones((0, 3, 0), dtype=adt))
73897389

73907390

7391+
def test_np_median():
7392+
class TestMedian(HybridBlock):
7393+
def __init__(self, axis=None, keepdims=False):
7394+
super(TestMedian, self).__init__()
7395+
self._axis = axis
7396+
self._keepdims = keepdims
7397+
7398+
def hybrid_forward(self, F, a):
7399+
return F.np.median(a, axis=self._axis, keepdims=self._keepdims)
7400+
7401+
flags = [True, False]
7402+
dtypes = ['float16', 'float32', 'float64']
7403+
qtypes = ['float32', 'float64']
7404+
tensor_shapes = [
7405+
((2, 3), None),
7406+
((2, 3, 4, 5), 3),
7407+
((2, 3, 4), (0, 2)),
7408+
((2, 3, 4), 1)
7409+
]
7410+
7411+
for hybridize, keepdims, (a_shape, axis), dtype in \
7412+
itertools.product(flags, flags, tensor_shapes, dtypes):
7413+
atol = 3e-4 if dtype == 'float16' else 1e-4
7414+
rtol = 3e-2 if dtype == 'float16' else 1e-2
7415+
test_median = TestMedian(axis=axis, keepdims=keepdims)
7416+
if hybridize:
7417+
test_median.hybridize()
7418+
a = np.random.uniform(-1.0, 1.0, size=a_shape)
7419+
np_out = _np.median(a.asnumpy(), axis=axis, keepdims=keepdims)
7420+
mx_out = test_median(a)
7421+
7422+
assert mx_out.shape == np_out.shape
7423+
assert_almost_equal(mx_out.asnumpy(), np_out, atol=atol, rtol=rtol)
7424+
7425+
mx_out = np.median(a, axis=axis, keepdims=keepdims)
7426+
np_out = _np.median(a.asnumpy(), axis=axis, keepdims=keepdims)
7427+
7428+
assert_almost_equal(mx_out.asnumpy(), np_out, atol=atol, rtol=rtol)
7429+
7430+
73917431
@with_seed()
73927432
@use_np
73937433
def test_np_quantile():

0 commit comments

Comments
 (0)