Skip to content

Commit 6e72f2f

Browse files
authored
Add seeds on Kernel Info and reduce randomness for Gaussian Blur (#6741)
* Add seeds on Kernel Info and reduce randomness for Gaussian Blur * Fix linter
1 parent 4d4711d commit 6e72f2f

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

test/prototype_transforms_kernel_infos.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,14 @@ def __init__(
4949
test_marks=None,
5050
# See InfoBase
5151
closeness_kwargs=None,
52+
seed=None,
5253
):
5354
super().__init__(id=kernel_name or kernel.__name__, test_marks=test_marks, closeness_kwargs=closeness_kwargs)
5455
self.kernel = kernel
5556
self.sample_inputs_fn = sample_inputs_fn
5657
self.reference_fn = reference_fn
5758
self.reference_inputs_fn = reference_inputs_fn
59+
self.seed = seed
5860

5961

6062
DEFAULT_IMAGE_CLOSENESS_KWARGS = dict(
@@ -1304,7 +1306,7 @@ def sample_inputs_center_crop_video():
13041306

13051307
def sample_inputs_gaussian_blur_image_tensor():
13061308
make_gaussian_blur_image_loaders = functools.partial(
1307-
make_image_loaders, sizes=["random"], color_spaces=[features.ColorSpace.RGB]
1309+
make_image_loaders, sizes=[(7, 33)], color_spaces=[features.ColorSpace.RGB]
13081310
)
13091311

13101312
for image_loader, kernel_size in itertools.product(make_gaussian_blur_image_loaders(), [5, (3, 3), [3, 3]]):
@@ -1317,7 +1319,7 @@ def sample_inputs_gaussian_blur_image_tensor():
13171319

13181320

13191321
def sample_inputs_gaussian_blur_video():
1320-
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
1322+
for video_loader in make_video_loaders(sizes=[(7, 33)], num_frames=[5]):
13211323
yield ArgsKwargs(video_loader, kernel_size=[3, 3])
13221324

13231325

@@ -1331,10 +1333,13 @@ def sample_inputs_gaussian_blur_video():
13311333
xfail_jit_python_scalar_arg("kernel_size"),
13321334
xfail_jit_python_scalar_arg("sigma"),
13331335
],
1336+
seed=0,
13341337
),
13351338
KernelInfo(
13361339
F.gaussian_blur_video,
13371340
sample_inputs_fn=sample_inputs_gaussian_blur_video,
1341+
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
1342+
seed=0,
13381343
),
13391344
]
13401345
)

test/test_prototype_transforms_functional.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77

88
import torch
9-
from common_utils import cache, cpu_and_gpu, needs_cuda
9+
from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed
1010
from prototype_common_utils import assert_close, make_bounding_boxes, make_image
1111
from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS
1212
from prototype_transforms_kernel_infos import KERNEL_INFOS
@@ -81,6 +81,8 @@ class TestKernels:
8181
@sample_inputs
8282
@pytest.mark.parametrize("device", cpu_and_gpu())
8383
def test_scripted_vs_eager(self, info, args_kwargs, device):
84+
if info.seed is not None:
85+
set_rng_seed(info.seed)
8486
kernel_eager = info.kernel
8587
kernel_scripted = script(kernel_eager)
8688

@@ -111,6 +113,8 @@ def _unbatch(self, batch, *, data_dims):
111113
@sample_inputs
112114
@pytest.mark.parametrize("device", cpu_and_gpu())
113115
def test_batched_vs_single(self, info, args_kwargs, device):
116+
if info.seed is not None:
117+
set_rng_seed(info.seed)
114118
(batched_input, *other_args), kwargs = args_kwargs.load(device)
115119

116120
feature_type = features.Image if features.is_simple_tensor(batched_input) else type(batched_input)
@@ -146,6 +150,8 @@ def test_batched_vs_single(self, info, args_kwargs, device):
146150
@sample_inputs
147151
@pytest.mark.parametrize("device", cpu_and_gpu())
148152
def test_no_inplace(self, info, args_kwargs, device):
153+
if info.seed is not None:
154+
set_rng_seed(info.seed)
149155
(input, *other_args), kwargs = args_kwargs.load(device)
150156

151157
if input.numel() == 0:
@@ -159,6 +165,8 @@ def test_no_inplace(self, info, args_kwargs, device):
159165
@sample_inputs
160166
@needs_cuda
161167
def test_cuda_vs_cpu(self, info, args_kwargs):
168+
if info.seed is not None:
169+
set_rng_seed(info.seed)
162170
(input_cpu, *other_args), kwargs = args_kwargs.load("cpu")
163171
input_cuda = input_cpu.to("cuda")
164172

@@ -170,6 +178,8 @@ def test_cuda_vs_cpu(self, info, args_kwargs):
170178
@sample_inputs
171179
@pytest.mark.parametrize("device", cpu_and_gpu())
172180
def test_dtype_and_device_consistency(self, info, args_kwargs, device):
181+
if info.seed is not None:
182+
set_rng_seed(info.seed)
173183
(input, *other_args), kwargs = args_kwargs.load(device)
174184

175185
output = info.kernel(input, *other_args, **kwargs)
@@ -182,6 +192,8 @@ def test_dtype_and_device_consistency(self, info, args_kwargs, device):
182192

183193
@reference_inputs
184194
def test_against_reference(self, info, args_kwargs):
195+
if info.seed is not None:
196+
set_rng_seed(info.seed)
185197
args, kwargs = args_kwargs.load("cpu")
186198

187199
actual = info.kernel(*args, **kwargs)

0 commit comments

Comments
 (0)