Skip to content

Add start_dim and end_dim functionality for common reduction operations. #54766

@selflein

Description

@selflein

🚀 Feature

Add start_dim and end_dim for common reduction operations such as torch.sum, torch.mean, etc. which applies the operation to dimensions [start_dim, end_dim). This would work similar to torch.flatten(input, start_dim=0, end_dim=-1).

Motivation

I often find myself adding helper functions like this

def sum_except_batch(x, num_batch_dims=1):
    """Sums all elements of `x` except for the first `num_batch_dims` dimensions."""
    reduce_dims = list(range(num_batch_dims, x.ndimension()))
    return torch.sum(x, dim=reduce_dims)

where I explicitly have to deal with shapes in order to reduce over multiple dimensions.

Pitch

Add additional signature in addition to the existing one for reduction operations.
E.g. for torch.sum there is

torch.sum(input, dim, keepdim=False, *, dtype=None)

This issues proposes adding a further signature

torch.sum(input, start_dim=0, end_dim=-1, keepdim=False, *, dtype=None)

cc @heitorschueroff

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureA request for a proper, new feature.module: reductionstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions