Skip to content

[feature request] Jagged / padding version of torch.stack / torch.cat + some general nested tensor discussion #65156

@vadimkantorov

Description

@vadimkantorov

This is a frequent primitive for collation: #1512

import torch

def stack_jagged(tensors, fill_value = 0):
    # does not use F.pad to not allocate, could as well use F.pad's arg names instead 
    shape = [len(tensors)] + [max(t.shape[dim] for t in tensors) for dim in range(len(tensors[0].shape))]
    res = torch.full(shape, fill_value, dtype = tensors[0].dtype, device = tensors[0].device)
    for r, t in zip(res, tensors):
        r[tuple(map(slice, t.shape))] = t
    return res

if __name__ == '__main__':
    print(stack_jagged([torch.zeros(3, 4, 5), torch.zeros(3, 5, 6)]).shape)

Also, having a better indexing function than u[tuple(map(slice, t.shape))] = t would be nice - e.g. multi-dim narrow/slice function

Similar function (but also returning masks) is at https://github.com/facebookresearch/detr/blob/eb9f7e03ed8e2ed2cd55528989fe7df890bc3fc0/util/misc.py#L306

Related to this are various NestedTensors implementations

Also one useful features is to be able to provide multiples for padding, e.g. to always pad to multiples of 64 (per dimension)

cc @nikitaved @pearu @cpuhrsch

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureA request for a proper, new feature.module: nestedtensorNestedTensor tag see issue #25032module: paddingmodule: sparseRelated to torch.sparsetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions