Skip to content

Commit 574e808

Browse files
xuhdevfacebook-github-bot
authored andcommitted
Add a bitwise NOT operator for integer and Boolean types (CUDA).
Summary: Pull Request resolved: #22320 Test Plan: Imported from OSS Differential Revision: D16183578 Pulled By: colesbury fbshipit-source-id: 2f72cce5e10fd637be1ac87e1bbfe0937a661034
1 parent e2dc1fc commit 574e808

File tree

5 files changed

+57
-18
lines changed

5 files changed

+57
-18
lines changed

aten/src/ATen/native/UnaryOps.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,26 @@
3636
namespace at {
3737
namespace native {
3838

39+
Tensor bitwise_not(const Tensor& self) {
40+
Tensor result = at::empty({0}, self.options());
41+
return at::bitwise_not_out(result, self);
42+
}
43+
44+
Tensor& bitwise_not_(Tensor& self) {
45+
return at::bitwise_not_out(self, self);
46+
}
47+
48+
Tensor& bitwise_not_out(Tensor& result, const Tensor& self) {
49+
checkBackend("bitwise_not", result, self.type().backend());
50+
assert_no_internal_overlap(result, "bitwise_not");
51+
auto iter = TensorIterator::unary_op(result, self);
52+
bitwise_not_stub(iter->device_type(), *iter);
53+
#ifdef BUILD_NAMEDTENSOR
54+
at::namedinference::propagate_names(result, self);
55+
#endif
56+
return result;
57+
}
58+
3959
Tensor clamp(const Tensor& self, optional<Scalar> min, optional<Scalar> max) {
4060
Tensor result = at::empty({0}, self.options());
4161
return clamp_out(result, self, min, max);
@@ -167,7 +187,6 @@ IMPLEMENT_UNARY_OP_VEC(abs)
167187
IMPLEMENT_UNARY_OP_VEC(acos)
168188
IMPLEMENT_UNARY_OP_VEC(asin)
169189
IMPLEMENT_UNARY_OP_VEC(atan)
170-
IMPLEMENT_UNARY_OP_VEC(bitwise_not)
171190
IMPLEMENT_UNARY_OP_VEC(ceil)
172191
IMPLEMENT_UNARY_OP_VEC(cos)
173192
IMPLEMENT_UNARY_OP_VEC(cosh)

aten/src/ATen/native/cuda/UnaryOpsKernel.cu

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,28 @@
1+
#include <limits>
12
#include <ATen/native/UnaryOps.h>
23
#include <ATen/native/cuda/Loops.cuh>
34
#include <ATen/Context.h>
45
#include <ATen/Dispatch.h>
56
#include <ATen/native/cuda/Loops.cuh>
67
#include <ATen/native/DispatchStub.h>
78
#include <ATen/native/TensorIterator.h>
8-
#include <limits>
99

1010
namespace at { namespace native {
1111

12+
void bitwise_not_kernel_cuda(TensorIterator& iter) {
13+
if (iter.dtype() == ScalarType::Bool) {
14+
gpu_kernel(iter, []GPU_LAMBDA(bool a) {
15+
return !a;
16+
});
17+
} else {
18+
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_not_cuda", [&]() {
19+
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
20+
return ~a;
21+
});
22+
});
23+
}
24+
}
25+
1226
template <typename scalar_t>
1327
void fill_kernel_impl(TensorIterator& iter, Scalar value_scalar) {
1428
auto value = value_scalar.to<scalar_t>();
@@ -24,5 +38,6 @@ static void fill_kernel_cuda(TensorIterator& iter, Scalar value) {
2438
}
2539

2640
REGISTER_DISPATCH(fill_stub, &fill_kernel_cuda);
41+
REGISTER_DISPATCH(bitwise_not_stub, &bitwise_not_kernel_cuda);
2742

2843
}}

aten/src/ATen/native/native_functions.yaml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,12 +362,11 @@
362362

363363
- func: bitwise_not_(Tensor(a!) self) -> Tensor(a!)
364364
variants: method
365-
dispatch:
366-
CPU: _bitwise_not__cpu
367365

368366
- func: bitwise_not(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
369367
dispatch:
370-
CPU: _bitwise_not_out_cpu
368+
CPU: bitwise_not_out
369+
CUDA: bitwise_not_out
371370

372371
- func: blackman_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
373372

test/test_cuda.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,6 +1088,9 @@ def test_type_conversions_same_gpu(self):
10881088
def test_neg(self):
10891089
_TestTorchMixin._test_neg(self, lambda t: t.cuda())
10901090

1091+
def test_bitwise_not(self):
1092+
_TestTorchMixin._test_bitwise_not(self, 'cuda')
1093+
10911094
def test_isinf(self):
10921095
_TestTorchMixin._test_isinf(self, lambda t: t.cuda())
10931096

test/test_torch.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1748,40 +1748,43 @@ def _test_neg(self, cast):
17481748
def test_neg(self):
17491749
self._test_neg(self, lambda t: t)
17501750

1751-
def test_bitwise_not(self):
1752-
res = 0xffff - torch.arange(127, dtype=torch.int8)
1753-
for t in (torch.BoolTensor,
1754-
torch.ByteTensor, torch.LongTensor, torch.IntTensor, torch.ShortTensor, torch.CharTensor):
1755-
if t == torch.BoolTensor:
1756-
a = torch.tensor([True, False])
1757-
expected_res = torch.tensor([False, True])
1751+
@staticmethod
1752+
def _test_bitwise_not(self, device):
1753+
res = 0xffff - torch.arange(127, dtype=torch.int8, device=device)
1754+
for dtype in (torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
1755+
if dtype == torch.bool:
1756+
a = torch.tensor([True, False], device=device)
1757+
expected_res = torch.tensor([False, True], device=device)
17581758
else:
1759-
a = torch.arange(127, dtype=t.dtype)
1760-
expected_res = res.type(t)
1759+
a = torch.arange(127, dtype=dtype, device=device)
1760+
expected_res = res.type(dtype)
17611761
# new tensor
17621762
self.assertEqual(expected_res, a.bitwise_not())
17631763
# out
1764-
b = t()
1764+
b = torch.empty(0, dtype=dtype, device=device)
17651765
torch.bitwise_not(a, out=b)
17661766
self.assertEqual(expected_res, b)
17671767
# in-place
17681768
a.bitwise_not_()
17691769
self.assertEqual(expected_res, a)
17701770

17711771
# test exceptions
1772-
for t in(torch.HalfTensor, torch.FloatTensor, torch.DoubleTensor):
1773-
a = torch.zeros(10, dtype=t.dtype)
1772+
for dtype in(torch.half, torch.float, torch.double):
1773+
a = torch.zeros(10, dtype=dtype, device=device)
17741774
# new tensor
17751775
with self.assertRaises(RuntimeError):
17761776
a.bitwise_not()
17771777
# out
1778-
b = t()
1778+
b = torch.empty(0, dtype=dtype, device=device)
17791779
with self.assertRaises(RuntimeError):
17801780
torch.bitwise_not(a, out=b)
17811781
# in-place
17821782
with self.assertRaises(RuntimeError):
17831783
a.bitwise_not_()
17841784

1785+
def test_bitwise_not(self):
1786+
self._test_bitwise_not(self, 'cpu')
1787+
17851788
def test_threshold(self):
17861789
for dtype in torch.testing.get_all_math_dtypes('cpu'):
17871790
if dtype != torch.uint8 and dtype != torch.float16:

0 commit comments

Comments
 (0)