@@ -10993,28 +10993,28 @@ def forward(self, x):
1099310993 def test_grid_sample (self ):
1099410994 n , c , h_in , w_in , h_out , w_out = 1 , 1 , 3 , 2 , 2 , 4
1099510995
10996- class Module (torch .nn .Module ):
10996+ class GridSampleModule (torch .nn .Module ):
1099710997
10998- def __init__ (self , mode : str , padding_mode : str , align_corners : bool ) -> None :
10998+ def __init__ (self , mode , padding_mode , align_corners ) -> None :
1099910999 super ().__init__ ()
1100011000 self .mode , self .padding_mode , self .align_corners = mode , padding_mode , align_corners
1100111001
1100211002 def forward (self , input , grid ):
1100311003 return torch .nn .functional .grid_sample (input , grid , self .mode , self .padding_mode , self .align_corners )
1100411004
1100511005 for mode , padding_mode , align_corners in itertools .product (
11006- ("bilinear" , "nearest" , "bicubic" ), # mode
11007- ("zeros" , "border" , "reflection" ), # padding_mode
11008- (True , False ), # align_corners
11006+ ("bilinear" , "nearest" , "bicubic" ),
11007+ ("zeros" , "border" , "reflection" ),
11008+ (True , False ),
1100911009 ):
11010- # note (mkozuki): Skip the combinations that fail locally.
11011- if (mode , padding_mode , align_corners ) in (
11012- ( "bicubic" , "border" , True ),
11013- ( "bicubic" , "border" , False ),
11014- ) :
11015- continue
11010+ atol_rtol = {}
11011+ if (mode , padding_mode ) == ( "bicubic" , "border" ):
11012+ if align_corners :
11013+ atol_rtol . update ({ "atol" : 0.3 , "rtol" : 0.4 })
11014+ else :
11015+ atol_rtol . update ({ "atol" : 0.02 , "rtol" : 0.02 })
1101611016 input , grid = torch .randn (n , c , h_in , w_in ), torch .randn (n , h_out , w_out , 2 )
11017- self .run_test (Module (mode , padding_mode , align_corners ), (input , grid ))
11017+ self .run_test (GridSampleModule (mode , padding_mode , align_corners ), (input , grid ), ** atol_rtol )
1101811018
1101911019
1102011020def make_test (name , base , layer , bidirectional , initial_state ,
0 commit comments