Skip to content

Commit bd0f132

Browse files
committed
MAINT: PR 22840 revisions
* `logsumexp()` had some array API compatibility issues on the release branch related to the version of the array API standard and possibly other dependency versions that are older than on `main` branch. This patch works around these by guarding `xp.real()` to only apply to complex input and coercing a complex Python scalar to an array to deal with a promotion issue: `TypeError: Python complex scalars can only be promoted with complex floating-point arrays.`
1 parent 033b138 commit bd0f132

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

scipy/special/_logsumexp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
xp_broadcast_promote,
88
xp_copy,
99
xp_float_to_complex,
10+
is_complex,
1011
)
1112
from scipy._lib import array_api_extra as xpx
1213

@@ -232,13 +233,14 @@ def _logsumexp(a, b, axis, return_sign, xp):
232233
m = xp.abs(m)
233234
else:
234235
# `a_max` can have a sign component for complex input
235-
sgn = sgn * xp.exp(xp.imag(a_max) * 1.0j)
236+
sgn = sgn * xp.exp(xp.imag(a_max) * xp.asarray(1.0j, dtype=a_max.dtype))
236237

237238
# Take log and undo shift
238239
out = xp.log1p(s) + xp.log(m) + a_max
239240

240241
if return_sign:
241-
out = xp.real(out)
242+
if is_complex(out, xp):
243+
out = xp.real(out)
242244
elif xp.isdtype(out.dtype, 'real floating'):
243245
out[sgn < 0] = xp.nan
244246

0 commit comments

Comments
 (0)