Skip to content

Conversation

@martinraison
Copy link
Contributor

This is a follow-up to #735 . The goal was to add a similar level of support for CUDA sparse tensors. This includes:

  • sparse_mask
  • to_dense
  • contiguous
  • transpose
  • spaddmm
  • spcadd
  • mul, div
  • cadd, csub, cmul
  • (new for CPU as well) hspmm operation that does multiplies a sparse matrix with a dense matrix and returns a matrix in the form of a hybrid tensor (i.e. 1 sparse dimension, 1 dense dimension). This seems to be the most natural output format for this problem, since the output is a collection of non-zero rows
  • updated Embedding layer to work with CUDA and use sparse efficient sparse operations
  • cuSPARSE handle management (same logic as for cuBLAS)

Bonus:

  • faster contiguous for CPU sparse tensors (the previous implementation was using insertion sort)
  • faster cadd for CPU sparse tensors using blas
  • abs for ShortTensor. For some reason this was missing, and I needed it for the testing code
  • more test coverage for spares tensors
  • faster tensor comparison in tests. This comes at the cost of slightly more complex logic (which is usually bad in test code), but it is hard to avoid since converting all sparse matrices to dense in order to compare them is very slow

I hacked this quick thing together just to get a sense of the speed improvement: https://gist.github.com/martinraison/1e7c18c6f6eda87f1cb4995b0e6a22a5

With the default params I get:

  • 10 sec/batch with dense + CPU
  • 0.86 sec/batch with dense + CUDA
  • 0.15 sec/batch with sparse + CPU
  • 0.13 sec/batch with sparse + CUDA

This shouldn't be considered as a benchmark though (it measures the time for a complete training iteration, including all the python processing, and the forward/backward passes through tanh/linear/cross-entropy)

test/common.py Outdated

This comment was marked as off-topic.

This comment was marked as off-topic.

test/common.py Outdated

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@ngimel
Copy link
Collaborator

ngimel commented Mar 31, 2017

A word of caution - be very careful with cusparse, because for some routines it is calling cudaFree (grrrr!), which not only has performance implications, but will deadlock on multiGPU runs with nccl.

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@adamlerer adamlerer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is super impressive, thank you! I

General point: I think the main use case is for LookupTable, right? My benchmarks suggest that while this is much faster than dense, we still have a ways to go to match legacy LookupTable performance. I think we need to copy over the LookupTable kernels to handle the corresponding sparse ops, so we can match the perf.

https://gist.github.com/adamlerer/865c1a09000c7dc8208e1456255209c2

$ pt benchmark.py 0
cuda=False
Old done in 0.169607 s
New done in 1.51252 s
New done in 0.791161 s (no optimizer)
$ pt benchmark.py 1
cuda=True
Old done in 0.127513 s
New done in 3.06416 s
New done in 2.91579 s (no optimizer)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@adamlerer
Copy link
Contributor

@martinraison has addressed all my comments for merge - I will try to speed up the GPU sparse operations necessary for lookup table this week, but that doesn't block merging this.

@apaszke
Copy link
Contributor

apaszke commented Apr 10, 2017

I'm still unsure about the meaning and handling of contiguous tensors. In dense libs, we never make the input contiguous for the user, because it could break sharing + it makes for a much less magical API. I'm also worried that someone will be adding non-linear ops, and they will forget that non-contiguous means not only that the indices aren't ordered, but they're duplicated too. These bugs will be hard to catch, because it's unlikely that someone will test with such inputs.

Is it true that right now we don't do any contiguous() calls on the tensors that are using for gradients in Embedding?

@martinraison
Copy link
Contributor Author

martinraison commented Apr 10, 2017

@apaszke

In dense libs, we never make the input contiguous for the user, because it could break sharing + it makes for a much less magical API

@fmassa also noticed it is surprising that contiguous operates in-place. We could just rename it to contiguous_. We might also be able to enforce that operations never make the input contiguous automatically (but either make a contiguous copy or throw an exception so that the user calls contiguous_ explicitly).

I'm also worried that someone will be adding non-linear ops, and they will forget that non-contiguous means not only that the indices aren't ordered, but they're duplicated too

I believe making indices unique is an O(NlogN) operation anyway, so if we're doing that we might as well make the tensor contiguous. However I think we'd be wasting a lot of computation if we were forcing all sparse tensors to always be contiguous. For example it means you would have to spend O(NlogN) time for every single sparse gradient aggregation, rather than doing a single O(NlogN) operation when applying the gradient update.

Is it true that right now we don't do any contiguous() calls on the tensors that are using for gradients in Embedding?

In the case of Adagrad, which is non-linear, we have to call contiguous on sparse gradients. Right now only SGD and Adagrad are supported with sparse gradients (which is why sparse=True is not the default for Embedding). We were discussing with @adamlerer about the best way to support other optimizers (it requires adjusting the maths a bit to support momentum without deviating too much from the dense formulation).

@adamlerer
Copy link
Contributor

I dug into speeding up sparse for nn.Embedding based on my benchmark on GPU (https://gist.github.com/adamlerer/865c1a09000c7dc8208e1456255209c2). I think my findings apply more broadly though.

  • Most importantly, the contiguous (or reorder) operation, which merges rows with the same index, is inevitably slow on GPU because it demands a host synchronization. The reason is that the number of unique rows (calculated on GPU) becomes the new nnz, i.e. indices.size[0] / values.size[0] (which is a CPU field). So, it ends up being much faster to compute operations like spcadd (the main op for Embedding update) directly on the non-'contiguous' tensor, using indexAdd or LookupTable_accGradParameters.
    The potential problem with never compacting sparse tensors is that they can grow to unbounded size if you repeatedly call cadd without ever calling contiguous. Maybe make it the user's responsibility to call contiguous when necessary? This can be problematic when the cadd is buried in a library somewhere e.g. in autograd backward. I don't have a good answer here... We could have heuristics for when to compact a tensor. Something like this would be kinda clever (too clever?):
void THCSTensor_(cadd)(r, a, b) {
...
r->minUnique = max(a->minUnique + b->minUnique);
...
if (r->nnz / r->minUnique > COMPACTION_THRESHOLD) {
  THCSTensor_(contiguous)(r);
  r->minUnique = r->nnz;
}
  • LookupTable_accGradParameters is just computing indexAdd, yet it's implemented completely differently. The different implementation was due to a desire a couple years ago for LookupTable to be deterministic therefore not use atomicAdd. But then why does indexAdd still use atomicAdd? Either we care about determinism or we don't... my guess is that we don't any more since even cudnn isn't deterministic.
    I benchmarked then and the non-deterministic one is several times faster, regardless of how many index collision there are. We should pick if we care about determinism, use that to pick an indexAdd kernel in THC, and then delegate LookupTable, spcadd, etc. all to that.

  • Autograd backwards on nn.Embedding spends about 1ms in Python autograd code, which after speeding things up takes >90% of the time for batch sizes below 1e4 (it was taking ~50% of the time before my kernel changes). So batch=1e3 isn't interesting, we should look at batch=1e4.

@ngimel
Copy link
Collaborator

ngimel commented Apr 14, 2017

@adamlerer, a note about cudnn being non-deterministic

  1. All convolution routines have deterministic variants (unfortunately, some of them are so slow they can not be used in production).
  2. Maxpooling used to be non-deterministic if kernel is bigger than stride, but starting from cudnn v6 there's an option of using deterministic maxpooling (and cudnn maxpooling is not used in pytorch anyway).
    So I agree, for LookupTable, like in cudnn, there should be an option "deterministic and (possibly) slow" and "nondeterministic and fast"
    How did you benchmark autograd overhead on nn.Embedding? 1ms seems like a lot.

@adamlerer
Copy link
Contributor

adamlerer commented Apr 14, 2017

@apaszke re: contiguous, lets think about exactly what we mean by "weight sharing". I think you mean something like

A = torch.sparse.Tensor(...)
B = t.t() # A and B are shared views; you don't want to silently break this sharing
C = A[:10] # also a shared view

Except that this already doesn't work for sparse tensors! There's no way (at least currently) to have a view on a sparse tensor. So I don't think this matters. I'd like to think of sparse tensor reordering (which we really should give a new name than contiguous) as an internal optimization detail; so logically, reorder(A) === A (cf. C++ mutable keyword). This is not the case for dense tensors where contiguous is externally visible. Let me know if you can think of a counterexample.
Re: possibility of user error with contiguous and non-contiguous tensors, yes you're right. It's exactly the reason why non-contiguous tensors lead to so many errors in TH proper. No getting around that - if you're writing THCS functions, you need to be aware of contiguous vs non-contiguous.


@martinraison yes I agree we should have a deterministic variant. But notice that this extends to indexAdd as well. Currently LookupTable is deterministic, indexAdd is non-deterministic. So I think we should have a deterministic and non-deterministic variant of indexAdd, and LookupTable / spcAdd should just delegate to that.

I benchmarked autograd overhead like this, with small batch size like 10. It takes ~700us / batch, and CPU sits at 100%. perf top shows it's spending most of it's time in python functions as well as libTHC routines that look like they're in the memory allocator. Let me know if you have any suggestions for better profiling.

...
for i in range(N):
    emb.zero_grad()
    out = emb(vbatch)
    out.backward(out.data)
    optimizer.step()

@apaszke
Copy link
Contributor

apaszke commented Apr 18, 2017

@adamlerer so do you propose that only the reordering should be an implementation detail, or would you like to make sorting an internal thing too? The ordering isn't visible externally, is it?

@adamlerer
Copy link
Contributor

The reorder function does a sort on indices followed by segmented reduction by index. I'm not sure what you mean by sorting vs. reordering.

My thought is that reorder should have no user-visible functional implications, i.e. the coalesced tensor version should be logically equivalent to the uncoalesced one. However, reorder should be exposed to the user because it may be necessary for efficiency (it's currently called contiguous() but should be called something different like coalesce_()).

@soumith
Copy link
Contributor

soumith commented Apr 18, 2017

this is now rebased to fewer commits and merged into master

@soumith soumith closed this Apr 18, 2017
jjsjann123 pushed a commit to jjsjann123/pytorch that referenced this pull request Nov 5, 2021
hubertlu-tw pushed a commit to hubertlu-tw/pytorch that referenced this pull request Nov 1, 2022
* support for fused dense layer with cublasLt, fusion in both fprop and bprop

* fix typo causing syntax error

* add fused GEMM+gelu+GEMM modue

* fix typo for workspace size

* update cublas check for 11600

* add tests for fused dense layer

* fix CUDA 10.x path

Co-authored-by: Sukru Eryilmaz <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants