Skip to content

[feature request] Support tensors of different sizes as batch elements in DataLoader #1512

@vadimkantorov

Description

@vadimkantorov

Motivating example is returning bounding box annotation for images along with an image. An annotation list can contain variable number of boxes depending on an image, and padding them to a single length (and storing that length) may be nasty and unnecessarily complex.

import torch.utils.data
loader = torch.utils.data.DataLoader(dataset = [(torch.zeros(3, 128, 128), torch.zeros(i, 4)) for i in range(1, 3)], batch_size = 2)
for batch in loader:
    print(batch)

Currently this blows with a message below because collate wants to torch.stack batch elements, regardless if they have same size:

File "...torch/utils/data/dataloader.py", line 188, in __next__
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File ".../torch/utils/data/dataloader.py", line 110, in default_collate
    return [default_collate(samples) for samples in transposed]
  File ".../torch/utils/data/dataloader.py", line 92, in default_collate
    return torch.stack(batch, 0, out=out)
  File ".../torch/functional.py", line 56, in stack
    inputs = [t.unsqueeze(dim) for t in sequence]
RuntimeError: cannot unsqueeze empty tensor at .../torch/lib/TH/generic/THTensor.c:530

Returning a list instead of variable-sized tensor doesn't work either. Providing a custom collate isn't very nice either, since most of the default behavior needs to be copied, and the default collate doesn't allow hooks.

A solution would be either adding an easy way to extend the default collate, or changing the first collate's branch to something like:

if all(map(torch.is_tensor, batch)) and any([tensor.size() != batch[0].size() for tensor in batch]):
     return batch

As a workaround, I'm currently monkey-patching the default collate like this:

collate_old = torch.utils.data.dataloader.default_collate
torch.utils.data.dataloader.default_collate = lambda batch: batch if all(map(torch.is_tensor, batch)) and any([tensor.size() != batch[0].size() for tensor in batch]) else collate_old(batch)

cc @ssnl @VitalyFedyunin @ejguan @NivekT @cpuhrsch

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureA request for a proper, new feature.module: dataloaderRelated to torch.utils.data.DataLoader and Samplermodule: nestedtensorNestedTensor tag see issue #25032triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions