@@ -189,16 +189,16 @@ def __iter__(self) -> Iterator[int]:
189189
190190 def init_generator (self ):
191191 if self .generator is None :
192- seed = int (torch .empty ((), dtype = torch .int64 ).random_ ().item ())
193- generator = torch .Generator ()
194- generator .manual_seed (seed )
192+ seed = int (torch .empty ((), dtype = torch .int64 ).random_ ())
193+ generator = torch .Generator ().manual_seed (seed )
195194 else :
196195 generator = self .generator
197196 return generator
198197
199198 def __len__ (self ) -> int :
200199 return self .num_samples
201200
201+
202202class SubsetRandomSampler (Sampler [int ]):
203203 r"""Samples elements randomly from a given list of indices, without replacement.
204204
@@ -355,9 +355,9 @@ class RandomBatchSampler(RandomSampler, BatchSampler):
355355 def __init__ (
356356 self ,
357357 data_source : Sized ,
358+ batch_size : int ,
359+ drop_last : bool ,
358360 replacement : bool = False ,
359- batch_size : int = 32 ,
360- drop_last : bool = False ,
361361 generator = None ,
362362 ) -> None :
363363 RandomSampler .__init__ (self , data_source , replacement , None , generator )
@@ -384,8 +384,8 @@ def __iter__(self) -> Iterator[torch.Tensor]:
384384 indices = self .sample_indices ()
385385
386386 # Slicing is faster on list when batch size is small
387- # if self.batch_size < 16:
388- indices = indices .tolist ()
387+ if self .batch_size < 16 :
388+ indices = indices .tolist ()
389389
390390 indices_batches = [
391391 indices [i : i + self .batch_size ]
0 commit comments