66import pytest
77
88import torch
9- from common_utils import cache , cpu_and_gpu , needs_cuda
9+ from common_utils import cache , cpu_and_gpu , needs_cuda , set_rng_seed
1010from prototype_common_utils import assert_close , make_bounding_boxes , make_image
1111from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS
1212from prototype_transforms_kernel_infos import KERNEL_INFOS
@@ -81,6 +81,8 @@ class TestKernels:
8181 @sample_inputs
8282 @pytest .mark .parametrize ("device" , cpu_and_gpu ())
8383 def test_scripted_vs_eager (self , info , args_kwargs , device ):
84+ if info .seed is not None :
85+ set_rng_seed (info .seed )
8486 kernel_eager = info .kernel
8587 kernel_scripted = script (kernel_eager )
8688
@@ -111,6 +113,8 @@ def _unbatch(self, batch, *, data_dims):
111113 @sample_inputs
112114 @pytest .mark .parametrize ("device" , cpu_and_gpu ())
113115 def test_batched_vs_single (self , info , args_kwargs , device ):
116+ if info .seed is not None :
117+ set_rng_seed (info .seed )
114118 (batched_input , * other_args ), kwargs = args_kwargs .load (device )
115119
116120 feature_type = features .Image if features .is_simple_tensor (batched_input ) else type (batched_input )
@@ -146,6 +150,8 @@ def test_batched_vs_single(self, info, args_kwargs, device):
146150 @sample_inputs
147151 @pytest .mark .parametrize ("device" , cpu_and_gpu ())
148152 def test_no_inplace (self , info , args_kwargs , device ):
153+ if info .seed is not None :
154+ set_rng_seed (info .seed )
149155 (input , * other_args ), kwargs = args_kwargs .load (device )
150156
151157 if input .numel () == 0 :
@@ -159,6 +165,8 @@ def test_no_inplace(self, info, args_kwargs, device):
159165 @sample_inputs
160166 @needs_cuda
161167 def test_cuda_vs_cpu (self , info , args_kwargs ):
168+ if info .seed is not None :
169+ set_rng_seed (info .seed )
162170 (input_cpu , * other_args ), kwargs = args_kwargs .load ("cpu" )
163171 input_cuda = input_cpu .to ("cuda" )
164172
@@ -170,6 +178,8 @@ def test_cuda_vs_cpu(self, info, args_kwargs):
170178 @sample_inputs
171179 @pytest .mark .parametrize ("device" , cpu_and_gpu ())
172180 def test_dtype_and_device_consistency (self , info , args_kwargs , device ):
181+ if info .seed is not None :
182+ set_rng_seed (info .seed )
173183 (input , * other_args ), kwargs = args_kwargs .load (device )
174184
175185 output = info .kernel (input , * other_args , ** kwargs )
@@ -182,6 +192,8 @@ def test_dtype_and_device_consistency(self, info, args_kwargs, device):
182192
183193 @reference_inputs
184194 def test_against_reference (self , info , args_kwargs ):
195+ if info .seed is not None :
196+ set_rng_seed (info .seed )
185197 args , kwargs = args_kwargs .load ("cpu" )
186198
187199 actual = info .kernel (* args , ** kwargs )
0 commit comments