-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Random Batch Sampler Speedup #147706
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Random Batch Sampler Speedup #147706
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147706
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 648ef95 with merge base 56039b5 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
This is the speed testing code: from typing import Sized, Iterator
import timeit
import numpy as np
import pandas as pd
from torch.utils.data import Sampler, RandomSampler, BatchSampler
class RandomBatchSampler(Sampler[list[int]]):
def __init__(
self,
data_source: Sized,
replacement: bool = False,
generator: Optional[np.random.Generator] = None,
batch_size: int = 32,
drop_last: bool = False,
) -> None:
super().__init__()
self.data_source = data_source
self.replacement = replacement
self.generator = generator
self.batch_size = batch_size
self.drop_last = drop_last and len(data_source) % self.batch_size > 0
if not isinstance(self.replacement, bool):
raise TypeError(
f"replacement should be a boolean value, but got replacement={self.replacement}"
)
self.n_batches = len(data_source) // batch_size
def sample_indices(self) -> NDArray[np.int_]:
generator = (
self.generator if self.generator is not None else np.random.default_rng()
)
if self.replacement:
indices = generator.integers(0, len(self.data_source), len(self.data_source))
else:
indices = np.arange(len(self.data_source))
generator.shuffle(indices)
return indices
def __iter__(self) -> Iterator[list[int]]:
indices = self.sample_indices()
indices_batches = [
indices[i : i + self.batch_size]
for i in range(0, len(indices), self.batch_size)
]
if self.drop_last:
indices_batches.pop()
yield from indices_batches
def __len__(self):
return self.n_batches
def _iter_on_origin_sampler(batch_size, drop_last, replacement):
for _ in BatchSampler(RandomSampler(range(DATA_SIZE), replacement=replacement), batch_size=batch_size, drop_last=drop_last):
pass
def _iter_on_new_sampler(batch_size, drop_last, replacement):
for _ in RandomBatchSampler((range(DATA_SIZE)), batch_size=batch_size, drop_last=drop_last, replacement=replacement):
pass
if __name__ == '__main__':
DATA_SIZE = 100_000
AVG_TIMES = 10
data = np.zeros(DATA_SIZE)
results = []
for batch_size in [4, 8, 64, 256, 1024, 4096, 8192, 16384]:
for drop_last in [True, False]:
for replacement in [True, False]:
timer = timeit.Timer(lambda: _iter_on_origin_sampler(batch_size, drop_last, replacement))
times_original = timer.repeat(AVG_TIMES, 1)
original_avg = np.mean(times_original)
original_std = np.std(times_original)
desc_original = f"{original_avg:.4f} +- {original_std:.1e}"
timer = timeit.Timer(lambda: _iter_on_new_sampler(batch_size, drop_last, replacement))
times_new = timer.repeat(AVG_TIMES, 1)
new_avg = np.mean(times_new)
new_std = np.std(times_new)
desc_new = f"{new_avg:.4f} +- {new_std:.1e}"
speedup_percent = "%.2f" % ((1 / new_avg - 1 / original_avg) * original_avg * 100) + "%"
current_row = [batch_size, drop_last, replacement,
desc_original,
desc_new,
speedup_percent]
results.append(current_row)
columns = ["batch_size", "drop_last", "replacement", "avg and std original", "avg and std new", "speedup"]
results = pd.DataFrame(results, columns=columns)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', 1000)
print(results.to_string(index=False)) |
|
/easycla |
|
@divyanshk Please help take a look at the PR, I think you're already on it. In the meantime, I have unblock CI on your PR, please sign the CLA first following the instructions on #147706 (comment) |
|
Thanks for waiting on me @GalAvineri
Thanks. |
torch/utils/data/sampler.py
Outdated
| self, | ||
| data_source: Sized, | ||
| replacement: bool = False, | ||
| generator: Optional[np.random.Generator] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This probably shouldn't be optional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll change the arguments be exactly as in RandomSampler and BatchSampler.
torch/utils/data/sampler.py
Outdated
|
|
||
| def sample_indices(self) -> NDArray[np.int_]: | ||
| generator = ( | ||
| self.generator if self.generator is not None else np.random.default_rng() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In case self.generator is None, we can use torch.Generator()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll remove all numpy from the implementation
|
I've opened an alternative PR #149441 that generalizes the approach in this PR for other samplers beside |
|
Since #149441 got a bit complicated, I thought to proceed with this simpler PR. I removed all @divyanshk Please let me know if there is anything else you would like :) |
|
These are the speedups after removing usage of |
test/test_dataloader.py
Outdated
| for replacement in [False, True]: | ||
| for drop_last in [False, True]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use @parametrize decorator to handle this ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Of course :)
| def __len__(self) -> int: | ||
| return self.num_samples | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the linter would require 2 lines before class definition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My Bad. I didn't run the linter yet. Thank you for catching that!
torch/utils/data/sampler.py
Outdated
| self.generator = generator | ||
|
|
||
| def __iter__(self) -> Iterator[int]: | ||
| for i in torch.randperm(len(self.indices), generator=self.generator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change will go away with a rebase on latest branch.
torch/utils/data/sampler.py
Outdated
|
|
||
| def init_generator(self): | ||
| if self.generator is None: | ||
| seed = int(torch.empty((), dtype=torch.int64).random_().item()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe could be not-simplified as:
seed = int(torch.empty((), dtype=torch.int64).random_())
generator = torch.Generator().manual_seed(seed)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I originally just refactored an existing code and extracted it as a method.
If you'd like, I'll apply the change you ask :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a suggestion. it's nice that manual_seed also returns a Generator :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with the suggestion :) Thank you!
torch/utils/data/sampler.py
Outdated
| batch_size: int = 32, | ||
| drop_last: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not have defaults for batch_size and drop_last.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done :)
torch/utils/data/sampler.py
Outdated
| indices = self.sample_indices() | ||
|
|
||
| # Slicing is faster on list when batch size is small | ||
| # if self.batch_size < 16: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The if should not be commented. This is a typo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect timing, was literally about to type that, thanks
|
Thank you for the review! Let me know your thoughts about this draft :) |
…-random-batch # Conflicts: # torch/utils/data/sampler.py
…-random-batch # Conflicts: # torch/utils/data/sampler.py
| return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type] | ||
|
|
||
|
|
||
| class RandomBatchSampler(Sampler[Union[torch.Tensor, list[int]]]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In case batch_size < 16 the yielded type is list[int], otherwise it is Tensor.
Is this an issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We decide on 16 based on the benchmarking ?
Slicing is faster on list when batch size is small
Do you mind mentioning how much is the difference ? I'm just thinking if we can get away with just one type - it is not intuitive how the types switch unless the user goes looking at the code.
Also, do we see significant perf drop if we just do List[int] like in BatchSampler?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are the speedups comparison between outputting Tensor and List[int]
tensor speedup list speedup
batch_size drop_last replacement
4 True True -59.74% 70.91%
False True -48.60% 105.75%
True False -66.06% 32.98%
False False -58.86% 72.66%
8 True True -21.46% 128.56%
False True -8.93% 162.25%
True False -38.05% 73.13%
False False -29.65% 92.26%
16 True True 53.87% 156.65%
False True 63.77% 177.13%
True False 15.04% 92.64%
False False 19.29% 107.54%
32 True True 184.59% 171.52%
False True 204.93% 187.73%
True False 99.53% 102.54%
False False 106.50% 111.35%
64 True True 418.91% 165.22%
False True 442.52% 169.36%
True False 218.32% 97.65%
False False 238.59% 106.67%
128 True True 825.41% 223.34%
False True 855.68% 238.24%
True False 354.56% 127.31%
False False 375.22% 136.33%
256 True True 1229.66% 201.41%
False True 1330.88% 217.68%
True False 562.71% 119.89%
False False 528.40% 129.21%
512 True True 1883.99% 210.61%
False True 1830.45% 219.07%
True False 628.83% 118.31%
False False 744.82% 133.60%
1024 True True 2177.97% 205.58%
False True 2314.08% 214.30%
True False 605.68% 120.53%
False False 1060.29% 124.61%
2048 True True 2551.70% 208.84%
False True 2664.56% 208.69%
True False 655.49% 121.70%
False False 732.72% 133.70%
4096 True True 2781.90% 208.66%
False True 2954.09% 209.90%
True False 684.58% 121.75%
False False 683.95% 125.92%
8192 True True 2882.72% 215.19%
False True 2808.82% 209.92%
True False 669.49% 124.20%
False False 723.95% 124.53%
16384 True True 3021.89% 219.85%
False True 3062.53% 218.42%
True False 722.83% 130.07%
False False 766.00% 130.34%
Tensor speedup is negative when batch_size < 16, which is why I used the list implementation in these cases.
The perf drop between Tensor and List[int] grows with batch_size.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you'd like we could just go for the List[int] implementation.
And when I have an improvement that outputs only Tensor i'll open another PR and we can discuss this again :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@GalAvineri Let us keep the list[int] implementation. It aligns with other samplers, and we shouldn't keep have branch in return type (would be confusing for the users)
|
@divyanshk @vadimkantorov Looking forward to your response! |
## Motivation Many PRs optimizing samplers (for eg #147706, #137423) are leveraging an adhoc script for benchmarking samplers. The script and outputs are often copied over in PRs. We want to begin centralizing benchmarks for torch.utils.data components. ## What ? * This PR adds a new sub-folder in `benchmarks` for `data`. This is aimed to cover benchmarking scripts for torch.utils.data components like dataloader and sampler. * Specifically, this PR includes a simple script to time samplers. This is often "copy-pasted" in PRs optimizing samplers. Having it in a centralized location should prevent that, and allow a common standard. ## Output ``` Benchmark Results: +--------------+-------------+----------------+-----------+-----------+ | Batch Size | Drop Last | Original (s) | New (s) | Speedup | +==============+=============+================+===========+===========+ | 4 | True | 0.004 | 0.0088 | -119.62% | +--------------+-------------+----------------+-----------+-----------+ | 4 | False | 0.0083 | 0.009 | -9.23% | +--------------+-------------+----------------+-----------+-----------+ | 8 | True | 0.003 | 0.0074 | -147.64% | +--------------+-------------+----------------+-----------+-----------+ | 8 | False | 0.0054 | 0.0075 | -38.72% | +--------------+-------------+----------------+-----------+-----------+ | 64 | True | 0.0021 | 0.0056 | -161.92% | +--------------+-------------+----------------+-----------+-----------+ | 64 | False | 0.0029 | 0.0055 | -92.50% | +--------------+-------------+----------------+-----------+-----------+ | 640 | True | 0.002 | 0.0055 | -168.75% | +--------------+-------------+----------------+-----------+-----------+ | 640 | False | 0.0024 | 0.0062 | -161.35% | +--------------+-------------+----------------+-----------+-----------+ | 6400 | True | 0.0021 | 0.0055 | -160.13% | +--------------+-------------+----------------+-----------+-----------+ | 6400 | False | 0.0021 | 0.0068 | -215.46% | +--------------+-------------+----------------+-----------+-----------+ | 64000 | True | 0.0042 | 0.0065 | -55.29% | +--------------+-------------+----------------+-----------+-----------+ | 64000 | False | 0.0029 | 0.0077 | -169.56% | +--------------+-------------+----------------+-----------+-----------+ ``` Pull Request resolved: #156974 Approved by: https://github.com/ramanishsingh
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Motivation
Sampleroutputs indices using a generator, forcingBatchSamplerto iterate over the indices one-by-one before grouping them into batches.If
Samplerconstructs the whole sequence of indices before yielding it, batching could be done more efficiently over the sequence than by iterating over a generator.This occurs for example in the widely used
RandomSampler.This PR replaces iteration with slicing by merging
RandomSamplerandBatchSamplerintoRandomBatchSampler.Builds upon #137423
Benchmarking code is based on #76950
In order to support the
replacementargument I used numpy'schoicesince I couldn't find an efficient alternative in pytorch.Therefore I also used a numpy.random.Generator in the
generatorargument.If it is required to not use numpy, I could look into finding an efficient torch alternative for
choice.