@@ -286,6 +286,153 @@ def test_kron_errors_and_warnings(self, device, dtype):
286286 with self .assertRaisesRegex (RuntimeError , "result dtype Int does not match self dtype" ):
287287 torch .kron (a , b , out = out )
288288
289+ @skipCUDAIfNoMagma
290+ @skipCPUIfNoLapack
291+ @dtypes (torch .float32 , torch .float64 , torch .complex64 , torch .complex128 )
292+ def test_lu_solve_batched_non_contiguous (self , device , dtype ):
293+ from numpy .linalg import solve
294+ from torch .testing ._internal .common_utils import random_fullrank_matrix_distinct_singular_value
295+
296+ A = random_fullrank_matrix_distinct_singular_value (2 , 2 , dtype = dtype , device = 'cpu' )
297+ b = torch .randn (2 , 2 , 2 , dtype = dtype , device = 'cpu' )
298+ x_exp = torch .as_tensor (solve (A .permute (0 , 2 , 1 ).numpy (), b .permute (2 , 1 , 0 ).numpy ())).to (device )
299+ A = A .to (device ).permute (0 , 2 , 1 )
300+ b = b .to (device ).permute (2 , 1 , 0 )
301+ assert not A .is_contiguous () and not b .is_contiguous (), "contiguous inputs"
302+ LU_data , LU_pivots = torch .lu (A )
303+ x = torch .lu_solve (b , LU_data , LU_pivots )
304+ self .assertEqual (x , x_exp )
305+
306+ def lu_solve_test_helper (self , A_dims , b_dims , pivot , device , dtype ):
307+ from torch .testing ._internal .common_utils import random_fullrank_matrix_distinct_singular_value
308+
309+ b = torch .randn (* b_dims , dtype = dtype , device = device )
310+ A = random_fullrank_matrix_distinct_singular_value (* A_dims , dtype = dtype ).to (device )
311+ LU_data , LU_pivots , info = torch .lu (A , get_infos = True , pivot = pivot )
312+ self .assertEqual (info , torch .zeros_like (info ))
313+ return b , A , LU_data , LU_pivots
314+
315+ @skipCPUIfNoLapack
316+ @skipCUDAIfNoMagma
317+ @dtypes (torch .float32 , torch .float64 , torch .complex64 , torch .complex128 )
318+ @precisionOverride ({torch .float32 : 1e-3 , torch .complex64 : 1e-3 ,
319+ torch .float64 : 1e-8 , torch .complex128 : 1e-8 })
320+ def test_lu_solve (self , device , dtype ):
321+ def sub_test (pivot ):
322+ for k , n in zip ([2 , 3 , 5 ], [3 , 5 , 7 ]):
323+ b , A , LU_data , LU_pivots = self .lu_solve_test_helper ((n ,), (n , k ), pivot , device , dtype )
324+ x = torch .lu_solve (b , LU_data , LU_pivots )
325+ # TODO(@ivanyashchuk): remove this once 'norm_cuda' is avaiable for complex dtypes
326+ if not self .device_type == 'cuda' and not dtype .is_complex :
327+ self .assertLessEqual (abs (b .dist (A .mm (x ), p = 1 )), self .precision )
328+ self .assertEqual (b , A .mm (x ))
329+
330+ sub_test (True )
331+ if self .device_type == 'cuda' :
332+ sub_test (False )
333+
334+ @skipCUDAIfNoMagma
335+ @skipCPUIfNoLapack
336+ @dtypes (torch .float32 , torch .float64 , torch .complex64 , torch .complex128 )
337+ @precisionOverride ({torch .float32 : 1e-3 , torch .complex64 : 1e-3 ,
338+ torch .float64 : 1e-8 , torch .complex128 : 1e-8 })
339+ def test_lu_solve_batched (self , device , dtype ):
340+ def sub_test (pivot ):
341+ def lu_solve_batch_test_helper (A_dims , b_dims , pivot ):
342+ b , A , LU_data , LU_pivots = self .lu_solve_test_helper (A_dims , b_dims , pivot , device , dtype )
343+ x_exp_list = []
344+ for i in range (b_dims [0 ]):
345+ x_exp_list .append (torch .lu_solve (b [i ], LU_data [i ], LU_pivots [i ]))
346+ x_exp = torch .stack (x_exp_list ) # Stacked output
347+ x_act = torch .lu_solve (b , LU_data , LU_pivots ) # Actual output
348+ self .assertEqual (x_exp , x_act ) # Equality check
349+ # TODO(@ivanyashchuk): remove this once batched matmul is avaiable on CUDA for complex dtypes
350+ if self .device_type == 'cuda' and dtype .is_complex :
351+ Ax_list = []
352+ for A_i , x_i in zip (A , x_act ):
353+ Ax_list .append (torch .matmul (A_i , x_i ))
354+ Ax = torch .stack (Ax_list )
355+ else :
356+ Ax = torch .matmul (A , x_act )
357+ self .assertLessEqual (abs (b .dist (Ax , p = 1 )), self .precision ) # Correctness check
358+ # In addition to the norm, check the individual entries
359+ # 'norm_cuda' is not implemented for complex dtypes
360+ self .assertEqual (b , Ax )
361+
362+ for batchsize in [1 , 3 , 4 ]:
363+ lu_solve_batch_test_helper ((5 , batchsize ), (batchsize , 5 , 10 ), pivot )
364+
365+ # Tests tensors with 0 elements
366+ b = torch .randn (3 , 0 , 3 , dtype = dtype , device = device )
367+ A = torch .randn (3 , 0 , 0 , dtype = dtype , device = device )
368+ LU_data , LU_pivots = torch .lu (A )
369+ self .assertEqual (torch .empty_like (b ), b .lu_solve (LU_data , LU_pivots ))
370+
371+ sub_test (True )
372+ if self .device_type == 'cuda' :
373+ sub_test (False )
374+
375+ @slowTest
376+ @skipCUDAIfNoMagma
377+ @skipCPUIfNoLapack
378+ @dtypes (torch .float32 , torch .float64 , torch .complex64 , torch .complex128 )
379+ def test_lu_solve_batched_many_batches (self , device , dtype ):
380+ def run_test (A_dims , b_dims ):
381+ b , A , LU_data , LU_pivots = self .lu_solve_test_helper (A_dims , b_dims , True , device , dtype )
382+ x = torch .lu_solve (b , LU_data , LU_pivots )
383+ # TODO(@ivanyashchuk): remove this once batched matmul is avaiable on CUDA for complex dtypes
384+ if self .device_type == 'cuda' and dtype .is_complex :
385+ Ax_list = []
386+ for A_i , x_i in zip (A , x ):
387+ Ax_list .append (torch .matmul (A_i , x_i ))
388+ Ax = torch .stack (Ax_list )
389+ else :
390+ Ax = torch .matmul (A , x )
391+ self .assertEqual (Ax , b .expand_as (Ax ))
392+
393+ run_test ((5 , 65536 ), (65536 , 5 , 10 ))
394+ run_test ((5 , 262144 ), (262144 , 5 , 10 ))
395+
396+ # TODO: once there is more support for complex dtypes on GPU, above tests should be updated
397+ # particularly when RuntimeError: _th_bmm_out not supported on CUDAType for ComplexFloat
398+ # and RuntimeError: "norm_cuda" not implemented for 'ComplexFloat' are fixed
399+ @unittest .expectedFailure
400+ @onlyCUDA
401+ @skipCUDAIfNoMagma
402+ @dtypes (torch .complex64 , torch .complex128 )
403+ def test_lu_solve_batched_complex_xfailed (self , device , dtype ):
404+ A_dims = (3 , 5 )
405+ b_dims = (5 , 3 , 2 )
406+ b , A , LU_data , LU_pivots = self .lu_solve_test_helper (A_dims , b_dims , True , device , dtype )
407+ x = torch .lu_solve (b , LU_data , LU_pivots )
408+ b_ = torch .matmul (A , x )
409+ self .assertEqual (b_ , b .expand_as (b_ ))
410+ self .assertLessEqual (abs (b .dist (torch .matmul (A , x ), p = 1 )), 1e-4 )
411+
412+ @skipCUDAIfNoMagma
413+ @skipCPUIfNoLapack
414+ @dtypes (torch .float32 , torch .float64 , torch .complex64 , torch .complex128 )
415+ def test_lu_solve_batched_broadcasting (self , device , dtype ):
416+ from numpy .linalg import solve
417+ from torch .testing ._internal .common_utils import random_fullrank_matrix_distinct_singular_value
418+
419+ def run_test (A_dims , b_dims , pivot = True ):
420+ A_matrix_size = A_dims [- 1 ]
421+ A_batch_dims = A_dims [:- 2 ]
422+ A = random_fullrank_matrix_distinct_singular_value (A_matrix_size , * A_batch_dims , dtype = dtype )
423+ b = torch .randn (* b_dims , dtype = dtype )
424+ x_exp = torch .as_tensor (solve (A .numpy (), b .numpy ())).to (dtype = dtype , device = device )
425+ A , b = A .to (device ), b .to (device )
426+ LU_data , LU_pivots = torch .lu (A , pivot = pivot )
427+ x = torch .lu_solve (b , LU_data , LU_pivots )
428+ self .assertEqual (x , x_exp )
429+
430+ # test against numpy.linalg.solve
431+ run_test ((2 , 1 , 3 , 4 , 4 ), (2 , 1 , 3 , 4 , 6 )) # no broadcasting
432+ run_test ((2 , 1 , 3 , 4 , 4 ), (4 , 6 )) # broadcasting b
433+ run_test ((4 , 4 ), (2 , 1 , 3 , 4 , 2 )) # broadcasting A
434+ run_test ((1 , 3 , 1 , 4 , 4 ), (2 , 1 , 3 , 4 , 5 )) # broadcasting A & b
435+
289436 # This test confirms that torch.linalg.norm's dtype argument works
290437 # as expected, according to the function's documentation
291438 @skipCUDAIfNoMagma
0 commit comments