-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Refactor batch sampler #8958
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor batch sampler #8958
Conversation
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@pytorchbot retest this please |
|
This change makes train_dataset = \
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=args.batch_size, sampler=torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=1, rank=0), **kwargs)The reason seems to be that it creates a lot of tensors for each index element. I was able to make it fast again by adding the cast to int. |
|
Thanks for the message @alsrgv ! indices = torch.randperm(len(self.dataset), generator=g).tolist()as Can you send a PR fixing this? |
|
Fixes it for me! |
Summary: Pull Request resolved: #10361 Differential Revision: D9240798 Pulled By: ezyang fbshipit-source-id: dc4cfe79612f711bbcff34a147877df6a5f7b89f
Summary: Pull Request resolved: pytorch#10361 Differential Revision: D9240798 Pulled By: ezyang fbshipit-source-id: dc4cfe79612f711bbcff34a147877df6a5f7b89f
Summary: Pull Request resolved: pytorch#10361 Differential Revision: D9240798 Pulled By: ezyang fbshipit-source-id: dc4cfe79612f711bbcff34a147877df6a5f7b89f
Summary: Since #8958 was merged, the BatchSampler samples 0d tensors from WeightedRandomSampler instead of integers. It significantly reduces performance. This PR fix it the same way as #10361 fix DistributedSampler. Pull Request resolved: #10636 Differential Revision: D9423869 Pulled By: zou3519 fbshipit-source-id: f94da2d4cccf70e63beea6cfc3d1230b5610ae44
Summary: Since pytorch#8958 was merged, the BatchSampler samples 0d tensors from WeightedRandomSampler instead of integers. It significantly reduces performance. This PR fix it the same way as pytorch#10361 fix DistributedSampler. Pull Request resolved: pytorch#10636 Differential Revision: D9423869 Pulled By: zou3519 fbshipit-source-id: f94da2d4cccf70e63beea6cfc3d1230b5610ae44
Fixes #8652, fixes #8957