@@ -299,6 +299,78 @@ def _test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwa
299299 for aligned in (True , False ):
300300 super ()._test_forward (device , contiguous , x_dtype , rois_dtype , aligned = aligned )
301301
302+ def test_qroialign (self ):
303+ """Make sure quantized version of RoIAlign is close to float version"""
304+ pool_size = 5
305+ img_size = 10
306+ n_channels = 2
307+ num_imgs = 1
308+ dtype = torch .float
309+
310+ def make_rois (num_rois = 1000 ):
311+ rois = torch .randint (0 , img_size // 2 , size = (num_rois , 5 )).to (dtype )
312+ rois [:, 0 ] = torch .randint (0 , num_imgs , size = (num_rois ,)) # set batch index
313+ rois [:, 3 :] += rois [:, 1 :3 ] # make sure boxes aren't degenerate
314+ return rois
315+
316+ for aligned in (True , False ):
317+ for scale , zero_point in ((1 , 0 ), (2 , 10 ), (0.1 , 50 )):
318+ for qdtype in (torch .qint8 , torch .quint8 , torch .qint32 ):
319+
320+ x = torch .randint (50 , 100 , size = (num_imgs , n_channels , img_size , img_size )).to (dtype )
321+ qx = torch .quantize_per_tensor (x , scale = scale , zero_point = zero_point , dtype = qdtype )
322+
323+ rois = make_rois ()
324+ qrois = torch .quantize_per_tensor (rois , scale = scale , zero_point = zero_point , dtype = qdtype )
325+
326+ x , rois = qx .dequantize (), qrois .dequantize () # we want to pass the same inputs
327+
328+ y = ops .roi_align (
329+ x ,
330+ rois ,
331+ output_size = pool_size ,
332+ spatial_scale = 1 ,
333+ sampling_ratio = - 1 ,
334+ aligned = aligned ,
335+ )
336+ qy = ops .roi_align (
337+ qx ,
338+ qrois ,
339+ output_size = pool_size ,
340+ spatial_scale = 1 ,
341+ sampling_ratio = - 1 ,
342+ aligned = aligned ,
343+ )
344+
345+ # The output qy is itself a quantized tensor and there might have been a loss of info when it was
346+ # quantized. For a fair comparison we need to quantize y as well
347+ quantized_float_y = torch .quantize_per_tensor (y , scale = scale , zero_point = zero_point , dtype = qdtype )
348+
349+ try :
350+ # Ideally, we would assert this, which passes with (scale, zero) == (1, 0)
351+ self .assertTrue ((qy == quantized_float_y ).all ())
352+ except AssertionError :
353+ # But because the computation aren't exactly the same between the 2 RoIAlign procedures, some
354+ # rounding error may lead to a difference of 2 in the output.
355+ # For example with (scale, zero) = (2, 10), 45.00000... will be quantized to 44
356+ # but 45.00000001 will be rounded to 46. We make sure below that:
357+ # - such discrepancies between qy and quantized_float_y are very rare (less then 5%)
358+ # - any difference between qy and quantized_float_y is == scale
359+ diff_idx = torch .where (qy != quantized_float_y )
360+ num_diff = diff_idx [0 ].numel ()
361+ self .assertTrue (num_diff / qy .numel () < .05 )
362+
363+ abs_diff = torch .abs (qy [diff_idx ].dequantize () - quantized_float_y [diff_idx ].dequantize ())
364+ t_scale = torch .full_like (abs_diff , fill_value = scale )
365+ self .assertTrue (torch .allclose (abs_diff , t_scale , atol = 1e-5 ))
366+
367+ x = torch .randint (50 , 100 , size = (2 , 3 , 10 , 10 )).to (dtype )
368+ qx = torch .quantize_per_tensor (x , scale = 1 , zero_point = 0 , dtype = torch .qint8 )
369+ rois = make_rois (10 )
370+ qrois = torch .quantize_per_tensor (rois , scale = 1 , zero_point = 0 , dtype = torch .qint8 )
371+ with self .assertRaisesRegex (RuntimeError , "Only one image per batch is allowed" ):
372+ ops .roi_align (qx , qrois , output_size = pool_size )
373+
302374
303375class PSRoIAlignTester (RoIOpTester , unittest .TestCase ):
304376 def fn (self , x , rois , pool_h , pool_w , spatial_scale = 1 , sampling_ratio = - 1 , ** kwargs ):
0 commit comments