Skip to content

Error in np.mean with where keyword (added in 1.20 version) #18552

@mtsokol

Description

@mtsokol

Hi!
While porting latest API changes introduced in 1.20 to jax library I've encountered an error with np.mean(.., where=mask) usage.
I've also linked colab reproduction link below. Is that an expected behavior?

Thank you for any help!

Reproducing code example:

import numpy as np
print(np.__version__) # should be `1.20`

a = np.random.randn(2,3,4)

# correctly computes sums along axis=2 with 'where' mask
np.sum(a, axis=2, keepdims=False, where=[False, True, False, True])

# but computing means with newly added 'where' keyword fails
np.mean(a, axis=2, keepdims=False, where=[False, True, False, True])

Here's also a reproduction in colab: https://colab.research.google.com/drive/1KFvp4BjMwO27iNVT4R9O3IvVfwMSw0nB?usp=sharing

Error message:

Traceback (most recent call last):
  File "bug.py", line 10, in <module>
    np.mean(a, axis=2, keepdims=False, where=[False, True, False, True])
  File "<__array_function__ internals>", line 5, in mean
  File "/usr/local/lib/python3.8/site-packages/numpy/core/fromnumeric.py", line 3419, in mean
    return _methods._mean(a, axis=axis, dtype=dtype,
  File "/usr/local/lib/python3.8/site-packages/numpy/core/_methods.py", line 167, in _mean
    if rcount == 0 if where is True else umr_any(rcount == 0):
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

NumPy/Python version information:

1.20.1 3.8.8 (default, Feb 21 2021, 10:35:39)
[Clang 12.0.0 (clang-1200.0.32.29)]

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions