Skip to content

Proposal: type promotion logic (torch.result_type) #9515

@colesbury

Description

@colesbury

This proposes an algorithm for computing the result type of a mixed-type operation like torch.add. The function will be exposed to Python via torch.result_type(*tensors). This proposal covers the default result type calculation; some operators may override the default behavior.

The similar NumPy function is numpy.result_type.

Wrapped numbers

A Tensor is a considered a "wrapped number" if it is auto-wrapped from a C++ or Python number type. Integer types are wrapped as 0-dim int64 tensors and floating-point types are wrapped as 0-dim double tensors. All wrapped numbers are 0-dim tensors, but not all 0-dim tensors are wrapped numbers. In general, wrapped numbers behave like normal 0-dim tensors, except they are handled specially in the torch.result_type calculation.

For example, in tensor + 5 and torch.add(tensor, 5), 5 gets wrapped as a 0-dim torch.int64 (a "wrapped number"). However, torch.add(tensor, torch.tensor(5)) does not have a wrapped number because torch.tensor(5) is an explicit construction. Wrapped number status does not propagate to returned tensors. Returned tensors are never considered wrapped numbers.

Result type calculation

Each operand has a category (integer or floating-point) and a priority:

  1. Tensors of dimension 1 or larger
  2. Tensors of dimension 0 that are not wrapped numbers
  3. Wrapped numbers

By default, only the highest priority operands participate in the type promotion logic. Lower priority operands participate if their category (e.g. floating-point) is of higher rank than any higher priority operands (e.g. integers).

In pseudo-code the result-type calculation is:

def result_type(*args):
  return promote_types(infer_scalar_type(arg) for arg in args if participates(arg, args))

def infer_scalar_type(arg):
  if is_wrapped_number(arg):
    return torch.get_default_dtype() if is_floating_point(arg) else torch.int64
  else:
    return arg.dtype

def participates(arg, args):
  if priority(arg) >= max(priority(other) for other in args):
    return True
  if category(arg) > max(category(other) for other in args if priority(other) > priority(arg)):
   return True
  return False

def priority(arg):
  if arg.dim() > 0: return 3
  elif not is_wrapped_number(arg): return 2
  else: return 1

def category(arg):
  if is_floating_point(arg): return 2
  else: return 1

Examples (assuming default float32 tensor dtype):

randn(3, dtype=float32) * 5 -> float32
tensor([0, 0, 1], dtype=uint8) + 1 -> uint8
tensor([0, 0, 1], dtype=uint8) + 1000 -> uint8  # NOTE: integer overflow
tensor([0, 0, 1], dtype=uint8) + 5.5 -> float32 (default tensor dtype)
tensor([0, 0, 1], dtype=uint8) + tensor(5.5, dtype=double) -> double

randn(3, dtype=float32) + tensor(5.5, dtype=double) -> float32
tensor(5.5, dtype=float16) + 2.2 -> float16
tensor(5.5, dtype=float16) + 100000 -> float16 # NOTE: inf
tensor(5.5, dtype=float16) + tensor(100000.0) -> float32 (default tensor dtype)

Appendix:

Why don't we use NumPy's behavior?

NumPy's result_type logic has two undesirable behaviors. The first is that requires examining the actual value of scalars (and 0-dim arrays). This would require a host-device synchronization for 0-dim CUDA tensors. The second is that it often up-promotes 0-dim arrays to float64 or int64. For example:

type(np.array(4.0, dtype=np.float32) + 1) -> np.float64
type(np.array(0, dtype=np.uint8) + 1) -> np.int64

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureA request for a proper, new feature.module: internalsRelated to internal abstractions in c10 and ATenmodule: type promotionRelated to semantics of type promotiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions