[ROCm] enable batched eigen decomposition (syevD_batched) on ROCm (#154525)

This PR implements `Batched Eigen Decomposition` (syevD_batched) on ROCm by calling rocSolver directly.
cuSolver doesn't support syevD_batched and neither does hipSolver. Direct call to rocSolver is required.

`syevD_batched` will be used on ROCm if all the following conditions are met:
- `rocSolver version >= 3.26`
- input data type is `float` or `double`
- batch size >= 2

Otherwise, non-batched `syevD` will be used on ROCm (complex data types, batch size==1,  rocSolver <3.26)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154525
Approved by: https://github.com/Mellonta
This commit is contained in:
Dmitry Nikolaev
2025-06-17 19:20:36 +00:00
committed by PyTorch MergeBot
parent ec08eb8ba2
commit 5eb5c3700b

View File

@ -28,6 +28,18 @@
#include <ATen/ops/zeros.h>
#endif
#if defined(USE_ROCM)
#include <rocsolver/rocsolver.h>
#include <ATen/cuda/tunable/GemmRocblas.h>
#define PYTORCH_ROCSOLVER_VERSION \
(ROCSOLVER_VERSION_MAJOR * 10000 + ROCSOLVER_VERSION_MINOR * 100 + ROCSOLVER_VERSION_PATCH)
#if (PYTORCH_ROCSOLVER_VERSION >= 32600)
#define ROCSOLVER_SYEVD_BATCHED_ENABLED 1
#else
#define ROCSOLVER_SYEVD_BATCHED_ENABLED 0
#endif
#endif // defined(USE_ROCM)
namespace at::native {
static cublasOperation_t to_cublas(TransposeType trans) {
@ -1204,6 +1216,115 @@ Tensor& orgqr_helper_cusolver(Tensor& result, const Tensor& tau) {
return result;
}
#if defined(USE_ROCM) && ROCSOLVER_SYEVD_BATCHED_ENABLED
template <typename scalar_t>
rocblas_status _rocsolver_syevd_strided_batched(
rocblas_handle handle,
const rocblas_evect evect,
const rocblas_fill uplo,
const rocblas_int n,
scalar_t* A,
const rocblas_int lda,
const rocblas_stride strideA,
scalar_t* D,
const rocblas_stride strideD,
scalar_t* E,
const rocblas_stride strideE,
rocblas_int* info,
const rocblas_int batch_count
);
template <>
rocblas_status _rocsolver_syevd_strided_batched<float>(
rocblas_handle handle,
const rocblas_evect evect,
const rocblas_fill uplo,
const rocblas_int n,
float* A,
const rocblas_int lda,
const rocblas_stride strideA,
float* D,
const rocblas_stride strideD,
float* E,
const rocblas_stride strideE,
rocblas_int* info,
const rocblas_int batch_count
){
return rocsolver_ssyevd_strided_batched(
handle, evect, uplo, n, A, lda, strideA, D, strideD, E, strideE, info, batch_count
);
}
template <>
rocblas_status _rocsolver_syevd_strided_batched<double>(
rocblas_handle handle,
const rocblas_evect evect,
const rocblas_fill uplo,
const rocblas_int n,
double* A,
const rocblas_int lda,
const rocblas_stride strideA,
double* D,
const rocblas_stride strideD,
double* E,
const rocblas_stride strideE,
rocblas_int* info,
const rocblas_int batch_count
){
return rocsolver_dsyevd_strided_batched(
handle, evect, uplo, n, A, lda, strideA, D, strideD, E, strideE, info, batch_count
);
}
template <typename scalar_t>
static void apply_syevd_batched_rocsolver(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
using value_t = typename c10::scalar_value_type<scalar_t>::type;
auto uplo = upper ? rocblas_fill::rocblas_fill_upper : rocblas_fill::rocblas_fill_lower;
auto evect = compute_eigenvectors ? rocblas_evect::rocblas_evect_original : rocblas_evect::rocblas_evect_none;
int64_t n = vectors.size(-1);
int64_t lda = std::max<int64_t>(1, n);
int64_t batch_size = batchCount(vectors);
auto vectors_stride = matrixStride(vectors);
auto values_stride = n;
auto vectors_data = vectors.data_ptr<scalar_t>();
auto values_data = values.data_ptr<value_t>();
auto infos_data = infos.data_ptr<int>();
auto work_stride = n;
auto work_size = work_stride * batch_size;
// allocate workspace storage on device
auto& allocator = *at::cuda::getCUDADeviceAllocator();
auto work_data = allocator.allocate(sizeof(scalar_t) * work_size);
rocblas_handle handle = static_cast<rocblas_handle>(at::cuda::getCurrentCUDASolverDnHandle());
// rocsolver will manage the workspace size automatically
if(!rocblas_is_managing_device_memory(handle))
TORCH_ROCBLAS_CHECK(rocblas_set_workspace(handle, nullptr, 0));
TORCH_ROCBLAS_CHECK(_rocsolver_syevd_strided_batched<scalar_t>(
handle,
evect,
uplo,
n,
vectors_data,
lda,
vectors_stride,
values_data,
values_stride,
static_cast<scalar_t*>(work_data.get()),
work_stride,
infos_data,
batch_size
));
}
#endif // USE_ROCM && ROCSOLVER_SYEVD_BATCHED_ENABLED
template <typename scalar_t>
static void apply_syevd(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
using value_t = typename c10::scalar_value_type<scalar_t>::type;
@ -1475,11 +1596,22 @@ static void linalg_eigh_cusolver_syevj_batched(const Tensor& eigenvalues, const
});
}
#if defined(USE_ROCM) && ROCSOLVER_SYEVD_BATCHED_ENABLED
static void linalg_eigh_rocsolver_syevd_batched(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
AT_DISPATCH_FLOATING_TYPES(eigenvectors.scalar_type(), "linalg_eigh_cuda", [&]() {
apply_syevd_batched_rocsolver<scalar_t>(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);});
}
#endif // USE_ROCM && ROCSOLVER_SYEVD_BATCHED_ENABLED
void linalg_eigh_cusolver(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
#ifdef USE_ROCM
// syevj has larger numerical errors than syevd
linalg_eigh_cusolver_syevd(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
#else
#if defined(USE_ROCM)
#if ROCSOLVER_SYEVD_BATCHED_ENABLED
if (batchCount(eigenvectors) > 1 && (eigenvectors.scalar_type() == at::kFloat || eigenvectors.scalar_type() == at::kDouble))
linalg_eigh_rocsolver_syevd_batched(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
else // not ROCSOLVER_SYEVD_BATCHED_ENABLED or batch==1 or complex input
#endif // ROCSOLVER_SYEVD_BATCHED_ENABLED
linalg_eigh_cusolver_syevd(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
#else // not USE_ROCM
if (use_cusolver_syevj_batched_ && batchCount(eigenvectors) > 1 && eigenvectors.size(-1) <= 32) {
// Use syevjBatched for batched matrix operation when matrix size <= 32
// See https://github.com/pytorch/pytorch/pull/53040#issuecomment-788264724