-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
module: named tensorNamed tensor supportNamed tensor supporttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Bug
Named tensors are not supported with comparison ops. In particular, we should add support for the following:
- torch.eq, torch.ne, torch.lt, torch.le, torch.ge, torch.gt
- torch.isnan, torch.isinf (implementations only rely on comparison ops iirc).
To Reproduce
x = torch.randn(3, 3, names=('N', 'C'))
y = torch.randn(3, 3, names=('N', 'C'))
x == y
RuntimeError: eq is not yet supported with named tensors. Please drop names via `tensor = tensor.rename(
None)`, call the op with an unnamed tensor, and set names on the result of the operation.
Implementation
TensorIterator performs name inference already. In addition, all of our comparison ops are currently implemented with TensorIterator. To turn on this support, we should annotate the operators in native_functions.yaml with supports_named_tensor = True:
pytorch/aten/src/ATen/native/native_functions.yaml
Lines 4632 to 4638 in b16358b
| - func: lt.Scalar(Tensor self, Scalar other) -> Tensor | |
| use_c10_dispatcher: full | |
| variants: method, function | |
| dispatch: | |
| CPU: lt | |
| CUDA: lt | |
| QuantizedCPU: lt_quantized_cpu |
That should be all that is necessary to turn on named tensor support.
Finally, we should add some tests to
pytorch/test/test_namedtensor.py
Lines 607 to 696 in 73347e0
| def test_binary_ops(self): | |
| def test_basic(op): | |
| a = torch.empty(2, 3, names=('N', 'C')) | |
| b = torch.empty(3, 2, names=('C', 'N')) | |
| c = torch.empty(3, names=('C',)) | |
| d = torch.empty(5, names=('W',)) | |
| self.assertEqual(op(a, a).names, ('N', 'C')) | |
| self.assertEqual(op(a, c).names, ('N', 'C')) | |
| with self.assertRaisesRegex(RuntimeError, "do not match"): | |
| op(a, d) | |
| with self.assertRaisesRegex(RuntimeError, "do not match"): | |
| op(a, b) | |
| def test_wildcard(op): | |
| a = torch.empty(2, 3, names=('N', 'C')) | |
| c = torch.empty(2, 3, names=(None, 'C')) | |
| self.assertEqual(op(a, c).names, ('N', 'C')) | |
| b = torch.empty(2, 3) | |
| self.assertEqual(op(a, b).names, ('N', 'C')) | |
| d = torch.empty(2, 3, names=('C', None)) | |
| with self.assertRaisesRegex(RuntimeError, "Misaligned"): | |
| op(d, c) | |
| def test_mixed_unnamed_named(op, is_inplace): | |
| named2 = torch.randn(1, 1, names=('N', 'C')) | |
| unnamed1 = torch.randn(1) | |
| unnamed2 = torch.randn(1, 1) | |
| unnamed3 = torch.randn(1, 1, 1) | |
| def compute_expected_names(tensor, other): | |
| assert tensor.has_names() ^ other.has_names() | |
| named = tensor if tensor.has_names() else other | |
| unnamed = other if tensor.has_names() else tensor | |
| unnamed_dim = unnamed.dim() | |
| if unnamed_dim > named.dim(): | |
| return [None] * (unnamed_dim - named.dim()) + list(named.names) | |
| else: | |
| return named.names | |
| inputs = itertools.chain( | |
| itertools.product([named2], [unnamed1, unnamed2, unnamed3]), | |
| itertools.product([unnamed1, unnamed2, unnamed3], [named2]), | |
| ) | |
| if is_inplace: | |
| # In-place ops have the constraint that they must not change shape. | |
| inputs = [(a, b) for (a, b) in inputs if a.dim() >= b.dim()] | |
| for tensor, other in inputs: | |
| expected_names = compute_expected_names(tensor, other) | |
| self.assertEqual(op(tensor, other).names, expected_names) | |
| def method(name, *args, **kwargs): | |
| return [Function(name, lambda a, b: getattr(a, name)(b, *args, **kwargs))] | |
| def out_function(name, *args, **kwargs): | |
| out_fn = getattr(torch, name) | |
| def fn(a, b): | |
| result = torch.empty([0], dtype=a.dtype, device=a.device) | |
| out_fn(a, b, *args, out=result, **kwargs) | |
| return result | |
| return [Function(name, fn)] | |
| def fn_method_and_inplace(name, *args, **kwargs): | |
| return ( | |
| method(name, *args, **kwargs) + | |
| method(name + '_', *args, **kwargs) + | |
| out_function(name, *args, **kwargs) | |
| ) | |
| tests = [ | |
| fn_method_and_inplace('add'), | |
| fn_method_and_inplace('div'), | |
| fn_method_and_inplace('mul'), | |
| fn_method_and_inplace('sub'), | |
| fn_method_and_inplace('pow'), | |
| fn_method_and_inplace('atan2'), | |
| method('copy_'), | |
| ] | |
| tests = flatten(tests) | |
| for name, op in tests: | |
| test_basic(op) | |
| test_wildcard(op) | |
| test_mixed_unnamed_named(op, is_inplace=name.endswith('_')) |
test_binary_ops and can be simple checks that the operators and all their variants (out-of-place, out=, in-place) do indeed propagate names.Metadata
Metadata
Assignees
Labels
module: named tensorNamed tensor supportNamed tensor supporttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module