-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Implement multiple dispatch #25653
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement multiple dispatch #25653
Conversation
Instead of considering only the TensorTypeSet of the first argument, we collect all Tensor and TensorList arguments and union them together before computing the dispatch type id. A minor bit of refactoring I had to do to get here was move the IterArgs functionality in torch/csrc/utils/variadic.h into ATen/core. There's some refactoring due on that file too (it has copies of some C++ helper pieces which already live in c10). There is a little bit of a hack in the code generator to turn 'self' arguments into '*this'. I think this may be duplicated with some logic somewhere else but I have to double check. Signed-off-by: Edward Z. Yang <[email protected]>
Instead of considering only the TensorTypeSet of the first argument, we collect all Tensor and TensorList arguments and union them together before computing the dispatch type id. A minor bit of refactoring I had to do to get here was move the IterArgs functionality in torch/csrc/utils/variadic.h into ATen/core. There's some refactoring due on that file too (it has copies of some C++ helper pieces which already live in c10). There is a little bit of a hack in the code generator to turn 'self' arguments into '*this'. I think this may be duplicated with some logic somewhere else but I have to double check. Signed-off-by: Edward Z. Yang <[email protected]>
Instead of considering only the TensorTypeSet of the first argument, we collect all Tensor and TensorList arguments and union them together before computing the dispatch type id. A minor bit of refactoring I had to do to get here was move the IterArgs functionality in torch/csrc/utils/variadic.h into ATen/core. There's some refactoring due on that file too (it has copies of some C++ helper pieces which already live in c10). There is a little bit of a hack in the code generator to turn 'self' arguments into '*this'. I think this may be duplicated with some logic somewhere else but I have to double check. Signed-off-by: Edward Z. Yang <[email protected]>
Instead of considering only the TensorTypeSet of the first argument, we collect all Tensor and TensorList arguments and union them together before computing the dispatch type id. A minor bit of refactoring I had to do to get here was move the IterArgs functionality in torch/csrc/utils/variadic.h into ATen/core. There's some refactoring due on that file too (it has copies of some C++ helper pieces which already live in c10). There is a little bit of a hack in the code generator to turn 'self' arguments into '*this'. I think this may be duplicated with some logic somewhere else but I have to double check. Signed-off-by: Edward Z. Yang <[email protected]>
Instead of considering only the TensorTypeSet of the first argument, we collect all Tensor and TensorList arguments and union them together before computing the dispatch type id. A minor bit of refactoring I had to do to get here was move the IterArgs functionality in torch/csrc/utils/variadic.h into ATen/core. There's some refactoring due on that file too (it has copies of some C++ helper pieces which already live in c10). There is a little bit of a hack in the code generator to turn 'self' arguments into '*this'. I think this may be duplicated with some logic somewhere else but I have to double check. Signed-off-by: Edward Z. Yang <[email protected]>
Instead of considering only the TensorTypeSet of the first argument, we collect all Tensor and TensorList arguments and union them together before computing the dispatch type id. A minor bit of refactoring I had to do to get here was move the IterArgs functionality in torch/csrc/utils/variadic.h into ATen/core. There's some refactoring due on that file too (it has copies of some C++ helper pieces which already live in c10). There is a little bit of a hack in the code generator to turn 'self' arguments into '*this'. I think this may be duplicated with some logic somewhere else but I have to double check. Signed-off-by: Edward Z. Yang <[email protected]>
Instead of considering only the TensorTypeSet of the first argument, we collect all Tensor and TensorList arguments and union them together before computing the dispatch type id. A minor bit of refactoring I had to do to get here was move the IterArgs functionality in torch/csrc/utils/variadic.h into ATen/core. There's some refactoring due on that file too (it has copies of some C++ helper pieces which already live in c10). There is a little bit of a hack in the code generator to turn 'self' arguments into '*this'. I think this may be duplicated with some logic somewhere else but I have to double check. Signed-off-by: Edward Z. Yang <[email protected]> ghstack-source-id: 89b5bf4 Pull Request resolved: #25653
Instead of considering only the TensorTypeSet of the first argument, we collect all Tensor and TensorList arguments and union them together before computing the dispatch type id. A minor bit of refactoring I had to do to get here was move the IterArgs functionality in torch/csrc/utils/variadic.h into ATen/core. There's some refactoring due on that file too (it has copies of some C++ helper pieces which already live in c10). There is a little bit of a hack in the code generator to turn 'self' arguments into '*this'. I think this may be duplicated with some logic somewhere else but I have to double check. Signed-off-by: Edward Z. Yang <[email protected]>
Instead of considering only the TensorTypeSet of the first argument, we collect all Tensor and TensorList arguments and union them together before computing the dispatch type id. A minor bit of refactoring I had to do to get here was move the IterArgs functionality in torch/csrc/utils/variadic.h into ATen/core. There's some refactoring due on that file too (it has copies of some C++ helper pieces which already live in c10). There is a little bit of a hack in the code generator to turn 'self' arguments into '*this'. I think this may be duplicated with some logic somewhere else but I have to double check. Signed-off-by: Edward Z. Yang <[email protected]>
Instead of considering only the TensorTypeSet of the first argument, we collect all Tensor and TensorList arguments and union them together before computing the dispatch type id. A minor bit of refactoring I had to do to get here was move the IterArgs functionality in torch/csrc/utils/variadic.h into ATen/core. There's some refactoring due on that file too (it has copies of some C++ helper pieces which already live in c10). There is a little bit of a hack in the code generator to turn 'self' arguments into '*this'. I think this may be duplicated with some logic somewhere else but I have to double check. Signed-off-by: Edward Z. Yang <[email protected]>
Instead of considering only the TensorTypeSet of the first argument, we collect all Tensor and TensorList arguments and union them together before computing the dispatch type id. A minor bit of refactoring I had to do to get here was move the IterArgs functionality in torch/csrc/utils/variadic.h into ATen/core. There's some refactoring due on that file too (it has copies of some C++ helper pieces which already live in c10). There is a little bit of a hack in the code generator to turn 'self' arguments into '*this'. I think this may be duplicated with some logic somewhere else but I have to double check. Signed-off-by: Edward Z. Yang <[email protected]> ghstack-source-id: 6368528 Pull Request resolved: #25653
Instead of considering only the TensorTypeSet of the first argument, we collect all Tensor and TensorList arguments and union them together before computing the dispatch type id. A minor bit of refactoring I had to do to get here was move the IterArgs functionality in torch/csrc/utils/variadic.h into ATen/core. There's some refactoring due on that file too (it has copies of some C++ helper pieces which already live in c10). There is a little bit of a hack in the code generator to turn 'self' arguments into '*this'. I think this may be duplicated with some logic somewhere else but I have to double check. Signed-off-by: Edward Z. Yang <[email protected]>
Instead of considering only the TensorTypeSet of the first argument, we collect all Tensor and TensorList arguments and union them together before computing the dispatch type id. A minor bit of refactoring I had to do to get here was move the IterArgs functionality in torch/csrc/utils/variadic.h into ATen/core. There's some refactoring due on that file too (it has copies of some C++ helper pieces which already live in c10). There is a little bit of a hack in the code generator to turn 'self' arguments into '*this'. I think this may be duplicated with some logic somewhere else but I have to double check. Signed-off-by: Edward Z. Yang <[email protected]> ghstack-source-id: f11f8a7 Pull Request resolved: #25653
|
I messed up the handling of |
Instead of considering only the TensorTypeSet of the first argument, we collect all Tensor and TensorList arguments and union them together before computing the dispatch type id. A minor bit of refactoring I had to do to get here was move the IterArgs functionality in torch/csrc/utils/variadic.h into ATen/core. There's some refactoring due on that file too (it has copies of some C++ helper pieces which already live in c10). There is a little bit of a hack in the code generator to turn 'self' arguments into '*this'. I think this may be duplicated with some logic somewhere else but I have to double check. Signed-off-by: Edward Z. Yang <[email protected]>
Instead of considering only the TensorTypeSet of the first argument, we collect all Tensor and TensorList arguments and union them together before computing the dispatch type id. A minor bit of refactoring I had to do to get here was move the IterArgs functionality in torch/csrc/utils/variadic.h into ATen/core. There's some refactoring due on that file too (it has copies of some C++ helper pieces which already live in c10). There is a little bit of a hack in the code generator to turn 'self' arguments into '*this'. I think this may be duplicated with some logic somewhere else but I have to double check. Signed-off-by: Edward Z. Yang <[email protected]>
aten/src/ATen/core/ATenDispatch.cpp
Outdated
| // us a non-empty list of tensors. So report a better error in this case. | ||
| // TODO: Maybe we should reword this error message | ||
| if (tid == TensorTypeId::UndefinedTensorId) { | ||
| TORCH_CHECK(false, "expected a non-empty list of Tensors") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once we're in this branch, perf doesn't matter anymore. You say that an empty list of tensors is usually the reason for this, but doesn't always have to be. The error message can be very misleading if it actually wasn't an empty list of tensors. Can we add an if/else here and show a correct error message?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As the code is currently structured, this unfortunately can't be conveniently done. I will have to think about how to restructure this area so I can get a good error message in this case :/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perf only doesn't matter if we assume the user doesn't use try/except as part of a serious pipeline.
aten/src/ATen/core/Variadic.h
Outdated
| self()(at::ArrayRef<T>{args}); | ||
| } | ||
|
|
||
| bool short_circuit() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make this constexpr. If we're lucky, the compiler optimizes the if/else statement in apply() away and generates faster code for iterating over the arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean, inlining should get it already, but making this constexpr won't hurt, so why not.
aten/src/ATen/function_wrapper.py
Outdated
| ', '.join(swizzle_self(t) for t in multidispatch_tensors) | ||
| ) | ||
| else: | ||
| # TODO: Err, what? If we didn't trigger multidispatch_tensors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So throw an error instead?
| "For integral input tensors, argument alpha must not be a floating point number."; | ||
|
|
||
| Tensor& add_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) { | ||
| if (other.is_sparse()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yay, fewer special cases in op kernels :)
smessmer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes look good. Only concern is the missing implementation for multi dispatch in the c10 dispatcher.
| void* ATenOpTable::getFallbackOp(TensorTypeId tid) const { | ||
| // TODO: an alternate strategy here would be to mask out the dead key | ||
| // and then redispatch gain (automatic delegation). I haven't done this | ||
| // for now to make it easier to smoke out error cases. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Redispatching sounds an interesting feature. Would it involve an auto tensor type conversion to make redispatching successful? For example, when dispatching op(t1, t2) failed with tid(t1) < tid(t2), t2 would be converted to the type of t1 before masking out tid(t2) and redispatching with tid(t1)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know! If I imagine lazy tensor, I'll have registered a boxed generic fallback function that works on everything. So that fallback would be responsible for doing the conversion (in this case, materializing the tensor, and then redispatching). So you wouldn't need the dispatcher to take care of this. I'm not sure I can think of a case where you wouldn't be able to provide a generic fallback (and it can do better than an automatic conversion.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, a generic fallback should work. The bottom line is that there is always a working CPU-dense kernel for each op and converting to CPU-dense tensor is supported by all other tensor types (correct assumption?). But such a fallback mechanism does not necessarily rely on lazy tensor feature, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generic fallbacks are specific to a tensor type. So for example, if I register a "lazy tensor" generic fallback, it will get called if there ever is a lazy tensor anywhere in the argument list.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am confused when you guys talk about Lazy Tensor ... is it our Lazy Tensor or now you have implemented some of your own? 😁
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of considering only the TensorTypeSet of the first argument, we collect all Tensor and TensorList arguments and union them together before computing the dispatch type id.
Billing of changes:
* ATenDispatch fallback code (i.e., what gets run if there is no entry for a function in the table) now lives out-of-line in a function `getFallbackOp`. This gave me an opportunity to write a more detailed error message, providing information about what registrations were available. There is a TODO in the fallback code, suggesting that we could automatically redispatch in the event that there is no handler for the key. But this is a bit of a design question, because it's not clear if automatic redispatch would cover up errors in the dispatch table (i.e., there *should* have been something registered at some key, but there wasn't.)
* Collection of Tensor/TensorList arguments is done using the trusty old IterArgs helper class. A minor bit of refactoring I had to do to get here was move the IterArgs functionality in torch/csrc/utils/variadic.h into ATen/core. There's some refactoring due on that file too (it has copies of some C++ helper pieces which already live in c10--you can't actually move the whole thing because it is literally incompatible with other code in the codebase). So instead of calling `type_set()` to get the type set of the dispatch argument, now we just call `at::detail::multi_dispatch_tensor_type_set` on all of the tensor/tensor list arguments.
* The code generator is adjusted to codegen collection of arguments as needed. There is a little bit of a hack in the code generator to turn 'self' arguments into '*this'. I think this may be duplicated with some logic somewhere else but I have to double check.
The new generated code looks like this:
```
inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) const {
static auto table = globalATenDispatch().getOpTable("aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)");
return table->getOp<Tensor & (Tensor &, const Tensor &, bool)>(at::detail::multi_dispatch_tensor_type_set(*this, src))(const_cast<Tensor&>(*this), src, non_blocking);
}
```
The key difference is that previously we wrote `type_set()` as argument to getOp; now it is a call to `multi_dispatch_tensor_type_set` which collects the type ids together.
After turning on multi-dispatch, I had to refactor existing code which previously dispatched one place, but now dispatches somewhere else. The primary component affected by this is sparse.
* Binary operations (add/sub/mul/div/addmm) now dispatch to sparse kernels even if you did add(dense, sparse). So I delete all the sparse handling code from dense kernels, and bulk up the sparse error handling to handle when the first argument is dense. In the case of addmm, I can eliminate the bridge code entirely (well, not quite: more on this below). I also updated the dispatch on sparse to actually point at sparse kernels. Pay special attention to the handling of `div_` by scalar: previously this logic lived in the "dense" `div_` implementation, but there is actually not any sparse kernel we dispatch to. I solved this particular problem by making a redispatch, but another valid approach would have been to add specific dispatches for sparse div on scalar. This codepath is poorly tested because it is only exercised from C++.
* One minor annoyance is that because I now want separate dispatch for dense and sparse, I also need to replicate the `add`, `add_`, `add_out` trifecta on the sparse side. I opted for a compromise here: I wrote new a new `add_sparse` trifecta, but reused the implementation between CPU and CUDA. This means that I hav to do another dispatch once I get to `add_out`. The alternative would have been to do twice as many copies for CPU and CUDA (thereby eliminating the extra dispatch) but that seemed distinctly not worth it.
* A lot of kernels in sparse assumed that the dispatch argument must be sparse. This is no longer true with dispatch, so I converted the asserts into plain error checking. This also means that we've perturbed the error message in the case of TestSparseOneOff.test_cuda_sparse_cpu_dense_add (I just updated the saved error message)
* `addmm` is a little bit even more special: the bridge code also handled broadcasting. I replicated the broadcasting logic between CPU and CUDA implementations to avoid an extra dispatch.
* `_sparse_addmm` gave me a bit of trouble, because I had forgotten why we had `torch.sparse.addmm` in the first place. But in the end, its changes followed along with the structural changes I made in addmm. I opted for an extra dispatch here for simplicity.
* c10d has some Variable-Tensor confusion in its sparse code. I've worked around it by judiciously inserting "no variable type" guards, but a more correct fix would be to just solve the confusion entirely.
Benchmark:
Apply the following patch to the base commit and this commit:
```
diff --git a/aten/src/ATen/native/Const.cpp b/aten/src/ATen/native/Const.cpp
new file mode 100644
index 0000000000..b66f4d3
--- /dev/null
+++ b/aten/src/ATen/native/Const.cpp
@@ -0,0 +1,10 @@
+#include <ATen/ATen.h>
+
+namespace at {
+namespace native {
+
+Tensor _const5(const Tensor& self, const Tensor& second, const Tensor& third, const Tensor& fourth, const Tensor& fifth) {
+ return self;
+}
+
+}} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index b494ed7..fddae63 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -5878,3 +5878,9 @@
dispatch:
CPU: im2col_backward_cpu
CUDA: im2col_backward_cuda
+
+# For benchmarking
+- func: _const5(Tensor self, Tensor second, Tensor third, Tensor fourth, Tensor fifth) -> Tensor
+ variants: function
+ dispatch:
+ CPU: _const5
```
Comparisons with timeit:
One-argument, representative case:
Before:
```
In [6]: %timeit x.reshape(1, 1)
1.46 µs ± 1.38 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [7]: %timeit x.reshape(1, 1)
1.48 µs ± 29.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [8]: %timeit x.reshape(1, 1)
1.52 µs ± 61.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
```
After:
```
In [3]: %timeit x.reshape(1, 1)
1.42 µs ± 1.34 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [4]: %timeit x.reshape(1, 1)
1.43 µs ± 1.01 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [5]: %timeit x.reshape(1, 1)
1.42 µs ± 0.982 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
```
Five-argument, synthetic case (we expect, with enough Tensor arguments, for there to be a slowdown, as we scale `O(n)` with number of arguments, compared to old dispatcher which is `O(1)` with number of arguments):
Before:
```
In [1]: import torch
In [2]: x = torch.zeros(1)
In [3]: %timeit torch._const5(x, x, x, x, x)
949 ns ± 1.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [4]: %timeit torch._const5(x, x, x, x, x)
954 ns ± 1.96 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [5]: %timeit torch._const5(x, x, x, x, x)
947 ns ± 0.601 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
```
After:
```
In [3]: %timeit torch._const5(x, x, x, x, x)
985 ns ± 9.11 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [4]: %timeit torch._const5(x, x, x, x, x)
984 ns ± 1.17 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [5]: %timeit torch._const5(x, x, x, x, x)
988 ns ± 0.555 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
```
Signed-off-by: Edward Z. Yang <[email protected]>
Differential Revision: [D17265918](https://our.internmc.facebook.com/intern/diff/D17265918)
Instead of considering only the TensorTypeSet of the first argument, we collect all Tensor and TensorList arguments and union them together before computing the dispatch type id.
Billing of changes:
* ATenDispatch fallback code (i.e., what gets run if there is no entry for a function in the table) now lives out-of-line in a function `getFallbackOp`. This gave me an opportunity to write a more detailed error message, providing information about what registrations were available. There is a TODO in the fallback code, suggesting that we could automatically redispatch in the event that there is no handler for the key. But this is a bit of a design question, because it's not clear if automatic redispatch would cover up errors in the dispatch table (i.e., there *should* have been something registered at some key, but there wasn't.)
* Collection of Tensor/TensorList arguments is done using the trusty old IterArgs helper class. A minor bit of refactoring I had to do to get here was move the IterArgs functionality in torch/csrc/utils/variadic.h into ATen/core. There's some refactoring due on that file too (it has copies of some C++ helper pieces which already live in c10--you can't actually move the whole thing because it is literally incompatible with other code in the codebase). So instead of calling `type_set()` to get the type set of the dispatch argument, now we just call `at::detail::multi_dispatch_tensor_type_set` on all of the tensor/tensor list arguments.
* The code generator is adjusted to codegen collection of arguments as needed. There is a little bit of a hack in the code generator to turn 'self' arguments into '*this'. I think this may be duplicated with some logic somewhere else but I have to double check.
The new generated code looks like this:
```
inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) const {
static auto table = globalATenDispatch().getOpTable("aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)");
return table->getOp<Tensor & (Tensor &, const Tensor &, bool)>(at::detail::multi_dispatch_tensor_type_set(*this, src))(const_cast<Tensor&>(*this), src, non_blocking);
}
```
The key difference is that previously we wrote `type_set()` as argument to getOp; now it is a call to `multi_dispatch_tensor_type_set` which collects the type ids together.
After turning on multi-dispatch, I had to refactor existing code which previously dispatched one place, but now dispatches somewhere else. The primary component affected by this is sparse.
* Binary operations (add/sub/mul/div/addmm) now dispatch to sparse kernels even if you did add(dense, sparse). So I delete all the sparse handling code from dense kernels, and bulk up the sparse error handling to handle when the first argument is dense. In the case of addmm, I can eliminate the bridge code entirely (well, not quite: more on this below). I also updated the dispatch on sparse to actually point at sparse kernels. Pay special attention to the handling of `div_` by scalar: previously this logic lived in the "dense" `div_` implementation, but there is actually not any sparse kernel we dispatch to. I solved this particular problem by making a redispatch, but another valid approach would have been to add specific dispatches for sparse div on scalar. This codepath is poorly tested because it is only exercised from C++.
* One minor annoyance is that because I now want separate dispatch for dense and sparse, I also need to replicate the `add`, `add_`, `add_out` trifecta on the sparse side. I opted for a compromise here: I wrote new a new `add_sparse` trifecta, but reused the implementation between CPU and CUDA. This means that I hav to do another dispatch once I get to `add_out`. The alternative would have been to do twice as many copies for CPU and CUDA (thereby eliminating the extra dispatch) but that seemed distinctly not worth it.
* A lot of kernels in sparse assumed that the dispatch argument must be sparse. This is no longer true with dispatch, so I converted the asserts into plain error checking. This also means that we've perturbed the error message in the case of TestSparseOneOff.test_cuda_sparse_cpu_dense_add (I just updated the saved error message)
* `addmm` is a little bit even more special: the bridge code also handled broadcasting. I replicated the broadcasting logic between CPU and CUDA implementations to avoid an extra dispatch.
* `_sparse_addmm` gave me a bit of trouble, because I had forgotten why we had `torch.sparse.addmm` in the first place. But in the end, its changes followed along with the structural changes I made in addmm. I opted for an extra dispatch here for simplicity.
* c10d has some Variable-Tensor confusion in its sparse code. I've worked around it by judiciously inserting "no variable type" guards, but a more correct fix would be to just solve the confusion entirely.
Benchmark:
Apply the following patch to the base commit and this commit:
```
diff --git a/aten/src/ATen/native/Const.cpp b/aten/src/ATen/native/Const.cpp
new file mode 100644
index 0000000000..b66f4d3
--- /dev/null
+++ b/aten/src/ATen/native/Const.cpp
@@ -0,0 +1,10 @@
+#include <ATen/ATen.h>
+
+namespace at {
+namespace native {
+
+Tensor _const5(const Tensor& self, const Tensor& second, const Tensor& third, const Tensor& fourth, const Tensor& fifth) {
+ return self;
+}
+
+}} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index b494ed7..fddae63 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -5878,3 +5878,9 @@
dispatch:
CPU: im2col_backward_cpu
CUDA: im2col_backward_cuda
+
+# For benchmarking
+- func: _const5(Tensor self, Tensor second, Tensor third, Tensor fourth, Tensor fifth) -> Tensor
+ variants: function
+ dispatch:
+ CPU: _const5
```
Comparisons with timeit:
One-argument, representative case:
Before:
```
In [6]: %timeit x.reshape(1, 1)
1.46 µs ± 1.38 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [7]: %timeit x.reshape(1, 1)
1.48 µs ± 29.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [8]: %timeit x.reshape(1, 1)
1.52 µs ± 61.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
```
After:
```
In [3]: %timeit x.reshape(1, 1)
1.42 µs ± 1.34 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [4]: %timeit x.reshape(1, 1)
1.43 µs ± 1.01 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [5]: %timeit x.reshape(1, 1)
1.42 µs ± 0.982 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
```
Five-argument, synthetic case (we expect, with enough Tensor arguments, for there to be a slowdown, as we scale `O(n)` with number of arguments, compared to old dispatcher which is `O(1)` with number of arguments):
Before:
```
In [1]: import torch
In [2]: x = torch.zeros(1)
In [3]: %timeit torch._const5(x, x, x, x, x)
949 ns ± 1.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [4]: %timeit torch._const5(x, x, x, x, x)
954 ns ± 1.96 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [5]: %timeit torch._const5(x, x, x, x, x)
947 ns ± 0.601 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
```
After:
```
In [3]: %timeit torch._const5(x, x, x, x, x)
985 ns ± 9.11 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [4]: %timeit torch._const5(x, x, x, x, x)
984 ns ± 1.17 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [5]: %timeit torch._const5(x, x, x, x, x)
988 ns ± 0.555 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
```
Signed-off-by: Edward Z. Yang <[email protected]>
Differential Revision: [D17265918](https://our.internmc.facebook.com/intern/diff/D17265918)
Instead of considering only the TensorTypeSet of the first argument, we collect all Tensor and TensorList arguments and union them together before computing the dispatch type id. XLA companion patch at pytorch/xla#1031 Billing of changes: * ATenDispatch fallback code (i.e., what gets run if there is no entry for a function in the table) now lives out-of-line in a function `getFallbackOp`. This gave me an opportunity to write a more detailed error message, providing information about what registrations were available. There is a TODO in the fallback code, suggesting that we could automatically redispatch in the event that there is no handler for the key. But this is a bit of a design question, because it's not clear if automatic redispatch would cover up errors in the dispatch table (i.e., there *should* have been something registered at some key, but there wasn't.) * Collection of Tensor/TensorList arguments is done using the trusty old IterArgs helper class. A minor bit of refactoring I had to do to get here was move the IterArgs functionality in torch/csrc/utils/variadic.h into ATen/core. There's some refactoring due on that file too (it has copies of some C++ helper pieces which already live in c10--you can't actually move the whole thing because it is literally incompatible with other code in the codebase). So instead of calling `type_set()` to get the type set of the dispatch argument, now we just call `at::detail::multi_dispatch_tensor_type_set` on all of the tensor/tensor list arguments. * The code generator is adjusted to codegen collection of arguments as needed. There is a little bit of a hack in the code generator to turn 'self' arguments into '*this'. I think this may be duplicated with some logic somewhere else but I have to double check. The new generated code looks like this: ``` inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) const { static auto table = globalATenDispatch().getOpTable("aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)"); return table->getOp<Tensor & (Tensor &, const Tensor &, bool)>(at::detail::multi_dispatch_tensor_type_set(*this, src))(const_cast<Tensor&>(*this), src, non_blocking); } ``` The key difference is that previously we wrote `type_set()` as argument to getOp; now it is a call to `multi_dispatch_tensor_type_set` which collects the type ids together. After turning on multi-dispatch, I had to refactor existing code which previously dispatched one place, but now dispatches somewhere else. The primary component affected by this is sparse. * Binary operations (add/sub/mul/div/addmm) now dispatch to sparse kernels even if you did add(dense, sparse). So I delete all the sparse handling code from dense kernels, and bulk up the sparse error handling to handle when the first argument is dense. In the case of addmm, I can eliminate the bridge code entirely (well, not quite: more on this below). I also updated the dispatch on sparse to actually point at sparse kernels. Pay special attention to the handling of `div_` by scalar: previously this logic lived in the "dense" `div_` implementation, but there is actually not any sparse kernel we dispatch to. I solved this particular problem by making a redispatch, but another valid approach would have been to add specific dispatches for sparse div on scalar. This codepath is poorly tested because it is only exercised from C++. * One minor annoyance is that because I now want separate dispatch for dense and sparse, I also need to replicate the `add`, `add_`, `add_out` trifecta on the sparse side. I opted for a compromise here: I wrote new a new `add_sparse` trifecta, but reused the implementation between CPU and CUDA. This means that I hav to do another dispatch once I get to `add_out`. The alternative would have been to do twice as many copies for CPU and CUDA (thereby eliminating the extra dispatch) but that seemed distinctly not worth it. * A lot of kernels in sparse assumed that the dispatch argument must be sparse. This is no longer true with dispatch, so I converted the asserts into plain error checking. This also means that we've perturbed the error message in the case of TestSparseOneOff.test_cuda_sparse_cpu_dense_add (I just updated the saved error message) * `addmm` is a little bit even more special: the bridge code also handled broadcasting. I replicated the broadcasting logic between CPU and CUDA implementations to avoid an extra dispatch. * `_sparse_addmm` gave me a bit of trouble, because I had forgotten why we had `torch.sparse.addmm` in the first place. But in the end, its changes followed along with the structural changes I made in addmm. I opted for an extra dispatch here for simplicity. * c10d has some Variable-Tensor confusion in its sparse code. I've worked around it by judiciously inserting "no variable type" guards, but a more correct fix would be to just solve the confusion entirely. Benchmark: Apply the following patch to the base commit and this commit: ``` diff --git a/aten/src/ATen/native/Const.cpp b/aten/src/ATen/native/Const.cpp new file mode 100644 index 0000000000..b66f4d3 --- /dev/null +++ b/aten/src/ATen/native/Const.cpp @@ -0,0 +1,10 @@ +#include <ATen/ATen.h> + +namespace at { +namespace native { + +Tensor _const5(const Tensor& self, const Tensor& second, const Tensor& third, const Tensor& fourth, const Tensor& fifth) { + return self; +} + +}} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b494ed7..fddae63 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5878,3 +5878,9 @@ dispatch: CPU: im2col_backward_cpu CUDA: im2col_backward_cuda + +# For benchmarking +- func: _const5(Tensor self, Tensor second, Tensor third, Tensor fourth, Tensor fifth) -> Tensor + variants: function + dispatch: + CPU: _const5 ``` Comparisons with timeit: One-argument, representative case: Before: ``` In [6]: %timeit x.reshape(1, 1) 1.46 µs ± 1.38 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [7]: %timeit x.reshape(1, 1) 1.48 µs ± 29.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [8]: %timeit x.reshape(1, 1) 1.52 µs ± 61.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` After: ``` In [3]: %timeit x.reshape(1, 1) 1.42 µs ± 1.34 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [4]: %timeit x.reshape(1, 1) 1.43 µs ± 1.01 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [5]: %timeit x.reshape(1, 1) 1.42 µs ± 0.982 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` Five-argument, synthetic case (we expect, with enough Tensor arguments, for there to be a slowdown, as we scale `O(n)` with number of arguments, compared to old dispatcher which is `O(1)` with number of arguments): Before: ``` In [1]: import torch In [2]: x = torch.zeros(1) In [3]: %timeit torch._const5(x, x, x, x, x) 949 ns ± 1.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [4]: %timeit torch._const5(x, x, x, x, x) 954 ns ± 1.96 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [5]: %timeit torch._const5(x, x, x, x, x) 947 ns ± 0.601 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` After: ``` In [3]: %timeit torch._const5(x, x, x, x, x) 985 ns ± 9.11 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [4]: %timeit torch._const5(x, x, x, x, x) 984 ns ± 1.17 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [5]: %timeit torch._const5(x, x, x, x, x) 988 ns ± 0.555 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` Signed-off-by: Edward Z. Yang <[email protected]> Differential Revision: [D17265918](https://our.internmc.facebook.com/intern/diff/D17265918)
zdevito
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am backed up, but I skimmed this quickly and it generally looks fine. Relying on @smessmer for more careful review.
dzhulgakov
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work!!!
| ts = ts | x.type_set(); | ||
| } | ||
| void operator()(at::ArrayRef<at::Tensor> xs) { | ||
| for (const auto& x : xs) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so is it official that dispatch looks into tensor lists beyond the first element?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it has to. Suppose you have lazy tensor and you stack a some normal tensors with lazy tensor; you have to get to the lazy tensor version in that case.
Instead of considering only the TensorTypeSet of the first argument, we collect all Tensor and TensorList arguments and union them together before computing the dispatch type id. XLA companion patch at pytorch/xla#1031 Billing of changes: * ATenDispatch fallback code (i.e., what gets run if there is no entry for a function in the table) now lives out-of-line in a function `getFallbackOp`. This gave me an opportunity to write a more detailed error message, providing information about what registrations were available. There is a TODO in the fallback code, suggesting that we could automatically redispatch in the event that there is no handler for the key. But this is a bit of a design question, because it's not clear if automatic redispatch would cover up errors in the dispatch table (i.e., there *should* have been something registered at some key, but there wasn't.) * Collection of Tensor/TensorList arguments is done using the trusty old IterArgs helper class. A minor bit of refactoring I had to do to get here was move the IterArgs functionality in torch/csrc/utils/variadic.h into ATen/core. There's some refactoring due on that file too (it has copies of some C++ helper pieces which already live in c10--you can't actually move the whole thing because it is literally incompatible with other code in the codebase). So instead of calling `type_set()` to get the type set of the dispatch argument, now we just call `at::detail::multi_dispatch_tensor_type_set` on all of the tensor/tensor list arguments. * The code generator is adjusted to codegen collection of arguments as needed. There is a little bit of a hack in the code generator to turn 'self' arguments into '*this'. I think this may be duplicated with some logic somewhere else but I have to double check. The new generated code looks like this: ``` inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) const { static auto table = globalATenDispatch().getOpTable("aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)"); return table->getOp<Tensor & (Tensor &, const Tensor &, bool)>(at::detail::multi_dispatch_tensor_type_set(*this, src))(const_cast<Tensor&>(*this), src, non_blocking); } ``` The key difference is that previously we wrote `type_set()` as argument to getOp; now it is a call to `multi_dispatch_tensor_type_set` which collects the type ids together. After turning on multi-dispatch, I had to refactor existing code which previously dispatched one place, but now dispatches somewhere else. The primary component affected by this is sparse. * Binary operations (add/sub/mul/div/addmm) now dispatch to sparse kernels even if you did add(dense, sparse). So I delete all the sparse handling code from dense kernels, and bulk up the sparse error handling to handle when the first argument is dense. In the case of addmm, I can eliminate the bridge code entirely (well, not quite: more on this below). I also updated the dispatch on sparse to actually point at sparse kernels. Pay special attention to the handling of `div_` by scalar: previously this logic lived in the "dense" `div_` implementation, but there is actually not any sparse kernel we dispatch to. I solved this particular problem by making a redispatch, but another valid approach would have been to add specific dispatches for sparse div on scalar. This codepath is poorly tested because it is only exercised from C++. * One minor annoyance is that because I now want separate dispatch for dense and sparse, I also need to replicate the `add`, `add_`, `add_out` trifecta on the sparse side. I opted for a compromise here: I wrote new a new `add_sparse` trifecta, but reused the implementation between CPU and CUDA. This means that I hav to do another dispatch once I get to `add_out`. The alternative would have been to do twice as many copies for CPU and CUDA (thereby eliminating the extra dispatch) but that seemed distinctly not worth it. * A lot of kernels in sparse assumed that the dispatch argument must be sparse. This is no longer true with dispatch, so I converted the asserts into plain error checking. This also means that we've perturbed the error message in the case of TestSparseOneOff.test_cuda_sparse_cpu_dense_add (I just updated the saved error message) * `addmm` is a little bit even more special: the bridge code also handled broadcasting. I replicated the broadcasting logic between CPU and CUDA implementations to avoid an extra dispatch. * `_sparse_addmm` gave me a bit of trouble, because I had forgotten why we had `torch.sparse.addmm` in the first place. But in the end, its changes followed along with the structural changes I made in addmm. I opted for an extra dispatch here for simplicity. * c10d has some Variable-Tensor confusion in its sparse code. I've worked around it by judiciously inserting "no variable type" guards, but a more correct fix would be to just solve the confusion entirely. Benchmark: Apply the following patch to the base commit and this commit: ``` diff --git a/aten/src/ATen/native/Const.cpp b/aten/src/ATen/native/Const.cpp new file mode 100644 index 0000000000..b66f4d3 --- /dev/null +++ b/aten/src/ATen/native/Const.cpp @@ -0,0 +1,10 @@ +#include <ATen/ATen.h> + +namespace at { +namespace native { + +Tensor _const5(const Tensor& self, const Tensor& second, const Tensor& third, const Tensor& fourth, const Tensor& fifth) { + return self; +} + +}} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b494ed7..fddae63 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5878,3 +5878,9 @@ dispatch: CPU: im2col_backward_cpu CUDA: im2col_backward_cuda + +# For benchmarking +- func: _const5(Tensor self, Tensor second, Tensor third, Tensor fourth, Tensor fifth) -> Tensor + variants: function + dispatch: + CPU: _const5 ``` Comparisons with timeit: One-argument, representative case: Before: ``` In [6]: %timeit x.reshape(1, 1) 1.46 µs ± 1.38 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [7]: %timeit x.reshape(1, 1) 1.48 µs ± 29.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [8]: %timeit x.reshape(1, 1) 1.52 µs ± 61.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` After: ``` In [3]: %timeit x.reshape(1, 1) 1.42 µs ± 1.34 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [4]: %timeit x.reshape(1, 1) 1.43 µs ± 1.01 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [5]: %timeit x.reshape(1, 1) 1.42 µs ± 0.982 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` Five-argument, synthetic case (we expect, with enough Tensor arguments, for there to be a slowdown, as we scale `O(n)` with number of arguments, compared to old dispatcher which is `O(1)` with number of arguments): Before: ``` In [1]: import torch In [2]: x = torch.zeros(1) In [3]: %timeit torch._const5(x, x, x, x, x) 949 ns ± 1.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [4]: %timeit torch._const5(x, x, x, x, x) 954 ns ± 1.96 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [5]: %timeit torch._const5(x, x, x, x, x) 947 ns ± 0.601 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` After: ``` In [3]: %timeit torch._const5(x, x, x, x, x) 985 ns ± 9.11 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [4]: %timeit torch._const5(x, x, x, x, x) 984 ns ± 1.17 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [5]: %timeit torch._const5(x, x, x, x, x) 988 ns ± 0.555 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` Signed-off-by: Edward Z. Yang <[email protected]> Differential Revision: [D17265918](https://our.internmc.facebook.com/intern/diff/D17265918) [ghstack-poisoned]
Instead of considering only the TensorTypeSet of the first argument, we collect all Tensor and TensorList arguments and union them together before computing the dispatch type id. XLA companion patch at pytorch/xla#1031 Billing of changes: * ATenDispatch fallback code (i.e., what gets run if there is no entry for a function in the table) now lives out-of-line in a function `getFallbackOp`. This gave me an opportunity to write a more detailed error message, providing information about what registrations were available. There is a TODO in the fallback code, suggesting that we could automatically redispatch in the event that there is no handler for the key. But this is a bit of a design question, because it's not clear if automatic redispatch would cover up errors in the dispatch table (i.e., there *should* have been something registered at some key, but there wasn't.) * Collection of Tensor/TensorList arguments is done using the trusty old IterArgs helper class. A minor bit of refactoring I had to do to get here was move the IterArgs functionality in torch/csrc/utils/variadic.h into ATen/core. There's some refactoring due on that file too (it has copies of some C++ helper pieces which already live in c10--you can't actually move the whole thing because it is literally incompatible with other code in the codebase). So instead of calling `type_set()` to get the type set of the dispatch argument, now we just call `at::detail::multi_dispatch_tensor_type_set` on all of the tensor/tensor list arguments. * The code generator is adjusted to codegen collection of arguments as needed. There is a little bit of a hack in the code generator to turn 'self' arguments into '*this'. I think this may be duplicated with some logic somewhere else but I have to double check. The new generated code looks like this: ``` inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) const { static auto table = globalATenDispatch().getOpTable("aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)"); return table->getOp<Tensor & (Tensor &, const Tensor &, bool)>(at::detail::multi_dispatch_tensor_type_set(*this, src))(const_cast<Tensor&>(*this), src, non_blocking); } ``` The key difference is that previously we wrote `type_set()` as argument to getOp; now it is a call to `multi_dispatch_tensor_type_set` which collects the type ids together. After turning on multi-dispatch, I had to refactor existing code which previously dispatched one place, but now dispatches somewhere else. The primary component affected by this is sparse. * Binary operations (add/sub/mul/div/addmm) now dispatch to sparse kernels even if you did add(dense, sparse). So I delete all the sparse handling code from dense kernels, and bulk up the sparse error handling to handle when the first argument is dense. In the case of addmm, I can eliminate the bridge code entirely (well, not quite: more on this below). I also updated the dispatch on sparse to actually point at sparse kernels. Pay special attention to the handling of `div_` by scalar: previously this logic lived in the "dense" `div_` implementation, but there is actually not any sparse kernel we dispatch to. I solved this particular problem by making a redispatch, but another valid approach would have been to add specific dispatches for sparse div on scalar. This codepath is poorly tested because it is only exercised from C++. * One minor annoyance is that because I now want separate dispatch for dense and sparse, I also need to replicate the `add`, `add_`, `add_out` trifecta on the sparse side. I opted for a compromise here: I wrote new a new `add_sparse` trifecta, but reused the implementation between CPU and CUDA. This means that I hav to do another dispatch once I get to `add_out`. The alternative would have been to do twice as many copies for CPU and CUDA (thereby eliminating the extra dispatch) but that seemed distinctly not worth it. * A lot of kernels in sparse assumed that the dispatch argument must be sparse. This is no longer true with dispatch, so I converted the asserts into plain error checking. This also means that we've perturbed the error message in the case of TestSparseOneOff.test_cuda_sparse_cpu_dense_add (I just updated the saved error message) * `addmm` is a little bit even more special: the bridge code also handled broadcasting. I replicated the broadcasting logic between CPU and CUDA implementations to avoid an extra dispatch. * `_sparse_addmm` gave me a bit of trouble, because I had forgotten why we had `torch.sparse.addmm` in the first place. But in the end, its changes followed along with the structural changes I made in addmm. I opted for an extra dispatch here for simplicity. * c10d has some Variable-Tensor confusion in its sparse code. I've worked around it by judiciously inserting "no variable type" guards, but a more correct fix would be to just solve the confusion entirely. Benchmark: Apply the following patch to the base commit and this commit: ``` diff --git a/aten/src/ATen/native/Const.cpp b/aten/src/ATen/native/Const.cpp new file mode 100644 index 0000000000..b66f4d3 --- /dev/null +++ b/aten/src/ATen/native/Const.cpp @@ -0,0 +1,10 @@ +#include <ATen/ATen.h> + +namespace at { +namespace native { + +Tensor _const5(const Tensor& self, const Tensor& second, const Tensor& third, const Tensor& fourth, const Tensor& fifth) { + return self; +} + +}} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b494ed7..fddae63 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5878,3 +5878,9 @@ dispatch: CPU: im2col_backward_cpu CUDA: im2col_backward_cuda + +# For benchmarking +- func: _const5(Tensor self, Tensor second, Tensor third, Tensor fourth, Tensor fifth) -> Tensor + variants: function + dispatch: + CPU: _const5 ``` Comparisons with timeit: One-argument, representative case: Before: ``` In [6]: %timeit x.reshape(1, 1) 1.46 µs ± 1.38 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [7]: %timeit x.reshape(1, 1) 1.48 µs ± 29.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [8]: %timeit x.reshape(1, 1) 1.52 µs ± 61.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` After: ``` In [3]: %timeit x.reshape(1, 1) 1.42 µs ± 1.34 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [4]: %timeit x.reshape(1, 1) 1.43 µs ± 1.01 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [5]: %timeit x.reshape(1, 1) 1.42 µs ± 0.982 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` Five-argument, synthetic case (we expect, with enough Tensor arguments, for there to be a slowdown, as we scale `O(n)` with number of arguments, compared to old dispatcher which is `O(1)` with number of arguments): Before: ``` In [1]: import torch In [2]: x = torch.zeros(1) In [3]: %timeit torch._const5(x, x, x, x, x) 949 ns ± 1.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [4]: %timeit torch._const5(x, x, x, x, x) 954 ns ± 1.96 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [5]: %timeit torch._const5(x, x, x, x, x) 947 ns ± 0.601 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` After: ``` In [3]: %timeit torch._const5(x, x, x, x, x) 985 ns ± 9.11 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [4]: %timeit torch._const5(x, x, x, x, x) 984 ns ± 1.17 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [5]: %timeit torch._const5(x, x, x, x, x) 988 ns ± 0.555 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` Signed-off-by: Edward Z. Yang <[email protected]> Differential Revision: [D17265918](https://our.internmc.facebook.com/intern/diff/D17265918) [ghstack-poisoned]
Instead of considering only the TensorTypeSet of the first argument, we collect all Tensor and TensorList arguments and union them together before computing the dispatch type id. XLA companion patch at pytorch/xla#1031 Billing of changes: * ATenDispatch fallback code (i.e., what gets run if there is no entry for a function in the table) now lives out-of-line in a function `getFallbackOp`. This gave me an opportunity to write a more detailed error message, providing information about what registrations were available. There is a TODO in the fallback code, suggesting that we could automatically redispatch in the event that there is no handler for the key. But this is a bit of a design question, because it's not clear if automatic redispatch would cover up errors in the dispatch table (i.e., there *should* have been something registered at some key, but there wasn't.) * Collection of Tensor/TensorList arguments is done using the trusty old IterArgs helper class. A minor bit of refactoring I had to do to get here was move the IterArgs functionality in torch/csrc/utils/variadic.h into ATen/core. There's some refactoring due on that file too (it has copies of some C++ helper pieces which already live in c10--you can't actually move the whole thing because it is literally incompatible with other code in the codebase). So instead of calling `type_set()` to get the type set of the dispatch argument, now we just call `at::detail::multi_dispatch_tensor_type_set` on all of the tensor/tensor list arguments. * The code generator is adjusted to codegen collection of arguments as needed. There is a little bit of a hack in the code generator to turn 'self' arguments into '*this'. I think this may be duplicated with some logic somewhere else but I have to double check. The new generated code looks like this: ``` inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) const { static auto table = globalATenDispatch().getOpTable("aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)"); return table->getOp<Tensor & (Tensor &, const Tensor &, bool)>(at::detail::multi_dispatch_tensor_type_set(*this, src))(const_cast<Tensor&>(*this), src, non_blocking); } ``` The key difference is that previously we wrote `type_set()` as argument to getOp; now it is a call to `multi_dispatch_tensor_type_set` which collects the type ids together. After turning on multi-dispatch, I had to refactor existing code which previously dispatched one place, but now dispatches somewhere else. The primary component affected by this is sparse. * Binary operations (add/sub/mul/div/addmm) now dispatch to sparse kernels even if you did add(dense, sparse). So I delete all the sparse handling code from dense kernels, and bulk up the sparse error handling to handle when the first argument is dense. In the case of addmm, I can eliminate the bridge code entirely (well, not quite: more on this below). I also updated the dispatch on sparse to actually point at sparse kernels. Pay special attention to the handling of `div_` by scalar: previously this logic lived in the "dense" `div_` implementation, but there is actually not any sparse kernel we dispatch to. I solved this particular problem by making a redispatch, but another valid approach would have been to add specific dispatches for sparse div on scalar. This codepath is poorly tested because it is only exercised from C++. * One minor annoyance is that because I now want separate dispatch for dense and sparse, I also need to replicate the `add`, `add_`, `add_out` trifecta on the sparse side. I opted for a compromise here: I wrote new a new `add_sparse` trifecta, but reused the implementation between CPU and CUDA. This means that I hav to do another dispatch once I get to `add_out`. The alternative would have been to do twice as many copies for CPU and CUDA (thereby eliminating the extra dispatch) but that seemed distinctly not worth it. * A lot of kernels in sparse assumed that the dispatch argument must be sparse. This is no longer true with dispatch, so I converted the asserts into plain error checking. This also means that we've perturbed the error message in the case of TestSparseOneOff.test_cuda_sparse_cpu_dense_add (I just updated the saved error message) * `addmm` is a little bit even more special: the bridge code also handled broadcasting. I replicated the broadcasting logic between CPU and CUDA implementations to avoid an extra dispatch. * `_sparse_addmm` gave me a bit of trouble, because I had forgotten why we had `torch.sparse.addmm` in the first place. But in the end, its changes followed along with the structural changes I made in addmm. I opted for an extra dispatch here for simplicity. * c10d has some Variable-Tensor confusion in its sparse code. I've worked around it by judiciously inserting "no variable type" guards, but a more correct fix would be to just solve the confusion entirely. Benchmark: Apply the following patch to the base commit and this commit: ``` diff --git a/aten/src/ATen/native/Const.cpp b/aten/src/ATen/native/Const.cpp new file mode 100644 index 0000000000..b66f4d3 --- /dev/null +++ b/aten/src/ATen/native/Const.cpp @@ -0,0 +1,10 @@ +#include <ATen/ATen.h> + +namespace at { +namespace native { + +Tensor _const5(const Tensor& self, const Tensor& second, const Tensor& third, const Tensor& fourth, const Tensor& fifth) { + return self; +} + +}} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b494ed7..fddae63 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5878,3 +5878,9 @@ dispatch: CPU: im2col_backward_cpu CUDA: im2col_backward_cuda + +# For benchmarking +- func: _const5(Tensor self, Tensor second, Tensor third, Tensor fourth, Tensor fifth) -> Tensor + variants: function + dispatch: + CPU: _const5 ``` Comparisons with timeit: One-argument, representative case: Before: ``` In [6]: %timeit x.reshape(1, 1) 1.46 µs ± 1.38 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [7]: %timeit x.reshape(1, 1) 1.48 µs ± 29.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [8]: %timeit x.reshape(1, 1) 1.52 µs ± 61.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` After: ``` In [3]: %timeit x.reshape(1, 1) 1.42 µs ± 1.34 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [4]: %timeit x.reshape(1, 1) 1.43 µs ± 1.01 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [5]: %timeit x.reshape(1, 1) 1.42 µs ± 0.982 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` Five-argument, synthetic case (we expect, with enough Tensor arguments, for there to be a slowdown, as we scale `O(n)` with number of arguments, compared to old dispatcher which is `O(1)` with number of arguments): Before: ``` In [1]: import torch In [2]: x = torch.zeros(1) In [3]: %timeit torch._const5(x, x, x, x, x) 949 ns ± 1.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [4]: %timeit torch._const5(x, x, x, x, x) 954 ns ± 1.96 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [5]: %timeit torch._const5(x, x, x, x, x) 947 ns ± 0.601 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` After: ``` In [3]: %timeit torch._const5(x, x, x, x, x) 985 ns ± 9.11 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [4]: %timeit torch._const5(x, x, x, x, x) 984 ns ± 1.17 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [5]: %timeit torch._const5(x, x, x, x, x) 988 ns ± 0.555 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` Signed-off-by: Edward Z. Yang <[email protected]> Differential Revision: [D17265918](https://our.internmc.facebook.com/intern/diff/D17265918) [ghstack-poisoned]
Instead of considering only the TensorTypeSet of the first argument, we collect all Tensor and TensorList arguments and union them together before computing the dispatch type id. XLA companion patch at pytorch/xla#1031 Billing of changes: * ATenDispatch fallback code (i.e., what gets run if there is no entry for a function in the table) now lives out-of-line in a function `getFallbackOp`. This gave me an opportunity to write a more detailed error message, providing information about what registrations were available. There is a TODO in the fallback code, suggesting that we could automatically redispatch in the event that there is no handler for the key. But this is a bit of a design question, because it's not clear if automatic redispatch would cover up errors in the dispatch table (i.e., there *should* have been something registered at some key, but there wasn't.) * Collection of Tensor/TensorList arguments is done using the trusty old IterArgs helper class. A minor bit of refactoring I had to do to get here was move the IterArgs functionality in torch/csrc/utils/variadic.h into ATen/core. There's some refactoring due on that file too (it has copies of some C++ helper pieces which already live in c10--you can't actually move the whole thing because it is literally incompatible with other code in the codebase). So instead of calling `type_set()` to get the type set of the dispatch argument, now we just call `at::detail::multi_dispatch_tensor_type_set` on all of the tensor/tensor list arguments. * The code generator is adjusted to codegen collection of arguments as needed. There is a little bit of a hack in the code generator to turn 'self' arguments into '*this'. I think this may be duplicated with some logic somewhere else but I have to double check. The new generated code looks like this: ``` inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) const { static auto table = globalATenDispatch().getOpTable("aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)"); return table->getOp<Tensor & (Tensor &, const Tensor &, bool)>(at::detail::multi_dispatch_tensor_type_set(*this, src))(const_cast<Tensor&>(*this), src, non_blocking); } ``` The key difference is that previously we wrote `type_set()` as argument to getOp; now it is a call to `multi_dispatch_tensor_type_set` which collects the type ids together. After turning on multi-dispatch, I had to refactor existing code which previously dispatched one place, but now dispatches somewhere else. The primary component affected by this is sparse. * Binary operations (add/sub/mul/div/addmm) now dispatch to sparse kernels even if you did add(dense, sparse). So I delete all the sparse handling code from dense kernels, and bulk up the sparse error handling to handle when the first argument is dense. In the case of addmm, I can eliminate the bridge code entirely (well, not quite: more on this below). I also updated the dispatch on sparse to actually point at sparse kernels. Pay special attention to the handling of `div_` by scalar: previously this logic lived in the "dense" `div_` implementation, but there is actually not any sparse kernel we dispatch to. I solved this particular problem by making a redispatch, but another valid approach would have been to add specific dispatches for sparse div on scalar. This codepath is poorly tested because it is only exercised from C++. * One minor annoyance is that because I now want separate dispatch for dense and sparse, I also need to replicate the `add`, `add_`, `add_out` trifecta on the sparse side. I opted for a compromise here: I wrote new a new `add_sparse` trifecta, but reused the implementation between CPU and CUDA. This means that I hav to do another dispatch once I get to `add_out`. The alternative would have been to do twice as many copies for CPU and CUDA (thereby eliminating the extra dispatch) but that seemed distinctly not worth it. * A lot of kernels in sparse assumed that the dispatch argument must be sparse. This is no longer true with dispatch, so I converted the asserts into plain error checking. This also means that we've perturbed the error message in the case of TestSparseOneOff.test_cuda_sparse_cpu_dense_add (I just updated the saved error message) * `addmm` is a little bit even more special: the bridge code also handled broadcasting. I replicated the broadcasting logic between CPU and CUDA implementations to avoid an extra dispatch. * `_sparse_addmm` gave me a bit of trouble, because I had forgotten why we had `torch.sparse.addmm` in the first place. But in the end, its changes followed along with the structural changes I made in addmm. I opted for an extra dispatch here for simplicity. * c10d has some Variable-Tensor confusion in its sparse code. I've worked around it by judiciously inserting "no variable type" guards, but a more correct fix would be to just solve the confusion entirely. Benchmark: Apply the following patch to the base commit and this commit: ``` diff --git a/aten/src/ATen/native/Const.cpp b/aten/src/ATen/native/Const.cpp new file mode 100644 index 0000000000..b66f4d3 --- /dev/null +++ b/aten/src/ATen/native/Const.cpp @@ -0,0 +1,10 @@ +#include <ATen/ATen.h> + +namespace at { +namespace native { + +Tensor _const5(const Tensor& self, const Tensor& second, const Tensor& third, const Tensor& fourth, const Tensor& fifth) { + return self; +} + +}} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b494ed7..fddae63 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5878,3 +5878,9 @@ dispatch: CPU: im2col_backward_cpu CUDA: im2col_backward_cuda + +# For benchmarking +- func: _const5(Tensor self, Tensor second, Tensor third, Tensor fourth, Tensor fifth) -> Tensor + variants: function + dispatch: + CPU: _const5 ``` Comparisons with timeit: One-argument, representative case: Before: ``` In [6]: %timeit x.reshape(1, 1) 1.46 µs ± 1.38 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [7]: %timeit x.reshape(1, 1) 1.48 µs ± 29.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [8]: %timeit x.reshape(1, 1) 1.52 µs ± 61.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` After: ``` In [3]: %timeit x.reshape(1, 1) 1.42 µs ± 1.34 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [4]: %timeit x.reshape(1, 1) 1.43 µs ± 1.01 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [5]: %timeit x.reshape(1, 1) 1.42 µs ± 0.982 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` Five-argument, synthetic case (we expect, with enough Tensor arguments, for there to be a slowdown, as we scale `O(n)` with number of arguments, compared to old dispatcher which is `O(1)` with number of arguments): Before: ``` In [1]: import torch In [2]: x = torch.zeros(1) In [3]: %timeit torch._const5(x, x, x, x, x) 949 ns ± 1.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [4]: %timeit torch._const5(x, x, x, x, x) 954 ns ± 1.96 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [5]: %timeit torch._const5(x, x, x, x, x) 947 ns ± 0.601 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` After: ``` In [3]: %timeit torch._const5(x, x, x, x, x) 985 ns ± 9.11 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [4]: %timeit torch._const5(x, x, x, x, x) 984 ns ± 1.17 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [5]: %timeit torch._const5(x, x, x, x, x) 988 ns ± 0.555 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` Signed-off-by: Edward Z. Yang <[email protected]> Differential Revision: [D17265918](https://our.internmc.facebook.com/intern/diff/D17265918) [ghstack-poisoned]
Instead of considering only the TensorTypeSet of the first argument, we collect all Tensor and TensorList arguments and union them together before computing the dispatch type id. XLA companion patch at pytorch/xla#1031 Billing of changes: * ATenDispatch fallback code (i.e., what gets run if there is no entry for a function in the table) now lives out-of-line in a function `getFallbackOp`. This gave me an opportunity to write a more detailed error message, providing information about what registrations were available. There is a TODO in the fallback code, suggesting that we could automatically redispatch in the event that there is no handler for the key. But this is a bit of a design question, because it's not clear if automatic redispatch would cover up errors in the dispatch table (i.e., there *should* have been something registered at some key, but there wasn't.) * Collection of Tensor/TensorList arguments is done using the trusty old IterArgs helper class. A minor bit of refactoring I had to do to get here was move the IterArgs functionality in torch/csrc/utils/variadic.h into ATen/core. There's some refactoring due on that file too (it has copies of some C++ helper pieces which already live in c10--you can't actually move the whole thing because it is literally incompatible with other code in the codebase). So instead of calling `type_set()` to get the type set of the dispatch argument, now we just call `at::detail::multi_dispatch_tensor_type_set` on all of the tensor/tensor list arguments. * The code generator is adjusted to codegen collection of arguments as needed. There is a little bit of a hack in the code generator to turn 'self' arguments into '*this'. I think this may be duplicated with some logic somewhere else but I have to double check. The new generated code looks like this: ``` inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) const { static auto table = globalATenDispatch().getOpTable("aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)"); return table->getOp<Tensor & (Tensor &, const Tensor &, bool)>(at::detail::multi_dispatch_tensor_type_set(*this, src))(const_cast<Tensor&>(*this), src, non_blocking); } ``` The key difference is that previously we wrote `type_set()` as argument to getOp; now it is a call to `multi_dispatch_tensor_type_set` which collects the type ids together. After turning on multi-dispatch, I had to refactor existing code which previously dispatched one place, but now dispatches somewhere else. The primary component affected by this is sparse. * Binary operations (add/sub/mul/div/addmm) now dispatch to sparse kernels even if you did add(dense, sparse). So I delete all the sparse handling code from dense kernels, and bulk up the sparse error handling to handle when the first argument is dense. In the case of addmm, I can eliminate the bridge code entirely (well, not quite: more on this below). I also updated the dispatch on sparse to actually point at sparse kernels. Pay special attention to the handling of `div_` by scalar: previously this logic lived in the "dense" `div_` implementation, but there is actually not any sparse kernel we dispatch to. I solved this particular problem by making a redispatch, but another valid approach would have been to add specific dispatches for sparse div on scalar. This codepath is poorly tested because it is only exercised from C++. * One minor annoyance is that because I now want separate dispatch for dense and sparse, I also need to replicate the `add`, `add_`, `add_out` trifecta on the sparse side. I opted for a compromise here: I wrote new a new `add_sparse` trifecta, but reused the implementation between CPU and CUDA. This means that I hav to do another dispatch once I get to `add_out`. The alternative would have been to do twice as many copies for CPU and CUDA (thereby eliminating the extra dispatch) but that seemed distinctly not worth it. * A lot of kernels in sparse assumed that the dispatch argument must be sparse. This is no longer true with dispatch, so I converted the asserts into plain error checking. This also means that we've perturbed the error message in the case of TestSparseOneOff.test_cuda_sparse_cpu_dense_add (I just updated the saved error message) * `addmm` is a little bit even more special: the bridge code also handled broadcasting. I replicated the broadcasting logic between CPU and CUDA implementations to avoid an extra dispatch. * `_sparse_addmm` gave me a bit of trouble, because I had forgotten why we had `torch.sparse.addmm` in the first place. But in the end, its changes followed along with the structural changes I made in addmm. I opted for an extra dispatch here for simplicity. * c10d has some Variable-Tensor confusion in its sparse code. I've worked around it by judiciously inserting "no variable type" guards, but a more correct fix would be to just solve the confusion entirely. Benchmark: Apply the following patch to the base commit and this commit: ``` diff --git a/aten/src/ATen/native/Const.cpp b/aten/src/ATen/native/Const.cpp new file mode 100644 index 0000000000..b66f4d3 --- /dev/null +++ b/aten/src/ATen/native/Const.cpp @@ -0,0 +1,10 @@ +#include <ATen/ATen.h> + +namespace at { +namespace native { + +Tensor _const5(const Tensor& self, const Tensor& second, const Tensor& third, const Tensor& fourth, const Tensor& fifth) { + return self; +} + +}} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b494ed7..fddae63 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5878,3 +5878,9 @@ dispatch: CPU: im2col_backward_cpu CUDA: im2col_backward_cuda + +# For benchmarking +- func: _const5(Tensor self, Tensor second, Tensor third, Tensor fourth, Tensor fifth) -> Tensor + variants: function + dispatch: + CPU: _const5 ``` Comparisons with timeit: One-argument, representative case: Before: ``` In [6]: %timeit x.reshape(1, 1) 1.46 µs ± 1.38 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [7]: %timeit x.reshape(1, 1) 1.48 µs ± 29.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [8]: %timeit x.reshape(1, 1) 1.52 µs ± 61.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` After: ``` In [3]: %timeit x.reshape(1, 1) 1.42 µs ± 1.34 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [4]: %timeit x.reshape(1, 1) 1.43 µs ± 1.01 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [5]: %timeit x.reshape(1, 1) 1.42 µs ± 0.982 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` Five-argument, synthetic case (we expect, with enough Tensor arguments, for there to be a slowdown, as we scale `O(n)` with number of arguments, compared to old dispatcher which is `O(1)` with number of arguments): Before: ``` In [1]: import torch In [2]: x = torch.zeros(1) In [3]: %timeit torch._const5(x, x, x, x, x) 949 ns ± 1.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [4]: %timeit torch._const5(x, x, x, x, x) 954 ns ± 1.96 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [5]: %timeit torch._const5(x, x, x, x, x) 947 ns ± 0.601 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` After: ``` In [3]: %timeit torch._const5(x, x, x, x, x) 985 ns ± 9.11 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [4]: %timeit torch._const5(x, x, x, x, x) 984 ns ± 1.17 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [5]: %timeit torch._const5(x, x, x, x, x) 988 ns ± 0.555 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` Signed-off-by: Edward Z. Yang <[email protected]> Differential Revision: [D17265918](https://our.internmc.facebook.com/intern/diff/D17265918) [ghstack-poisoned]
Instead of considering only the TensorTypeSet of the first argument, we collect all Tensor and TensorList arguments and union them together before computing the dispatch type id. XLA companion patch at pytorch/xla#1031 Billing of changes: * ATenDispatch fallback code (i.e., what gets run if there is no entry for a function in the table) now lives out-of-line in a function `getFallbackOp`. This gave me an opportunity to write a more detailed error message, providing information about what registrations were available. There is a TODO in the fallback code, suggesting that we could automatically redispatch in the event that there is no handler for the key. But this is a bit of a design question, because it's not clear if automatic redispatch would cover up errors in the dispatch table (i.e., there *should* have been something registered at some key, but there wasn't.) * Collection of Tensor/TensorList arguments is done using the trusty old IterArgs helper class. A minor bit of refactoring I had to do to get here was move the IterArgs functionality in torch/csrc/utils/variadic.h into ATen/core. There's some refactoring due on that file too (it has copies of some C++ helper pieces which already live in c10--you can't actually move the whole thing because it is literally incompatible with other code in the codebase). So instead of calling `type_set()` to get the type set of the dispatch argument, now we just call `at::detail::multi_dispatch_tensor_type_set` on all of the tensor/tensor list arguments. * The code generator is adjusted to codegen collection of arguments as needed. There is a little bit of a hack in the code generator to turn 'self' arguments into '*this'. I think this may be duplicated with some logic somewhere else but I have to double check. The new generated code looks like this: ``` inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) const { static auto table = globalATenDispatch().getOpTable("aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)"); return table->getOp<Tensor & (Tensor &, const Tensor &, bool)>(at::detail::multi_dispatch_tensor_type_set(*this, src))(const_cast<Tensor&>(*this), src, non_blocking); } ``` The key difference is that previously we wrote `type_set()` as argument to getOp; now it is a call to `multi_dispatch_tensor_type_set` which collects the type ids together. After turning on multi-dispatch, I had to refactor existing code which previously dispatched one place, but now dispatches somewhere else. The primary component affected by this is sparse. * Binary operations (add/sub/mul/div/addmm) now dispatch to sparse kernels even if you did add(dense, sparse). So I delete all the sparse handling code from dense kernels, and bulk up the sparse error handling to handle when the first argument is dense. In the case of addmm, I can eliminate the bridge code entirely (well, not quite: more on this below). I also updated the dispatch on sparse to actually point at sparse kernels. Pay special attention to the handling of `div_` by scalar: previously this logic lived in the "dense" `div_` implementation, but there is actually not any sparse kernel we dispatch to. I solved this particular problem by making a redispatch, but another valid approach would have been to add specific dispatches for sparse div on scalar. This codepath is poorly tested because it is only exercised from C++. * One minor annoyance is that because I now want separate dispatch for dense and sparse, I also need to replicate the `add`, `add_`, `add_out` trifecta on the sparse side. I opted for a compromise here: I wrote new a new `add_sparse` trifecta, but reused the implementation between CPU and CUDA. This means that I hav to do another dispatch once I get to `add_out`. The alternative would have been to do twice as many copies for CPU and CUDA (thereby eliminating the extra dispatch) but that seemed distinctly not worth it. * A lot of kernels in sparse assumed that the dispatch argument must be sparse. This is no longer true with dispatch, so I converted the asserts into plain error checking. This also means that we've perturbed the error message in the case of TestSparseOneOff.test_cuda_sparse_cpu_dense_add (I just updated the saved error message) * `addmm` is a little bit even more special: the bridge code also handled broadcasting. I replicated the broadcasting logic between CPU and CUDA implementations to avoid an extra dispatch. * `_sparse_addmm` gave me a bit of trouble, because I had forgotten why we had `torch.sparse.addmm` in the first place. But in the end, its changes followed along with the structural changes I made in addmm. I opted for an extra dispatch here for simplicity. * c10d has some Variable-Tensor confusion in its sparse code. I've worked around it by judiciously inserting "no variable type" guards, but a more correct fix would be to just solve the confusion entirely. Benchmark: Apply the following patch to the base commit and this commit: ``` diff --git a/aten/src/ATen/native/Const.cpp b/aten/src/ATen/native/Const.cpp new file mode 100644 index 0000000000..b66f4d3 --- /dev/null +++ b/aten/src/ATen/native/Const.cpp @@ -0,0 +1,10 @@ +#include <ATen/ATen.h> + +namespace at { +namespace native { + +Tensor _const5(const Tensor& self, const Tensor& second, const Tensor& third, const Tensor& fourth, const Tensor& fifth) { + return self; +} + +}} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b494ed7..fddae63 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5878,3 +5878,9 @@ dispatch: CPU: im2col_backward_cpu CUDA: im2col_backward_cuda + +# For benchmarking +- func: _const5(Tensor self, Tensor second, Tensor third, Tensor fourth, Tensor fifth) -> Tensor + variants: function + dispatch: + CPU: _const5 ``` Comparisons with timeit: One-argument, representative case: Before: ``` In [6]: %timeit x.reshape(1, 1) 1.46 µs ± 1.38 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [7]: %timeit x.reshape(1, 1) 1.48 µs ± 29.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [8]: %timeit x.reshape(1, 1) 1.52 µs ± 61.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` After: ``` In [3]: %timeit x.reshape(1, 1) 1.42 µs ± 1.34 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [4]: %timeit x.reshape(1, 1) 1.43 µs ± 1.01 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [5]: %timeit x.reshape(1, 1) 1.42 µs ± 0.982 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` Five-argument, synthetic case (we expect, with enough Tensor arguments, for there to be a slowdown, as we scale `O(n)` with number of arguments, compared to old dispatcher which is `O(1)` with number of arguments): Before: ``` In [1]: import torch In [2]: x = torch.zeros(1) In [3]: %timeit torch._const5(x, x, x, x, x) 949 ns ± 1.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [4]: %timeit torch._const5(x, x, x, x, x) 954 ns ± 1.96 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [5]: %timeit torch._const5(x, x, x, x, x) 947 ns ± 0.601 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` After: ``` In [3]: %timeit torch._const5(x, x, x, x, x) 985 ns ± 9.11 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [4]: %timeit torch._const5(x, x, x, x, x) 984 ns ± 1.17 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [5]: %timeit torch._const5(x, x, x, x, x) 988 ns ± 0.555 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` Signed-off-by: Edward Z. Yang <[email protected]> Differential Revision: [D17265918](https://our.internmc.facebook.com/intern/diff/D17265918) [ghstack-poisoned]
Instead of considering only the TensorTypeSet of the first argument, we collect all Tensor and TensorList arguments and union them together before computing the dispatch type id. XLA companion patch at pytorch/xla#1031 Billing of changes: * ATenDispatch fallback code (i.e., what gets run if there is no entry for a function in the table) now lives out-of-line in a function `getFallbackOp`. This gave me an opportunity to write a more detailed error message, providing information about what registrations were available. There is a TODO in the fallback code, suggesting that we could automatically redispatch in the event that there is no handler for the key. But this is a bit of a design question, because it's not clear if automatic redispatch would cover up errors in the dispatch table (i.e., there *should* have been something registered at some key, but there wasn't.) * Collection of Tensor/TensorList arguments is done using the trusty old IterArgs helper class. A minor bit of refactoring I had to do to get here was move the IterArgs functionality in torch/csrc/utils/variadic.h into ATen/core. There's some refactoring due on that file too (it has copies of some C++ helper pieces which already live in c10--you can't actually move the whole thing because it is literally incompatible with other code in the codebase). So instead of calling `type_set()` to get the type set of the dispatch argument, now we just call `at::detail::multi_dispatch_tensor_type_set` on all of the tensor/tensor list arguments. * The code generator is adjusted to codegen collection of arguments as needed. There is a little bit of a hack in the code generator to turn 'self' arguments into '*this'. I think this may be duplicated with some logic somewhere else but I have to double check. The new generated code looks like this: ``` inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) const { static auto table = globalATenDispatch().getOpTable("aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)"); return table->getOp<Tensor & (Tensor &, const Tensor &, bool)>(at::detail::multi_dispatch_tensor_type_set(*this, src))(const_cast<Tensor&>(*this), src, non_blocking); } ``` The key difference is that previously we wrote `type_set()` as argument to getOp; now it is a call to `multi_dispatch_tensor_type_set` which collects the type ids together. After turning on multi-dispatch, I had to refactor existing code which previously dispatched one place, but now dispatches somewhere else. The primary component affected by this is sparse. * Binary operations (add/sub/mul/div/addmm) now dispatch to sparse kernels even if you did add(dense, sparse). So I delete all the sparse handling code from dense kernels, and bulk up the sparse error handling to handle when the first argument is dense. In the case of addmm, I can eliminate the bridge code entirely (well, not quite: more on this below). I also updated the dispatch on sparse to actually point at sparse kernels. Pay special attention to the handling of `div_` by scalar: previously this logic lived in the "dense" `div_` implementation, but there is actually not any sparse kernel we dispatch to. I solved this particular problem by making a redispatch, but another valid approach would have been to add specific dispatches for sparse div on scalar. This codepath is poorly tested because it is only exercised from C++. * One minor annoyance is that because I now want separate dispatch for dense and sparse, I also need to replicate the `add`, `add_`, `add_out` trifecta on the sparse side. I opted for a compromise here: I wrote new a new `add_sparse` trifecta, but reused the implementation between CPU and CUDA. This means that I hav to do another dispatch once I get to `add_out`. The alternative would have been to do twice as many copies for CPU and CUDA (thereby eliminating the extra dispatch) but that seemed distinctly not worth it. * A lot of kernels in sparse assumed that the dispatch argument must be sparse. This is no longer true with dispatch, so I converted the asserts into plain error checking. This also means that we've perturbed the error message in the case of TestSparseOneOff.test_cuda_sparse_cpu_dense_add (I just updated the saved error message) * `addmm` is a little bit even more special: the bridge code also handled broadcasting. I replicated the broadcasting logic between CPU and CUDA implementations to avoid an extra dispatch. * `_sparse_addmm` gave me a bit of trouble, because I had forgotten why we had `torch.sparse.addmm` in the first place. But in the end, its changes followed along with the structural changes I made in addmm. I opted for an extra dispatch here for simplicity. * c10d has some Variable-Tensor confusion in its sparse code. I've worked around it by judiciously inserting "no variable type" guards, but a more correct fix would be to just solve the confusion entirely. Benchmark: Apply the following patch to the base commit and this commit: ``` diff --git a/aten/src/ATen/native/Const.cpp b/aten/src/ATen/native/Const.cpp new file mode 100644 index 0000000000..b66f4d3 --- /dev/null +++ b/aten/src/ATen/native/Const.cpp @@ -0,0 +1,10 @@ +#include <ATen/ATen.h> + +namespace at { +namespace native { + +Tensor _const5(const Tensor& self, const Tensor& second, const Tensor& third, const Tensor& fourth, const Tensor& fifth) { + return self; +} + +}} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b494ed7..fddae63 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5878,3 +5878,9 @@ dispatch: CPU: im2col_backward_cpu CUDA: im2col_backward_cuda + +# For benchmarking +- func: _const5(Tensor self, Tensor second, Tensor third, Tensor fourth, Tensor fifth) -> Tensor + variants: function + dispatch: + CPU: _const5 ``` Comparisons with timeit: One-argument, representative case: Before: ``` In [6]: %timeit x.reshape(1, 1) 1.46 µs ± 1.38 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [7]: %timeit x.reshape(1, 1) 1.48 µs ± 29.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [8]: %timeit x.reshape(1, 1) 1.52 µs ± 61.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` After: ``` In [3]: %timeit x.reshape(1, 1) 1.42 µs ± 1.34 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [4]: %timeit x.reshape(1, 1) 1.43 µs ± 1.01 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [5]: %timeit x.reshape(1, 1) 1.42 µs ± 0.982 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` Five-argument, synthetic case (we expect, with enough Tensor arguments, for there to be a slowdown, as we scale `O(n)` with number of arguments, compared to old dispatcher which is `O(1)` with number of arguments): Before: ``` In [1]: import torch In [2]: x = torch.zeros(1) In [3]: %timeit torch._const5(x, x, x, x, x) 949 ns ± 1.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [4]: %timeit torch._const5(x, x, x, x, x) 954 ns ± 1.96 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [5]: %timeit torch._const5(x, x, x, x, x) 947 ns ± 0.601 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` After: ``` In [3]: %timeit torch._const5(x, x, x, x, x) 985 ns ± 9.11 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [4]: %timeit torch._const5(x, x, x, x, x) 984 ns ± 1.17 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [5]: %timeit torch._const5(x, x, x, x, x) 988 ns ± 0.555 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` Signed-off-by: Edward Z. Yang <[email protected]> Differential Revision: [D17265918](https://our.internmc.facebook.com/intern/diff/D17265918) [ghstack-poisoned]
Summary: Pull Request resolved: pytorch/pytorch#25653 Instead of considering only the TensorTypeSet of the first argument, we collect all Tensor and TensorList arguments and union them together before computing the dispatch type id. Billing of changes: * ATenDispatch fallback code (i.e., what gets run if there is no entry for a function in the table) now lives out-of-line in a function `getFallbackOp`. This gave me an opportunity to write a more detailed error message, providing information about what registrations were available. There is a TODO in the fallback code, suggesting that we could automatically redispatch in the event that there is no handler for the key. But this is a bit of a design question, because it's not clear if automatic redispatch would cover up errors in the dispatch table (i.e., there *should* have been something registered at some key, but there wasn't.) * Collection of Tensor/TensorList arguments is done using the trusty old IterArgs helper class. A minor bit of refactoring I had to do to get here was move the IterArgs functionality in torch/csrc/utils/variadic.h into ATen/core. There's some refactoring due on that file too (it has copies of some C++ helper pieces which already live in c10--you can't actually move the whole thing because it is literally incompatible with other code in the codebase). So instead of calling `type_set()` to get the type set of the dispatch argument, now we just call `at::detail::multi_dispatch_tensor_type_set` on all of the tensor/tensor list arguments. * The code generator is adjusted to codegen collection of arguments as needed. There is a little bit of a hack in the code generator to turn 'self' arguments into '*this'. I think this may be duplicated with some logic somewhere else but I have to double check. After turning on multi-dispatch, I had to refactor existing code which previously dispatched one place, but now dispatches somewhere else. The primary component affected by this is sparse. * Binary operations (add/sub/mul/div/addmm) now dispatch to sparse kernels even if you did add(dense, sparse). So I delete all the sparse handling code from dense kernels, and bulk up the sparse error handling to handle when the first argument is dense. In the case of addmm, I can eliminate the bridge code entirely (well, not quite: more on this below). I also updated the dispatch on sparse to actually point at sparse kernels. Pay special attention to the handling of `div_` by scalar: previously this logic lived in the "dense" `div_` implementation, but there is actually not any sparse kernel we dispatch to. I solved this particular problem by making a redispatch, but another valid approach would have been to add specific dispatches for sparse div on scalar. This codepath is poorly tested because it is only exercised from C++. * One minor annoyance is that because I now want separate dispatch for dense and sparse, I also need to replicate the `add`, `add_`, `add_out` trifecta on the sparse side. I opted for a compromise here: I wrote new a new `add_sparse` trifecta, but reused the implementation between CPU and CUDA. This means that I hav to do another dispatch once I get to `add_out`. The alternative would have been to do twice as many copies for CPU and CUDA (thereby eliminating the extra dispatch) but that seemed distinctly not worth it. * A lot of kernels in sparse assumed that the dispatch argument must be sparse. This is no longer true with dispatch, so I converted the asserts into plain error checking. This also means that we've perturbed the error message in the case of TestSparseOneOff.test_cuda_sparse_cpu_dense_add (I just updated the saved error message) * `addmm` is a little bit even more special: the bridge code also handled broadcasting. I replicated the broadcasting logic between CPU and CUDA implementations to avoid an extra dispatch. * `_sparse_addmm` gave me a bit of trouble, because I had forgotten why we had `torch.sparse.addmm` in the first place. But in the end, its changes followed along with the structural changes I made in addmm. I opted for an extra dispatch here for simplicity. * c10d has some Variable-Tensor confusion in its sparse code. I've worked around it by judiciously inserting "no variable type" guards, but a more correct fix would be to just solve the confusion entirely. Benchmark: Apply the following patch to the base commit and this commit: ``` diff --git a/aten/src/ATen/native/Const.cpp b/aten/src/ATen/native/Const.cpp new file mode 100644 index 0000000000..b66f4d3ece --- /dev/null +++ b/aten/src/ATen/native/Const.cpp @@ -0,0 +1,10 @@ +#include <ATen/ATen.h> + +namespace at { +namespace native { + +Tensor _const5(const Tensor& self, const Tensor& second, const Tensor& third, const Tensor& fourth, const Tensor& fifth) { + return self; +} + +}} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b494ed7950..fddae638bb 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5878,3 +5878,9 @@ dispatch: CPU: im2col_backward_cpu CUDA: im2col_backward_cuda + +# For benchmarking +- func: _const5(Tensor self, Tensor second, Tensor third, Tensor fourth, Tensor fifth) -> Tensor + variants: function + dispatch: + CPU: _const5 ``` Comparisons with timeit: One-argument, representative case: Before: ``` In [6]: %timeit x.reshape(1, 1) 1.46 µs ± 1.38 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [7]: %timeit x.reshape(1, 1) 1.48 µs ± 29.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [8]: %timeit x.reshape(1, 1) 1.52 µs ± 61.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` After: ``` In [3]: %timeit x.reshape(1, 1) 1.42 µs ± 1.34 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [4]: %timeit x.reshape(1, 1) 1.43 µs ± 1.01 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [5]: %timeit x.reshape(1, 1) 1.42 µs ± 0.982 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` Five-argument, synthetic case (we expect, with enough Tensor arguments, for there to be a slowdown, as we scale `O(n)` with number of arguments, compared to old dispatcher which is `O(1)` with number of arguments): Before: ``` In [1]: import torch In [2]: x = torch.zeros(1) In [3]: %timeit torch._const5(x, x, x, x, x) 949 ns ± 1.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [4]: %timeit torch._const5(x, x, x, x, x) 954 ns ± 1.96 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [5]: %timeit torch._const5(x, x, x, x, x) 947 ns ± 0.601 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` After: ``` In [3]: %timeit torch._const5(x, x, x, x, x) 985 ns ± 9.11 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [4]: %timeit torch._const5(x, x, x, x, x) 984 ns ± 1.17 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) In [5]: %timeit torch._const5(x, x, x, x, x) 988 ns ± 0.555 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) ``` Signed-off-by: Edward Z. Yang <[email protected]> Test Plan: Imported from OSS Differential Revision: D17265918 Pulled By: ezyang fbshipit-source-id: 221efe4e86a40f36abc81e2ebceaa7e251c90b3d
Stack from ghstack:
Instead of considering only the TensorTypeSet of the first argument, we collect all Tensor and TensorList arguments and union them together before computing the dispatch type id.
XLA companion patch at pytorch/xla#1031
Billing of changes:
getFallbackOp. This gave me an opportunity to write a more detailed error message, providing information about what registrations were available. There is a TODO in the fallback code, suggesting that we could automatically redispatch in the event that there is no handler for the key. But this is a bit of a design question, because it's not clear if automatic redispatch would cover up errors in the dispatch table (i.e., there should have been something registered at some key, but there wasn't.)type_set()to get the type set of the dispatch argument, now we just callat::detail::multi_dispatch_tensor_type_seton all of the tensor/tensor list arguments.The new generated code looks like this:
The key difference is that previously we wrote
type_set()as argument to getOp; now it is a call tomulti_dispatch_tensor_type_setwhich collects the type ids together.After turning on multi-dispatch, I had to refactor existing code which previously dispatched one place, but now dispatches somewhere else. The primary component affected by this is sparse.
div_by scalar: previously this logic lived in the "dense"div_implementation, but there is actually not any sparse kernel we dispatch to. I solved this particular problem by making a redispatch, but another valid approach would have been to add specific dispatches for sparse div on scalar. This codepath is poorly tested because it is only exercised from C++.add,add_,add_outtrifecta on the sparse side. I opted for a compromise here: I wrote new a newadd_sparsetrifecta, but reused the implementation between CPU and CUDA. This means that I hav to do another dispatch once I get toadd_out. The alternative would have been to do twice as many copies for CPU and CUDA (thereby eliminating the extra dispatch) but that seemed distinctly not worth it.addmmis a little bit even more special: the bridge code also handled broadcasting. I replicated the broadcasting logic between CPU and CUDA implementations to avoid an extra dispatch._sparse_addmmgave me a bit of trouble, because I had forgotten why we hadtorch.sparse.addmmin the first place. But in the end, its changes followed along with the structural changes I made in addmm. I opted for an extra dispatch here for simplicity.Benchmark:
Apply the following patch to the base commit and this commit:
Comparisons with timeit:
One-argument, representative case:
Before:
After:
Five-argument, synthetic case (we expect, with enough Tensor arguments, for there to be a slowdown, as we scale
O(n)with number of arguments, compared to old dispatcher which isO(1)with number of arguments):Before:
After:
Signed-off-by: Edward Z. Yang [email protected]
Differential Revision: D17265918