Skip to content

Commit b58918f

Browse files
committed
Added fixes
1 parent 4fe6d28 commit b58918f

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

torch/utils/data/sampler.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,16 +189,16 @@ def __iter__(self) -> Iterator[int]:
189189

190190
def init_generator(self):
191191
if self.generator is None:
192-
seed = int(torch.empty((), dtype=torch.int64).random_().item())
193-
generator = torch.Generator()
194-
generator.manual_seed(seed)
192+
seed = int(torch.empty((), dtype=torch.int64).random_())
193+
generator = torch.Generator().manual_seed(seed)
195194
else:
196195
generator = self.generator
197196
return generator
198197

199198
def __len__(self) -> int:
200199
return self.num_samples
201200

201+
202202
class SubsetRandomSampler(Sampler[int]):
203203
r"""Samples elements randomly from a given list of indices, without replacement.
204204
@@ -355,9 +355,9 @@ class RandomBatchSampler(RandomSampler, BatchSampler):
355355
def __init__(
356356
self,
357357
data_source: Sized,
358+
batch_size: int,
359+
drop_last: bool,
358360
replacement: bool = False,
359-
batch_size: int = 32,
360-
drop_last: bool = False,
361361
generator=None,
362362
) -> None:
363363
RandomSampler.__init__(self, data_source, replacement, None, generator)
@@ -384,8 +384,8 @@ def __iter__(self) -> Iterator[torch.Tensor]:
384384
indices = self.sample_indices()
385385

386386
# Slicing is faster on list when batch size is small
387-
# if self.batch_size < 16:
388-
indices = indices.tolist()
387+
if self.batch_size < 16:
388+
indices = indices.tolist()
389389

390390
indices_batches = [
391391
indices[i : i + self.batch_size]

0 commit comments

Comments
 (0)