[ROCm] add gfx1150 gfx1151 to supported gemm lists (#164744)

This is one of a few PRs needed to address https://github.com/pytorch/pytorch/pull/164744 fully.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164744
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
Jeff Daily
2025-10-07 00:02:20 +00:00
committed by PyTorch MergeBot
parent 361c5d362c
commit 44a5d41993
5 changed files with 29 additions and 11 deletions

View File

@ -483,8 +483,8 @@ at::BlasBackend Context::blasPreferredBackend() {
#if ROCM_VERSION >= 60300
"gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908",
#endif
#if ROCM_VERSION >= 60500
"gfx950"
#if ROCM_VERSION >= 70000
"gfx950", "gfx1150", "gfx1151"
#endif
};
for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) {

View File

@ -1270,7 +1270,7 @@ void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
}
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
if (at::detail::getCUDAHooks().isGPUArch({"gfx1100"})) { //no CK GEMM version for gfx1100
if (at::detail::getCUDAHooks().isGPUArch({"gfx11", "gfx12"})) { //no CK GEMM version
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
} else{
at::native::gemm_internal_ck<float>(CUDABLAS_GEMM_ARGS(float));

View File

@ -285,8 +285,8 @@ static bool isSupportedHipLtROCmArch(int index) {
#if ROCM_VERSION >= 60300
"gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908",
#endif
#if ROCM_VERSION >= 60500
"gfx950"
#if ROCM_VERSION >= 70000
"gfx950", "gfx1150", "gfx1151"
#endif
};
return at::detail::getCUDAHooks().isGPUArch(archs, index);

View File

@ -772,13 +772,21 @@ void dispatch_bfloat16_gemm_wmma(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
template <>
void gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
auto dprops = at::cuda::getCurrentDeviceProperties();
std::string_view arch(dprops->gcnArchName);
if (arch == "gfx1100") {
static const std::vector<std::string> wmma_archs = {
"gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201",
#if ROCM_VERSION >= 70000
"gfx1150", "gfx1151"
#endif
};
if (at::detail::getCUDAHooks().isGPUArch(wmma_archs)) {
dispatch_bfloat16_gemm_wmma(CUDABLAS_GEMM_ARGS(at::BFloat16));
} else{
}
else if (at::detail::getCUDAHooks().isGPUArch({"gfx9"})) {
dispatch_bfloat16_gemm(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
else {
TORCH_CHECK(false, "gemm_internal_ck<at::BFloat16> unsupported gfx arch");
}
}
} // namespace at::native

View File

@ -599,11 +599,21 @@ void dispatch_half_gemm_wmma(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
template <>
void gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
if (at::detail::getCUDAHooks().isGPUArch({"gfx1100"})) {
static const std::vector<std::string> wmma_archs = {
"gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201",
#if ROCM_VERSION >= 70000
"gfx1150", "gfx1151"
#endif
};
if (at::detail::getCUDAHooks().isGPUArch(wmma_archs)) {
dispatch_half_gemm_wmma(CUDABLAS_GEMM_ARGS(at::Half));
} else{
}
else if (at::detail::getCUDAHooks().isGPUArch({"gfx9"})) {
dispatch_half_gemm(CUDABLAS_GEMM_ARGS(at::Half));
}
else {
TORCH_CHECK(false, "gemm_internal_ck<at::Half> unsupported gfx arch");
}
}
} // namespace at::native