@@ -304,20 +304,20 @@ def test_qroialign(self):
304304 pool_size = 5
305305 img_size = 10
306306 n_channels = 2
307- num_batches = 2
307+ num_imgs = 2
308308 dtype = torch .float
309309
310310 def make_rois (num_rois = 1000 ):
311311 rois = torch .randint (0 , img_size // 2 , size = (num_rois , 5 )).to (dtype )
312- rois [:, 0 ] = torch .randint (0 , num_batches , size = (num_rois ,)) # set batch index
312+ rois [:, 0 ] = torch .randint (0 , num_imgs , size = (num_rois ,)) # set batch index
313313 rois [:, 3 :] += rois [:, 1 :3 ] # make sure boxes aren't degenerate
314314 return rois
315315
316316 for aligned in (True , False ):
317317 for scale , zero_point in ((1 , 0 ), (2 , 10 ), (0.1 , 50 )):
318318 for qdtype in (torch .qint8 , torch .quint8 , torch .qint32 ):
319319
320- x = torch .randint (50 , 100 , size = (num_batches , n_channels , img_size , img_size )).to (dtype )
320+ x = torch .randint (50 , 100 , size = (num_imgs , n_channels , img_size , img_size )).to (dtype )
321321 qx = torch .quantize_per_tensor (x , scale = scale , zero_point = zero_point , dtype = qdtype )
322322
323323 rois = make_rois ()
@@ -364,6 +364,13 @@ def make_rois(num_rois=1000):
364364 t_scale = torch .full_like (abs_diff , fill_value = scale )
365365 self .assertTrue (torch .allclose (abs_diff , t_scale , atol = 1e-5 ))
366366
367+ x = torch .randint (50 , 100 , size = (129 , 3 , 10 , 10 )).to (dtype )
368+ qx = torch .quantize_per_tensor (x , scale = 0 , zero_point = 1 , dtype = torch .qint8 )
369+ rois = make_rois (10 )
370+ qrois = torch .quantize_per_tensor (rois , scale = 0 , zero_point = 1 , dtype = torch .qint8 )
371+ with self .assertRaisesRegex (RuntimeError , "There are 129 input images in the batch, but the RoIs tensor" ):
372+ ops .roi_align (qx , qrois , output_size = pool_size )
373+
367374
368375class PSRoIAlignTester (RoIOpTester , unittest .TestCase ):
369376 def fn (self , x , rois , pool_h , pool_w , spatial_scale = 1 , sampling_ratio = - 1 , ** kwargs ):
0 commit comments