-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Added Bfloat16 tensor for cpu with very limited support #21860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ort" Added Bfloat16 tensor for cpu with very limited support gh-metadata: pytorch pytorch 21860 gh/izdeby/10/head
…ort" Added Bfloat16 tensor for cpu with very limited support gh-metadata: pytorch pytorch 21860 gh/izdeby/10/head
…ort" Added Bfloat16 tensor for cpu with very limited support gh-metadata: pytorch pytorch 21860 gh/izdeby/10/head
…ort" Added Bfloat16 tensor for cpu with very limited support gh-metadata: pytorch pytorch 21860 gh/izdeby/10/head
…ort" Added Bfloat16 tensor for cpu with very limited support gh-metadata: pytorch pytorch 21860 gh/izdeby/10/head
…ort" Added Bfloat16 tensor for cpu with very limited support gh-metadata: pytorch pytorch 21860 gh/izdeby/10/head
…ort" Added Bfloat16 tensor for cpu with very limited support gh-metadata: pytorch pytorch 21860 gh/izdeby/10/head
…ort" Added Bfloat16 tensor for cpu with very limited support gh-metadata: pytorch pytorch 21860 gh/izdeby/10/head
…ort" Added Bfloat16 tensor for cpu with very limited support gh-metadata: pytorch pytorch 21860 gh/izdeby/10/head
…ort" Added Bfloat16 tensor for cpu with very limited support gh-metadata: pytorch pytorch 21860 gh/izdeby/10/head
…ort" Added Bfloat16 tensor for cpu with very limited support gh-metadata: pytorch pytorch 21860 gh/izdeby/10/head
…ort" Added Bfloat16 tensor for cpu with very limited support gh-metadata: pytorch pytorch 21860 gh/izdeby/10/head
…ort" Added Bfloat16 tensor for cpu with very limited support gh-metadata: pytorch pytorch 21860 gh/izdeby/10/head
…ort" Added Bfloat16 tensor for cpu with very limited support gh-metadata: pytorch pytorch 21860 gh/izdeby/10/head
…ort" Added Bfloat16 tensor for cpu with very limited support gh-metadata: pytorch pytorch 21860 gh/izdeby/10/head
…ort" Added Bfloat16 tensor for cpu with very limited support gh-metadata: pytorch pytorch 21860 gh/izdeby/10/head
Added Bfloat16 tensor for cpu with very limited support gh-metadata: pytorch pytorch 21860 gh/izdeby/10/head
| } \ | ||
| }() | ||
|
|
||
| #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is technically BC breaking, but probably no one uses it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should i create a separate version for AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND to have 2 and 3 scalar types to be safe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nah, I think it's fine. I'd want to know if someone actually uses this so they can complain :).
| all_dtypes = torch.testing.get_all_dtypes() | ||
| do_test_dtypes(self, all_dtypes, torch.strided, torch.device('cpu')) | ||
| if torch.cuda.is_available(): | ||
| all_dtypes.remove(torch.bfloat16) # Remove once _th_zero_ is enabled on cuda for bfloat16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as with the unfold test below, it's better to run the actual test and assert it raises an exception so you force the test to get updated. As it is, we'll probably forget to actually test this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed for all the tests in this file but cant update this test due to a weird behavior in common_utils.py. This test will be updated in the next PR in the stack and CUDA path will work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cant update this test due to a weird behavior in common_utils.py
OOC, what kind of behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In common_utils.py, in do_test_dtypes i would add something like:
if dtype == torch.bfloat16 and device == 'cuda:0': but it will be ignored even when
dtype value is torch.bfloat16 and device is "cuda:0". The problem was with comparing the device variable value to 'cuda:0' - it would never pass.
Same check would work in test_torch.py with no problems.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was not aware that device comparison with a string was a supported operation :) What I would expect to work is if you did device == torch.device('cuda:0')
Added Bfloat16 tensor for cpu with very limited support gh-metadata: pytorch pytorch 21860 gh/izdeby/10/head
Added Bfloat16 tensor for cpu with very limited support gh-metadata: pytorch pytorch 21860 gh/izdeby/10/head
gchanan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just some follow on issues to file/fix.
| for dt in torch.testing.get_all_dtypes(): | ||
| x = torch.tensor((1, 1), dtype=dt, device=device) | ||
| if (device == 'cuda' and dt == torch.bfloat16): | ||
| self.assertRaises(RuntimeError, lambda: x.clone()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: for assertRaises, is there a message you can you set for when the tests fail? It would be helpful to put something like "placeholder for future functionality -- check if you should update the test" so it's clear we don't actually intend for this to be the perfect behavior.
Added Bfloat16 tensor for cpu with very limited support gh-metadata: pytorch pytorch 21860 gh/izdeby/10/head
Summary: Pull Request resolved: pytorch/pytorch#21860 ghimport-source-id: 5290755b63033cdfdeb911a4ecf4aa282b3db02d Test Plan: Imported from OSS Differential Revision: D15856091 Pulled By: izdeby fbshipit-source-id: 54e7e17be1b5c5a2e80a41feaeaeba75dbb8108f
|
This PR broke lint. |
@ailzhang, im working of a fix PR |
Stack from ghstack:
Differential Revision: D15856091
Enabled bfloat16 on CPU for all methods in declaration.cwrap which are supported by Half.
Tested via unit tests