-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Open
Labels
function requestA request for a new function or the addition of new arguments/modes to an existing function.A request for a new function or the addition of new arguments/modes to an existing function.module: reductionsmodule: uxtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
I am trying to sum a tensor over its first n axes, where n is a parameter I may not know in advance. Why something like the following doesn’t work?
v = torch.randn(100, 20, 10)
torch.sum(v, range(2)) # I would write range(n) here.
This gives an error:
TypeError: sum() received an invalid combination of arguments - got (Tensor, range), but expected one of:
* (Tensor input, torch.dtype dtype)
didn't match because some of the arguments have invalid types: (Tensor, range)
* (Tensor input, tuple of names dim, bool keepdim, torch.dtype dtype, Tensor out)
* (Tensor input, tuple of ints dim, bool keepdim, torch.dtype dtype, Tensor out)
However if I instantiate the range:
v = torch.randn(100, 20, 10)
torch.sum(v, list(range(2)))
Then it works. Typing list(range(n)) everytime is inconvenient compared to just range(n).
So please allow range arguments to dim parameter of reducing operations such as sum.
Metadata
Metadata
Assignees
Labels
function requestA request for a new function or the addition of new arguments/modes to an existing function.A request for a new function or the addition of new arguments/modes to an existing function.module: reductionsmodule: uxtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module