-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🚀 Feature
Both torch.linalg.matrix_rank and torch.linalg.pinv calculate singular values of the provided matrix and truncate them based on the specified tolerance (argument is called rcond for torch.linalg.pinv and tol for torch.linalg.matrix_rank).
Currently implemented behavior for setting the tolerance and the default tolerance values follow NumPy.
However, NumPy is not consistent in the default values and in treating the provided tolerance as relative or absolute.
| default for the argument |
default tolerance |
truncation criteria (if default) |
truncation criteria (if specified) |
|
|---|---|---|---|---|
| matrix_rank | None | eps * max(rows, cols) | tol * max(singular_values) | tol |
| pinv | 1e-15 | 1e-15 | tol * max(singular_values) | tol * max(singular_values) |
The proposal is to implement a unified way to specify the absolute or relative tolerances for the truncation of singular values as following:
def matrix_rank_or_pinv(input, *, atol = 0, rtol = default_rtol):
...
singular_values = ... # compute singular values of input
truncation_criteria = max(atol, rtol * max(singular_values) )
truncated_singular_values = singular_values > truncation_criteria
...Possible choices of default_rtol:
- NumPy uses
eps * max(rows, cols)formatrix_rankand1e-15forpinv - TensorFlow uses the same default as NumPy for
matrix_rankbut10 * eps * max(rows, cols)forpinv - JAX uses the same defaults as TensorFlow
- Julia uses
eps * min(rows, cols)both forpinvandmatrix_rank
Use of max(atol, rtol * ...) for defining the truncation criteria follows math.isclose.
Backwards compatibility / NumPy compatibility:
def matrix_rank(input, tol = None):
if tol is None:
return matrix_rank(input, atol = 0, rtol = eps * max(rows, cols))
else:
return matrix_rank(input, atol = tol, rtol = 0)
def pinv(input, rcond = 1e-15)
return pinv(input, atol = 0, rtol = rcond)cc @jianyuh @nikitaved @pearu @mruberry @heitorschueroff @walterddr @IvanYashchuk @rgommers