mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ROCm] add gfx1150 gfx1151 to supported gemm lists (#164744)
This is one of a few PRs needed to address https://github.com/pytorch/pytorch/pull/164744 fully. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164744 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
361c5d362c
commit
44a5d41993
@ -483,8 +483,8 @@ at::BlasBackend Context::blasPreferredBackend() {
|
||||
#if ROCM_VERSION >= 60300
|
||||
"gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908",
|
||||
#endif
|
||||
#if ROCM_VERSION >= 60500
|
||||
"gfx950"
|
||||
#if ROCM_VERSION >= 70000
|
||||
"gfx950", "gfx1150", "gfx1151"
|
||||
#endif
|
||||
};
|
||||
for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) {
|
||||
|
@ -1270,7 +1270,7 @@ void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
|
||||
}
|
||||
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
|
||||
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
||||
if (at::detail::getCUDAHooks().isGPUArch({"gfx1100"})) { //no CK GEMM version for gfx1100
|
||||
if (at::detail::getCUDAHooks().isGPUArch({"gfx11", "gfx12"})) { //no CK GEMM version
|
||||
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
|
||||
} else{
|
||||
at::native::gemm_internal_ck<float>(CUDABLAS_GEMM_ARGS(float));
|
||||
|
@ -285,8 +285,8 @@ static bool isSupportedHipLtROCmArch(int index) {
|
||||
#if ROCM_VERSION >= 60300
|
||||
"gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908",
|
||||
#endif
|
||||
#if ROCM_VERSION >= 60500
|
||||
"gfx950"
|
||||
#if ROCM_VERSION >= 70000
|
||||
"gfx950", "gfx1150", "gfx1151"
|
||||
#endif
|
||||
};
|
||||
return at::detail::getCUDAHooks().isGPUArch(archs, index);
|
||||
|
@ -772,13 +772,21 @@ void dispatch_bfloat16_gemm_wmma(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||
|
||||
template <>
|
||||
void gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
std::string_view arch(dprops->gcnArchName);
|
||||
if (arch == "gfx1100") {
|
||||
static const std::vector<std::string> wmma_archs = {
|
||||
"gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201",
|
||||
#if ROCM_VERSION >= 70000
|
||||
"gfx1150", "gfx1151"
|
||||
#endif
|
||||
};
|
||||
if (at::detail::getCUDAHooks().isGPUArch(wmma_archs)) {
|
||||
dispatch_bfloat16_gemm_wmma(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
} else{
|
||||
}
|
||||
else if (at::detail::getCUDAHooks().isGPUArch({"gfx9"})) {
|
||||
dispatch_bfloat16_gemm(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK(false, "gemm_internal_ck<at::BFloat16> unsupported gfx arch");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -599,11 +599,21 @@ void dispatch_half_gemm_wmma(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
||||
|
||||
template <>
|
||||
void gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
||||
if (at::detail::getCUDAHooks().isGPUArch({"gfx1100"})) {
|
||||
static const std::vector<std::string> wmma_archs = {
|
||||
"gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201",
|
||||
#if ROCM_VERSION >= 70000
|
||||
"gfx1150", "gfx1151"
|
||||
#endif
|
||||
};
|
||||
if (at::detail::getCUDAHooks().isGPUArch(wmma_archs)) {
|
||||
dispatch_half_gemm_wmma(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
} else{
|
||||
}
|
||||
else if (at::detail::getCUDAHooks().isGPUArch({"gfx9"})) {
|
||||
dispatch_half_gemm(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK(false, "gemm_internal_ck<at::Half> unsupported gfx arch");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
Reference in New Issue
Block a user