-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Enabled comparison ops with named tensors #27162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@izdeby has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@izdeby has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Let's expand our testing to make sure names are propagated when we pass an out= tensor to torch.eq, torch.ne, etc
| self.assertEqual((a > 1).names, ['N', 'C']) | ||
| self.assertEqual((a < 1).names, ['N', 'C']) | ||
| self.assertEqual((a >= 1).names, ['N', 'C']) | ||
| self.assertEqual((a <= 1).names, ['N', 'C']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good to add explicit tests for the out= variants. Those can be accessed with torch.eq(a, b, out=blah).
| b = torch.randn(3, 3, names=('N', 'C'), device=device) | ||
| scalar = torch.randn([], device=device) | ||
|
|
||
| self.assertEqual((a == b).names, ['N', 'C']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: It's possible to write less code by having a list of all of the operations:
ops = [lambda a, b: a == b, lambda a, b: a != b, ...]
and then running them through a for loop. But that makes it harder to pdb into and can be less readable
| self.assertEqual((a > scalar).names, ['N', 'C']) | ||
| self.assertEqual((a < scalar).names, ['N', 'C']) | ||
| self.assertEqual((a >= scalar).names, ['N', 'C']) | ||
| self.assertEqual((a <= scalar).names, ['N', 'C']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also add support for torch.isnan in this PR? torch.isnan(x) is implemented with x != x, so it should be sufficient to add supports_named_tensor: True to its native_functions entry.
I realized that torch.isinf isn't as simple to add named tensor support for (although it does call a comparison op, it requires zeros_like), so we can punt on that for now.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@izdeby has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. The linter seems to be complaining but aside from that this looks good.
|
Here are instructions for how to submit a patch to v1.3.0 #27011, we should do that when this PR gets merged into master. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@izdeby has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Fixing this [issue](pytorch/pytorch#27077). Tested via unit tests Pull Request resolved: pytorch/pytorch#27162 Differential Revision: D17694187 Pulled By: izdeby fbshipit-source-id: 939017c91605c89a0e08e0c3f8fe21de93bba95b
|
@ailzhang, looks like this PR breaks XLA |
|
Thanks @izdeby ! The failed tests were caused by one of my PR upstreaming a patch from pytorch/xla to pytorch/pytorch. So it's all good for this PR. :D |
Summary: Fixing this [issue](pytorch#27077). Tested via unit tests Pull Request resolved: pytorch#27162 Differential Revision: D17694187 Pulled By: izdeby fbshipit-source-id: 939017c91605c89a0e08e0c3f8fe21de93bba95b
Summary: Fixing this [issue](pytorch#27077). Tested via unit tests Pull Request resolved: pytorch#27162 Differential Revision: D17694187 Pulled By: izdeby fbshipit-source-id: 939017c91605c89a0e08e0c3f8fe21de93bba95b
Fixing this issue.
Tested via unit tests