Skip to content

ENH: linalg: enable N-D batch support in special matrix functions#21446

Merged
j-bowhay merged 10 commits intoscipy:mainfrom
mdhaber:batch_special_matrices
Nov 9, 2024
Merged

ENH: linalg: enable N-D batch support in special matrix functions#21446
j-bowhay merged 10 commits intoscipy:mainfrom
mdhaber:batch_special_matrices

Conversation

@mdhaber
Copy link
Copy Markdown
Contributor

@mdhaber mdhaber commented Aug 24, 2024

Reference issue

#16090 (comment)

What does this implement/fix?

Begins to add N-dimensional batch support to the special matrix functions. This PR adds documentation, tests, and a simple implementation for API consistency. More performant implementations can come in follow-up PRs.

Additional information

gh-16090 will add support for companion. Done!
gh-21419 will add support for circulant
A follow-up will add support for leslie and toeplitz. These accept two arguments of different shapes. This is easy to handle with existing code; I'll just need move the machinery from scipy.stats.axis_nan_policy to _lib. I decided to leave that code where it is and just import it. The import happens when the function is called for the first time to avoid import cycles. This is temporary and an immediate follow-up PR can either move the machinery to _lib to avoid circular imports or eliminate the need for the import by adding batch support directly.

I left out toeplitz because it documents that it ravels arrays of any number of dimensions, so we can't add batch support until after a cycle of warnings for >1 dimensional input. Which of the following is preferred?
1. Emit a FutureWarning that the behavior will change in SciPy 1.17; to ensure consistent behavior, the user must ravel input before passing it into toeplitz.
2. Emit a DeprecationWarning about support for >1 dimensional input. Raise an error in SciPy 1.17, and add N-D batch support after that.
toeplitz gets a FutureWarning because it documents that it ravels N-D input. In SciPy 1.17 we can add N-D batch support as we do for leslie.

I went ahead adding the code inside the function rather than using a decorator because I was going to have to edit the documentation of each function anyway. If we could leave the documentation alone and maybe just add a standard note about batch support, I might do all that with a decorator, instead.

Here is a wrapper we can use for most other linalg functions:

Details
from functools import wraps
import numpy as np
from scipy import linalg

def _apply_over_batch(*argdefs, result_packer=None):
    names, ndims = list(zip(*argdefs))
    n_arrays = len(names)
    result_packer = result_packer if result_packer is not None else lambda res: res[0]

    def wrapper(f):
        @wraps(f)
        def wrapped(*args, **kwargs):
            args = list(args)

            # Ensure all arrays in `arrays`, other arguments in `other_args`/`kwargs`
            arrays, other_args = args[:n_arrays], args[n_arrays:]
            for i, name in enumerate(names):
                if name in kwargs:
                    if i + 1 <= len(args):
                        raise ValueError(f'{f.__name__}() got multiple values for argument `{name}`.')
                    else:
                        arrays.append(kwargs.pop(name))

            # Determine batch shape
            batch_shapes = []
            for i, (array, ndim) in enumerate(zip(arrays, ndims)):
                array = np.asarray(array)
                arrays[i] = array
                batch_shapes.append(array.shape[:-ndim] if ndim > 0 else array.shape)

            if not any(batch_shapes):
                return f(*arrays, *other_args, **kwargs)

            batch_shape = np.broadcast_shapes(*batch_shapes)  # Gives an OK error message

            # Main loop
            results = []
            for index in np.ndindex(batch_shape):
                result = f(*(array[*index] for array in arrays), *other_args, **kwargs)
                # Distinguish between single output and multiple outputs
                # Could improve that are define the wrapper with that information
                result = (result,) if not isinstance(result, tuple) else result
                results.append(result)
            results = list(zip(*results))

            # Reshape results
            for i, result in enumerate(results):
                result = np.stack(result)
                core_shape = result.shape[1:]
                results[i] = np.reshape(result, batch_shape + core_shape)

            return result_packer(results)

        return wrapped
    return wrapper

rng = np.random.default_rng(13578923452)
shape = (3, 4, 5, 5)
A = rng.random(shape)
AT = np.swapaxes(A, -1, -2)
A = A + AT + np.eye(shape[-1])*shape[-1]

B = rng.random(shape)
BT = np.swapaxes(A, -1, -2)
B = B + BT + np.eye(shape[-1])*shape[-1]

# f = _apply_over_batch(('a', 2), ('b', 2), result_packer=tuple)(linalg.eig)
f = _apply_over_batch(('a', 2))(linalg.inv)

res = f(A)

@mdhaber mdhaber added the enhancement A new feature or improvement label Aug 24, 2024
@mdhaber mdhaber requested review from ilayn and larsoner as code owners August 24, 2024 18:15
@rgommers
Copy link
Copy Markdown
Member

#16090 (comment)

It may be useful to move that into a new status/tracking issue, since it's a fairly important discussion that starts halfway in another PR now. Looks like all participants there are on board with the plan (and I am too, full batching support would be a clear win and very nice to have in my opinion), so this is more a "make it more discoverable" rather than a "possibly controversial" request for an issue.

I left out toeplitz because it documents that it ravels arrays of any number of dimensions, so we can't add batch support until after a cycle of warnings for >1 dimensional input. Which of the following is preferred?

1. Emit a `FutureWarning` that the behavior will change in SciPy 1.17; to ensure consistent behavior, the user must `ravel` input before passing it into `toeplitz`.

2. Emit a `DeprecationWarning` about support for >1 dimensional input. Raise an error in SciPy 1.17, and add N-D batch support after that.

I prefer option (1). It seems simpler, and the end result is the same. FutureWarning is also more appropriate than DeprecationWarning, and it bubbles up to end users by default, which is what you want here I think.

@mdhaber
Copy link
Copy Markdown
Contributor Author

mdhaber commented Aug 26, 2024

@rgommers thanks, that's definitely my preference, too. I proposed the other alternative in part based on your comment in gh-19890.

FutureWarning's for behavioral changes are a problem, unless there's a new keyword or other such thing that users can switch to.

Just to check - the problem there was that the user could do nothing in the meantime to preserve the behavior without getting the FutureWarning?

And here, we don't have that problem - the user can get the same behavior without the warning simply by passing a raveled array, and they will not be bothered by the transition again.

I'll go ahead and add that warning. I will put up a tracking issue in a bit, but it will take a little while to put together something that covers more than just the special matrices. (The comment was only about special matrices, and those are getting taken care of pretty much all together.)

@rgommers
Copy link
Copy Markdown
Member

Just to check - the problem there was that the user could do nothing in the meantime to preserve the behavior without getting the FutureWarning?

Yes indeed. In this case, the user can simply do what the warning suggests (ravel) and the problem goes away. Then a warning is nice. If they can't take such an action, it's much less nice.

@mdhaber
Copy link
Copy Markdown
Contributor Author

mdhaber commented Aug 26, 2024

Right - it's just that a stricter BC policy (which I've seen referred to as the "Hinsen principle", but only in scientific-python/specs#180) would require that either the behavior stays the same or an error is raised. I wanted to be sure that wasn't the concern. (I have trouble imaginging a use of FutureWarnings at all with such a strict policy.)

@rgommers
Copy link
Copy Markdown
Member

[...] would require that either the behavior stays the same or an error is raised. I wanted to be sure that wasn't the concern

I don't think it is. There are always tradeoffs, and of course it's never great to change behavior (in general we do have to try harder to avoid that than changes that yield exceptions). However in this case:

  • The return shape changes, it's not like the answer of some computation only changes numerical value (which would be worse). So most likely the user will get an exception in their code if they miss the warning and upgrade to the new behavior, rather than a different answer.
  • The behavior only changes for non-idiomatic usage of the function
  • There is no good alternative here; renaming the function doesn't seem like a reasonable option (much more disruptive, and we'd end up with a non-preferred name).

Comment thread scipy/linalg/_special_matrices.py Outdated
n = f.shape[-1]

if f.ndim > 1 or s.ndim > 1:
from scipy.stats._resampling import _vectorize_statistic
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can eliminate this import by using the decorator from gh-21462.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume we can now do this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think so

@ilayn
Copy link
Copy Markdown
Member

ilayn commented Aug 27, 2024

By the way, if we remove the 2D checks from them, cosm, sinm, tanm, coshm, sinhm, and tanhm are ready to accept nDarray inputs since they are all arithmetics over expm.

Comment thread scipy/linalg/_special_matrices.py Outdated
Copy link
Copy Markdown
Member

@j-bowhay j-bowhay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few minor comments but otherwise looks in good shape

Comment thread scipy/linalg/_special_matrices.py Outdated
Comment thread scipy/linalg/_special_matrices.py Outdated
Comment thread scipy/linalg/_special_matrices.py Outdated
@j-bowhay
Copy link
Copy Markdown
Member

j-bowhay commented Nov 9, 2024

Happy to look again once the comments above are resolved

@mdhaber
Copy link
Copy Markdown
Contributor Author

mdhaber commented Nov 9, 2024

Sorry I missed that.

Comment thread scipy/linalg/_special_matrices.py Outdated
Copy link
Copy Markdown
Member

@j-bowhay j-bowhay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement A new feature or improvement scipy.linalg

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants