Skip to content

Commit ffd0003

Browse files
IvanYashchukfacebook-github-bot
authored andcommitted
Added support for complex input for torch.lu_solve (#46862)
Summary: `torch.lu_solve` now works for complex inputs both on CPU and GPU. I moved the existing tests to `test_linalg.py` and modified them to test complex dtypes, but I didn't modify/improve the body of the tests. Ref. #33152 Pull Request resolved: #46862 Reviewed By: nikithamalgifb Differential Revision: D24543682 Pulled By: anjali411 fbshipit-source-id: 165bde39ef95cafebf976c5ba4b487297efe8433
1 parent 2ed3430 commit ffd0003

File tree

5 files changed

+186
-110
lines changed

5 files changed

+186
-110
lines changed

aten/src/ATen/native/BatchLinearAlgebra.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1119,7 +1119,7 @@ Tensor _lu_solve_helper_cpu(const Tensor& self, const Tensor& LU_data, const Ten
11191119
if (self.numel() == 0 || LU_data.numel() == 0) {
11201120
return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
11211121
}
1122-
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lu_solve_cpu", [&]{
1122+
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "lu_solve_cpu", [&]{
11231123
apply_lu_solve<scalar_t>(self_working_copy, LU_data_working_copy, LU_pivots_working_copy, infos);
11241124
});
11251125
if (self.dim() > 2) {

aten/src/ATen/native/cuda/BatchLinearAlgebra.cu

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -998,6 +998,23 @@ void magmaLuSolve<float>(
998998
AT_CUDA_CHECK(cudaGetLastError());
999999
}
10001000

1001+
template<>
1002+
void magmaLuSolve<c10::complex<double>>(
1003+
magma_int_t n, magma_int_t nrhs, c10::complex<double>* dA, magma_int_t ldda, magma_int_t* ipiv,
1004+
c10::complex<double>* dB, magma_int_t lddb, magma_int_t* info) {
1005+
MagmaStreamSyncGuard guard;
1006+
magma_zgetrs_gpu(MagmaNoTrans, n, nrhs, reinterpret_cast<magmaDoubleComplex*>(dA), ldda, ipiv, reinterpret_cast<magmaDoubleComplex*>(dB), lddb, info);
1007+
AT_CUDA_CHECK(cudaGetLastError());
1008+
}
1009+
1010+
template<>
1011+
void magmaLuSolve<c10::complex<float>>(
1012+
magma_int_t n, magma_int_t nrhs, c10::complex<float>* dA, magma_int_t ldda, magma_int_t* ipiv,
1013+
c10::complex<float>* dB, magma_int_t lddb, magma_int_t* info) {
1014+
MagmaStreamSyncGuard guard;
1015+
magma_cgetrs_gpu(MagmaNoTrans, n, nrhs, reinterpret_cast<magmaFloatComplex*>(dA), ldda, ipiv, reinterpret_cast<magmaFloatComplex*>(dB), lddb, info);
1016+
AT_CUDA_CHECK(cudaGetLastError());
1017+
}
10011018

10021019
template<>
10031020
void magmaLuSolveBatched<double>(
@@ -1016,6 +1033,24 @@ void magmaLuSolveBatched<float>(
10161033
info = magma_sgetrs_batched(MagmaNoTrans, n, nrhs, dA_array, ldda, dipiv_array, dB_array, lddb, batchsize, magma_queue.get_queue());
10171034
AT_CUDA_CHECK(cudaGetLastError());
10181035
}
1036+
1037+
template<>
1038+
void magmaLuSolveBatched<c10::complex<double>>(
1039+
magma_int_t n, magma_int_t nrhs, c10::complex<double>** dA_array, magma_int_t ldda, magma_int_t** dipiv_array,
1040+
c10::complex<double>** dB_array, magma_int_t lddb, magma_int_t& info,
1041+
magma_int_t batchsize, const MAGMAQueue& magma_queue) {
1042+
info = magma_zgetrs_batched(MagmaNoTrans, n, nrhs, reinterpret_cast<magmaDoubleComplex**>(dA_array), ldda, dipiv_array, reinterpret_cast<magmaDoubleComplex**>(dB_array), lddb, batchsize, magma_queue.get_queue());
1043+
AT_CUDA_CHECK(cudaGetLastError());
1044+
}
1045+
1046+
template<>
1047+
void magmaLuSolveBatched<c10::complex<float>>(
1048+
magma_int_t n, magma_int_t nrhs, c10::complex<float>** dA_array, magma_int_t ldda, magma_int_t** dipiv_array,
1049+
c10::complex<float>** dB_array, magma_int_t lddb, magma_int_t& info,
1050+
magma_int_t batchsize, const MAGMAQueue& magma_queue) {
1051+
info = magma_cgetrs_batched(MagmaNoTrans, n, nrhs, reinterpret_cast<magmaFloatComplex**>(dA_array), ldda, dipiv_array, reinterpret_cast<magmaFloatComplex**>(dB_array), lddb, batchsize, magma_queue.get_queue());
1052+
AT_CUDA_CHECK(cudaGetLastError());
1053+
}
10191054
#endif
10201055

10211056
#define ALLOCATE_ARRAY(name, type, size) \
@@ -1986,7 +2021,7 @@ Tensor _lu_solve_helper_cuda(const Tensor& self, const Tensor& LU_data, const Te
19862021
if (self.numel() == 0 || LU_data.numel() == 0) {
19872022
return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
19882023
}
1989-
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lu_solve_cuda", [&]{
2024+
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "lu_solve_cuda", [&]{
19902025
apply_lu_solve<scalar_t>(self_working_copy, LU_data_working_copy, LU_pivots_working_copy, info);
19912026
});
19922027
TORCH_CHECK(info == 0, "MAGMA lu_solve : invalid argument: ", -info);

test/test_linalg.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,153 @@ def test_kron_errors_and_warnings(self, device, dtype):
286286
with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match self dtype"):
287287
torch.kron(a, b, out=out)
288288

289+
@skipCUDAIfNoMagma
290+
@skipCPUIfNoLapack
291+
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
292+
def test_lu_solve_batched_non_contiguous(self, device, dtype):
293+
from numpy.linalg import solve
294+
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
295+
296+
A = random_fullrank_matrix_distinct_singular_value(2, 2, dtype=dtype, device='cpu')
297+
b = torch.randn(2, 2, 2, dtype=dtype, device='cpu')
298+
x_exp = torch.as_tensor(solve(A.permute(0, 2, 1).numpy(), b.permute(2, 1, 0).numpy())).to(device)
299+
A = A.to(device).permute(0, 2, 1)
300+
b = b.to(device).permute(2, 1, 0)
301+
assert not A.is_contiguous() and not b.is_contiguous(), "contiguous inputs"
302+
LU_data, LU_pivots = torch.lu(A)
303+
x = torch.lu_solve(b, LU_data, LU_pivots)
304+
self.assertEqual(x, x_exp)
305+
306+
def lu_solve_test_helper(self, A_dims, b_dims, pivot, device, dtype):
307+
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
308+
309+
b = torch.randn(*b_dims, dtype=dtype, device=device)
310+
A = random_fullrank_matrix_distinct_singular_value(*A_dims, dtype=dtype).to(device)
311+
LU_data, LU_pivots, info = torch.lu(A, get_infos=True, pivot=pivot)
312+
self.assertEqual(info, torch.zeros_like(info))
313+
return b, A, LU_data, LU_pivots
314+
315+
@skipCPUIfNoLapack
316+
@skipCUDAIfNoMagma
317+
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
318+
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
319+
torch.float64: 1e-8, torch.complex128: 1e-8})
320+
def test_lu_solve(self, device, dtype):
321+
def sub_test(pivot):
322+
for k, n in zip([2, 3, 5], [3, 5, 7]):
323+
b, A, LU_data, LU_pivots = self.lu_solve_test_helper((n,), (n, k), pivot, device, dtype)
324+
x = torch.lu_solve(b, LU_data, LU_pivots)
325+
# TODO(@ivanyashchuk): remove this once 'norm_cuda' is avaiable for complex dtypes
326+
if not self.device_type == 'cuda' and not dtype.is_complex:
327+
self.assertLessEqual(abs(b.dist(A.mm(x), p=1)), self.precision)
328+
self.assertEqual(b, A.mm(x))
329+
330+
sub_test(True)
331+
if self.device_type == 'cuda':
332+
sub_test(False)
333+
334+
@skipCUDAIfNoMagma
335+
@skipCPUIfNoLapack
336+
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
337+
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
338+
torch.float64: 1e-8, torch.complex128: 1e-8})
339+
def test_lu_solve_batched(self, device, dtype):
340+
def sub_test(pivot):
341+
def lu_solve_batch_test_helper(A_dims, b_dims, pivot):
342+
b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, pivot, device, dtype)
343+
x_exp_list = []
344+
for i in range(b_dims[0]):
345+
x_exp_list.append(torch.lu_solve(b[i], LU_data[i], LU_pivots[i]))
346+
x_exp = torch.stack(x_exp_list) # Stacked output
347+
x_act = torch.lu_solve(b, LU_data, LU_pivots) # Actual output
348+
self.assertEqual(x_exp, x_act) # Equality check
349+
# TODO(@ivanyashchuk): remove this once batched matmul is avaiable on CUDA for complex dtypes
350+
if self.device_type == 'cuda' and dtype.is_complex:
351+
Ax_list = []
352+
for A_i, x_i in zip(A, x_act):
353+
Ax_list.append(torch.matmul(A_i, x_i))
354+
Ax = torch.stack(Ax_list)
355+
else:
356+
Ax = torch.matmul(A, x_act)
357+
self.assertLessEqual(abs(b.dist(Ax, p=1)), self.precision) # Correctness check
358+
# In addition to the norm, check the individual entries
359+
# 'norm_cuda' is not implemented for complex dtypes
360+
self.assertEqual(b, Ax)
361+
362+
for batchsize in [1, 3, 4]:
363+
lu_solve_batch_test_helper((5, batchsize), (batchsize, 5, 10), pivot)
364+
365+
# Tests tensors with 0 elements
366+
b = torch.randn(3, 0, 3, dtype=dtype, device=device)
367+
A = torch.randn(3, 0, 0, dtype=dtype, device=device)
368+
LU_data, LU_pivots = torch.lu(A)
369+
self.assertEqual(torch.empty_like(b), b.lu_solve(LU_data, LU_pivots))
370+
371+
sub_test(True)
372+
if self.device_type == 'cuda':
373+
sub_test(False)
374+
375+
@slowTest
376+
@skipCUDAIfNoMagma
377+
@skipCPUIfNoLapack
378+
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
379+
def test_lu_solve_batched_many_batches(self, device, dtype):
380+
def run_test(A_dims, b_dims):
381+
b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, True, device, dtype)
382+
x = torch.lu_solve(b, LU_data, LU_pivots)
383+
# TODO(@ivanyashchuk): remove this once batched matmul is avaiable on CUDA for complex dtypes
384+
if self.device_type == 'cuda' and dtype.is_complex:
385+
Ax_list = []
386+
for A_i, x_i in zip(A, x):
387+
Ax_list.append(torch.matmul(A_i, x_i))
388+
Ax = torch.stack(Ax_list)
389+
else:
390+
Ax = torch.matmul(A, x)
391+
self.assertEqual(Ax, b.expand_as(Ax))
392+
393+
run_test((5, 65536), (65536, 5, 10))
394+
run_test((5, 262144), (262144, 5, 10))
395+
396+
# TODO: once there is more support for complex dtypes on GPU, above tests should be updated
397+
# particularly when RuntimeError: _th_bmm_out not supported on CUDAType for ComplexFloat
398+
# and RuntimeError: "norm_cuda" not implemented for 'ComplexFloat' are fixed
399+
@unittest.expectedFailure
400+
@onlyCUDA
401+
@skipCUDAIfNoMagma
402+
@dtypes(torch.complex64, torch.complex128)
403+
def test_lu_solve_batched_complex_xfailed(self, device, dtype):
404+
A_dims = (3, 5)
405+
b_dims = (5, 3, 2)
406+
b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, True, device, dtype)
407+
x = torch.lu_solve(b, LU_data, LU_pivots)
408+
b_ = torch.matmul(A, x)
409+
self.assertEqual(b_, b.expand_as(b_))
410+
self.assertLessEqual(abs(b.dist(torch.matmul(A, x), p=1)), 1e-4)
411+
412+
@skipCUDAIfNoMagma
413+
@skipCPUIfNoLapack
414+
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
415+
def test_lu_solve_batched_broadcasting(self, device, dtype):
416+
from numpy.linalg import solve
417+
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
418+
419+
def run_test(A_dims, b_dims, pivot=True):
420+
A_matrix_size = A_dims[-1]
421+
A_batch_dims = A_dims[:-2]
422+
A = random_fullrank_matrix_distinct_singular_value(A_matrix_size, *A_batch_dims, dtype=dtype)
423+
b = torch.randn(*b_dims, dtype=dtype)
424+
x_exp = torch.as_tensor(solve(A.numpy(), b.numpy())).to(dtype=dtype, device=device)
425+
A, b = A.to(device), b.to(device)
426+
LU_data, LU_pivots = torch.lu(A, pivot=pivot)
427+
x = torch.lu_solve(b, LU_data, LU_pivots)
428+
self.assertEqual(x, x_exp)
429+
430+
# test against numpy.linalg.solve
431+
run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6)) # no broadcasting
432+
run_test((2, 1, 3, 4, 4), (4, 6)) # broadcasting b
433+
run_test((4, 4), (2, 1, 3, 4, 2)) # broadcasting A
434+
run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5)) # broadcasting A & b
435+
289436
# This test confirms that torch.linalg.norm's dtype argument works
290437
# as expected, according to the function's documentation
291438
@skipCUDAIfNoMagma

test/test_torch.py

Lines changed: 0 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -9145,114 +9145,6 @@ def test_kthvalue(self, device, dtype):
91459145
x = torch.tensor([2], device=device, dtype=dtype)
91469146
self.assertEqual(x.squeeze().kthvalue(1), x.kthvalue(1))
91479147

9148-
@skipCUDAIfNoMagma
9149-
@skipCPUIfNoLapack
9150-
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
9151-
@dtypes(torch.double)
9152-
def test_lu_solve_batched_non_contiguous(self, device, dtype):
9153-
from numpy.linalg import solve
9154-
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
9155-
9156-
A = random_fullrank_matrix_distinct_singular_value(2, 2, dtype=dtype, device='cpu')
9157-
b = torch.randn(2, 2, 2, dtype=dtype, device='cpu')
9158-
x_exp = torch.as_tensor(solve(A.permute(0, 2, 1).numpy(), b.permute(2, 1, 0).numpy())).to(device)
9159-
A = A.to(device).permute(0, 2, 1)
9160-
b = b.to(device).permute(2, 1, 0)
9161-
assert not A.is_contiguous() and not b.is_contiguous(), "contiguous inputs"
9162-
LU_data, LU_pivots = torch.lu(A)
9163-
x = torch.lu_solve(b, LU_data, LU_pivots)
9164-
self.assertEqual(x, x_exp)
9165-
9166-
def lu_solve_test_helper(self, A_dims, b_dims, pivot, device, dtype):
9167-
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
9168-
9169-
b = torch.randn(*b_dims, dtype=dtype, device=device)
9170-
A = random_fullrank_matrix_distinct_singular_value(*A_dims, dtype=dtype, device=device)
9171-
LU_data, LU_pivots, info = torch.lu(A, get_infos=True, pivot=pivot)
9172-
self.assertEqual(info, torch.zeros_like(info))
9173-
return b, A, LU_data, LU_pivots
9174-
9175-
@skipCPUIfNoLapack
9176-
@skipCUDAIfNoMagma
9177-
@dtypes(torch.double)
9178-
def test_lu_solve(self, device, dtype):
9179-
def sub_test(pivot):
9180-
for k, n in zip([2, 3, 5], [3, 5, 7]):
9181-
b, A, LU_data, LU_pivots = self.lu_solve_test_helper((n,), (n, k), pivot, device, dtype)
9182-
x = torch.lu_solve(b, LU_data, LU_pivots)
9183-
self.assertLessEqual(b.dist(A.mm(x)), 1e-12)
9184-
9185-
sub_test(True)
9186-
if self.device_type == 'cuda':
9187-
sub_test(False)
9188-
9189-
@skipCUDAIfNoMagma
9190-
@skipCPUIfNoLapack
9191-
@dtypes(torch.double)
9192-
def test_lu_solve_batched(self, device, dtype):
9193-
def sub_test(pivot):
9194-
def lu_solve_batch_test_helper(A_dims, b_dims, pivot):
9195-
b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, pivot, device, dtype)
9196-
x_exp_list = []
9197-
for i in range(b_dims[0]):
9198-
x_exp_list.append(torch.lu_solve(b[i], LU_data[i], LU_pivots[i]))
9199-
x_exp = torch.stack(x_exp_list) # Stacked output
9200-
x_act = torch.lu_solve(b, LU_data, LU_pivots) # Actual output
9201-
self.assertEqual(x_exp, x_act) # Equality check
9202-
self.assertLessEqual(b.dist(torch.matmul(A, x_act)), 1e-12) # Correctness check
9203-
9204-
for batchsize in [1, 3, 4]:
9205-
lu_solve_batch_test_helper((5, batchsize), (batchsize, 5, 10), pivot)
9206-
9207-
# Tests tensors with 0 elements
9208-
b = torch.randn(3, 0, 3, dtype=dtype, device=device)
9209-
A = torch.randn(3, 0, 0, dtype=dtype, device=device)
9210-
LU_data, LU_pivots = torch.lu(A)
9211-
self.assertEqual(torch.empty_like(b), b.lu_solve(LU_data, LU_pivots))
9212-
9213-
sub_test(True)
9214-
if self.device_type == 'cuda':
9215-
sub_test(False)
9216-
9217-
@slowTest
9218-
@skipCUDAIfNoMagma
9219-
@skipCPUIfNoLapack
9220-
@dtypes(torch.double)
9221-
def test_lu_solve_batched_many_batches(self, device, dtype):
9222-
def run_test(A_dims, b_dims):
9223-
b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, True, device, dtype)
9224-
x = torch.lu_solve(b, LU_data, LU_pivots)
9225-
b_ = torch.matmul(A, x)
9226-
self.assertEqual(b_, b.expand_as(b_))
9227-
9228-
run_test((5, 65536), (65536, 5, 10))
9229-
run_test((5, 262144), (262144, 5, 10))
9230-
9231-
@skipCUDAIfNoMagma
9232-
@skipCPUIfNoLapack
9233-
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
9234-
@dtypes(torch.double)
9235-
def test_lu_solve_batched_broadcasting(self, device, dtype):
9236-
from numpy.linalg import solve
9237-
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
9238-
9239-
def run_test(A_dims, b_dims, pivot=True):
9240-
A_matrix_size = A_dims[-1]
9241-
A_batch_dims = A_dims[:-2]
9242-
A = random_fullrank_matrix_distinct_singular_value(A_matrix_size, *A_batch_dims, dtype=dtype)
9243-
b = torch.randn(*b_dims, dtype=dtype)
9244-
x_exp = torch.as_tensor(solve(A.numpy(), b.numpy())).to(dtype=dtype, device=device)
9245-
A, b = A.to(device), b.to(device)
9246-
LU_data, LU_pivots = torch.lu(A, pivot=pivot)
9247-
x = torch.lu_solve(b, LU_data, LU_pivots)
9248-
self.assertEqual(x, x_exp)
9249-
9250-
# test against numpy.linalg.solve
9251-
run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6)) # no broadcasting
9252-
run_test((2, 1, 3, 4, 4), (4, 6)) # broadcasting b
9253-
run_test((4, 4), (2, 1, 3, 4, 2)) # broadcasting A
9254-
run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5)) # broadcasting A & b
9255-
92569148
# Assert for illegal dtype would not be raised on XLA
92579149
@onlyOnCPUAndCUDA
92589150
def test_minmax_illegal_dtype(self, device):

torch/_torch_docs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4478,6 +4478,8 @@ def merge_dicts(*dicts):
44784478
Returns the LU solve of the linear system :math:`Ax = b` using the partially pivoted
44794479
LU factorization of A from :meth:`torch.lu`.
44804480
4481+
Supports real-valued and complex-valued inputs.
4482+
44814483
Arguments:
44824484
b (Tensor): the RHS tensor of size :math:`(*, m, k)`, where :math:`*`
44834485
is zero or more batch dimensions.

0 commit comments

Comments
 (0)