-
Notifications
You must be signed in to change notification settings - Fork 27.4k
[Feature request] Let DistributedSampler take a Sampler as input #23430
Description
🚀 Feature
Motivation
Currently, DistributedSampler assumes that it takes a Dataset as argument. But in reality, the only information it exploits from it is its len.
We sometimes want to have a custom Sampler to be used in distributed mode. So it might be desirable to also let DistributedSampler take a Sampler as argument.
Potential implementation
The only difference is that in
pytorch/torch/utils/data/distributed.py
Lines 57 to 61 in 46224ef
| # subsample | |
| indices = indices[self.rank:self.total_size:self.num_replicas] | |
| assert len(indices) == self.num_samples | |
| return iter(indices) |
We would additionally have something like
if isinstance(self.dataset, Sampler):
orig_indices = list(iter(self.dataset))
indices = [orig_indices[i] for i in indices]
return iter(indices)Pitch
More modularity and code reuse
sampler = MyNiceSampler(dataset)
if distributed:
sampler = DistributedSampler(sampler)Additionally, it make writing code more (in my view) clear. Instead of
if distributed:
sampler = DistributedSampler(dataset)
else:
sampler = RandomSampler(dataset)we can always have
sampler = RandomSampler(dataset)
if distributed:
sampler = DistributedSampler(sampler, shuffle=False)which, at first sight might seem very similar, but they imply different things.
Alternatives
We can integrate the functionality of DistributedSampler inside our custom sampler, but this seems redundant.
cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @ssnl @VitalyFedyunin @ejguan @NivekT @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang