-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
Currently we have a few functions on Tensor with multiple overloads. This makes parsing, documentation, and error messages more complicated.
These overloads get ambiguous with zero-dim tensors (scalars) because they can bind to either the "float" or "Tensor" overloads.
We should combine the Tensor/scalar overloads and move optional arguments to the end of the function (as keyword-only args). The old signatures will be deprecated (issue a warning) and removed from the documentation and error messages.
Functions which accept a Tensor arguments will also accept Python numbers in their place. The numbers will automatically get promoted to zero-dim tensors.
torch.max currently has three overloads:
torch.max(input) -> Tensor
torch.max(input, dim, keepdim=False) -> Tensor, LongTensor
torch.max(input, other)
We should combine the first two overloads and make the third overload (element-wise max) a separate function (fmax). Eventual max over an array should only return the max elements (not the indices) unless return_indices=True. For backwards compatibility, if return_indices is unspecified,
def max(input, dim=None, keepdim=False, return_indices=None):
if isinstance(dim, torch.Tensor):
# raise deprecation warning
return fmax(input, other)
if return_indices is None:
if dim is not None:
# raise deprecation warning about return_indices
return_indices = dim is not None
# dispatch to ATen implementation
...
def fmax(input, other):
# element-wise maximum preferring non-NaNWe should also do the same for torch.min.
We have already updated add, addmm, addbmm, addcmul, sub.