feat: Allow DistributedSampler to accept Sampler as input#171597
feat: Allow DistributedSampler to accept Sampler as input#171597ankitlade12 wants to merge 2 commits intopytorch:mainfrom
Conversation
Fixes pytorch#23430 This change enables composition of sampling strategies in distributed training by allowing DistributedSampler to wrap another Sampler. Changes: - Add Union[Dataset, Sampler] type hint to __init__ - Add _input_is_sampler flag for runtime detection - Update __iter__ to use wrapped sampler's indices when applicable - Update docstring with new usage examples - Add 7 comprehensive unit tests covering various sampler types Example usage: base_sampler = WeightedRandomSampler(weights, num_samples) sampler = DistributedSampler(base_sampler, shuffle=False)
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/171597
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b677028 with merge base b35a75b ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Some design concerns from the original issue:
|
…mpler - Propagate set_epoch (or set_seed) to child samplers - Add split_contiguous parameter to allow contiguous index blocks - Add unit tests for both new features
|
|
I've addressed both concerns in this implementation:
I've added 10 tests to verify that both the epoch propagation and the splitting strategies work as expected. |
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
|
Hey @divyanshk @ramanishsingh @aelavender — just wanted to bump this PR. It was recently marked as Stale, but it's ready for review. I've addressed the design concerns raised by @vadimkantorov (set_epoch propagation and contiguous vs. interleaved splitting) and added tests for both. Would really appreciate it if you could take a look when you get a chance. Happy to make any changes based on your feedback. Thanks! |
Fixes #23430
This change enables composition of sampling strategies in distributed training by allowing DistributedSampler to wrap another Sampler.
Changes:
Example usage:
base_sampler = WeightedRandomSampler(weights, num_samples)
sampler = DistributedSampler(base_sampler, shuffle=False)