mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ROCm] enable grouped gemm fallback (#162419)
Enables bf16 group gemm alternative path as described in #161366 Fast path will be enabled in future through CK integration. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162419 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
d22d916719
commit
b477fb106f
@ -1080,16 +1080,6 @@ static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=fals
|
||||
#endif
|
||||
}
|
||||
|
||||
static bool _grouped_mm_allowed_device() {
|
||||
#ifdef USE_ROCM
|
||||
return false;
|
||||
#else
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
// CUDA capability 8.0 and greater
|
||||
return dprops->major >= 8;
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static bool _scaled_mm_is_fnuz() {
|
||||
return at::detail::getCUDAHooks().isGPUArch({"gfx942"});
|
||||
@ -1786,14 +1776,19 @@ Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
std::optional<c10::ScalarType> out_dtype) {
|
||||
#ifndef USE_ROCM
|
||||
_grouped_mm_validate_inputs(mat_a, mat_b, offs, bias, out_dtype);
|
||||
bool a_b_and_out_are_bf16 = (
|
||||
mat_a.dtype() == at::kBFloat16 &&
|
||||
mat_b.dtype() == at::kBFloat16 &&
|
||||
out_dtype.value_or(at::kBFloat16) == at::kBFloat16
|
||||
);
|
||||
#ifndef USE_ROCM
|
||||
bool use_fast_path = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true) && a_b_and_out_are_bf16;
|
||||
#else
|
||||
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
|
||||
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
|
||||
bool use_fast_path = false;
|
||||
#endif
|
||||
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
if (use_fast_path) {
|
||||
@ -1803,9 +1798,6 @@ std::optional<c10::ScalarType> out_dtype) {
|
||||
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
|
||||
}
|
||||
return out;
|
||||
#else
|
||||
TORCH_CHECK(false, "grouped gemm is not supported on ROCM")
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor _bmm_dtype_cuda(const Tensor& batch1, const Tensor& batch2, const at::ScalarType out_dtype) {
|
||||
|
@ -316,7 +316,6 @@ class TestMatmulCuda(TestCase):
|
||||
self.assertEqual(agrad, a.grad)
|
||||
self.assertEqual(bgrad, b.grad)
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||
@xfailIfSM120OrLater
|
||||
@unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater")
|
||||
@parametrize("strided", [False, True])
|
||||
@ -355,7 +354,6 @@ class TestMatmulCuda(TestCase):
|
||||
start = offs_cpu[i]
|
||||
self.grouped_mm_helper(alist, blist, gO, agradlist, bgradlist, out)
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||
@xfailIfSM120OrLater
|
||||
@unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater")
|
||||
@parametrize("strided", [False, True])
|
||||
@ -412,7 +410,6 @@ class TestMatmulCuda(TestCase):
|
||||
self.grouped_mm_helper(alist, b, gOlist, agradlist, bgradlist, outlist)
|
||||
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||
@xfailIfSM120OrLater
|
||||
@unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater")
|
||||
@parametrize("strided", [False, True])
|
||||
@ -447,7 +444,6 @@ class TestMatmulCuda(TestCase):
|
||||
out.backward(gO)
|
||||
self.grouped_mm_helper(a, b, gO, a.grad, b.grad, out)
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||
@xfailIfSM120OrLater
|
||||
@unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater")
|
||||
@parametrize("strided", [False, True])
|
||||
@ -455,6 +451,8 @@ class TestMatmulCuda(TestCase):
|
||||
@parametrize("b_row_major", [False, True])
|
||||
@dtypes(torch.bfloat16, torch.float32, torch.float16)
|
||||
def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, dtype):
|
||||
if TEST_WITH_ROCM and a_row_major and b_row_major and dtype in [torch.bfloat16, torch.float16]:
|
||||
self.skipTest("failed using hipblaslt on rocm 6.4.2")
|
||||
device = "cuda"
|
||||
s_int = int(strided)
|
||||
m, n, k, n_groups = 16, 32, 64, 4
|
||||
|
Reference in New Issue
Block a user