-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[MPS] cholesky ex version #146799
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MPS] cholesky ex version #146799
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,7 @@ | |
| #include <ATen/ops/baddbmm_native.h> | ||
| #include <ATen/ops/bmm_native.h> | ||
| #include <ATen/ops/cholesky_native.h> | ||
| #include <ATen/ops/linalg_cholesky_ex_native.h> | ||
| #include <ATen/ops/linalg_cholesky_native.h> | ||
| #include <ATen/ops/linalg_lu_factor_ex_native.h> | ||
| #include <ATen/ops/linalg_lu_factor_native.h> | ||
|
|
@@ -1051,7 +1052,11 @@ static void lu_unpack_mps_impl(const Tensor& LU_data, | |
| } | ||
| } | ||
|
|
||
| static Tensor& linalg_cholesky_mps_impl(const Tensor& input, bool upper, Tensor& out) { | ||
| static void linalg_cholesky_mps_impl(const Tensor& input, | ||
| bool upper, | ||
| bool check_errors, | ||
| const Tensor& out, | ||
| const Tensor& info) { | ||
| using namespace mps; | ||
|
|
||
| TORCH_CHECK(out.is_mps()); | ||
|
|
@@ -1061,9 +1066,11 @@ static void lu_unpack_mps_impl(const Tensor& LU_data, | |
|
|
||
| if (input.numel() == 0 || out.numel() == 0) { | ||
| out.zero_(); | ||
| return out; | ||
| return; | ||
| } | ||
| resize_output(out, input.sizes()); | ||
| auto input_sizes = input.sizes(); | ||
| resize_output(out, input_sizes); | ||
| resize_output(info, {input_sizes.begin(), input_sizes.end() - 2}); | ||
| out.copy_(input); | ||
|
|
||
| int64_t ndim = out.dim(); | ||
|
|
@@ -1083,14 +1090,16 @@ static void lu_unpack_mps_impl(const Tensor& LU_data, | |
| int64_t NB = std::min<int64_t>(32, N); | ||
| int64_t numBlocks = (N + NB - 1) / NB; | ||
|
|
||
| Tensor success = at::empty({B}, input.options().dtype(kInt)).fill_(1); | ||
| auto info_ = info.dim() >= 2 ? info.view({B}) : info; | ||
| auto info_sizes = info.sizes(); | ||
| info_.fill_(0); | ||
|
|
||
| MTLSize threadGroupSize = MTLSizeMake(32, 8, 1); | ||
|
|
||
| @autoreleasepool { | ||
| dispatch_sync_with_rethrow(stream->queue(), ^() { | ||
| auto computeEncoder = stream->commandEncoder(); | ||
| mtl_setArgs(computeEncoder, out, success, N, NB); | ||
| mtl_setArgs(computeEncoder, out, info_, N, NB); | ||
| for (int64_t k = 0; k < numBlocks; k++) { | ||
| [computeEncoder setComputePipelineState:factorDiagonalPSO]; | ||
| mtl_setBytes(computeEncoder, k, 4); | ||
|
|
@@ -1118,10 +1127,32 @@ static void lu_unpack_mps_impl(const Tensor& LU_data, | |
| } | ||
| }); | ||
| } | ||
|
|
||
| TORCH_CHECK(success.all().item<bool>(), "linalg.cholesky: Input matrix is not positive definite"); | ||
| out.tril_(); // | ||
| return upper ? out.transpose_(ndim - 2, ndim - 1) : out; | ||
| int status; | ||
| if (check_errors) { | ||
| if (info_.dim() > 0) { | ||
| // batch case | ||
| for (const auto i : c10::irange(B)) { | ||
| status = info_[i].item<int>(); | ||
| TORCH_CHECK( | ||
| status == 0, | ||
| "linalg.cholesky(): (Batch element ", | ||
| i, | ||
| "): The factorization could not be completed because the input is not positive-definite (the leading minor of order ", | ||
| status, | ||
| " is not positive-definite)."); | ||
| } | ||
| } else { | ||
| // single matrix case(no batch size) | ||
| status = info.item<int>(); | ||
| TORCH_CHECK( | ||
| status == 0, | ||
| "linalg.cholesky(): The factorization could not be completed because the input is not positive-definite (the leading minor of order ", | ||
| status, | ||
| " is not positive-definite)."); | ||
| } | ||
| } | ||
| out.tril_(); | ||
| upper ? out.transpose_(ndim - 2, ndim - 1) : out; | ||
|
Comment on lines
+1154
to
+1155
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will silently alter the stride structure of
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's not the same. The kernel does decomposition in the lower part of the matrix. If you do out.triu_() instead of out.tril_ -> transpose, then you get the upper part of the matrix which isn't really the correct output.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you have some stride assumptions in the kernel, or is it stride-agnostic? If it is stride-agnostic, then the kernel could be run on the transposed variant.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It assumes that input is row major(contiguous)
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I printed data ptr inside the mps function and outside in python: Yields: First one being print from python, 2nd one being before launching the kernel from C++ and 3rd one being again from python. So yeah confirmed
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As per https://github.com/pytorch/pytorch/pull/146799/files#r1952464144, this is expected. Sorry for the confusion. But we should have issues when
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No issues from what I check:
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Isalia20 , could you remove
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I see the issue now: |
||
| } | ||
| } // namespace mps | ||
|
|
||
|
|
@@ -1285,21 +1316,19 @@ Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, cons | |
|
|
||
| Tensor cholesky_mps(const Tensor& self, bool upper) { | ||
| auto out = at::empty_like(self, MemoryFormat::Contiguous); | ||
| mps::linalg_cholesky_mps_impl(self, upper, out); | ||
| cholesky_mps_out(self, upper, out); | ||
| return out; | ||
| } | ||
|
|
||
| Tensor& cholesky_mps_out(const Tensor& self, bool upper, Tensor& out) { | ||
| return mps::linalg_cholesky_mps_impl(self, upper, out); | ||
| } | ||
|
|
||
| Tensor& linalg_cholesky_out_mps(const Tensor& self, bool upper, Tensor& out) { | ||
| return mps::linalg_cholesky_mps_impl(self, upper, out); | ||
| auto info = at::empty({}, self.options().dtype(kInt)); | ||
| mps::linalg_cholesky_mps_impl(self, upper, true, out, info); | ||
| return out; | ||
| } | ||
|
|
||
| Tensor linalg_cholesky_mps(const Tensor& self, bool upper) { | ||
| auto out = at::empty_like(self, MemoryFormat::Contiguous); | ||
| return mps::linalg_cholesky_mps_impl(self, upper, out); | ||
| TORCH_IMPL_FUNC(linalg_cholesky_ex_out_mps) | ||
| (const Tensor& self, bool upper, bool check_errors, const Tensor& L, const Tensor& info) { | ||
| mps::linalg_cholesky_mps_impl(self, upper, check_errors, L, info); | ||
| } | ||
|
|
||
| Tensor addbmm_mps(const Tensor& self, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -699,7 +699,6 @@ def mps_ops_modifier(ops): | |
| 'index_reduceamin': None, | ||
| 'kthvalue': None, | ||
| 'lcm': None, | ||
| 'linalg.cholesky_ex': None, | ||
| 'linalg.cond': None, | ||
| 'linalg.eigh': None, | ||
| 'linalg.eigvalsh': None, | ||
|
|
@@ -6525,14 +6524,23 @@ def test_sort(self): | |
| atol=0, rtol=0 | ||
| ) | ||
|
|
||
| def test_cholesky(self): | ||
| def test_linalg_cholesky(self): | ||
| from torch.testing._internal.common_utils import random_hermitian_pd_matrix | ||
|
|
||
| def run_cholesky_test(size, *batch_dims, upper): | ||
| def run_cholesky_test(size, *batch_dims, upper=False, check_errors=False): | ||
| if check_errors: | ||
| # expect failure for non-positive definite matrix | ||
| input_mps = torch.eye(size, dtype=torch.float32, device="mps") | ||
| input_mps[0, 0] = -1 | ||
| error_msg = r'The factorization could not be completed because the input is not positive-definite' | ||
| with self.assertRaisesRegex(RuntimeError, error_msg): | ||
| torch.linalg.cholesky_ex(input_mps, upper=upper, check_errors=check_errors) | ||
| return | ||
| # output checks for positive definite matrix | ||
| input_cpu = random_hermitian_pd_matrix(size, *batch_dims, dtype=torch.float32, device="cpu") | ||
| input_mps = input_cpu.to('mps') | ||
| output_cpu = torch.linalg.cholesky(input_cpu, upper=upper) | ||
| output_mps = torch.linalg.cholesky(input_mps, upper=upper) | ||
| output_cpu = torch.linalg.cholesky_ex(input_cpu, upper=upper) | ||
| output_mps = torch.linalg.cholesky_ex(input_mps, upper=upper) | ||
|
Comment on lines
+6542
to
+6543
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let us also check that
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. output_cpu and output_mps is a tuple of L and info tensors so assertEqual is comparing both of them. Do you mean to add a separate test where info might be >1?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, when erroring on non-psd inputs :)
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll do it a bit later today and also adapt the error message
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added better error message |
||
| self.assertEqual(output_cpu, output_mps, atol=2e-5, rtol=1e-6) | ||
|
|
||
| # test with different even/odd matrix sizes | ||
|
|
@@ -6548,6 +6556,18 @@ def run_cholesky_test(size, *batch_dims, upper): | |
| # test >3D matrices | ||
| run_cholesky_test(128, 10, 10, upper=False) | ||
| run_cholesky_test(128, 2, 2, 2, 2, 10, 10, upper=True) | ||
| run_cholesky_test(32, 2, upper=False, check_errors=True) | ||
| run_cholesky_test(32, 2, upper=True, check_errors=True) | ||
|
|
||
| def test_linalg_cholesky_info(self): | ||
| # non psd matrix with leading minor of order 2 being not positive definite | ||
| A = torch.tensor([ | ||
| [4.0, 1.0, 0.0], | ||
| [1.0, -2.0, 1.0], | ||
| [0.0, 1.0, 3.0] | ||
| ], device="mps") | ||
| with self.assertRaisesRegex(RuntimeError, r'leading minor of order 2 is not positive-definite'): | ||
| torch.linalg.cholesky_ex(A, check_errors=True) | ||
|
|
||
| def test_upsample_nearest2d(self): | ||
| def helper(N, C, H, W, memory_format): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is MPS different?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MPS Kernel assumes row-major layout for the matrix where it does the decomposition
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can the kernel be made row-major/col-major agnostic so as to be able preserve the consistency across backends?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll take a look ~next week to see if I can make it work for col-major so we don't need to make it row major for MPS only, but why do we want to preserve consistency across backends? Lot of ops on MPS use row major layout and require contiguous call on it before passing it to some MPS kernel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In linalg LAPACK seems like the source of truth, and it is written in Fortran where col-major is the standard layout :(
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we can re-use the kernel without that much code change (i.e. no need to make it stride-agnostic for now). In the Meta function we request C-contiguous when upper=False and F-contiguous when upper=True for the MPS. Then we only need to remove the line upper ? out.transpose_(...) : out (and probably replace it with out.tril_() : out.triu_(). Or something along these lines. Should resolve the issue for now with out, before the kernel is adapted for better memory accesses when in column-major mode...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've tried it but I'm afraid it doesn't work. I'll address this in the followup PR with the kernel change for column major mode rather than going into the rabbit hole now for a temporary fix