[numpy] torch.{all, any} : Extend Dtype Support#44790
[numpy] torch.{all, any} : Extend Dtype Support#44790kshitij12345 wants to merge 9 commits intopytorch:masterfrom
Conversation
💊 CI failures summary and remediationsAs of commit 1b242ef (more details on the Dr. CI page):
🕵️ 2 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
heitorschueroff
left a comment
There was a problem hiding this comment.
LGTM. Thanks for this PR, it's a very welcomed change.
Note: Now that torch.all and torch.any supports all dtypes, we should document it in the public APIs as mentioned here #44779, but this can be a separate PR.
facebook-github-bot
left a comment
There was a problem hiding this comment.
@heitorschueroff has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
heitorschueroff
left a comment
There was a problem hiding this comment.
It looks like the XLA failures are related to the changes. Could you look into what's causing it please?
|
@heitorschueroff Thanks for looking at it. As for XLA, I m not really sure what is happening. Thanks! |
|
@kshitij12345 XLA change is ready, I will merge it when this pr is merged. |
|
@JackCaoG Thanks for updating XLA. @kshitij12345 Could you rebase please, I'll merge it then. |
|
@heitorschueroff Have fixed the conflict. ROCm failure looks irrelevant. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
@heitorschueroff has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Thank you for this contribution, I'm important your changes now. |
heitorschueroff
left a comment
There was a problem hiding this comment.
It looks like I missed some details in my review. Our internal tests on phabricator are complaining. I left some comments from phabricator, they should be fairly quick to fix and then I can land it without problems.
| return c; | ||
| }, | ||
| /*ident=*/true); | ||
| if (c10::isIntegralType(iter.dtype(), /*include_bool=*/true)) { |
There was a problem hiding this comment.
include_bool -> includeBool
| return c; | ||
| }, | ||
| /*ident=*/false); | ||
| if (c10::isIntegralType(iter.dtype(), /*include_bool=*/true)) { |
There was a problem hiding this comment.
include_bool -> includeBool
| if (c10::isIntegralType(iter.dtype(), /*include_bool=*/true)) { | ||
| binary_kernel_reduce_vec( | ||
| iter, | ||
| [=](uint8_t a, uint8_t b) -> uint8_t { return a && b; }, |
There was a problem hiding this comment.
Avoid the implicit cast with:
[=](uint8_t a, uint8_t b) -> uint8_t { return ((a && b) ? 1 : 0); },
| if (c10::isIntegralType(iter.dtype(), /*include_bool=*/true)) { | ||
| binary_kernel_reduce_vec( | ||
| iter, | ||
| [=](uint8_t a, uint8_t b) -> uint8_t { return a || b; }, |
There was a problem hiding this comment.
Avoid the implicit cast with:
[=](uint8_t a, uint8_t b) -> uint8_t { return ((a && b) ? 1 : 0); },
| // true/false. | ||
| Vec256<uint8_t> c = Vec256<uint8_t>(); | ||
| for (int i = 0; i != Vec256<uint8_t>::size(); i++) { | ||
| c[i] = a[i] && b[i]; |
There was a problem hiding this comment.
Avoid implicit cast with:
c[i] = ((a[i] && b[i]) ? 1 : 0);
| [=](Vec256<uint8_t> a, Vec256<uint8_t> b) { | ||
| Vec256<uint8_t> c = Vec256<uint8_t>(); | ||
| for (int i = 0; i != Vec256<uint8_t>::size(); i++) { | ||
| c[i] = a[i] || b[i]; |
There was a problem hiding this comment.
Avoid implicit cast with:
c[i] = ((a[i] && b[i]) ? 1 : 0);
|
@heitorschueroff Done. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
@heitorschueroff has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@heitorschueroff merged this pull request in 6575e67. |
1 similar comment
|
@heitorschueroff merged this pull request in 6575e67. |
|
@JackCaoG Please merge the XLA fix. Thanks! |
|
@kshitij12345 there are a few problems with this PR
Can you please work on fixing 2) and 3) ? |
I merged XLA change. If this pr will be reverted I will revert the xla pr as well. Otherwise I will work on a companion pr to fix the result type. |
|
I'd look at the |
|
Also please don't forget to update documentation. |
|
@ngimel Behaviour in previous version (1.5.1) >>> import torch
>>> torch.__version__
'1.5.1+cu101'
>>> x = torch.zeros(3,3)
>>> x.to(torch.uint8).all()
tensor(0, dtype=torch.uint8)
>>> x.to(torch.bool).all()
tensor(False) |
|
cc @mruberry for deprecating return type for uint8. In any case, for all other types there are no bc breaking concerns, so we should implement correct behavior. |
|
I would try to update the uint8 behavior to be consistent and document the change as BC-breaking. If a scripted network relies on the current behavior (extremely unlikely) we can write an upgrader. |
Summary: BC-breaking note: This PR changes the behavior of the any and all functions to always return a bool tensor. Previously these functions were only defined on bool and uint8 tensors, and when called on uint8 tensors they would also return a uint8 tensor. (When called on a bool tensor they would return a bool tensor.) PR summary: #44790 (comment) Fixes 2 and 3 Also Fixes #48352 Changes * Output dtype is always `bool` (consistent with numpy) **BC Breaking (Previously used to match the input dtype**) * Uses vectorized version for all dtypes on CPU * Enables test for complex * Update doc for `torch.all` and `torch.any` TODO * [x] Update docs * [x] Benchmark * [x] Raise issue on XLA Pull Request resolved: #47878 Reviewed By: H-Huang Differential Revision: D25421263 Pulled By: mruberry fbshipit-source-id: c6c681ef94004d2bcc787be61a72aa059b333e69
Summary: BC-breaking note: This PR changes the behavior of the any and all functions to always return a bool tensor. Previously these functions were only defined on bool and uint8 tensors, and when called on uint8 tensors they would also return a uint8 tensor. (When called on a bool tensor they would return a bool tensor.) PR summary: pytorch#44790 (comment) Fixes 2 and 3 Also Fixes pytorch#48352 Changes * Output dtype is always `bool` (consistent with numpy) **BC Breaking (Previously used to match the input dtype**) * Uses vectorized version for all dtypes on CPU * Enables test for complex * Update doc for `torch.all` and `torch.any` TODO * [x] Update docs * [x] Benchmark * [x] Raise issue on XLA Pull Request resolved: pytorch#47878 Reviewed By: H-Huang Differential Revision: D25421263 Pulled By: mruberry fbshipit-source-id: c6c681ef94004d2bcc787be61a72aa059b333e69
Summary: BC-breaking note: This PR changes the behavior of the any and all functions to always return a bool tensor. Previously these functions were only defined on bool and uint8 tensors, and when called on uint8 tensors they would also return a uint8 tensor. (When called on a bool tensor they would return a bool tensor.) PR summary: #44790 (comment) Fixes 2 and 3 Also Fixes #48352 Changes * Output dtype is always `bool` (consistent with numpy) **BC Breaking (Previously used to match the input dtype**) * Uses vectorized version for all dtypes on CPU * Enables test for complex * Update doc for `torch.all` and `torch.any` TODO * [x] Update docs * [x] Benchmark * [x] Raise issue on XLA Pull Request resolved: #47878 Reviewed By: albanD Differential Revision: D25714324 Pulled By: mruberry fbshipit-source-id: a87345f725297524242d69402dfe53060521ea5d
Summary: BC-breaking note: This PR changes the behavior of the any and all functions to always return a bool tensor. Previously these functions were only defined on bool and uint8 tensors, and when called on uint8 tensors they would also return a uint8 tensor. (When called on a bool tensor they would return a bool tensor.) PR summary: pytorch#44790 (comment) Fixes 2 and 3 Also Fixes pytorch#48352 Changes * Output dtype is always `bool` (consistent with numpy) **BC Breaking (Previously used to match the input dtype**) * Uses vectorized version for all dtypes on CPU * Enables test for complex * Update doc for `torch.all` and `torch.any` TODO * [x] Update docs * [x] Benchmark * [x] Raise issue on XLA Pull Request resolved: pytorch#47878 Reviewed By: albanD Differential Revision: D25714324 Pulled By: mruberry fbshipit-source-id: a87345f725297524242d69402dfe53060521ea5d
Reference #44779