Skip to content

Conversation

@nairbv
Copy link
Collaborator

@nairbv nairbv commented Apr 24, 2019

import torch
a = torch.nn.Embedding(3, 4, sparse=True).half().cuda()
a(torch.LongTensor([1, 0]).cuda()).sum().backward()

gave: RuntimeError: torch.cuda.sparse.HalfTensor is not enabled

This PR enables sparse.HalfTensor on cuda. Still won't work for CPU.

@pytorchbot pytorchbot added module: internals Related to internal abstractions in c10 and ATen module: nn Related to torch.nn labels Apr 24, 2019
@nairbv
Copy link
Collaborator Author

nairbv commented Apr 29, 2019

@pytorchbot retest this please

@nairbv nairbv requested review from gchanan and zou3519 April 30, 2019 19:29
@nairbv
Copy link
Collaborator Author

nairbv commented May 6, 2019

@pytorchbot rebase this please

@pytorchbot
Copy link
Collaborator

Sorry, I can't merge this because there are conflicts. To merge this yourself, run the commands below:

git fetch origin master
git fetch [email protected]:nairbv/pytorch.git sparse_half_tensors
git checkout FETCH_HEAD
git merge origin/master
git push [email protected]:nairbv/pytorch.git HEAD:sparse_half_tensors

(To learn more about this bot, see Bot commands.)

@gchanan
Copy link
Contributor

gchanan commented May 7, 2019

summary from in-person discussion:
for now, let's punt on providing math ops on float16 on CPU, which means we won't be able to do:

  1. coalesce() on a CPU,Half,Sparse tensor, because the semantics of coalesce are to add the elements together that share indices.
  2. sparse_to_dense on CPU,Half,Sparse because it is currently implemented as zeros(dense) + sparse.

But this will allow us to train sparse half embeddings on cuda, which is the high priority ask.

to_dense requires add_. add is much slower than float for half types on CPU.
@nairbv
Copy link
Collaborator Author

nairbv commented May 8, 2019

@pytorchbot restest this please

@nairbv
Copy link
Collaborator Author

nairbv commented May 9, 2019

@pytorchbot retest this please

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nairbv is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request May 10, 2019
Summary:
```
import torch
a = torch.nn.Embedding(3, 4, sparse=True).half().cuda()
a(torch.LongTensor([1, 0]).cuda()).sum().backward()

```
gave: `RuntimeError: torch.cuda.sparse.HalfTensor is not enabled`

This PR enables sparse.HalfTensor on cuda. Still won't work for CPU.
Pull Request resolved: pytorch/pytorch#19695

Differential Revision: D15281162

Pulled By: nairbv

fbshipit-source-id: 0d83d946a059393bd53d8b8102e2daa9b4c02588
@facebook-github-bot
Copy link
Contributor

@nairbv merged this pull request in d68802b.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: internals Related to internal abstractions in c10 and ATen module: nn Related to torch.nn

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants