[ROCm] enable grouped gemm fallback (#162419)

Enables bf16 group gemm alternative path as described in #161366
Fast path will be enabled in future through CK integration.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162419
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
Jeff Daily
2025-09-09 20:04:54 +00:00
committed by PyTorch MergeBot
parent d22d916719
commit b477fb106f
2 changed files with 8 additions and 18 deletions

View File

@ -1080,16 +1080,6 @@ static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=fals
#endif
}
static bool _grouped_mm_allowed_device() {
#ifdef USE_ROCM
return false;
#else
auto dprops = at::cuda::getCurrentDeviceProperties();
// CUDA capability 8.0 and greater
return dprops->major >= 8;
#endif
}
#ifdef USE_ROCM
static bool _scaled_mm_is_fnuz() {
return at::detail::getCUDAHooks().isGPUArch({"gfx942"});
@ -1786,14 +1776,19 @@ Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& bias,
std::optional<c10::ScalarType> out_dtype) {
#ifndef USE_ROCM
_grouped_mm_validate_inputs(mat_a, mat_b, offs, bias, out_dtype);
bool a_b_and_out_are_bf16 = (
mat_a.dtype() == at::kBFloat16 &&
mat_b.dtype() == at::kBFloat16 &&
out_dtype.value_or(at::kBFloat16) == at::kBFloat16
);
#ifndef USE_ROCM
bool use_fast_path = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true) && a_b_and_out_are_bf16;
#else
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
bool use_fast_path = false;
#endif
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
if (use_fast_path) {
@ -1803,9 +1798,6 @@ std::optional<c10::ScalarType> out_dtype) {
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
}
return out;
#else
TORCH_CHECK(false, "grouped gemm is not supported on ROCM")
#endif
}
Tensor _bmm_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype) {

View File

@ -316,7 +316,6 @@ class TestMatmulCuda(TestCase):
self.assertEqual(agrad, a.grad)
self.assertEqual(bgrad, b.grad)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@xfailIfSM120OrLater
@unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater")
@parametrize("strided", [False, True])
@ -355,7 +354,6 @@ class TestMatmulCuda(TestCase):
start = offs_cpu[i]
self.grouped_mm_helper(alist, blist, gO, agradlist, bgradlist, out)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@xfailIfSM120OrLater
@unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater")
@parametrize("strided", [False, True])
@ -412,7 +410,6 @@ class TestMatmulCuda(TestCase):
self.grouped_mm_helper(alist, b, gOlist, agradlist, bgradlist, outlist)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@xfailIfSM120OrLater
@unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater")
@parametrize("strided", [False, True])
@ -447,7 +444,6 @@ class TestMatmulCuda(TestCase):
out.backward(gO)
self.grouped_mm_helper(a, b, gO, a.grad, b.grad, out)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@xfailIfSM120OrLater
@unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater")
@parametrize("strided", [False, True])
@ -455,6 +451,8 @@ class TestMatmulCuda(TestCase):
@parametrize("b_row_major", [False, True])
@dtypes(torch.bfloat16, torch.float32, torch.float16)
def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, dtype):
if TEST_WITH_ROCM and a_row_major and b_row_major and dtype in [torch.bfloat16, torch.float16]:
self.skipTest("failed using hipblaslt on rocm 6.4.2")
device = "cuda"
s_int = int(strided)
m, n, k, n_groups = 16, 32, 64, 4