-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
high prioritymodule: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generalmodule: regressionIt used to work, and now it doesn'tIt used to work, and now it doesn'ttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone
Description
🐛 Bug
In pytorch version 1.2.0, argmax is supported for half datatype.
However, in the current version, it generates a runtime error that argmax_cuda is not supported for half.
To Reproduce
Steps to reproduce the behavior:
Version : 1.2.0a0+eddebaf
>>> import torch
>>> torch.__version__
'1.2.0a0+eddebaf'
>>> a = torch.randn(1,1,2,2).cuda().half()
>>> print(a)
tensor([[[[ 2.3008, 0.9243],
[-0.2551, -0.0803]]]], device='cuda:0', dtype=torch.float16)
>>> a.argmax()
tensor(0, device='cuda:0')
Where as with the current versions,
>>> import torch
>>> torch.__version__
'1.4.0a0+bc69744'
>>> a = torch.randn(1,1,2,2).cuda().half()
>>> print(a)
tensor([[[[-1.3799, 0.4214],
[ 0.5303, 0.2732]]]], device='cuda:0', dtype=torch.float16)
>>> a.argmax()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: "argmax_cuda" not implemented for 'Half'
I have seen that there were some changes done in this PR #26181
So, is there a way to run argmax with fp16?
PiotrDabrowskey, mranzinger, ehsk, BramVanroy, shatu and 1 more
Metadata
Metadata
Assignees
Labels
high prioritymodule: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generalmodule: regressionIt used to work, and now it doesn'tIt used to work, and now it doesn'ttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module