-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[MPS] cholesky ex version #146799
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MPS] cholesky ex version #146799
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146799
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4ea9ce5 with merge base 91c4bf3 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Attention! native_functions.yaml was changedIf you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info. Caused by: |
|
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
| out.tril_(); | ||
| upper ? out.transpose_(ndim - 2, ndim - 1) : out; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will silently alter the stride structure of out if upper == true. It is better be upper ? out.triu_() : out.tril_().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not the same. The kernel does decomposition in the lower part of the matrix. If you do out.triu_() instead of out.tril_ -> transpose, then you get the upper part of the matrix which isn't really the correct output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have some stride assumptions in the kernel, or is it stride-agnostic? If it is stride-agnostic, then the kernel could be run on the transposed variant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It assumes that input is row major(contiguous)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out can be provided externally as column-major. What would happen in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I printed data ptr inside the mps function and outside in python:
import torch
out = torch.rand(3, 3, 3, device="mps").permute(2, 1, 0)
x = torch.rand(3, 3, 3, device="mps")
x = x.mT @ x
data_ptr = out.data_ptr()
print(f"0x{data_ptr:x}") # lowercase hex
torch.linalg.cholesky(x, out=out)
print(f"0x{out.data_ptr():x}")
Yields:
0x10a4d68d0
0x10fb19150
0x10a4d68d0
First one being print from python, 2nd one being before launching the kernel from C++ and 3rd one being again from python. So yeah confirmed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As per https://github.com/pytorch/pytorch/pull/146799/files#r1952464144, this is expected. Sorry for the confusion. But we should have issues when out is contiguous and upper=True it seems.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No issues from what I check:
import torch
out = torch.rand(3, 3, 3, device="mps").permute(2, 1, 0)
x = torch.rand(3, 3, 3, device="mps")
x = x.mT @ x
data_ptr = out.data_ptr()
print(f"0x{data_ptr:x}") # lowercase hex
print(out.stride())
print(out.is_contiguous())
res1 = torch.linalg.cholesky(x, out=out, upper=True)
res2 = torch.linalg.cholesky(x.cpu(), out=out.cpu(), upper=True)
print(f"0x{out.data_ptr():x}")
print(out.stride())
torch.testing.assert_close(res1.cpu(), res2)
0x113f70cc0
(1, 3, 9)
False
0x114f3a510
0x113f70cc0
(1, 3, 9)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Isalia20 , could you remove permute so that out is contiguous? In the Meta function, as per your modification, out is re-used only if it is contiguous.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see the issue now:
0x10bc7b840
(9, 3, 1)
True
0x10bc7b840
0x10bc7b840
(9, 1, 3)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thank you! I have left some comments regarding the silent stride-altering behavior in out and the values of the info vector.
|
|
||
| // L | ||
| auto L_strides = at::native::batched_matrix_contiguous_strides(A_shape, /*f-contig*=*/true); | ||
| auto L_strides = at::native::batched_matrix_contiguous_strides(A_shape, /*f-contig*=*/A.device().type() != at::kMPS); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is MPS different?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MPS Kernel assumes row-major layout for the matrix where it does the decomposition
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can the kernel be made row-major/col-major agnostic so as to be able preserve the consistency across backends?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll take a look ~next week to see if I can make it work for col-major so we don't need to make it row major for MPS only, but why do we want to preserve consistency across backends? Lot of ops on MPS use row major layout and require contiguous call on it before passing it to some MPS kernel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In linalg LAPACK seems like the source of truth, and it is written in Fortran where col-major is the standard layout :(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we can re-use the kernel without that much code change (i.e. no need to make it stride-agnostic for now). In the Meta function we request C-contiguous when upper=False and F-contiguous when upper=True for the MPS. Then we only need to remove the line upper ? out.transpose_(...) : out (and probably replace it with out.tril_() : out.triu_(). Or something along these lines. Should resolve the issue for now with out, before the kernel is adapted for better memory accesses when in column-major mode...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've tried it but I'm afraid it doesn't work. I'll address this in the followup PR with the kernel change for column major mode rather than going into the rabbit hole now for a temporary fix
|
Thanks, I'll address the comments a little later today |
| output_cpu = torch.linalg.cholesky_ex(input_cpu, upper=upper) | ||
| output_mps = torch.linalg.cholesky_ex(input_mps, upper=upper) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let us also check that info is the same since its behavior is altered?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
output_cpu and output_mps is a tuple of L and info tensors so assertEqual is comparing both of them. Do you mean to add a separate test where info might be >1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, when erroring on non-psd inputs :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll do it a bit later today and also adapt the error message
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added better error message
|
@pytorchbot merge -f "MPS is green" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
PR pytorch#145701 didn't have experimental version of cholesky. This PR adds that version Pull Request resolved: pytorch#146799 Approved by: https://github.com/malfet
PR #145701 didn't have experimental version of cholesky. This PR adds that version