Conversation
|
FWIW, under the reviewer name there is a link that switch your PR to draft. |
6ea496a to
78319cb
Compare
86e454e to
4cb81ab
Compare
torch_xla/csrc/aten_xla_type.cpp
Outdated
| XLA_FN_COUNTER("xla::"); | ||
| // Einsum operations with more than 2 operands, like bilinear operations, are | ||
| // not currently supported in XLA | ||
| if (tensors.size() > 2) { |
There was a problem hiding this comment.
@bdhirsh This is a bit tricky, we want to overwrite einsum but we can only support a certain type of einsum. Looking at https://github.com/pytorch/xla/blob/master/torch_xla/csrc/aten_autograd_ops.cpp#L31 it seems like we actually need to let pt/xla to take care falling back for both forward and backward. However there is no such thing as einsum_backward, what are we going to do with backward case?
There was a problem hiding this comment.
on a second thought we should not fallback but call at::native function like https://github.com/pytorch/xla/blob/master/torch_xla/csrc/aten_xla_type.cpp#L3086 which will redecompose einsum into smaller ops. I think this should solve the backward issue for unsupported einsum too.
There was a problem hiding this comment.
Yes, I think you want to fall back by calling into at::native::einsum. Falling back to CPU seems like a big pessimisation; instead, you just want to fall back to the existing decomposition in core, and run XLA on the decomposed ops.
There was a problem hiding this comment.
I also... think that should fix the requires_grad issue that you're seeing, but it's worth confirming 😛.
There was a problem hiding this comment.
Updated to at::native::einsum here, but there's still the issue of requires_grad for cases with 1 or 2 operands. We want to use aten_autograd_ops::EinsumAutogradFunction::apply there to leverage XLA's einsum implementation.
There was a problem hiding this comment.
ah, so it's working in the case where you call at::native::einsum, but it's losing the requires_grad in the case where it calls the custom autograd function?
If it's an issue with the autograd function, what I would check first is:
(1) Did you implement the autograd function the same way that the existing ones in pytorch_xla are implemented? (e.g. max_pool2d?
(2) If so, can you repro the issue max_pool2d too?
There was a problem hiding this comment.
As far as I know, this einsum autograd function is implemented the same way as the max_pool functions. However, I cannot reproduce the issue with max_pool2d.
Max pool 2d was overriden as an autograd function in #2236. Since then it looks like we've removed scripts/gen.py, but I believe all I have to do to setup the code generation now is add einsum to the autograd section of xla_native_functions.yaml. Could there be additional setup required, either in this repo or the parent repo? CC @JackCaoG
There was a problem hiding this comment.
It doesn't look like we explicitly do anything in the max_pool implementations to forward requires_grad to the output. Where is that happening?
There was a problem hiding this comment.
@bdhirsh I think one difference is that pytorch does not have backward formula for einsum, it will always get dispatched to smaller ops. Not sure if this will affect the behavior of require_grads.
|
Right now, there's an issue with the result of einsum not requiring grad, even when the inputs require grad I believe this may be because there is no We can also recreate this with cpp tests like |
|
@bdhirsh Can you help on #3843 (comment), I am actually not sure where the |
4cb81ab to
5fdf370
Compare
|
@ezyang I am wondering if you have any idea regarding #3843 (comment). This is blocking one of our benchmark experiment |
|
I'll take a look, but what I expected is for you to override AutogradXLA and use the C++ custom autograd function api to setup your derivative. Do you have all this? |
yup. We
I was trying to look up how The issue right now is after we do |
|
I am able to confirm that with this pr, forward compute the correct result but does not have grad_fn |
|
if I do the same thing to maxpool2d, which is another function we overwrite backward using |
|
I think the problem is torch::autograd::Function doesn't support variable list. you need to add support for this is in torch/csrc/autograd/custom_function.h |
5fdf370 to
674f2cd
Compare
|
After adding support for Edit: It looks like |
5b325ce to
c998b82
Compare
| output_shape = shape_one; | ||
| } | ||
|
|
||
| return output_shape; |
There was a problem hiding this comment.
can't we just call BuildEinsumBackward once and then walk through the return vector and decide if we need a tuple shape?
There was a problem hiding this comment.
Done, but I needed to add an InferOutputShapes to handle the std::vector<xla::XlaOp> output from BuildEinsumBackward
c998b82 to
e449a12
Compare
e449a12 to
36c3d10
Compare
36c3d10 to
3369eed
Compare
This is required to unblock pytorch/xla#3843, which lowers the einsum op for pytorch/xla. Because one method input parameter is a TensorList, we need to support TensorLists here so that we can support einsum gradients. Pull Request resolved: #84583 Approved by: https://github.com/soulitzer
This is required to unblock pytorch/xla#3843, which lowers the einsum op for pytorch/xla. Because one method input parameter is a TensorList, we need to support TensorLists here so that we can support einsum gradients. Pull Request resolved: #84583 Approved by: https://github.com/soulitzer
Implements op lowering for einsum and einsum backward when (1) there are at most 2 inputs and (2) there are no equations (forward or backward) that have an index in one element (input or output) which is absent from any other element. When these conditions are not met, we fall back to the at::native implementation, which will break the einsum op down into constitutive operations.
If we want to relax condition (2), then we need a change in XLA to support those kind of einsum equations. Currently, such equations lead to an INVALID_ARGUMENT status when trying to get the shape of the output. Likewise, if we want to relax condition (1) we either need a change in XLA, or a change in the upstream to break down einsums with 3 or more inputs