Skip to content

Allow range in dim argument of reducing operations such as sum #32288

@cossio

Description

@cossio

I am trying to sum a tensor over its first n axes, where n is a parameter I may not know in advance. Why something like the following doesn’t work?

v = torch.randn(100, 20, 10)
torch.sum(v, range(2))  # I would write range(n) here.

This gives an error:

TypeError: sum() received an invalid combination of arguments - got (Tensor, range), but expected one of:
 * (Tensor input, torch.dtype dtype)
      didn't match because some of the arguments have invalid types: (Tensor, range)
 * (Tensor input, tuple of names dim, bool keepdim, torch.dtype dtype, Tensor out)
 * (Tensor input, tuple of ints dim, bool keepdim, torch.dtype dtype, Tensor out)

However if I instantiate the range:

v = torch.randn(100, 20, 10)
torch.sum(v, list(range(2)))

Then it works. Typing list(range(n)) everytime is inconvenient compared to just range(n).

So please allow range arguments to dim parameter of reducing operations such as sum.

cc @heitorschueroff

Metadata

Metadata

Assignees

No one assigned

    Labels

    function requestA request for a new function or the addition of new arguments/modes to an existing function.module: reductionsmodule: uxtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions