-
Notifications
You must be signed in to change notification settings - Fork 27k
Description
🚀 Feature Discussion
The question is whether solve-like functions need to support batch broadcasting for b of shape (n,).
PyTorch currently includes several functions for problems of type: find x s.t. ||Ax - b|| is minimized (Ax = b). Let's call it "solve-like functions":
torch.linalg.solve
torch.solve
torch.cholesky_solve
torch.triangular_solve
torch.lu_solve
torch.lstsq
For A with shape (n, n) only torch.linalg.solve allows the b input of shape (n,) or (n, nrhs). Other functions require the b input to be a 2-dimensional tensor of shape (n, nrhs).
Supporting 1-dimensional b input is NumPy and SciPy compatible. SciPy doesn't support batched inputs, NumPy supports batched inputs for numpy.linalg.solve, but not for numpy.linalg.lstsq.
numpy.linalg.solve supports batch-wise broadcasting only for (n, nrhs) type of b inputs:
import torch
import numpy as np
a = torch.randn(2, 3, 1, 3, 3)
b = torch.randn(3)
np.linalg.solve(a, b) # doesn't work
np.linalg.solve(a, b.unsqueeze(-1)) # this one works
# both cases work currently for torch.linalg.solve
a = torch.randn(3, 3)
b = torch.randn(2, 3, 1, 3)
np.linalg.solve(a, b) # doesn't work
np.linalg.solve(a, b.unsqueeze(-1)) # this one works
# torch.linalg.solve currently has the same behaviourNumPy's behaviour makes sense because (a.inverse() @ b).shape = torch.Size([2, 3, 1, 3]) and batched matrix multiplication doesn't work for a @ (a.inverse() @ b), but works for a @ (a.inverse() @ b).unsqueeze(-1).
For NumPy compatibility, we need to support batch broadcasting for b of shape (n, nrhs) for torch.linalg.solve and consequently apply the same behavior to all solve-like functions.
Do solve-like functions need to support batch broadcasting for b of shape (n,)?
The problem here is ambiguity for deciding whether we have a matrix or vector b. For example for A of shape (3, 3, 3) how should we interpret b of shape (3, 3) is a single matrix input to be batch broadcasted or a batch of vectors.
Currently torch.linalg.solve treats the matrix case as primary and b is regarded as vector if b.ndim == 1 or ((A.ndim-b.ndim == 1) and (A.shape[:-1] == b.shape)). This rule is compatible with NumPy, see https://github.com/numpy/numpy/blob/v1.20.0/numpy/linalg/linalg.py#L384-L389.
Interestingly NumPy fails for this case:
a = torch.randn(3, 3, 3)
b = torch.randn(3)
np.linalg.solve(a, b) # doesn't work
np.linalg.solve(a, b.unsqueeze(-1)) # also doesn't work
# both cases work with torch.linalg.solveAdditional context
Memory inefficiency of the actual implementation is discussed here #49252.
cc @jianyuh @nikitaved @pearu @mruberry @heitorschueroff @walterddr @IvanYashchuk