-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Description
I have an issue with using torch.max with explicit outputs, torch.max(input, dim, keepdim=False, max=None, max_indices=None). I'm running PyTorch 0.1.12. The following example:
import torch
t = torch.Tensor(3,4)
i = torch.LongTensor(1,4)
r = torch.Tensor(1,4)
a,b=torch.max(t,0)
print(a,b)
torch.max(t,0,max=r,max_indices=i)
works for the first part but throws an error in the second part:
TypeError: torch.max received an invalid combination of arguments - got (torch.FloatTensor, int, max_indices=torch.LongTensor, max=torch.FloatTensor), but expected one of:
* (torch.FloatTensor source)
* (torch.FloatTensor source, torch.FloatTensor other)
* (torch.FloatTensor source, int dim)
Metadata
Metadata
Assignees
Labels
No labels