-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Added possibility to index scalars by bool masks #21030
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Added possibility to index scalars by bool masks gh-metadata: pytorch pytorch 21030 gh/izdeby/1/head
Why? :) |
| uintMask = torch.tensor(True, dtype=torch.uint8, device=device) | ||
| boolMask = torch.tensor(True, dtype=torch.bool, device=device) | ||
| self.assertEqual(a[uintMask], a[boolMask]) | ||
| self.assertEqual(a[uintMask].dtype, a[boolMask].dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test feels a little goofy, because aren't you intending to change the uintMask behavior in the near future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well once its changed i will update the test but for now i want to show that the result is identical, as expected.
|
Should there be documentation updates in these PRs? |
|
|
@pytorchbot retest this please |
ezyang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm getting somewhat nervous about the testing, because I feel like there is almost definitely existing tests for uintMask and the right thing would be adapt them to test boolMask too, but that does not appear to have been done in this PR. But if you want to do this all as a big pass when you start changing uintMask behavior in BC-breaking ways, I guess that's fine.
Stack from ghstack:
This PR is a part of a stack which will change result tensor type of comparison ops from uint8 to bool. As this change is rather big and a lot of prep work is needed, im breaking it into a stack.
Changes in this PR:
Differential Revision: D15530498