-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
This proposes an algorithm for computing the result type of a mixed-type operation like torch.add. The function will be exposed to Python via torch.result_type(*tensors). This proposal covers the default result type calculation; some operators may override the default behavior.
The similar NumPy function is numpy.result_type.
Wrapped numbers
A Tensor is a considered a "wrapped number" if it is auto-wrapped from a C++ or Python number type. Integer types are wrapped as 0-dim int64 tensors and floating-point types are wrapped as 0-dim double tensors. All wrapped numbers are 0-dim tensors, but not all 0-dim tensors are wrapped numbers. In general, wrapped numbers behave like normal 0-dim tensors, except they are handled specially in the torch.result_type calculation.
For example, in tensor + 5 and torch.add(tensor, 5), 5 gets wrapped as a 0-dim torch.int64 (a "wrapped number"). However, torch.add(tensor, torch.tensor(5)) does not have a wrapped number because torch.tensor(5) is an explicit construction. Wrapped number status does not propagate to returned tensors. Returned tensors are never considered wrapped numbers.
Result type calculation
Each operand has a category (integer or floating-point) and a priority:
- Tensors of dimension 1 or larger
- Tensors of dimension 0 that are not wrapped numbers
- Wrapped numbers
By default, only the highest priority operands participate in the type promotion logic. Lower priority operands participate if their category (e.g. floating-point) is of higher rank than any higher priority operands (e.g. integers).
In pseudo-code the result-type calculation is:
def result_type(*args):
return promote_types(infer_scalar_type(arg) for arg in args if participates(arg, args))
def infer_scalar_type(arg):
if is_wrapped_number(arg):
return torch.get_default_dtype() if is_floating_point(arg) else torch.int64
else:
return arg.dtype
def participates(arg, args):
if priority(arg) >= max(priority(other) for other in args):
return True
if category(arg) > max(category(other) for other in args if priority(other) > priority(arg)):
return True
return False
def priority(arg):
if arg.dim() > 0: return 3
elif not is_wrapped_number(arg): return 2
else: return 1
def category(arg):
if is_floating_point(arg): return 2
else: return 1
Examples (assuming default float32 tensor dtype):
randn(3, dtype=float32) * 5 -> float32
tensor([0, 0, 1], dtype=uint8) + 1 -> uint8
tensor([0, 0, 1], dtype=uint8) + 1000 -> uint8 # NOTE: integer overflow
tensor([0, 0, 1], dtype=uint8) + 5.5 -> float32 (default tensor dtype)
tensor([0, 0, 1], dtype=uint8) + tensor(5.5, dtype=double) -> double
randn(3, dtype=float32) + tensor(5.5, dtype=double) -> float32
tensor(5.5, dtype=float16) + 2.2 -> float16
tensor(5.5, dtype=float16) + 100000 -> float16 # NOTE: inf
tensor(5.5, dtype=float16) + tensor(100000.0) -> float32 (default tensor dtype)
Appendix:
Why don't we use NumPy's behavior?
NumPy's result_type logic has two undesirable behaviors. The first is that requires examining the actual value of scalars (and 0-dim arrays). This would require a host-device synchronization for 0-dim CUDA tensors. The second is that it often up-promotes 0-dim arrays to float64 or int64. For example:
type(np.array(4.0, dtype=np.float32) + 1) -> np.float64
type(np.array(0, dtype=np.uint8) + 1) -> np.int64