Skip to content

Proposal: simplify overloaded Tensor function signatures #2739

@colesbury

Description

@colesbury

Currently we have a few functions on Tensor with multiple overloads. This makes parsing, documentation, and error messages more complicated.

These overloads get ambiguous with zero-dim tensors (scalars) because they can bind to either the "float" or "Tensor" overloads.

We should combine the Tensor/scalar overloads and move optional arguments to the end of the function (as keyword-only args). The old signatures will be deprecated (issue a warning) and removed from the documentation and error messages.

Functions which accept a Tensor arguments will also accept Python numbers in their place. The numbers will automatically get promoted to zero-dim tensors.

torch.max currently has three overloads:

torch.max(input) -> Tensor
torch.max(input, dim, keepdim=False) -> Tensor, LongTensor
torch.max(input, other)

We should combine the first two overloads and make the third overload (element-wise max) a separate function (fmax). Eventual max over an array should only return the max elements (not the indices) unless return_indices=True. For backwards compatibility, if return_indices is unspecified,

def max(input, dim=None, keepdim=False, return_indices=None):
  if isinstance(dim, torch.Tensor):
      # raise deprecation warning
      return fmax(input, other)
  if return_indices is None:
     if dim is not None:
       # raise deprecation warning about return_indices
     return_indices = dim is not None
  # dispatch to ATen implementation
  ...

def fmax(input, other):
   # element-wise maximum preferring non-NaN

We should also do the same for torch.min.

We have already updated add, addmm, addbmm, addcmul, sub.

Metadata

Metadata

Assignees

No one assigned

    Labels

    triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions