-
-
Notifications
You must be signed in to change notification settings - Fork 12.2k
Error in np.mean with where keyword (added in 1.20 version) #18552
Copy link
Copy link
Closed
Milestone
Description
Hi!
While porting latest API changes introduced in 1.20 to jax library I've encountered an error with np.mean(.., where=mask) usage.
I've also linked colab reproduction link below. Is that an expected behavior?
Thank you for any help!
Reproducing code example:
import numpy as np
print(np.__version__) # should be `1.20`
a = np.random.randn(2,3,4)
# correctly computes sums along axis=2 with 'where' mask
np.sum(a, axis=2, keepdims=False, where=[False, True, False, True])
# but computing means with newly added 'where' keyword fails
np.mean(a, axis=2, keepdims=False, where=[False, True, False, True])Here's also a reproduction in colab: https://colab.research.google.com/drive/1KFvp4BjMwO27iNVT4R9O3IvVfwMSw0nB?usp=sharing
Error message:
Traceback (most recent call last):
File "bug.py", line 10, in <module>
np.mean(a, axis=2, keepdims=False, where=[False, True, False, True])
File "<__array_function__ internals>", line 5, in mean
File "/usr/local/lib/python3.8/site-packages/numpy/core/fromnumeric.py", line 3419, in mean
return _methods._mean(a, axis=axis, dtype=dtype,
File "/usr/local/lib/python3.8/site-packages/numpy/core/_methods.py", line 167, in _mean
if rcount == 0 if where is True else umr_any(rcount == 0):
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
NumPy/Python version information:
1.20.1 3.8.8 (default, Feb 21 2021, 10:35:39)
[Clang 12.0.0 (clang-1200.0.32.29)]
Reactions are currently unavailable