mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm] use hipSolver instead of MAGMA for Cholesky (#163977)
Currently, the Cholesky factorization and least squares operation defaults to magma when Pytorch is compiled for ROCm. This shows suboptimal performance. This change allows PyTorch to rely on hipSolver instead of Magma. @jeffdaily Pull Request resolved: https://github.com/pytorch/pytorch/pull/163977 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
7bbde0c094
commit
238dc65368
@ -1238,7 +1238,7 @@ Tensor _cholesky_solve_helper_cuda_magma(const Tensor& self, const Tensor& A, bo
|
||||
// Todo: cusolverDn<T>potrsBatched only supports nrhs == 1 and does not have good performance.
|
||||
// Batched cholesky_solve is dispatched to magma.
|
||||
Tensor _cholesky_solve_helper_cuda(const Tensor& self, const Tensor& A, bool upper) {
|
||||
#if defined(USE_LINALG_SOLVER) && !defined(USE_ROCM)
|
||||
#if defined(USE_LINALG_SOLVER)
|
||||
auto preferred_backend = at::globalContext().linalgPreferredBackend();
|
||||
switch (preferred_backend) {
|
||||
case at::LinalgBackend::Cusolver:
|
||||
@ -1352,7 +1352,7 @@ void cholesky_helper_magma(const Tensor& input, bool upper, const Tensor& info)
|
||||
}
|
||||
|
||||
static void cholesky_kernel(const Tensor& input, const Tensor& info, bool upper) {
|
||||
#if defined(USE_LINALG_SOLVER) && !defined(USE_ROCM)
|
||||
#if defined(USE_LINALG_SOLVER)
|
||||
auto preferred_backend = at::globalContext().linalgPreferredBackend();
|
||||
switch (preferred_backend) {
|
||||
case at::LinalgBackend::Cusolver:
|
||||
@ -2709,7 +2709,7 @@ void linalg_lstsq_gels(const Tensor& A, const Tensor& B, const Tensor& /*infos*/
|
||||
}
|
||||
|
||||
void gels_looped(const Tensor& a, Tensor& b, Tensor& infos) {
|
||||
#if defined(USE_LINALG_SOLVER) && !defined(USE_ROCM)
|
||||
#if defined(USE_LINALG_SOLVER)
|
||||
auto preferred_backend = at::globalContext().linalgPreferredBackend();
|
||||
switch (preferred_backend) {
|
||||
case at::LinalgBackend::Magma:
|
||||
@ -2733,7 +2733,7 @@ void lstsq_kernel(const Tensor& a, Tensor& b, Tensor& /*rank*/, Tensor& /*singul
|
||||
// first handle the underdetermined case (m < n)
|
||||
// this case is not supported by MAGMA or cuBLAS
|
||||
if (m < n) {
|
||||
#if defined(USE_LINALG_SOLVER) && !defined(USE_ROCM)
|
||||
#if defined(USE_LINALG_SOLVER)
|
||||
linalg_lstsq_gels(a, b, infos);
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
|
Reference in New Issue
Block a user