-
Notifications
You must be signed in to change notification settings - Fork 26.3k
BFloat16: enable prepacked weights's inference #48922
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 070f721 (more details on the Dr. CI page):
🕵️ 2 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
[ghstack-poisoned]
|
@VitalyFedyunin , the error is caused by runing old cpu device(only support avx2), but onednn need the device support avx512 at least for bfloat16 path, so I want to add a gloabal context like torch._C.has_mkldnn to skip bfloat16 test case if cpu device not support avx512,. is it ok for your side? |
|
@VitalyFedyunin, There has another option: fall back fp32 path in ideep if onednn is not supported on some old device, this option don't need to change pytorch code except updata ideep. if this option is ok, I will updata ideep at first. Thanks! |
@VitalyFedyunin This option sounds better to us since it does not require extra CPU arch check from the PyTorch side. IDEEP will support BF16 regardless of CPU archs. We will have a quick upgrade on IDEEP if it looks good to you. :-) |
|
Current error doesn't tell anything about avx support: Falling back to fp32 path will create invalid impression for users. It is much better to error out with nice warning when something is not supported. For tests you don't need to introduce new global context. Something simple like @skipIf(test_arch(),"not supported architecture") will suffice. |
@VitalyFedyunin Thanks for the suggestions. So, we will also add cpu arch check inside |
Differential Revision: [D25537188](https://our.internmc.facebook.com/intern/diff/D25537188) [ghstack-poisoned]
Differential Revision: [D25537188](https://our.internmc.facebook.com/intern/diff/D25537188) [ghstack-poisoned]
Differential Revision: [D25537188](https://our.internmc.facebook.com/intern/diff/D25537188) [ghstack-poisoned]
Differential Revision: [D25537188](https://our.internmc.facebook.com/intern/diff/D25537188) [ghstack-poisoned]
|
Ok we fixed tests, but can we please make sure that we are going to have nice error if someone tries it on old CPU |
Done. |
Differential Revision: [D25537188](https://our.internmc.facebook.com/intern/diff/D25537188) [ghstack-poisoned]
Differential Revision: [D25537188](https://our.internmc.facebook.com/intern/diff/D25537188) [ghstack-poisoned]
|
@VitalyFedyunin , please help review it. Thanks! |
| try: | ||
| # for bf16 path, OneDNN requires the cpu has intel avx512 with avx512bw, | ||
| # avx512vl, and avx512dq. | ||
| cmd = "grep avx512bw /proc/cpuinfo | grep avx512vl | grep avx512dq" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did we give up on testing bf16 on windows? If so, tests should be marked to skip windows, instead of relying on obscure errors that would be thrown.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For windows, bf16 cases are skiped now, the reason is that I can't find a easy way to check the device as linux.
| if isinstance(m, torch.nn.Linear): | ||
| return MkldnnLinear(m) | ||
| return MkldnnLinear(m, d) | ||
| elif isinstance(m, torch.nn.Conv1d): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is dtype argument ignored on Conv1d, Conv3d and BatchNorm? If they don't support it, we should error out. We cannot silently ignore it. Similarly, if dtype argument that's generally not supported is passed (e.g. torch.double) we should error out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dtype argument is added on Conv1d, Conv3d and will assert if given dtype is not supported.
test/test_mkldnn.py
Outdated
| self._test_serialization(mkldnn_conv2d, (x.to_mkldnn(),)) | ||
| self._test_tracing(mkldnn_conv2d, (x.to_mkldnn(),)) | ||
|
|
||
| @unittest.skipIf(not has_bf16_support(), "OneDNN bfloat16 path requires AVX512") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these tests should not be skipped if there's no bf16 support, instead they should be catching and asserting that appropriate errors are raised.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add test cases if there's no bf16 support.
Differential Revision: [D25537188](https://our.internmc.facebook.com/intern/diff/D25537188) [ghstack-poisoned]
Differential Revision: [D25537188](https://our.internmc.facebook.com/intern/diff/D25537188) [ghstack-poisoned]
ngimel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR would really benefit from better description and documentation, e.g. the fact that bfloat16 in mkldnn is supported only on avx512 processors really should be put in the description and in the Note somewhere in the code.
Similarly, I would like to understand why bias is never converted to bfloat16. If it is by design, it should be commented upon and put in a note.
Similarly, for BatchNorm parameters - if OneDNN wants them in fp32, that's fine, but there should be a comment about this.
| return MkldnnConv3d(m) | ||
| return MkldnnConv3d(m, d) | ||
| elif isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm3d): | ||
| return MkldnnBatchNorm(m) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is dtype ignored for batchnorm? Does it expect its parameters to remain in fp32?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, batchnorm's parameters remain in fp32, for MKLDNN batchnorm bf16 path, it requires input is bf16, but parameters are fp32. but for conv and linear, weights require fp32, bias can be fp32 or bf16, for good accuracy, we chose fp32 for bias.
Differential Revision: [D25537188](https://our.internmc.facebook.com/intern/diff/D25537188) [ghstack-poisoned]
Yes, I add some comments in test_mkldnn.py and torch/utils/mkldnn.py to describe them. |
Differential Revision: [D25537188](https://our.internmc.facebook.com/intern/diff/D25537188) [ghstack-poisoned]
|
@ngimel , the failed case seems not related to this PR. |
Differential Revision: [D25537188](https://our.internmc.facebook.com/intern/diff/D25537188) [ghstack-poisoned]
Differential Revision: [D25537188](https://our.internmc.facebook.com/intern/diff/D25537188) [ghstack-poisoned]
Differential Revision: [D25537188](https://our.internmc.facebook.com/intern/diff/D25537188) [ghstack-poisoned]
|
@VitalyFedyunin merged this pull request in 324c6aa. |
Summary: Pull Request resolved: pytorch#48922 Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D25537188 Pulled By: VitalyFedyunin fbshipit-source-id: ab6eb1ba8cffb5ba9d00d05db8ef616628f8c932
Stack from ghstack:
Differential Revision: D25537188