Skip to content

Comparison op support for named tensors #27077

@zou3519

Description

@zou3519

🐛 Bug

Named tensors are not supported with comparison ops. In particular, we should add support for the following:

  • torch.eq, torch.ne, torch.lt, torch.le, torch.ge, torch.gt
  • torch.isnan, torch.isinf (implementations only rely on comparison ops iirc).

To Reproduce

x = torch.randn(3, 3, names=('N', 'C'))
y = torch.randn(3, 3, names=('N', 'C'))
x == y

RuntimeError: eq is not yet supported with named tensors. Please drop names via `tensor = tensor.rename(
None)`, call the op with an unnamed tensor, and set names on the result of the operation.

Implementation

TensorIterator performs name inference already. In addition, all of our comparison ops are currently implemented with TensorIterator. To turn on this support, we should annotate the operators in native_functions.yaml with supports_named_tensor = True:

- func: lt.Scalar(Tensor self, Scalar other) -> Tensor
use_c10_dispatcher: full
variants: method, function
dispatch:
CPU: lt
CUDA: lt
QuantizedCPU: lt_quantized_cpu

That should be all that is necessary to turn on named tensor support.

Finally, we should add some tests to

def test_binary_ops(self):
def test_basic(op):
a = torch.empty(2, 3, names=('N', 'C'))
b = torch.empty(3, 2, names=('C', 'N'))
c = torch.empty(3, names=('C',))
d = torch.empty(5, names=('W',))
self.assertEqual(op(a, a).names, ('N', 'C'))
self.assertEqual(op(a, c).names, ('N', 'C'))
with self.assertRaisesRegex(RuntimeError, "do not match"):
op(a, d)
with self.assertRaisesRegex(RuntimeError, "do not match"):
op(a, b)
def test_wildcard(op):
a = torch.empty(2, 3, names=('N', 'C'))
c = torch.empty(2, 3, names=(None, 'C'))
self.assertEqual(op(a, c).names, ('N', 'C'))
b = torch.empty(2, 3)
self.assertEqual(op(a, b).names, ('N', 'C'))
d = torch.empty(2, 3, names=('C', None))
with self.assertRaisesRegex(RuntimeError, "Misaligned"):
op(d, c)
def test_mixed_unnamed_named(op, is_inplace):
named2 = torch.randn(1, 1, names=('N', 'C'))
unnamed1 = torch.randn(1)
unnamed2 = torch.randn(1, 1)
unnamed3 = torch.randn(1, 1, 1)
def compute_expected_names(tensor, other):
assert tensor.has_names() ^ other.has_names()
named = tensor if tensor.has_names() else other
unnamed = other if tensor.has_names() else tensor
unnamed_dim = unnamed.dim()
if unnamed_dim > named.dim():
return [None] * (unnamed_dim - named.dim()) + list(named.names)
else:
return named.names
inputs = itertools.chain(
itertools.product([named2], [unnamed1, unnamed2, unnamed3]),
itertools.product([unnamed1, unnamed2, unnamed3], [named2]),
)
if is_inplace:
# In-place ops have the constraint that they must not change shape.
inputs = [(a, b) for (a, b) in inputs if a.dim() >= b.dim()]
for tensor, other in inputs:
expected_names = compute_expected_names(tensor, other)
self.assertEqual(op(tensor, other).names, expected_names)
def method(name, *args, **kwargs):
return [Function(name, lambda a, b: getattr(a, name)(b, *args, **kwargs))]
def out_function(name, *args, **kwargs):
out_fn = getattr(torch, name)
def fn(a, b):
result = torch.empty([0], dtype=a.dtype, device=a.device)
out_fn(a, b, *args, out=result, **kwargs)
return result
return [Function(name, fn)]
def fn_method_and_inplace(name, *args, **kwargs):
return (
method(name, *args, **kwargs) +
method(name + '_', *args, **kwargs) +
out_function(name, *args, **kwargs)
)
tests = [
fn_method_and_inplace('add'),
fn_method_and_inplace('div'),
fn_method_and_inplace('mul'),
fn_method_and_inplace('sub'),
fn_method_and_inplace('pow'),
fn_method_and_inplace('atan2'),
method('copy_'),
]
tests = flatten(tests)
for name, op in tests:
test_basic(op)
test_wildcard(op)
test_mixed_unnamed_named(op, is_inplace=name.endswith('_'))
. The tests don't have to be a part of test_binary_ops and can be simple checks that the operators and all their variants (out-of-place, out=, in-place) do indeed propagate names.

Metadata

Metadata

Assignees

Labels

module: named tensorNamed tensor supporttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions