-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
Motivating example is returning bounding box annotation for images along with an image. An annotation list can contain variable number of boxes depending on an image, and padding them to a single length (and storing that length) may be nasty and unnecessarily complex.
import torch.utils.data
loader = torch.utils.data.DataLoader(dataset = [(torch.zeros(3, 128, 128), torch.zeros(i, 4)) for i in range(1, 3)], batch_size = 2)
for batch in loader:
print(batch)Currently this blows with a message below because collate wants to torch.stack batch elements, regardless if they have same size:
File "...torch/utils/data/dataloader.py", line 188, in __next__
batch = self.collate_fn([self.dataset[i] for i in indices])
File ".../torch/utils/data/dataloader.py", line 110, in default_collate
return [default_collate(samples) for samples in transposed]
File ".../torch/utils/data/dataloader.py", line 92, in default_collate
return torch.stack(batch, 0, out=out)
File ".../torch/functional.py", line 56, in stack
inputs = [t.unsqueeze(dim) for t in sequence]
RuntimeError: cannot unsqueeze empty tensor at .../torch/lib/TH/generic/THTensor.c:530
Returning a list instead of variable-sized tensor doesn't work either. Providing a custom collate isn't very nice either, since most of the default behavior needs to be copied, and the default collate doesn't allow hooks.
A solution would be either adding an easy way to extend the default collate, or changing the first collate's branch to something like:
if all(map(torch.is_tensor, batch)) and any([tensor.size() != batch[0].size() for tensor in batch]):
return batchAs a workaround, I'm currently monkey-patching the default collate like this:
collate_old = torch.utils.data.dataloader.default_collate
torch.utils.data.dataloader.default_collate = lambda batch: batch if all(map(torch.is_tensor, batch)) and any([tensor.size() != batch[0].size() for tensor in batch]) else collate_old(batch)