Skip to content

Conversation

@GalAvineri
Copy link

@GalAvineri GalAvineri commented Mar 18, 2025

Motivation

#147706 attempts to accelerate BatchSampler over RandomSampler by utilizing the fact that RandomSampler can construct all the epoch's indices before yielding them.
This PR generalizes this approach for all samplers that share this feature (e.g SequentialSampler).

Content

This PR introduces a new sampler base class ArrayableSampler (a poor name perhaps, happy for suggestions!)
that has a function to_array which returns the entire sequence of indices, instead of yielding it.

BatchSampler is modified to call to_array if it is available, and then partition the indices into batches more efficiently.

RandomSampler and SequentialSampler are changed to inherit ArrayableSampler instead of Sampler and implement to_array.

I've also added unit tests for BatchSampler.

Results

These are the speedup results over RandomSampler and SequentialSampler

Random Sampler
                                        original(avg)  original(std)  new(avg)     new(std)    speedup     
batch_size   drop_last    replacement                                                                      
4            True         True          0.083266       0.001230       0.011100     0.000496      650.15%   
             False        True          0.097272       0.001048       0.010956     0.000122      787.86%   
             True         False         0.071846       0.001248       0.019380     0.000427      270.73%   
             False        False         0.081651       0.000393       0.019177     0.000406      325.77%   
8            True         True          0.080392       0.000948       0.006527     0.000057     1131.65%   
             False        True          0.089747       0.001443       0.006300     0.000141     1324.56%   
             True         False         0.070335       0.000481       0.014993     0.000398      369.10%   
             False        False         0.076151       0.001038       0.014292     0.000989      432.84%   
16           True         True          0.079936       0.001022       0.003918     0.000063     1940.24%   
             False        True          0.088889       0.002255       0.003966     0.000034     2141.47%   
             True         False         0.070394       0.002158       0.012234     0.000371      475.39%   
             False        False         0.073136       0.000844       0.012345     0.000358      492.46%   
32           True         True          0.079251       0.001090       0.002816     0.000034     2714.11%   
             False        True          0.086134       0.001740       0.002776     0.000021     3002.72%   
             True         False         0.068372       0.000683       0.010850     0.000388      530.14%   
             False        False         0.070534       0.000757       0.011073     0.000405      537.00%   
64           True         True          0.076503       0.000867       0.002152     0.000031     3455.23%   
             False        True          0.080709       0.000728       0.002079     0.000033     3781.88%   
             True         False         0.067604       0.000163       0.010141     0.000429      566.67%   
             False        False         0.068694       0.000324       0.010150     0.000402      576.80%   
256          True         True          0.076467       0.000447       0.001673     0.000041     4471.89%   
             False        True          0.079399       0.000464       0.001671     0.000036     4652.11%   
             True         False         0.066305       0.000353       0.009784     0.000383      577.66%   
             False        False         0.068494       0.000760       0.009861     0.000351      594.57%   
1024         True         True          0.077544       0.000437       0.001531     0.000028     4964.72%   
             False        True          0.078970       0.000251       0.001532     0.000035     5055.80%   
             True         False         0.066495       0.000693       0.009903     0.000433      571.45%   
             False        False         0.068854       0.001016       0.009248     0.000885      644.53%   
4096         True         True          0.080214       0.000778       0.001599     0.000085     4915.45%   
             False        True          0.080381       0.001041       0.001580     0.000045     4988.82%   
             True         False         0.067910       0.000534       0.009977     0.000956      580.65%   
             False        False         0.067867       0.000811       0.009625     0.000386      605.11%   
8192         True         True          0.079692       0.001228       0.001605     0.000042     4864.47%   
             False        True          0.081308       0.001007       0.001569     0.000043     5082.90%   
             True         False         0.067922       0.002579       0.009508     0.000522      614.35%   
             False        False         0.067451       0.001880       0.009628     0.000383      600.59%   
16384        True         True          0.082358       0.001515       0.001587     0.000049     5088.30%   
             False        True          0.079919       0.001728       0.001474     0.000034     5323.23%   
             True         False         0.068146       0.000946       0.010022     0.000331      579.98%   
             False        False         0.067269       0.000629       0.009658     0.000369      596.53%   

Sequential Sampler
                           original(avg)  original(std)  new(avg)     new(std)    speedup     
batch_size   drop_last                                                                        
4            True          0.011663       0.000044       0.009717     0.000065       20.04%   
             False         0.022071       0.000238       0.009743     0.000142      126.53%   
8            True          0.009131       0.000133       0.005157     0.000044       77.08%   
             False         0.014645       0.000262       0.004918     0.000120      197.81%   
16           True          0.008144       0.000128       0.002611     0.000016      211.87%   
             False         0.012597       0.000151       0.002699     0.000015      366.73%   
32           True          0.007929       0.000087       0.001406     0.000020      463.90%   
             False         0.009932       0.000150       0.001423     0.000021      598.01%   
64           True          0.006814       0.000077       0.000793     0.000014      759.12%   
             False         0.008856       0.000146       0.000789     0.000009     1022.34%   
256          True          0.006819       0.000096       0.000358     0.000009     1804.35%   
             False         0.008643       0.000073       0.000357     0.000006     2324.08%   
1024         True          0.007234       0.000107       0.000241     0.000006     2903.41%   
             False         0.008019       0.000117       0.000247     0.000007     3147.97%   
4096         True          0.007520       0.000068       0.000263     0.000093     2761.04%   
             False         0.007552       0.000134       0.000258     0.000080     2830.46%   
8192         True          0.007736       0.000096       0.000227     0.000017     3312.73%   
             False         0.007355       0.000107       0.000217     0.000007     3283.53%   
16384        True          0.009124       0.000134       0.000211     0.000017     4215.36%   
             False         0.007744       0.000100       0.000228     0.000024     3303.29%   

Note

While BatchSampler previously yielded List[int], it now yields numpy arrays instead.
Furthermore RandomSampler.to_array uses numpy generator instead of torch generator.

I'll provide speed comparisons using alternative implementations:

  1. Using numpy generator and yielding List[int].
  2. Using torch generator and yielding Tensor.
  3. Using torch generator and yielding List[int].

@GalAvineri GalAvineri requested a review from divyanshk as a code owner March 18, 2025 21:40
@pytorch-bot pytorch-bot bot added the release notes: dataloader release notes category label Mar 18, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 18, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/149441

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e08e87a with merge base f47573f (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@GalAvineri
Copy link
Author

GalAvineri commented Mar 18, 2025

This is the code comparing speed between version of BatchSampler, based on the benchmarking in #76950:

import itertools
import timeit
import numpy as np
import pandas as pd
from itertools import product
from tqdm import tqdm

from sampler import Sampler, Iterator, Iterable, Union
from sampler import SequentialSampler, RandomSampler, BatchSampler

def iter_sampler(sampler):
    for _ in sampler:
        pass

def time_sampler(sampler):
    timer = timeit.Timer(lambda: iter_sampler(sampler))
    times = timer.repeat(AVG_TIMES, 1)
    return np.mean(times), np.std(times)

class PrevBatchSampler(Sampler[list[int]]):
    """
    A copy of the previous BatchSampler
    """

    def __init__(
        self,
        sampler: Union[Sampler[int], Iterable[int]],
        batch_size: int,
        drop_last: bool,
    ) -> None:
        # Since collections.abc.Iterable does not check for `__getitem__`, which
        # is one way for an object to be an iterable, we don't do an `isinstance`
        # check here.
        if (
            not isinstance(batch_size, int)
            or isinstance(batch_size, bool)
            or batch_size <= 0
        ):
            raise ValueError(
                f"batch_size should be a positive integer value, but got batch_size={batch_size}"
            )
        if not isinstance(drop_last, bool):
            raise ValueError(
                f"drop_last should be a boolean value, but got drop_last={drop_last}"
            )
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last

    def __iter__(self) -> Iterator[list[int]]:
        # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951
        sampler_iter = iter(self.sampler)
        if self.drop_last:
            # Create multiple references to the same iterator
            args = [sampler_iter] * self.batch_size
            for batch_droplast in zip(*args):
                yield [*batch_droplast]
        else:
            batch = [*itertools.islice(sampler_iter, self.batch_size)]
            while batch:
                yield batch
                batch = [*itertools.islice(sampler_iter, self.batch_size)]

    def __len__(self) -> int:
        # Can only be called if self.sampler has __len__ implemented
        # We cannot enforce this condition, so we turn off typechecking for the
        # implementation below.
        # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
        if self.drop_last:
            return len(self.sampler) // self.batch_size  # type: ignore[arg-type]
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size  # type: ignore[arg-type]


if __name__ == '__main__':
    DATA_SIZE = 1_000_000
    AVG_TIMES = 10

    data = np.zeros(DATA_SIZE)

    batch_sizes = [4, 8, 16, 32, 64, 256, 1024, 4096, 8192, 16384]
    replacements = [True, False]
    drop_lasts = [True, False]

    results_sequential = []
    results_random = []
    for batch_size, replacement, drop_last in tqdm(product(batch_sizes, replacements, drop_lasts)):
        # Sequential
        if not replacement:
            sequential_sampler = SequentialSampler(data)
            prev_batch_sampler = PrevBatchSampler(sequential_sampler, batch_size, drop_last)
            new_batch_sampler = BatchSampler(sequential_sampler, batch_size, drop_last)

            avg_prev, std_prev = time_sampler(prev_batch_sampler)
            avg_new, std_new = time_sampler(new_batch_sampler)
            speedup = "%.2f" % ((1 / avg_new - 1 / avg_prev) * avg_prev * 100) + "%"

            row = [batch_size, drop_last, avg_prev, std_prev, avg_new, std_new, speedup]
            results_sequential.append(row)

        # Random
        sequential_sampler = RandomSampler(data, replacement)
        prev_batch_sampler = PrevBatchSampler(sequential_sampler, batch_size, drop_last)
        new_batch_sampler = BatchSampler(sequential_sampler, batch_size, drop_last)

        avg_prev, std_prev = time_sampler(prev_batch_sampler)
        avg_new, std_new = time_sampler(new_batch_sampler)
        speedup = "%.2f" % ((1 / avg_new - 1 / avg_prev) * avg_prev * 100) + "%"

        row = [batch_size, drop_last, replacement, avg_prev, std_prev, avg_new, std_new, speedup]
        results_random.append(row)


sequential_columns = ["batch_size", "drop_last",
                      "original(avg)", "original(std)", "new(avg)", "new(std)", "speedup"]
seq_df = pd.DataFrame(results_sequential, columns=sequential_columns)
seq_df = seq_df.set_index(["batch_size", "drop_last"])

random_columns = ["batch_size", "drop_last", 'replacement',
                  "original(avg)", "original(std)", "new(avg)", "new(std)", "speedup"]
random_df = pd.DataFrame(results_random, columns=random_columns)
random_df = random_df.set_index(["batch_size", "drop_last", "replacement"])

pd.set_option('display.max_columns', None)
pd.set_option('display.width', 1000)

print('Random Sampler')
print(random_df.to_string(justify='left', col_space=12))
print('\nSequential Sampler')
print(seq_df.to_string(justify='left', col_space=12))

@GalAvineri GalAvineri changed the title Added functionality of ArrayableSampler Faster Batch Sampler Mar 18, 2025
@GalAvineri GalAvineri changed the title Faster Batch Sampler Batch Sampler Speedup Mar 18, 2025
@janeyx99 janeyx99 requested a review from sivag1 March 20, 2025 19:41
@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 20, 2025
@GalAvineri
Copy link
Author

GalAvineri commented Mar 25, 2025

Here is the speedup achieved by alternative implementations:

  1. Using numpy generator and yielding numpy arrays
  2. Using numpy generator and yielding List[int]
  3. Using torch generator and yielding Tensor
  4. Using torch generator and yielding List[int]
Random Sampler
implementation                         numpy array  numpy list   torch tensor torch list  
batch_size   drop_last    replacement                                                     
4            False        False          355%        170%          -32%        -45%       
                          True           803%        312%          -11%        -31%       
             True         False          261%        114%          -41%        -55%       
                          True           643%        241%          -28%        -43%       
8            False        False          418%        172%           13%        -17%       
                          True          1294%        408%           49%          4%       
             True         False          350%        139%            1%        -27%       
                          True          1199%        370%           48%          7%       
16           False        False          477%        226%           93%         38%       
                          True          2116%        576%          196%        101%       
             True         False          469%        223%           78%         38%       
                          True          2019%        531%          172%         76%       
32           False        False          528%        263%          209%        108%       
                          True          2972%        700%          410%        211%       
             True         False          511%        256%          200%        103%       
                          True          2624%        629%          373%        182%       
64           False        False          546%        282%          340%        173%       
                          True          3826%        789%          755%        340%       
             True         False          533%        276%          331%        167%       
                          True          3600%        751%          731%        318%       
256          False        False          576%        302%          579%        267%       
                          True          4739%        887%         1778%        560%       
             True         False          558%        288%          571%        260%       
                          True          4585%        828%         1732%        530%       
1024         False        False          570%        297%          659%        298%       
                          True          4699%        832%         2360%        607%       
             True         False          562%        299%          668%        294%       
                          True          4723%        823%         2322%        605%       
4096         False        False          563%        305%          693%        313%       
                          True          4810%        809%         2648%        636%       
             True         False          581%        310%          703%        315%       
                          True          4611%        792%         2582%        609%       
8192         False        False          574%        302%          710%        314%       
                          True          5190%        883%         2849%        668%       
             True         False          567%        299%          706%        311%       
                          True          5167%        853%         2856%        672%       
16384        False        False          576%        295%          714%        311%       
                          True          5330%        840%         2982%        665%       
             True         False          594%        291%          704%        307%       
                          True          5269%        828%         2976%        643%       

Sequential Sampler
implementation            numpy array  numpy list   torch tensor torch list  
batch_size   drop_last                                                       
4            False          108%         -8%          -80%        -85%       
             True            22%        -49%          -90%        -92%       
8            False          205%        -12%          -72%        -80%       
             True           102%        -41%          -81%        -87%       
16           False          311%         -7%          -61%        -74%       
             True           206%        -33%          -73%        -81%       
32           False          555%          0%          -37%        -64%       
             True           407%        -21%          -49%        -71%       
64           False          992%          9%           16%        -49%       
             True           752%        -14%          -10%        -61%       
256          False         2243%         19%          341%        -20%       
             True          1807%         -6%          233%        -36%       
1024         False         2983%         18%         1393%         -8%       
             True          2786%          5%         1232%        -19%       
4096         False         3289%         15%         3692%         -8%       
             True          3241%         10%         3583%        -10%       
8192         False         3380%         11%         5179%         -7%       
             True          3376%          8%         5232%         -8%       
16384        False         3517%          4%         6618%         -8%       
             True          4135%         23%         7551%          5%

@GalAvineri GalAvineri force-pushed the faster-batch-sampler branch from 3e72d4a to e08e87a Compare March 30, 2025 10:54
@divyanshk
Copy link
Contributor

@albanD Would you know the guidance on using numpy in torch.utils?

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ho we should not use numpy here. Numpy is not a dependency of PyTorch, so you cannot rely on it being available!

torch.arange and similar functions should give you the same behavior for what you need here!

@GalAvineri
Copy link
Author

GalAvineri commented Apr 11, 2025

The tensor based implementation produces negative speedup for sequential sampler when batchsize <= 64.
Unfortunately I wasn't able to improve this to non-negative speedup.

I have 3 ideas to address this:

  1. Use the numpy based implementation, but add a condition on whether it's available. i.e
try:
    import numpy
    # Use new code (numpy based)
except:
    # Use original code
  1. Keep the tensor based change for RandomSampler (since it provides positive speedup for all cases).
    Remove it from SequentialSampler.

  2. Use the tensor based implementation on both RandomSampler and SequentialSampler, with a condition on the batch size. i.e

if batchsize > 64:
    # Use new code (tensor based)
else:
    # Use original code

Let me know what you think!

@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jun 12, 2025
@github-actions github-actions bot closed this Jul 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

open source release notes: dataloader release notes category Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants