Skip to content

[numpy] torch.{all/any} : output dtype is always bool#47878

Closed
kshitij12345 wants to merge 45 commits intopytorch:masterfrom
kshitij12345:develop/numpy/all-any-bool-out
Closed

[numpy] torch.{all/any} : output dtype is always bool#47878
kshitij12345 wants to merge 45 commits intopytorch:masterfrom
kshitij12345:develop/numpy/all-any-bool-out

Conversation

@kshitij12345
Copy link
Copy Markdown
Collaborator

@kshitij12345 kshitij12345 commented Nov 12, 2020

PR summary:

#44790 (comment)

Fixes 2 and 3

Also Fixes #48352

Changes

  • Output dtype is always bool (consistent with numpy) (except for uint8, where it's uint8)
  • Uses vectorized version for all dtypes on CPU
  • Enables test for complex
  • Update doc for torch.all and torch.any

TODO

  • Update docs
  • Benchmark
  • Raise issue on XLA

* output dtype of any/all is always bool.
* remove AndOps and OrOps.
* vectorised version for all dtypes CPU.
* verify exact_dtype.
* enable complex dtype.
@dr-ci
Copy link
Copy Markdown

dr-ci bot commented Nov 12, 2020

💊 CI failures summary and remediations

As of commit b145649 (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

1 failure not recognized by patterns:

Job Step Action
CircleCI pytorch_linux_bionic_py3_8_gcc9_coverage_test1 Run tests 🔁 rerun
1 job timed out:
  • pytorch_linux_bionic_py3_8_gcc9_coverage_test1

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

This comment has been revised 106 times.

} else {
auto iter = make_reduction(
"all", result, self, dim, keepdim, self.scalar_type());
"all", result, self, dim, keepdim, /*out_dtype=*/kBool);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this invocation of make_reduction will explicitly cast input to bool type, I doubt that's what you want.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need that for CPU dynamic casting is not supported,

void binary_kernel_reduce_vec(TensorIterator& iter, func_t op, vec_func_t vop, double ident = 0) {
using traits = binary_function_traits<func_t>;
static_assert(
all_same<
typename traits::result_type,
typename traits::arg1_t,
typename traits::arg2_t>::value,
"all types must match");

And the inferred type of the data, is based on the return type of op passed to binary_kernel_reduce_vec,

static inline void reduction128(char** data, int64_t n, int64_t stride, func_t op, vec_func_t vop, bool reduce) {
VEC_LOOP_HEADER(func_t, data)

#define VEC_LOOP_HEADER(func_t, data) \
using scalar_t = typename function_traits<func_t>::result_type; \
using Vec = Vec256<scalar_t>; \
char* out_ptr = data[0]; \
(void) out_ptr;

Let me know if I am missing something 🤔
Thanks!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be possible to not always cast the inputs to bool. I think there are two cases:

  • both inputs have the same scalar type
  • the inputs have different scalar types

For the first case the conversion to bool can happen in the kernel.

For the second case the inputs do need to be copied to avoid excessive template instantiations. In this case the inputs might as well be cast to bool ahead of the kernel being called.

The first case, where the scalar types are the same, is probably much more common than different scalar types, and would be nice to optimize for. Would it be possible to address that case?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first case, where the scalar types are the same, is probably much more common than different scalar types, and would be nice to optimize for. Would it be possible to address that case?

This is already taken care of,

static TensorIterator make_reduction(
const char* name, Tensor& result, const Tensor& self, IntArrayRef dim,
bool keepdim, ScalarType in_dtype, ScalarType out_dtype)
{
// check that result type and dtype match if provided
TORCH_CHECK(
!result.defined() || result.scalar_type() == out_dtype,
name, ": provided dtype must match dtype of result. Got ",
toString(result.scalar_type()),
" and ",
toString(out_dtype),
".");
int64_t ndim = self.dim();
auto mask = make_dim_mask(dim, ndim);
allocate_reduction_result(result, self, mask, keepdim, out_dtype);
auto viewed_result = review_reduce_result(result, ndim, mask, keepdim);
namedinference::propagate_names_for_reduction(result, self, dim, keepdim);
if (self.scalar_type() == in_dtype) {
return TensorIterator::reduce_op(viewed_result, self);
}
return TensorIterator::reduce_op(viewed_result, self.to(in_dtype));
}

At the bottom there is

if (self.scalar_type() == in_dtype) {
return TensorIterator::reduce_op(viewed_result, self);
}

@kshitij12345
Copy link
Copy Markdown
Collaborator Author

kshitij12345 commented Nov 13, 2020

Benchmark

Benchmark Script

import torch
import itertools
import time
from torch.utils.benchmark import Timer
from torch.utils.benchmark import Compare
import sys

print('Using pytorch %s' % (torch.__version__))

shapes = [(32,), (32, 32), (2, 16, 32), (2, 16, 32, 32), (8, 16, 64, 64)]
num_threads = 8
dtypes = [torch.bool, torch.float]
repeats = 10

for dtype in dtypes:
    results = []
    for mat1_shape in shapes:
        mat1 = torch.randn(*mat1_shape, device='cpu').to(dtype)
        mat1_cuda = mat1.to('cuda')

        tasks = [("torch.all(mat1)", "torch.all CPU"),
                 # In CUDA, this path is taken only when training is False.
                 ("torch.all(mat1_cuda)", "torch.all CUDA")]

        timers = [Timer(stmt=stmt, num_threads=num_threads, label=f"All {dtype}", sub_label=f"{(mat1_shape)}", description=label, globals=globals()) for stmt, label in tasks]

        for i, timer in enumerate(timers * repeats):
            results.append(
                timer.blocked_autorange()
            )
            print(f"\r{i + 1} / {len(timers) * repeats}", end="")
            sys.stdout.flush()
    comparison = Compare(results)
    comparison.print()

Before PR (dtype is same as Input and hence no casting)

Using pytorch 1.8.0a0+73e121d
20 / 20[-------------------- All torch.bool --------------------]
                       |  torch.all CPU  |  torch.all CUDA
8 threads: -----------------------------------------------
      (32,)            |        3.4      |       10.1     
      (32, 32)         |        4.3      |       10.5     
      (2, 16, 32)      |        4.4      |       10.6     
      (2, 16, 32, 32)  |       48.7      |       10.6     
      (8, 16, 64, 64)  |       83.7      |       15.4     

Times are in microseconds (us).

20 / 20[------------------ All torch.float32 -------------------]
                       |  torch.all CPU  |  torch.all CUDA
8 threads: -----------------------------------------------
      (32,)            |        3.0      |       10.3     
      (32, 32)         |        7.6      |       10.4     
      (2, 16, 32)      |        7.6      |       10.4     
      (2, 16, 32, 32)  |      149.0      |       10.5     
      (8, 16, 64, 64)  |      317.3      |       15.2     

Times are in microseconds (us).

After PR

Using pytorch 1.8.0a0+d7c8d3c
20 / 20[-------------------- All torch.bool --------------------]
                       |  torch.all CPU  |  torch.all CUDA
8 threads: -----------------------------------------------
      (32,)            |        3.4      |       10.3     
      (32, 32)         |        4.4      |       10.5     
      (2, 16, 32)      |        4.5      |       10.6     
      (2, 16, 32, 32)  |       47.9      |       10.5     
      (8, 16, 64, 64)  |       83.3      |       16.0     

Times are in microseconds (us).

20 / 20[------------------ All torch.float32 -------------------]
                       |  torch.all CPU  |  torch.all CUDA
8 threads: -----------------------------------------------
      (32,)            |        5.9      |       22.4     
      (32, 32)         |        7.4      |       21.0     
      (2, 16, 32)      |        6.8      |       20.2     
      (2, 16, 32, 32)  |       63.9      |       20.8     
      (8, 16, 64, 64)  |      111.2      |       26.0     

Times are in microseconds (us).

Performance of CUDA has taken a hit. I think for CUDA we can avoid the casting.

EDIT
After avoiding the casting on CUDA.

Using pytorch 1.8.0a0+d7c8d3c
20 / 20[-------------------- All torch.bool --------------------]
                       |  torch.all CPU  |  torch.all CUDA
8 threads: -----------------------------------------------
      (32,)            |        3.5      |       10.9     
      (32, 32)         |        4.2      |       10.6     
      (2, 16, 32)      |        4.2      |       10.8     
      (2, 16, 32, 32)  |       45.1      |       11.1     
      (8, 16, 64, 64)  |       83.1      |       15.8     

Times are in microseconds (us).

20 / 20[------------------ All torch.float32 -------------------]
                       |  torch.all CPU  |  torch.all CUDA
8 threads: -----------------------------------------------
      (32,)            |        5.9      |       11.0     
      (32, 32)         |        7.3      |       11.2     
      (2, 16, 32)      |        7.2      |       11.3     
      (2, 16, 32, 32)  |       55.0      |       11.4     
      (8, 16, 64, 64)  |      113.6      |       16.6     

Times are in microseconds (us).

@codecov
Copy link
Copy Markdown

codecov bot commented Nov 13, 2020

Codecov Report

Merging #47878 (21322fd) into master (47db191) will increase coverage by 46.17%.
The diff coverage is 63.63%.

@@             Coverage Diff             @@
##           master   #47878       +/-   ##
===========================================
+ Coverage   34.74%   80.92%   +46.17%     
===========================================
  Files         460     1855     +1395     
  Lines       58048   200200   +142152     
===========================================
+ Hits        20170   162014   +141844     
- Misses      37878    38186      +308     

* dont cast input for cuda.
* update test to reflect cuda doesn't support complex.
@smessmer smessmer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 13, 2020
@kshitij12345
Copy link
Copy Markdown
Collaborator Author

@ngimel PTAL

@ngimel
Copy link
Copy Markdown
Collaborator

ngimel commented Nov 16, 2020

Please update the docs, currently docs say that any/all are unique to BoolTensor.

Copy link
Copy Markdown
Contributor

@heitorschueroff heitorschueroff left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made some comments on the docs. For the testing, are you testing for discontiguous/strided input tensors?

* update docs
* update test for non-contiguous tensor
@kshitij12345
Copy link
Copy Markdown
Collaborator Author

@mruberry Have made the necessary changes to preserve dtype when input dtype is uint8. Please review.

TORCH_CHECK(self.layout() == Layout::Strided,
"any only supports strided layout, got: ", self.layout());
// Refer [all, any : uint8 compatibility]
TORCH_CHECK(result.scalar_type() == ScalarType::Bool || result.scalar_type() == ScalarType::Byte,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs a slight update.

I suggest separating these lines into two TORCH_CHECKs. The first one checks if result's dtype is bool OR self's dtype is byte. If that check fails then the error message should report that any "expected" a boolean 'out' tensor but got ...

The second TORCH_CHECK checks that result's dtype is byte or self's dtype is not byte. If that check fails then the error message should report that any "expected" an int8 'out' tensor but got ...

These conditionals can be phrased slightly differently if you prefer.

torch.any(x, dim, out=out)
self.assertEqual(expected, out)
else:
with self.assertRaisesRegex(RuntimeError, "all only supports bool tensor for result, got"):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These error messages will need to be updated, too

self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype)

def _test_out_variant(x, dim):
out = torch.empty_like(x)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The actual output will actually be a scalar tensor, right? Let's make out a scalar tensor here, too. Otherwise when we start throwing runtime errors on resizing out= tensors we'll have to update this.

Copy link
Copy Markdown
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice updates, @kshitij12345! Two small comments inline.

I'm going to import this now to run some additional BC-compat tests on it.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@kshitij12345
Copy link
Copy Markdown
Collaborator Author

Is it ok if I push the new changes or will it stop the internal BC compat tests?

@mruberry
Copy link
Copy Markdown
Collaborator

Is it ok if I push the new changes or will it stop the internal BC compat tests?

Go ahead. It won't affect them.

@mruberry
Copy link
Copy Markdown
Collaborator

Update: this is passing our internal BC-compat tests for existing models, which is great news. Let's make those few last tweaks and land this.

@JackCaoG
Copy link
Copy Markdown
Collaborator

@mruberry If I understand correctly, the new behavior is

  1. int8 -> int8
  2. uint8 -> uint8
  3. everything_else -> bool

If that is the case I will work on the pt/xla pr

@mruberry
Copy link
Copy Markdown
Collaborator

@mruberry If I understand correctly, the new behavior is

  1. int8 -> int8
  2. uint8 -> uint8
  3. everything_else -> bool

If that is the case I will work on the pt/xla pr

Hey @JackCaoG! Was just about to ping you now that this is looking OK. It's just:

  1. uint8 -> uint8
  2. everything else -> bool

No rush. Obviously we'll wait for the pt/xla PR to be ready. Sorry that we had to thrash on this for BC concerns.

hwangdeyu pushed a commit to hwangdeyu/pytorch that referenced this pull request Jan 6, 2021
Summary:
BC-breaking note:

This PR changes the behavior of the any and all functions to always return a bool tensor. Previously these functions were only defined on bool and uint8 tensors, and when called on uint8 tensors they would also return a uint8 tensor. (When called on a bool tensor they would return a bool tensor.)

PR summary:

pytorch#44790 (comment)

Fixes 2 and 3

Also Fixes pytorch#48352

Changes
* Output dtype is always `bool` (consistent with numpy) **BC Breaking (Previously used to match the input dtype**)
* Uses vectorized version for all dtypes on CPU
* Enables test for complex
* Update doc for `torch.all` and `torch.any`

TODO
* [x] Update docs
* [x] Benchmark
* [x] Raise issue on XLA

Pull Request resolved: pytorch#47878

Reviewed By: H-Huang

Differential Revision: D25421263

Pulled By: mruberry

fbshipit-source-id: c6c681ef94004d2bcc787be61a72aa059b333e69
@JackCaoG
Copy link
Copy Markdown
Collaborator

JackCaoG commented Jan 7, 2021

Hi @mruberry XLA pr is ready, I will merge it after this one is merged

@mruberry
Copy link
Copy Markdown
Collaborator

mruberry commented Jan 7, 2021

Hi @mruberry XLA pr is ready, I will merge it after this one is merged

Awesome, thanks @JackCaoG. I'll check that this PR is mergeable later today and, if it is, notify you that I've started the merge process tomorrow morning during work hours. Then I'll ping you once it's landed.

@mruberry mruberry removed the module: bc-breaking Related to a BC-breaking change label Jan 8, 2021
@mruberry
Copy link
Copy Markdown
Collaborator

mruberry commented Jan 8, 2021

Good morning, @JackCaoG. Final tests are running. Expected merge time is around 11AM PST today. I'll keep you updated.

@JackCaoG
Copy link
Copy Markdown
Collaborator

JackCaoG commented Jan 8, 2021

@mruberry It seems like pytorch pr has been merged? I will go ahead and merge the xla pr

@mruberry
Copy link
Copy Markdown
Collaborator

mruberry commented Jan 8, 2021

@JackCaoG Yep, sorry was in a meeting. Merged.

hwangdeyu pushed a commit to hwangdeyu/pytorch that referenced this pull request Jan 14, 2021
Summary:
BC-breaking note:

This PR changes the behavior of the any and all functions to always return a bool tensor. Previously these functions were only defined on bool and uint8 tensors, and when called on uint8 tensors they would also return a uint8 tensor. (When called on a bool tensor they would return a bool tensor.)

PR summary:

pytorch#44790 (comment)

Fixes 2 and 3

Also Fixes pytorch#48352

Changes
* Output dtype is always `bool` (consistent with numpy) **BC Breaking (Previously used to match the input dtype**)
* Uses vectorized version for all dtypes on CPU
* Enables test for complex
* Update doc for `torch.all` and `torch.any`

TODO
* [x] Update docs
* [x] Benchmark
* [x] Raise issue on XLA

Pull Request resolved: pytorch#47878

Reviewed By: albanD

Differential Revision: D25714324

Pulled By: mruberry

fbshipit-source-id: a87345f725297524242d69402dfe53060521ea5d
@kshitij12345 kshitij12345 deleted the develop/numpy/all-any-bool-out branch January 15, 2021 06:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged open source Reverted triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FR] .all and .any should have dim argument

8 participants