Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/native/mps/operations/LinearAlgebra.mm
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A,
"linalg.lu_factor(): MPS doesn't support complex types.");
TORCH_CHECK(pivot, "linalg.lu_factor(): MPS doesn't allow pivot == False.");

Tensor A_t = A;
Tensor A_t = A.contiguous();
uint64_t aRows = A_t.size(-2);
uint64_t aCols = A_t.size(-1);
uint64_t aElemSize = A_t.element_size();
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13986,7 +13986,7 @@
- func: _linalg_det.result(Tensor A, *, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots)
structured: True
dispatch:
CPU, CUDA: _linalg_det_out
CPU, CUDA, MPS: _linalg_det_out

- func: linalg_det(Tensor A) -> Tensor
python_module: linalg
Expand Down
39 changes: 38 additions & 1 deletion test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def mps_ops_grad_modifier(ops):
'lu': [torch.float16, torch.float32], # missing `aten::lu_unpack`.
'linalg.lu_factor': [torch.float16, torch.float32], # missing `aten::lu_unpack`.
'linalg.lu_factor_ex': [torch.float16, torch.float32], # missing `aten::lu_unpack`.
'linalg.det': [torch.float16, torch.float32], # missing aten::lu_solve.out
'aminmax': [torch.float32, torch.float16],
'special.i1': [torch.float16], # "i1_backward" not implemented for 'Half'

Expand Down Expand Up @@ -696,7 +697,6 @@ def mps_ops_modifier(ops):
'lcm': None,
'linalg.cholesky_ex': None,
'linalg.cond': None,
'linalg.det': None,
'linalg.eigh': None,
'linalg.eigvalsh': None,
'linalg.householder_product': None,
Expand Down Expand Up @@ -2728,6 +2728,10 @@ def run_lu_factor_ex_test(size, *batch_dims, check_errors):
out_mps = torch.linalg.lu_factor_ex(input_mps, check_errors=check_errors)
self.assertEqual(out_cpu, out_mps)

out_cpu = torch.linalg.lu_factor_ex(input_cpu.mT, check_errors=check_errors)
out_mps = torch.linalg.lu_factor_ex(input_mps.mT, check_errors=check_errors)
self.assertEqual(out_cpu, out_mps)

# test with different even/odd matrix sizes
matrix_sizes = [1, 2, 3, 4]
# even/odd batch sizes
Expand All @@ -2741,6 +2745,39 @@ def run_lu_factor_ex_test(size, *batch_dims, check_errors):
run_lu_factor_ex_test(32, 10, 10, check_errors=False)
run_lu_factor_ex_test(32, 2, 2, 2, 2, 10, 10, check_errors=True)

def test_linalg_det(self):
from torch.testing._internal.common_utils import make_fullrank_matrices_with_distinct_singular_values

make_fullrank = make_fullrank_matrices_with_distinct_singular_values
make_arg = partial(make_fullrank, device="cpu", dtype=torch.float32)

def run_det_test(size, *batch_dims):
input_cpu = make_arg(*batch_dims, size, size)
input_mps = input_cpu.to('mps')
out_cpu = torch.linalg.det(input_cpu)
out_mps = torch.linalg.det(input_mps)
self.assertEqual(out_cpu, out_mps)

# non-contiguous matrices
input_cpu_T = input_cpu.mT
input_mps_T = input_mps.mT
out_cpu_T = torch.linalg.det(input_cpu_T)
out_mps_T = torch.linalg.det(input_mps_T)
self.assertEqual(out_cpu_T, out_mps_T)

# test with different even/odd matrix sizes
matrix_sizes = [2, 3, 4]
# even/odd batch sizes
batch_sizes = [1, 2, 4]

for size in matrix_sizes:
for batch_size in batch_sizes:
run_det_test(size, batch_size)

# test >3D matrices
run_det_test(32, 10, 10)
run_det_test(32, 2, 2, 10, 10)

def test_layer_norm(self):
# TODO: Test non-contiguous
def helper(input_shape, normalized_shape, eps=1e-05, elementwise_affine=True, dtype=torch.float32):
Expand Down
Loading