Skip to content

Commit 815649c

Browse files
Update _NPY_MAXDIMS (#7167)
* Update `_NPY_MAXDIMS` numpy 1.x supported a maximum of 32 dimensions and then raised an error .. this was an issue for large scale simulations with qubits up to 35 for state vector simulation and up to 18 qubits for density matrix simulations. As a workaround to support these cases we worked with 1D and 2D arrays for these cases which was very inefficient. now numpy 2.0 supports up to 64 dimensions * use the value from numpy instead of hardcoding it * use suggestion from @maffoo * Update cirq-core/cirq/linalg/transformations.py Co-authored-by: Pavol Juhas <[email protected]> * nit * Update transformations.py * Handle negative dimensions in can_numpy_support_shape And add a direct unit test. --------- Co-authored-by: Pavol Juhas <[email protected]>
1 parent b3561b4 commit 815649c

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

cirq-core/cirq/linalg/transformations.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Utility methods for transforming matrices or vectors."""
1616

1717
import dataclasses
18+
import functools
1819
from typing import Any, List, Optional, Sequence, Tuple, Union
1920

2021
import numpy as np
@@ -29,8 +30,6 @@
2930
# user provides a different np.array([]) value.
3031
RaiseValueErrorIfNotProvided: np.ndarray = np.array([])
3132

32-
_NPY_MAXDIMS = 32 # Should be changed once numpy/numpy#5744 is resolved.
33-
3433

3534
def reflection_matrix_pow(reflection_matrix: np.ndarray, exponent: float):
3635
"""Raises a matrix with two opposing eigenvalues to a power.
@@ -807,6 +806,15 @@ def transpose_flattened_array(t: np.ndarray, shape: Sequence[int], axes: Sequenc
807806
return ret
808807

809808

809+
@functools.cache
810+
def _can_numpy_support_dims(num_dims: int) -> bool:
811+
try:
812+
_ = np.empty((1,) * num_dims)
813+
return True
814+
except ValueError: # pragma: no cover
815+
return False
816+
817+
810818
def can_numpy_support_shape(shape: Sequence[int]) -> bool:
811819
"""Returns whether numpy supports the given shape or not numpy/numpy#5744."""
812-
return len(shape) <= _NPY_MAXDIMS
820+
return min(shape, default=0) >= 0 and _can_numpy_support_dims(len(shape))

cirq-core/cirq/linalg/transformations_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,3 +648,8 @@ def test_transpose_flattened_array(num_dimensions):
648648
assert np.array_equal(want, got)
649649
got = linalg.transpose_flattened_array(A.reshape(shape), shape, axes).reshape(want.shape)
650650
assert np.array_equal(want, got)
651+
652+
653+
@pytest.mark.parametrize('shape, result', [((), True), (30 * (1,), True), ((-3, 1, -1), False)])
654+
def test_can_numpy_support_shape(shape: tuple[int, ...], result: bool) -> None:
655+
assert linalg.can_numpy_support_shape(shape) is result

0 commit comments

Comments
 (0)