@@ -1471,10 +1471,18 @@ Tensor _inverse_helper_cuda_legacy(const Tensor& self) {
14711471
14721472Tensor _inverse_helper_cuda (const Tensor& self) {
14731473#ifdef USE_CUSOLVER
1474- if ((self.dim () == 2 ) || (/* self.dim() > 2 && */ batchCount (self) <= 2 ) || !use_magma_) {
1475- return _inverse_helper_cuda_lib (self); // cusolver or cublas
1476- } else {
1477- return _inverse_helper_cuda_legacy (self); // magma-cuda
1474+ auto preferred_backend = at::globalContext ().linalgPreferredBackend ();
1475+ switch (preferred_backend) {
1476+ case at::LinalgBackend::Cusolver:
1477+ return _inverse_helper_cuda_lib (self); // cusolver or cublas
1478+ case at::LinalgBackend::Magma:
1479+ return _inverse_helper_cuda_legacy (self); // magma-cuda
1480+ default :
1481+ if (batchCount (self) <= 2 || !use_magma_) {
1482+ return _inverse_helper_cuda_lib (self); // cusolver or cublas
1483+ } else {
1484+ return _inverse_helper_cuda_legacy (self); // magma-cuda
1485+ }
14781486 }
14791487#else
14801488 return _inverse_helper_cuda_legacy (self); // magma-cuda
@@ -1503,10 +1511,18 @@ Tensor& _linalg_inv_out_helper_cuda(Tensor &result, Tensor& infos_lu, Tensor& in
15031511 // This function calculates the inverse matrix in-place
15041512 // result should be in column major order and contain matrices to invert
15051513#ifdef USE_CUSOLVER
1506- if ((result.dim () == 2 ) || (/* result.dim() > 2 && */ batchCount (result) <= 2 ) || !use_magma_) {
1507- return _linalg_inv_out_helper_cuda_lib (result, infos_lu, infos_getri); // cusolver or cublas
1508- } else {
1509- return _linalg_inv_out_helper_cuda_legacy (result, infos_lu, infos_getri); // magma-cuda
1514+ auto preferred_backend = at::globalContext ().linalgPreferredBackend ();
1515+ switch (preferred_backend) {
1516+ case at::LinalgBackend::Cusolver:
1517+ return _linalg_inv_out_helper_cuda_lib (result, infos_lu, infos_getri); // cusolver or cublas
1518+ case at::LinalgBackend::Magma:
1519+ return _linalg_inv_out_helper_cuda_legacy (result, infos_lu, infos_getri); // magma-cuda
1520+ default :
1521+ if (batchCount (result) <= 2 || !use_magma_) {
1522+ return _linalg_inv_out_helper_cuda_lib (result, infos_lu, infos_getri); // cusolver or cublas
1523+ } else {
1524+ return _linalg_inv_out_helper_cuda_legacy (result, infos_lu, infos_getri); // magma-cuda
1525+ }
15101526 }
15111527#else
15121528 return _linalg_inv_out_helper_cuda_legacy (result, infos_lu, infos_getri); // magma-cuda
@@ -1600,10 +1616,18 @@ Tensor _cholesky_solve_helper_cuda_magma(const Tensor& self, const Tensor& A, bo
16001616// Batched cholesky_solve is dispatched to magma.
16011617Tensor _cholesky_solve_helper_cuda (const Tensor& self, const Tensor& A, bool upper) {
16021618#ifdef USE_CUSOLVER
1603- if (batchCount (self) == 1 || !use_magma_) {
1604- return _cholesky_solve_helper_cuda_cusolver (self, A, upper);
1605- } else {
1606- return _cholesky_solve_helper_cuda_magma (self, A, upper);
1619+ auto preferred_backend = at::globalContext ().linalgPreferredBackend ();
1620+ switch (preferred_backend) {
1621+ case at::LinalgBackend::Cusolver:
1622+ return _cholesky_solve_helper_cuda_cusolver (self, A, upper);
1623+ case at::LinalgBackend::Magma:
1624+ return _cholesky_solve_helper_cuda_magma (self, A, upper);
1625+ default :
1626+ if (batchCount (self) == 1 || !use_magma_) {
1627+ return _cholesky_solve_helper_cuda_cusolver (self, A, upper);
1628+ } else {
1629+ return _cholesky_solve_helper_cuda_magma (self, A, upper);
1630+ }
16071631 }
16081632#else
16091633 return _cholesky_solve_helper_cuda_magma (self, A, upper);
@@ -1706,10 +1730,20 @@ void cholesky_helper_magma(const Tensor& input, bool upper, const Tensor& info)
17061730
17071731static void cholesky_kernel (const Tensor& input, const Tensor& info, bool upper) {
17081732#ifdef USE_CUSOLVER
1709- if (batchCount (input) == 1 || !use_magma_ || use_cusolver_potrf_batched_) {
1710- cholesky_helper_cusolver (input, upper, info);
1711- } else {
1712- cholesky_helper_magma (input, upper, info);
1733+ auto preferred_backend = at::globalContext ().linalgPreferredBackend ();
1734+ switch (preferred_backend) {
1735+ case at::LinalgBackend::Cusolver:
1736+ cholesky_helper_cusolver (input, upper, info);
1737+ break ;
1738+ case at::LinalgBackend::Magma:
1739+ cholesky_helper_magma (input, upper, info);
1740+ break ;
1741+ default :
1742+ if (batchCount (input) == 1 || !use_magma_ || use_cusolver_potrf_batched_) {
1743+ cholesky_helper_cusolver (input, upper, info);
1744+ } else {
1745+ cholesky_helper_magma (input, upper, info);
1746+ }
17131747 }
17141748#else
17151749 cholesky_helper_magma (input, upper, info);
@@ -1777,10 +1811,19 @@ Tensor& cholesky_inverse_kernel_impl(Tensor &result, Tensor& infos, bool upper)
17771811 // result should be in column major order and contain matrices to invert
17781812 // the content of result is overwritten by 'apply_cholesky_inverse'
17791813#ifdef USE_CUSOLVER
1780- if (batchCount (result) == 1 || !use_magma_) {
1781- return cholesky_inverse_kernel_impl_cusolver (result, infos, upper);
1782- } else {
1783- return cholesky_inverse_kernel_impl_magma (result, infos, upper);
1814+ auto preferred_backend = at::globalContext ().linalgPreferredBackend ();
1815+ switch (preferred_backend) {
1816+ case at::LinalgBackend::Cusolver:
1817+ return cholesky_inverse_kernel_impl_cusolver (result, infos, upper);
1818+ case at::LinalgBackend::Magma:
1819+ return cholesky_inverse_kernel_impl_magma (result, infos, upper);
1820+ default :
1821+ if (batchCount (result) == 1 ||
1822+ !use_magma_) {
1823+ return cholesky_inverse_kernel_impl_cusolver (result, infos, upper);
1824+ } else {
1825+ return cholesky_inverse_kernel_impl_magma (result, infos, upper);
1826+ }
17841827 }
17851828#else
17861829 return cholesky_inverse_kernel_impl_magma (result, infos, upper);
@@ -1944,20 +1987,39 @@ static void lu_batched_magma(const Tensor& input, const Tensor& pivots, const Te
19441987static void apply_lu (const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
19451988 int64_t batch_size = batchCount (input);
19461989#ifdef USE_CUSOLVER
1947- // Use a heuristic to determine that cusolver is faster than MAGMA for the following sizes.
1948- auto m = input.size (-2 );
1949- // exclude complex128 since nan_to_num_ does not work with it.
1950- if ((batch_size == 1 || (batch_size <= 8 && m <= 16 ) || !use_magma_ ) && !input.is_complex ()) {
1951- lu_looped_cusolver (input, pivots, infos, compute_pivots);
1990+ auto preferred_backend = at::globalContext ().linalgPreferredBackend ();
1991+ switch (preferred_backend) {
1992+ case at::LinalgBackend::Cusolver:
1993+ lu_looped_cusolver (input, pivots, infos, compute_pivots);
1994+ break ;
1995+ case at::LinalgBackend::Magma:
1996+ if (batch_size == 1 ) {
1997+ lu_looped_magma (input, pivots, infos, compute_pivots);
1998+ } else {
1999+ lu_batched_magma (input, pivots, infos, compute_pivots);
2000+ }
2001+ break ;
2002+ default :
2003+ // Use a heuristic to determine that cusolver is faster than MAGMA for the following sizes.
2004+ auto m = input.size (-2 );
2005+ // exclude complex128 since nan_to_num_ does not work with it.
2006+ if ((batch_size == 1 ||
2007+ (batch_size <= 8 && m <= 16 ) ||
2008+ !use_magma_)
2009+ && !input.is_complex ()) {
2010+ lu_looped_cusolver (input, pivots, infos, compute_pivots);
2011+ } else {
2012+ lu_batched_magma (input, pivots, infos, compute_pivots);
2013+ }
19522014 }
19532015#else
19542016 if (batch_size == 1 ) {
19552017 lu_looped_magma (input, pivots, infos, compute_pivots);
19562018 }
1957- #endif // USE_CUSOLVER
19582019 else {
19592020 lu_batched_magma (input, pivots, infos, compute_pivots);
19602021 }
2022+ #endif // USE_CUSOLVER
19612023}
19622024
19632025REGISTER_CUDA_DISPATCH (lu_stub, &apply_lu);
@@ -2064,12 +2126,12 @@ Tensor& orgqr_kernel_impl(Tensor& result, const Tensor& tau) {
20642126 // See discussions in https://github.com/pytorch/pytorch/pull/51348 for comparison of cuSOLVER-MAGMA
20652127 // and Windows failure.
20662128 // For reference here is the MAGMA-based implementation: https://gist.github.com/IvanYashchuk/2db50002c9d3c1462ff769e6410ad983
2067- #if defined(USE_CUSOLVER)
2068- return orgqr_helper_cusolver (result, tau); // cusolver
2069- #else
2070- TORCH_CHECK (false , " Calling torch.orgqr on a CUDA tensor requires compiling " ,
2071- " PyTorch with cuSOLVER. Please use PyTorch built with cuSOLVER support." );
2072- #endif
2129+ #if defined(USE_CUSOLVER)
2130+ return orgqr_helper_cusolver (result, tau); // cusolver
2131+ #else
2132+ TORCH_CHECK (false , " Calling torch.orgqr on a CUDA tensor requires compiling " ,
2133+ " PyTorch with cuSOLVER. Please use PyTorch built with cuSOLVER support." );
2134+ #endif
20732135}
20742136
20752137REGISTER_CUDA_DISPATCH (orgqr_stub, &orgqr_kernel_impl);
@@ -2136,7 +2198,14 @@ void geqrf_magma(const Tensor& input, const Tensor& tau) {
21362198// This is a backend library dispatching helper function for calling looped batch implementation
21372199void geqrf_looped (const Tensor& input, const Tensor& tau) {
21382200#if defined(USE_CUSOLVER)
2139- return geqrf_cusolver (input, tau);
2201+ auto preferred_backend = at::globalContext ().linalgPreferredBackend ();
2202+ switch (preferred_backend) {
2203+ case at::LinalgBackend::Magma:
2204+ return geqrf_magma (input, tau);
2205+ case at::LinalgBackend::Cusolver:
2206+ default :
2207+ return geqrf_cusolver (input, tau);
2208+ }
21402209#else
21412210 return geqrf_magma (input, tau);
21422211#endif
@@ -2273,9 +2342,16 @@ std::tuple<Tensor, Tensor> linalg_qr_helper_magma(const Tensor& self, c10::strin
22732342
22742343std::tuple<Tensor, Tensor> _linalg_qr_helper_cuda (const Tensor& input, c10::string_view mode) {
22752344#if defined(USE_CUSOLVER)
2276- // _linalg_qr_helper_default is a generic function that is implemented using
2277- // geqrf_stub and orgqr_stub. It dispatches to cuSOLVER for CUDA inputs if USE_CUSOLVER is defined
2278- return _linalg_qr_helper_default (input, mode);
2345+ auto preferred_backend = at::globalContext ().linalgPreferredBackend ();
2346+ switch (preferred_backend) {
2347+ case at::LinalgBackend::Magma:
2348+ return linalg_qr_helper_magma (input, mode);
2349+ case at::LinalgBackend::Cusolver:
2350+ default :
2351+ // _linalg_qr_helper_default is a generic function that is implemented using
2352+ // geqrf_stub and orgqr_stub. It dispatches to cuSOLVER for CUDA inputs if USE_CUSOLVER is defined
2353+ return _linalg_qr_helper_default (input, mode);
2354+ }
22792355#else
22802356 return linalg_qr_helper_magma (input, mode);
22812357#endif
@@ -2432,7 +2508,15 @@ void linalg_eigh_magma(const Tensor& eigenvalues, const Tensor& eigenvectors, co
24322508
24332509void linalg_eigh_kernel (const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
24342510#if defined(USE_CUSOLVER)
2435- linalg_eigh_cusolver (eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
2511+ auto preferred_backend = at::globalContext ().linalgPreferredBackend ();
2512+ switch (preferred_backend) {
2513+ case at::LinalgBackend::Magma:
2514+ linalg_eigh_magma (eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
2515+ break ;
2516+ case at::LinalgBackend::Cusolver:
2517+ default :
2518+ linalg_eigh_cusolver (eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
2519+ }
24362520#else
24372521 linalg_eigh_magma (eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
24382522#endif
@@ -2731,7 +2815,14 @@ std::tuple<Tensor, Tensor, Tensor> _svd_helper_cuda_legacy(const Tensor& self, b
27312815
27322816std::tuple<Tensor, Tensor, Tensor> _svd_helper_cuda (const Tensor& self, bool some, bool compute_uv) {
27332817#ifdef USE_CUSOLVER
2734- return _svd_helper_cuda_lib (self, some, compute_uv);
2818+ auto preferred_backend = at::globalContext ().linalgPreferredBackend ();
2819+ switch (preferred_backend) {
2820+ case at::LinalgBackend::Magma:
2821+ return _svd_helper_cuda_legacy (self, some, compute_uv);
2822+ case at::LinalgBackend::Cusolver:
2823+ default :
2824+ return _svd_helper_cuda_lib (self, some, compute_uv);
2825+ }
27352826#else
27362827 return _svd_helper_cuda_legacy (self, some, compute_uv);
27372828#endif
@@ -3046,10 +3137,17 @@ void linalg_lstsq_gels(const Tensor& A, const Tensor& B, const Tensor& /*infos*/
30463137
30473138void gels_looped (const Tensor& a, Tensor& b, Tensor& infos) {
30483139#if defined(USE_CUSOLVER)
3049- // linalg_lstsq_gels is a generic function that is implemented using
3050- // geqrf_stub, ormqr_stub, and triangular_solve_stub
3051- // It dispatches to cuSOLVER for CUDA inputs if USE_CUSOLVER is defined
3052- return linalg_lstsq_gels (a, b, infos);
3140+ auto preferred_backend = at::globalContext ().linalgPreferredBackend ();
3141+ switch (preferred_backend) {
3142+ case at::LinalgBackend::Magma:
3143+ return gels_magma (a, b, infos);
3144+ case at::LinalgBackend::Cusolver:
3145+ default :
3146+ // linalg_lstsq_gels is a generic function that is implemented using
3147+ // geqrf_stub, ormqr_stub, and triangular_solve_stub
3148+ // It dispatches to cuSOLVER for CUDA inputs if USE_CUSOLVER is defined
3149+ return linalg_lstsq_gels (a, b, infos);
3150+ }
30533151#else
30543152 return gels_magma (a, b, infos);
30553153#endif
0 commit comments