Skip to content

Commit 202788b

Browse files
Circumvent bad type annotation for unitary_group.rvs
See scipy/scipy-stubs#987
1 parent a7cd499 commit 202788b

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

graphix/random_objects.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,20 +34,20 @@ def rand_herm(sz: IntLike, rng: Generator | None = None) -> npt.NDArray[np.compl
3434
return tmp + tmp.conj().T
3535

3636

37-
def rand_unit(sz: IntLike, rng: Generator | None = None) -> npt.NDArray[np.float64]:
37+
def rand_unit(sz: IntLike, rng: Generator | None = None) -> npt.NDArray[np.complex128]:
3838
"""Generate haar random unitary matrix of size sz*sz."""
3939
rng = ensure_rng(rng)
4040
if sz == 1:
4141
return np.array([np.exp(1j * rng.random(size=1) * 2 * np.pi)])
42-
# unitary_group.rvs returns onp.Array3D[np.float64]
43-
# https://github.com/scipy/scipy-stubs/blob/3b629159e8da5cc3aa82b871135489d6d2fd5f8e/scipy-stubs/stats/_multivariate.pyi#L370
44-
return unitary_group.rvs(sz, random_state=rng)
42+
# unitary_group.rvs is currently annotated onp.Array3D[np.float64] in scipy-stubs
43+
# See https://github.com/scipy/scipy-stubs/issues/987
44+
return unitary_group.rvs(sz, random_state=rng).astype(np.complex128, copy=False)
4545

4646

4747
UNITS = np.array([1, 1j])
4848

4949

50-
def rand_dm(dim: IntLike, rng: Generator | None = None, rank: IntLike | None = None) -> npt.NDArray[np.float64]:
50+
def rand_dm(dim: IntLike, rng: Generator | None = None, rank: IntLike | None = None) -> npt.NDArray[np.complex128]:
5151
"""Generate random density matrices (positive semi-definite matrices with unit trace).
5252
5353
Returns either a :class:`graphix.sim.density_matrix.DensityMatrix` or a :class:`np.ndarray` depending on the parameter *dm_dtype*.

0 commit comments

Comments
 (0)