Skip to content

argmax for half datatype. #28623

@lcskrishna

Description

@lcskrishna

🐛 Bug

In pytorch version 1.2.0, argmax is supported for half datatype.
However, in the current version, it generates a runtime error that argmax_cuda is not supported for half.

To Reproduce

Steps to reproduce the behavior:

Version : 1.2.0a0+eddebaf

>>> import torch
>>> torch.__version__
'1.2.0a0+eddebaf'
>>> a = torch.randn(1,1,2,2).cuda().half()
>>> print(a)
tensor([[[[ 2.3008,  0.9243],
          [-0.2551, -0.0803]]]], device='cuda:0', dtype=torch.float16)
>>> a.argmax()
tensor(0, device='cuda:0')

Where as with the current versions,

>>> import torch
>>> torch.__version__
'1.4.0a0+bc69744'
>>> a = torch.randn(1,1,2,2).cuda().half()
>>> print(a)
tensor([[[[-1.3799,  0.4214],
          [ 0.5303,  0.2732]]]], device='cuda:0', dtype=torch.float16)
>>> a.argmax()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: "argmax_cuda" not implemented for 'Half'

I have seen that there were some changes done in this PR #26181
So, is there a way to run argmax with fp16?

cc @ezyang @gchanan @zou3519 @jerryzh168

Metadata

Metadata

Assignees

Labels

high prioritymodule: cudaRelated to torch.cuda, and CUDA support in generalmodule: regressionIt used to work, and now it doesn'ttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions