Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/native/BatchLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ TORCH_META_FUNC(linalg_cholesky_ex)(const Tensor& A,
auto ndim = A_shape.size();

// L
auto L_strides = at::native::batched_matrix_contiguous_strides(A_shape, /*f-contig*=*/true);
auto L_strides = at::native::batched_matrix_contiguous_strides(A_shape, /*f-contig*=*/A.device().type() != at::kMPS);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is MPS different?

Copy link
Collaborator Author

@Isalia20 Isalia20 Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MPS Kernel assumes row-major layout for the matrix where it does the decomposition

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the kernel be made row-major/col-major agnostic so as to be able preserve the consistency across backends?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll take a look ~next week to see if I can make it work for col-major so we don't need to make it row major for MPS only, but why do we want to preserve consistency across backends? Lot of ops on MPS use row major layout and require contiguous call on it before passing it to some MPS kernel

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In linalg LAPACK seems like the source of truth, and it is written in Fortran where col-major is the standard layout :(

Copy link
Collaborator

@nikitaved nikitaved Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we can re-use the kernel without that much code change (i.e. no need to make it stride-agnostic for now). In the Meta function we request C-contiguous when upper=False and F-contiguous when upper=True for the MPS. Then we only need to remove the line upper ? out.transpose_(...) : out (and probably replace it with out.tril_() : out.triu_(). Or something along these lines. Should resolve the issue for now with out, before the kernel is adapted for better memory accesses when in column-major mode...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried it but I'm afraid it doesn't work. I'll address this in the followup PR with the kernel change for column major mode rather than going into the rabbit hole now for a temporary fix

set_output_strided(0, A_shape, L_strides, A.options(), {});

// info
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/mps/kernels/LinearAlgebra.metal
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ inline float blockReduceSum(

kernel void factorDiagonalBlock(
device float* A [[buffer(0)]],
device int* success [[buffer(1)]],
device int* info [[buffer(1)]],
constant uint& N [[buffer(2)]],
constant uint& NB [[buffer(3)]],
constant uint& k [[buffer(4)]],
Expand Down Expand Up @@ -142,7 +142,7 @@ kernel void factorDiagonalBlock(
if (linear_tid == 0) {
float diagVal = tile[kk][kk] - diagElt;
if (diagVal <= 0.0f) {
success[bid.x] = 0;
info[bid.x] = kk + 1;
return;
}
tile[kk][kk] = sqrt(diagVal);
Expand Down
65 changes: 47 additions & 18 deletions aten/src/ATen/native/mps/operations/LinearAlgebra.mm
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <ATen/ops/baddbmm_native.h>
#include <ATen/ops/bmm_native.h>
#include <ATen/ops/cholesky_native.h>
#include <ATen/ops/linalg_cholesky_ex_native.h>
#include <ATen/ops/linalg_cholesky_native.h>
#include <ATen/ops/linalg_lu_factor_ex_native.h>
#include <ATen/ops/linalg_lu_factor_native.h>
Expand Down Expand Up @@ -1051,7 +1052,11 @@ static void lu_unpack_mps_impl(const Tensor& LU_data,
}
}

static Tensor& linalg_cholesky_mps_impl(const Tensor& input, bool upper, Tensor& out) {
static void linalg_cholesky_mps_impl(const Tensor& input,
bool upper,
bool check_errors,
const Tensor& out,
const Tensor& info) {
using namespace mps;

TORCH_CHECK(out.is_mps());
Expand All @@ -1061,9 +1066,11 @@ static void lu_unpack_mps_impl(const Tensor& LU_data,

if (input.numel() == 0 || out.numel() == 0) {
out.zero_();
return out;
return;
}
resize_output(out, input.sizes());
auto input_sizes = input.sizes();
resize_output(out, input_sizes);
resize_output(info, {input_sizes.begin(), input_sizes.end() - 2});
out.copy_(input);

int64_t ndim = out.dim();
Expand All @@ -1083,14 +1090,16 @@ static void lu_unpack_mps_impl(const Tensor& LU_data,
int64_t NB = std::min<int64_t>(32, N);
int64_t numBlocks = (N + NB - 1) / NB;

Tensor success = at::empty({B}, input.options().dtype(kInt)).fill_(1);
auto info_ = info.dim() >= 2 ? info.view({B}) : info;
auto info_sizes = info.sizes();
info_.fill_(0);

MTLSize threadGroupSize = MTLSizeMake(32, 8, 1);

@autoreleasepool {
dispatch_sync_with_rethrow(stream->queue(), ^() {
auto computeEncoder = stream->commandEncoder();
mtl_setArgs(computeEncoder, out, success, N, NB);
mtl_setArgs(computeEncoder, out, info_, N, NB);
for (int64_t k = 0; k < numBlocks; k++) {
[computeEncoder setComputePipelineState:factorDiagonalPSO];
mtl_setBytes(computeEncoder, k, 4);
Expand Down Expand Up @@ -1118,10 +1127,32 @@ static void lu_unpack_mps_impl(const Tensor& LU_data,
}
});
}

TORCH_CHECK(success.all().item<bool>(), "linalg.cholesky: Input matrix is not positive definite");
out.tril_(); //
return upper ? out.transpose_(ndim - 2, ndim - 1) : out;
int status;
if (check_errors) {
if (info_.dim() > 0) {
// batch case
for (const auto i : c10::irange(B)) {
status = info_[i].item<int>();
TORCH_CHECK(
status == 0,
"linalg.cholesky(): (Batch element ",
i,
"): The factorization could not be completed because the input is not positive-definite (the leading minor of order ",
status,
" is not positive-definite).");
}
} else {
// single matrix case(no batch size)
status = info.item<int>();
TORCH_CHECK(
status == 0,
"linalg.cholesky(): The factorization could not be completed because the input is not positive-definite (the leading minor of order ",
status,
" is not positive-definite).");
}
}
out.tril_();
upper ? out.transpose_(ndim - 2, ndim - 1) : out;
Comment on lines +1154 to +1155
Copy link
Collaborator

@nikitaved nikitaved Feb 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will silently alter the stride structure of out if upper == true. It is better be upper ? out.triu_() : out.tril_().

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not the same. The kernel does decomposition in the lower part of the matrix. If you do out.triu_() instead of out.tril_ -> transpose, then you get the upper part of the matrix which isn't really the correct output.

Copy link
Collaborator

@nikitaved nikitaved Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have some stride assumptions in the kernel, or is it stride-agnostic? If it is stride-agnostic, then the kernel could be run on the transposed variant.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It assumes that input is row major(contiguous)

Copy link
Collaborator

@nikitaved nikitaved Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out can be provided externally as column-major. What would happen in this case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I printed data ptr inside the mps function and outside in python:

import torch

out = torch.rand(3, 3, 3, device="mps").permute(2, 1, 0)

x = torch.rand(3, 3, 3, device="mps")
x = x.mT @ x

data_ptr = out.data_ptr()
print(f"0x{data_ptr:x}")  # lowercase hex
torch.linalg.cholesky(x, out=out)
print(f"0x{out.data_ptr():x}")

Yields:

0x10a4d68d0
0x10fb19150
0x10a4d68d0

First one being print from python, 2nd one being before launching the kernel from C++ and 3rd one being again from python. So yeah confirmed

Copy link
Collaborator

@nikitaved nikitaved Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per https://github.com/pytorch/pytorch/pull/146799/files#r1952464144, this is expected. Sorry for the confusion. But we should have issues when out is contiguous and upper=True it seems.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No issues from what I check:

import torch

out = torch.rand(3, 3, 3, device="mps").permute(2, 1, 0)

x = torch.rand(3, 3, 3, device="mps")
x = x.mT @ x

data_ptr = out.data_ptr()
print(f"0x{data_ptr:x}")  # lowercase hex
print(out.stride())
print(out.is_contiguous())
res1 = torch.linalg.cholesky(x, out=out, upper=True)
res2 = torch.linalg.cholesky(x.cpu(), out=out.cpu(), upper=True)
print(f"0x{out.data_ptr():x}")
print(out.stride())
torch.testing.assert_close(res1.cpu(), res2)
0x113f70cc0
(1, 3, 9)
False
0x114f3a510
0x113f70cc0
(1, 3, 9)

Copy link
Collaborator

@nikitaved nikitaved Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Isalia20 , could you remove permute so that out is contiguous? In the Meta function, as per your modification, out is re-used only if it is contiguous.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see the issue now:

0x10bc7b840
(9, 3, 1)
True
0x10bc7b840
0x10bc7b840
(9, 1, 3)

}
} // namespace mps

Expand Down Expand Up @@ -1285,21 +1316,19 @@ Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, cons

Tensor cholesky_mps(const Tensor& self, bool upper) {
auto out = at::empty_like(self, MemoryFormat::Contiguous);
mps::linalg_cholesky_mps_impl(self, upper, out);
cholesky_mps_out(self, upper, out);
return out;
}

Tensor& cholesky_mps_out(const Tensor& self, bool upper, Tensor& out) {
return mps::linalg_cholesky_mps_impl(self, upper, out);
}

Tensor& linalg_cholesky_out_mps(const Tensor& self, bool upper, Tensor& out) {
return mps::linalg_cholesky_mps_impl(self, upper, out);
auto info = at::empty({}, self.options().dtype(kInt));
mps::linalg_cholesky_mps_impl(self, upper, true, out, info);
return out;
}

Tensor linalg_cholesky_mps(const Tensor& self, bool upper) {
auto out = at::empty_like(self, MemoryFormat::Contiguous);
return mps::linalg_cholesky_mps_impl(self, upper, out);
TORCH_IMPL_FUNC(linalg_cholesky_ex_out_mps)
(const Tensor& self, bool upper, bool check_errors, const Tensor& L, const Tensor& info) {
mps::linalg_cholesky_mps_impl(self, upper, check_errors, L, info);
}

Tensor addbmm_mps(const Tensor& self,
Expand Down
7 changes: 1 addition & 6 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13901,18 +13901,13 @@
structured: True
dispatch:
CPU, CUDA: linalg_cholesky_ex_out
MPS: linalg_cholesky_ex_out_mps

- func: linalg_cholesky(Tensor self, *, bool upper=False) -> Tensor
python_module: linalg
dispatch:
CompositeImplicitAutograd: linalg_cholesky
MPS: linalg_cholesky_mps

- func: linalg_cholesky.out(Tensor self, *, bool upper=False, Tensor(a!) out) -> Tensor(a!)
python_module: linalg
dispatch:
CompositeImplicitAutograd: linalg_cholesky_out
MPS: linalg_cholesky_out_mps

- func: linalg_cross(Tensor self, Tensor other, *, int dim=-1) -> Tensor
python_module: linalg
Expand Down
30 changes: 25 additions & 5 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,6 @@ def mps_ops_modifier(ops):
'index_reduceamin': None,
'kthvalue': None,
'lcm': None,
'linalg.cholesky_ex': None,
'linalg.cond': None,
'linalg.eigh': None,
'linalg.eigvalsh': None,
Expand Down Expand Up @@ -6525,14 +6524,23 @@ def test_sort(self):
atol=0, rtol=0
)

def test_cholesky(self):
def test_linalg_cholesky(self):
from torch.testing._internal.common_utils import random_hermitian_pd_matrix

def run_cholesky_test(size, *batch_dims, upper):
def run_cholesky_test(size, *batch_dims, upper=False, check_errors=False):
if check_errors:
# expect failure for non-positive definite matrix
input_mps = torch.eye(size, dtype=torch.float32, device="mps")
input_mps[0, 0] = -1
error_msg = r'The factorization could not be completed because the input is not positive-definite'
with self.assertRaisesRegex(RuntimeError, error_msg):
torch.linalg.cholesky_ex(input_mps, upper=upper, check_errors=check_errors)
return
# output checks for positive definite matrix
input_cpu = random_hermitian_pd_matrix(size, *batch_dims, dtype=torch.float32, device="cpu")
input_mps = input_cpu.to('mps')
output_cpu = torch.linalg.cholesky(input_cpu, upper=upper)
output_mps = torch.linalg.cholesky(input_mps, upper=upper)
output_cpu = torch.linalg.cholesky_ex(input_cpu, upper=upper)
output_mps = torch.linalg.cholesky_ex(input_mps, upper=upper)
Comment on lines +6542 to +6543
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let us also check that info is the same since its behavior is altered?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output_cpu and output_mps is a tuple of L and info tensors so assertEqual is comparing both of them. Do you mean to add a separate test where info might be >1?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, when erroring on non-psd inputs :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do it a bit later today and also adapt the error message

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added better error message

self.assertEqual(output_cpu, output_mps, atol=2e-5, rtol=1e-6)

# test with different even/odd matrix sizes
Expand All @@ -6548,6 +6556,18 @@ def run_cholesky_test(size, *batch_dims, upper):
# test >3D matrices
run_cholesky_test(128, 10, 10, upper=False)
run_cholesky_test(128, 2, 2, 2, 2, 10, 10, upper=True)
run_cholesky_test(32, 2, upper=False, check_errors=True)
run_cholesky_test(32, 2, upper=True, check_errors=True)

def test_linalg_cholesky_info(self):
# non psd matrix with leading minor of order 2 being not positive definite
A = torch.tensor([
[4.0, 1.0, 0.0],
[1.0, -2.0, 1.0],
[0.0, 1.0, 3.0]
], device="mps")
with self.assertRaisesRegex(RuntimeError, r'leading minor of order 2 is not positive-definite'):
torch.linalg.cholesky_ex(A, check_errors=True)

def test_upsample_nearest2d(self):
def helper(N, C, H, W, memory_format):
Expand Down
5 changes: 0 additions & 5 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -410,11 +410,6 @@
self: cholesky_backward(grad, upper, L)
L: cholesky_jvp(self_t, L, upper)

# temporarily here before linalg_cholesky dispatches to linalg_cholesky_ex on MPS device
- name: linalg_cholesky(Tensor self, *, bool upper=False) -> Tensor
self: cholesky_backward(grad, upper, result)
result: cholesky_jvp(self_t, result, upper)

- name: cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor
self, input2: cholesky_solve_backward(grad, self, input2, result, upper, grad_input_mask)
result: cholesky_solve_jvp(result, input2_p, input2_t, self_t, upper)
Expand Down
Loading