Skip to content

Commit 71f8806

Browse files
committed
cosmetic & adjust tolerance
1 parent 06f8164 commit 71f8806

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

test/onnx/test_pytorch_onnx_onnxruntime.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10993,28 +10993,28 @@ def forward(self, x):
1099310993
def test_grid_sample(self):
1099410994
n, c, h_in, w_in, h_out, w_out = 1, 1, 3, 2, 2, 4
1099510995

10996-
class Module(torch.nn.Module):
10996+
class GridSampleModule(torch.nn.Module):
1099710997

10998-
def __init__(self, mode: str, padding_mode: str, align_corners: bool) -> None:
10998+
def __init__(self, mode, padding_mode, align_corners) -> None:
1099910999
super().__init__()
1100011000
self.mode, self.padding_mode, self.align_corners = mode, padding_mode, align_corners
1100111001

1100211002
def forward(self, input, grid):
1100311003
return torch.nn.functional.grid_sample(input, grid, self.mode, self.padding_mode, self.align_corners)
1100411004

1100511005
for mode, padding_mode, align_corners in itertools.product(
11006-
("bilinear", "nearest", "bicubic"), # mode
11007-
("zeros", "border", "reflection"), # padding_mode
11008-
(True, False), # align_corners
11006+
("bilinear", "nearest", "bicubic"),
11007+
("zeros", "border", "reflection"),
11008+
(True, False),
1100911009
):
11010-
# note (mkozuki): Skip the combinations that fail locally.
11011-
if (mode, padding_mode, align_corners) in (
11012-
("bicubic", "border", True),
11013-
("bicubic", "border", False),
11014-
):
11015-
continue
11010+
atol_rtol = {}
11011+
if (mode, padding_mode) == ("bicubic", "border"):
11012+
if align_corners:
11013+
atol_rtol.update({"atol": 0.3, "rtol": 0.4})
11014+
else:
11015+
atol_rtol.update({"atol": 0.02, "rtol": 0.02})
1101611016
input, grid = torch.randn(n, c, h_in, w_in), torch.randn(n, h_out, w_out, 2)
11017-
self.run_test(Module(mode, padding_mode, align_corners), (input, grid))
11017+
self.run_test(GridSampleModule(mode, padding_mode, align_corners), (input, grid), **atol_rtol)
1101811018

1101911019

1102011020
def make_test(name, base, layer, bidirectional, initial_state,

0 commit comments

Comments
 (0)