Skip to content

Add relative and absolute tolerances for matrix_rank, pinv #54151

@IvanYashchuk

Description

@IvanYashchuk

🚀 Feature

Both torch.linalg.matrix_rank and torch.linalg.pinv calculate singular values of the provided matrix and truncate them based on the specified tolerance (argument is called rcond for torch.linalg.pinv and tol for torch.linalg.matrix_rank).
Currently implemented behavior for setting the tolerance and the default tolerance values follow NumPy.
However, NumPy is not consistent in the default values and in treating the provided tolerance as relative or absolute.

default
for the argument
default
tolerance
truncation criteria
(if default)
truncation criteria
(if specified)
matrix_rank None eps * max(rows, cols) tol * max(singular_values) tol
pinv 1e-15 1e-15 tol * max(singular_values) tol * max(singular_values)

The proposal is to implement a unified way to specify the absolute or relative tolerances for the truncation of singular values as following:

def matrix_rank_or_pinv(input, *, atol = 0, rtol = default_rtol):
    ...
    singular_values = ... # compute singular values of input
    truncation_criteria = max(atol, rtol * max(singular_values) )
    truncated_singular_values = singular_values > truncation_criteria
    ...

Possible choices of default_rtol:

  • NumPy uses eps * max(rows, cols) for matrix_rank and 1e-15 for pinv
  • TensorFlow uses the same default as NumPy for matrix_rank but 10 * eps * max(rows, cols) for pinv
  • JAX uses the same defaults as TensorFlow
  • Julia uses eps * min(rows, cols) both for pinv and matrix_rank

Use of max(atol, rtol * ...) for defining the truncation criteria follows math.isclose.

Backwards compatibility / NumPy compatibility:

def matrix_rank(input, tol = None):
    if tol is None:
        return matrix_rank(input, atol = 0, rtol = eps * max(rows, cols))
    else:
        return matrix_rank(input, atol = tol, rtol = 0)

def pinv(input, rcond = 1e-15)
    return pinv(input, atol = 0, rtol = rcond)

cc @jianyuh @nikitaved @pearu @mruberry @heitorschueroff @walterddr @IvanYashchuk @rgommers

Metadata

Metadata

Assignees

Labels

enhancementNot as big of a feature, but technically not a bug. Should be easy to fixmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmulmodule: numpyRelated to numpy support, and also numpy compatibility of our operatorstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions