@@ -1748,40 +1748,43 @@ def _test_neg(self, cast):
17481748 def test_neg(self):
17491749 self._test_neg(self, lambda t: t)
17501750
1751- def test_bitwise_not(self):
1752- res = 0xffff - torch.arange(127, dtype=torch.int8)
1753- for t in ( torch.BoolTensor,
1754- torch.ByteTensor , torch.LongTensor , torch.IntTensor , torch.ShortTensor , torch.CharTensor ):
1755- if t == torch.BoolTensor :
1756- a = torch.tensor([True, False])
1757- expected_res = torch.tensor([False, True])
1751+ @staticmethod
1752+ def _test_bitwise_not(self, device):
1753+ res = 0xffff - torch.arange(127, dtype= torch.int8, device=device)
1754+ for dtype in (torch.bool, torch.uint8 , torch.int8 , torch.int16 , torch.int32 , torch.int64 ):
1755+ if dtype == torch.bool :
1756+ a = torch.tensor([True, False], device=device )
1757+ expected_res = torch.tensor([False, True], device=device )
17581758 else:
1759- a = torch.arange(127, dtype=t. dtype)
1760- expected_res = res.type(t )
1759+ a = torch.arange(127, dtype=dtype, device=device )
1760+ expected_res = res.type(dtype )
17611761 # new tensor
17621762 self.assertEqual(expected_res, a.bitwise_not())
17631763 # out
1764- b = t( )
1764+ b = torch.empty(0, dtype=dtype, device=device )
17651765 torch.bitwise_not(a, out=b)
17661766 self.assertEqual(expected_res, b)
17671767 # in-place
17681768 a.bitwise_not_()
17691769 self.assertEqual(expected_res, a)
17701770
17711771 # test exceptions
1772- for t in(torch.HalfTensor , torch.FloatTensor , torch.DoubleTensor ):
1773- a = torch.zeros(10, dtype=t. dtype)
1772+ for dtype in(torch.half , torch.float , torch.double ):
1773+ a = torch.zeros(10, dtype=dtype, device=device )
17741774 # new tensor
17751775 with self.assertRaises(RuntimeError):
17761776 a.bitwise_not()
17771777 # out
1778- b = t( )
1778+ b = torch.empty(0, dtype=dtype, device=device )
17791779 with self.assertRaises(RuntimeError):
17801780 torch.bitwise_not(a, out=b)
17811781 # in-place
17821782 with self.assertRaises(RuntimeError):
17831783 a.bitwise_not_()
17841784
1785+ def test_bitwise_not(self):
1786+ self._test_bitwise_not(self, 'cpu')
1787+
17851788 def test_threshold(self):
17861789 for dtype in torch.testing.get_all_math_dtypes('cpu'):
17871790 if dtype != torch.uint8 and dtype != torch.float16:
0 commit comments