Skip to content

Broadcasting behaviour for linear algebra solvers #52915

@IvanYashchuk

Description

@IvanYashchuk

🚀 Feature Discussion

The question is whether solve-like functions need to support batch broadcasting for b of shape (n,).

PyTorch currently includes several functions for problems of type: find x s.t. ||Ax - b|| is minimized (Ax = b). Let's call it "solve-like functions":

torch.linalg.solve
torch.solve
torch.cholesky_solve
torch.triangular_solve
torch.lu_solve
torch.lstsq

For A with shape (n, n) only torch.linalg.solve allows the b input of shape (n,) or (n, nrhs). Other functions require the b input to be a 2-dimensional tensor of shape (n, nrhs).
Supporting 1-dimensional b input is NumPy and SciPy compatible. SciPy doesn't support batched inputs, NumPy supports batched inputs for numpy.linalg.solve, but not for numpy.linalg.lstsq.

numpy.linalg.solve supports batch-wise broadcasting only for (n, nrhs) type of b inputs:

import torch
import numpy as np
a = torch.randn(2, 3, 1, 3, 3)
b = torch.randn(3)
np.linalg.solve(a, b) # doesn't work
np.linalg.solve(a, b.unsqueeze(-1)) # this one works
# both cases work currently for torch.linalg.solve
a = torch.randn(3, 3)
b = torch.randn(2, 3, 1, 3)
np.linalg.solve(a, b) # doesn't work
np.linalg.solve(a, b.unsqueeze(-1)) # this one works
# torch.linalg.solve currently has the same behaviour

NumPy's behaviour makes sense because (a.inverse() @ b).shape = torch.Size([2, 3, 1, 3]) and batched matrix multiplication doesn't work for a @ (a.inverse() @ b), but works for a @ (a.inverse() @ b).unsqueeze(-1).

For NumPy compatibility, we need to support batch broadcasting for b of shape (n, nrhs) for torch.linalg.solve and consequently apply the same behavior to all solve-like functions.

Do solve-like functions need to support batch broadcasting for b of shape (n,)?

The problem here is ambiguity for deciding whether we have a matrix or vector b. For example for A of shape (3, 3, 3) how should we interpret b of shape (3, 3) is a single matrix input to be batch broadcasted or a batch of vectors.

Currently torch.linalg.solve treats the matrix case as primary and b is regarded as vector if b.ndim == 1 or ((A.ndim-b.ndim == 1) and (A.shape[:-1] == b.shape)). This rule is compatible with NumPy, see https://github.com/numpy/numpy/blob/v1.20.0/numpy/linalg/linalg.py#L384-L389.

Interestingly NumPy fails for this case:

a = torch.randn(3, 3, 3)
b = torch.randn(3)
np.linalg.solve(a, b) # doesn't work
np.linalg.solve(a, b.unsqueeze(-1)) # also doesn't work
# both cases work with torch.linalg.solve

Additional context

Memory inefficiency of the actual implementation is discussed here #49252.

cc @jianyuh @nikitaved @pearu @mruberry @heitorschueroff @walterddr @IvanYashchuk

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmultriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions