Skip to content

torch.max with explicit outputs broken #1853

@mys007

Description

@mys007

I have an issue with using torch.max with explicit outputs, torch.max(input, dim, keepdim=False, max=None, max_indices=None). I'm running PyTorch 0.1.12. The following example:

import torch

t = torch.Tensor(3,4)
i = torch.LongTensor(1,4)
r = torch.Tensor(1,4)

a,b=torch.max(t,0)
print(a,b)

torch.max(t,0,max=r,max_indices=i)

works for the first part but throws an error in the second part:

TypeError: torch.max received an invalid combination of arguments - got (torch.FloatTensor, int, max_indices=torch.LongTensor, max=torch.FloatTensor), but expected one of:
 * (torch.FloatTensor source)
 * (torch.FloatTensor source, torch.FloatTensor other)
 * (torch.FloatTensor source, int dim)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions