[ROCm] Add FP8 rowwise support to _scaled_grouped_mm + Submodule update (#159075)

Summary:

In this PR we integrate the [FBGEMM AMD FP8 rowwise scaling grouped GEMM kernel](https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped) to add support for the `_scaled_grouped_mm` API on AMD. `_scaled_grouped_mm` is [currently supported on Nvidia](9faef3d17c/aten/src/ATen/native/cuda/Blas.cpp (L1614)), this PR aims to bring parity to AMD. Related: [[RFC]: PyTorch Low-Precision GEMMs Public API](https://github.com/pytorch/pytorch/issues/157950#top) #157950.

The kernel is developed using the Composable Kernel framework. Only MI300X is currently supported. In the near future we plan to add support for MI350X as well. For data types we support FP8 e3m4.

The kernel support will be gated with the `USE_FBGEMM_GENAI` flag. We hope to enable this by default for relevant AMD builds.

Note we also update submodule `third_party/fbgemm` to 0adf62831 for the required updates from fbgemm.

Test Plan:

**Hipify & build**
```
python tools/amd_build/build_amd.py
USE_FBGEMM_GENAI=1 python setup.py develop
```

**Unit tests**
```
python test/test_matmul_cuda.py -- TestFP8MatmulCUDA
Ran 488 tests in 32.969s
OK (skipped=454)
```

**Performance Sample**
| G  | M | N | K | Runtime Ms | GB/S | TFLOPS |
| --  | -- | -- | -- | -- | -- | -- |
| 128 | 1 | 2048 | 5120 | 0.37| 3590 | 7.17 |
| 128 | 64 | 2048 | 5120 | 0.51| 2792 | 338.34 |
| 128 | 128 | 2048 | 5120 | 0.66| 2272 | 522.72 |
| 128 | 1 | 5120 | 1024 | 0.21| 3224 | 6.43 |
| 128 | 64 | 5120 | 1024 | 0.29| 2590 | 291.40 |
| 128 | 128 | 5120 | 1024 | 0.40| 2165 | 434.76 |
| 128 | 1 | 4096 | 4096 | 0.69| 3126 | 6.25 |
| 128 | 64 | 4096 | 4096 | 0.85| 2655 | 324.66 |
| 128 | 128 | 4096 | 4096 | 1.10| 2142 | 501.40 |
| 128 | 1 | 8192 | 8192 | 2.45| 3508 | 7.01 |
| 128 | 64 | 8192 | 8192 | 3.27| 2692 | 336.74 |
| 128 | 128 | 8192 | 8192 | 4.04| 2224 | 543.76 |
| 16 | 1 | 2048 | 5120 | 0.04| 3928 | 7.85 |
| 16 | 64 | 2048 | 5120 | 0.05| 3295 | 399.29 |
| 16 | 128 | 2048 | 5120 | 0.07| 2558 | 588.69 |
| 16 | 1 | 5120 | 1024 | 0.03| 3119 | 6.23 |
| 16 | 64 | 5120 | 1024 | 0.03| 2849 | 320.62 |
| 16 | 128 | 5120 | 1024 | 0.05| 2013 | 404.11 |
| 16 | 1 | 4096 | 4096 | 0.06| 4512 | 9.02 |
| 16 | 64 | 4096 | 4096 | 0.09| 3124 | 381.95 |
| 16 | 128 | 4096 | 4096 | 0.13| 2340 | 547.67 |
| 16 | 1 | 8192 | 8192 | 0.32| 3374 | 6.75 |
| 16 | 64 | 8192 | 8192 | 0.42| 2593 | 324.28 |
| 16 | 128 | 8192 | 8192 | 0.53| 2120 | 518.36 |

- Using ROCm 6.4.1
- Collected through `triton.testing.do_bench_cudagraph`

**Binary size with gfx942 arch**
Before: 116103856 Jul 23 14:12 build/lib/libtorch_hip.so
After:  118860960 Jul 23 14:29 build/lib/libtorch_hip.so
The difference is 2757104 bytes (~2.6 MiB).

Reviewers: @drisspg @ngimel @jwfromm @jeffdaily

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159075
Approved by: https://github.com/drisspg
This commit is contained in:
Chris Thi
2025-07-30 23:53:58 +00:00
committed by PyTorch MergeBot
parent 25c3a7e317
commit c400c8e2e0
9 changed files with 150 additions and 48 deletions

View File

@ -872,6 +872,14 @@ cmake_dependent_option(
"USE_CUDA OR USE_ROCM;NOT MSVC"
OFF)
cmake_dependent_option(
USE_FBGEMM_GENAI
"Whether to build FBGEMM GenAI quantized GEMM kernels.\
Will be disabled if not supported by the platform"
OFF
"USE_CUDA OR USE_ROCM"
OFF)
# CAVEAT: Again, Flash Attention2 will error while building for sm52 while Mem
# Eff Attention won't
cmake_dependent_option(
@ -905,6 +913,10 @@ if(USE_FBGEMM)
string(APPEND CMAKE_CXX_FLAGS " -DUSE_FBGEMM")
endif()
if(USE_FBGEMM_GENAI)
string(APPEND CMAKE_CXX_FLAGS " -DUSE_FBGEMM_GENAI")
endif()
if(USE_PYTORCH_QNNPACK)
string(APPEND CMAKE_CXX_FLAGS " -DUSE_PYTORCH_QNNPACK")
endif()

View File

@ -247,6 +247,50 @@ if(USE_MEM_EFF_ATTENTION)
list(APPEND ATen_ATTENTION_KERNEL_SRCS ${mem_eff_attention_cuda_kernels_cu})
endif()
IF(USE_FBGEMM_GENAI AND USE_ROCM AND NOT "gfx942" IN_LIST PYTORCH_ROCM_ARCH)
message(WARNING "Unsupported ROCM arch for FBGEMM GenAI, will set USE_FBGEMM_GENAI to OFF")
set(USE_FBGEMM_GENAI off)
endif()
# FBGEMM GenAI
IF(USE_FBGEMM_GENAI)
set(FBGEMM_THIRD_PARTY ${PROJECT_SOURCE_DIR}/third_party/fbgemm/external/)
set(FBGEMM_GENAI_DIR ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize)
if(USE_ROCM)
# Only include the kernels we want to build to avoid increasing binary size.
file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
"${FBGEMM_GENAI_DIR}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
"${FBGEMM_GENAI_DIR}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
# Add additional HIPCC compiler flags for performance
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
-mllvm
-amdgpu-coerce-illegal-types=1
-mllvm
-enable-post-misched=0
-mllvm
-greedy-reverse-local-assignment=1
-fhip-new-launch-api)
hip_add_library(
fbgemm_genai STATIC
${fbgemm_genai_native_rocm_hip}
HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
target_include_directories(fbgemm_genai PUBLIC
# FBGEMM version of Composable Kernel is used due to some customizations
${FBGEMM_THIRD_PARTY}/composable_kernel/include
${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
${FBGEMM_GENAI_DIR}/include/
${FBGEMM_GENAI_DIR}/common/include/
)
endif()
endif()
# XNNPACK
file(GLOB native_xnnpack "native/xnnpack/*.cpp")

View File

@ -21,6 +21,10 @@
#include <ATen/native/cuda/GroupMM.h>
#include <ATen/ceil_div.h>
#ifdef USE_FBGEMM_GENAI
#include <fbgemm_gpu/torch_ops.h>
#endif
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
@ -1216,7 +1220,7 @@ std::pair<ScalingType, ScalingType> get_joint_scaling(
// - `scale_a`: a tensor with the inverse scale of `mat1`, whose shape/strides/dtype depend on the scaling scheme
// - `scale_b`: a tensor with the inverse scale of `mat2`, whose shape/strides/dtype depend on the scaling scheme
// - `scale_result`: a scalar tensor with the scale of the output, only utilized if the output is a float8 type
// - `use_fast_accum`: if true, enables fast float8 accumulation
// - `use_fast_accum`: if true, enables fast float8 accumulation. Backends may ignore this option if not applicable.
// - `out`: a reference to the output tensor
Tensor&
@ -1525,6 +1529,7 @@ namespace {
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
#ifndef USE_ROCM
// For TMA transfers, strides of output tensor have to be either
// 1, or aligned to 16 bytes.
const auto last_dim = out_size.size() - 1;
@ -1536,9 +1541,10 @@ namespace {
} else {
out_stride = {out_size[1] * size_padded, size_padded, 1};
}
auto out = at::empty_strided(out_size, out_stride, mat_a.options().dtype(out_dtype_));
return out;
return at::empty_strided(out_size, out_stride, mat_a.options().dtype(out_dtype_));
#else
return at::empty(out_size, mat_a.options().dtype(out_dtype_));
#endif
}
bool check_valid_strides_and_return_transposed(const Tensor& mat) {
@ -1619,12 +1625,9 @@ const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum) {
#ifndef USE_ROCM
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true);
TORCH_CHECK(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = 9.0");
bool allowed_device = _scaled_mm_allowed_device();
TORCH_CHECK(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = 9.0, or ROCm MI300+");
TORCH_CHECK(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type());
TORCH_CHECK(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type());
TORCH_CHECK(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
TORCH_CHECK(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
TORCH_CHECK(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
@ -1664,6 +1667,10 @@ bool use_fast_accum) {
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype);
#ifndef USE_ROCM
TORCH_CHECK(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type());
TORCH_CHECK(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type());
at::cuda::detail::f8f8bf16_grouped_mm(
mat_a,
mat_b,
@ -1674,12 +1681,23 @@ bool use_fast_accum) {
use_fast_accum,
out);
return out;
#else
TORCH_CHECK(false, "grouped gemm is not supported on ROCM")
#ifdef USE_FBGEMM_GENAI
TORCH_CHECK(mat_a.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_a.scalar_type());
TORCH_CHECK(mat_b.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_b.scalar_type());
fbgemm_gpu::f8f8bf16_rowwise_grouped_mm(
mat_a,
// FBGEMM expects B matrix shape to be (.., N, K)
mat_b.transpose(-2, -1),
scale_a,
scale_b,
offs,
out);
return out;
#else
TORCH_CHECK(false, "grouped gemm is not supported without USE_FBGEMM_GENAI on ROCM")
#endif
#endif
}

View File

@ -1771,6 +1771,10 @@ if(USE_ROCM)
target_link_libraries(torch_hip PUBLIC torch_cpu_library ${Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS})
target_link_libraries(torch_hip PRIVATE ${Caffe2_HIP_DEPENDENCY_LIBS})
if(USE_FBGEMM_GENAI)
target_link_libraries(torch_hip PRIVATE fbgemm_genai)
endif()
# Since PyTorch files contain HIP headers, this is also needed to capture the includes.
# ROCM_INCLUDE_DIRS is defined in LoadHIP.cmake
target_include_directories(torch_hip PRIVATE ${Caffe2_HIP_INCLUDE} ${ROCM_INCLUDE_DIRS})

View File

@ -58,6 +58,9 @@
# USE_FBGEMM=0
# disables the FBGEMM build
#
# USE_FBGEMM_GENAI=1
# enables the FBGEMM GenAI kernels to build
#
# USE_KINETO=0
# disables usage of libkineto library for profiling
#

View File

@ -27,6 +27,7 @@ from torch.testing._internal.common_cuda import (
xfailIfSM120OrLater,
_get_torch_cuda_version,
PLATFORM_SUPPORTS_FP8,
PLATFORM_SUPPORTS_FP8_GROUPED_GEMM,
PLATFORM_SUPPORTS_MX_GEMM,
IS_SM90,
)
@ -768,6 +769,7 @@ class TestMatmulCuda(TestCase):
torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accum
f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
f8_grouped_msg = "FP8 grouped is only supported on SM90 and MI300+ devices"
mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+"
# avoid division by zero when calculating scale
@ -1845,17 +1847,16 @@ class TestFP8Matmul(TestCase):
# _scaled_mm() already has more combinations of parameters than
# _scaled_grouped_mm(), for supporting more than one inputs layout
# combinations.
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@xfailIfSM100OrLater
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8_GROUPED_GEMM, f8_grouped_msg)
@parametrize("fast_accum", [False, True])
@parametrize("strided", [False, True])
# AMD does not support non-contiguous inputs yet
@parametrize("strided", [False] + ([True] if torch.version.cuda else []))
def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided):
device = "cuda"
fp8_dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn
m, n, k, n_groups = 16, 32, 64, 4
a = torch.randn(m, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups]
b = torch.randn(n, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups]
a = torch.randn(m, k * n_groups + k * int(strided), device=device).to(fp8_dtype)[:, :k * n_groups]
b = torch.randn(n, k * n_groups + k * int(strided), device=device).to(fp8_dtype)[:, :k * n_groups]
scale_a = torch.rand(m * n_groups, device=device, dtype=torch.float32)
scale_b = torch.rand(n * n_groups, device=device, dtype=torch.float32)
offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32)
@ -1874,17 +1875,17 @@ class TestFP8Matmul(TestCase):
self.scaled_grouped_mm_helper(alist, blist, ascalelist, bscalelist, out, fast_accum)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@xfailIfSM100OrLater
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8_GROUPED_GEMM, f8_grouped_msg)
@parametrize("fast_accum", [False, True])
@parametrize("strided", [False, True])
# AMD does not support non-contiguous inputs yet
@parametrize("strided", [False] + ([True] if torch.version.cuda else []))
def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided):
device = "cuda"
fp8_dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn
m, n, k, n_groups = 16, 32, 64, 4
s_int = int(strided)
a = torch.randn(m * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k]
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
a = torch.randn(m * n_groups, k * (1 + s_int), device=device).to(fp8_dtype)[:, :k]
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(fp8_dtype)[::(1 + s_int), :, :k]
self.assertTrue(a.is_contiguous() is not strided)
self.assertTrue(b.is_contiguous() is not strided)
for check_zero_size in (True, False):
@ -1896,7 +1897,6 @@ class TestFP8Matmul(TestCase):
offs[0] = offs[1]
scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32)
scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
f = torch._scaled_grouped_mm
out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
@ -1912,17 +1912,17 @@ class TestFP8Matmul(TestCase):
self.scaled_grouped_mm_helper(alist, b, ascalelist, scale_b, outlist, fast_accum)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@xfailIfSM100OrLater
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8_GROUPED_GEMM, f8_grouped_msg)
@parametrize("fast_accum", [False, True])
@parametrize("strided", [False, True])
# AMD does not support non-contiguous inputs yet
@parametrize("strided", [False] + ([True] if torch.version.cuda else []))
def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided):
device = "cuda"
fp8_dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn
m, n, k, n_groups = 16, 32, 64, 4
s_int = int(strided)
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(fp8_dtype)[::(1 + s_int), :, :k]
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(fp8_dtype)[::(1 + s_int), :, :k]
self.assertTrue(a.is_contiguous() is not strided)
self.assertTrue(b.is_contiguous() is not strided)
scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
@ -1935,17 +1935,17 @@ class TestFP8Matmul(TestCase):
self.scaled_grouped_mm_helper(a, b, scale_a, scale_b, out, fast_accum)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@xfailIfSM100OrLater
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8_GROUPED_GEMM, f8_grouped_msg)
@parametrize("fast_accum", [False, True])
@parametrize("strided", [False, True])
# AMD does not support non-contiguous inputs yet
@parametrize("strided", [False] + ([True] if torch.version.cuda else []))
def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided):
device = "cuda"
fp8_dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn
m, n, k, n_groups = 16, 32, 64, 4
s_int = int(strided)
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
b = torch.randn(n * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k]
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(fp8_dtype)[::(1 + s_int), :, :k]
b = torch.randn(n * n_groups, k * (1 + s_int), device=device).to(fp8_dtype)[:, :k]
self.assertTrue(a.is_contiguous() is not strided)
self.assertTrue(b.is_contiguous() is not strided)
scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)

View File

@ -7313,13 +7313,18 @@ def _create_grouped_mm_output_tensor(mat1, mat2, offs, out_dtype):
out_dtype = out_dtype or mat1.dtype
if torch.version.cuda:
alignment = 16 // out_dtype.itemsize
size_padded = (out_size[-1] + alignment - 1) // alignment * alignment
if mat1_is_2d == mat2_is_2d:
out_stride = [out_size[1] * size_padded, size_padded, 1]
else:
out_stride = [size_padded, 1]
out = torch.empty_strided(out_size, out_stride, dtype=out_dtype, device=mat1.device)
out = torch.empty_strided(
out_size, out_stride, dtype=out_dtype, device=mat1.device
)
else:
out = torch.empty(out_size, dtype=out_dtype, device=mat1.device)
return out
@ -7345,8 +7350,9 @@ def _meta_grouped_mm_common(
# aten/src/ATen/native/cuda/Blas.cpp.
if scaled:
fp8_dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn
torch._check(
mat_a.dtype == torch.float8_e4m3fn and mat_b.dtype == torch.float8_e4m3fn,
mat_a.dtype == fp8_dtype and mat_b.dtype == fp8_dtype,
lambda: f"Expected inputs of E4M3 FP8 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.",
)
else:

View File

@ -107,8 +107,23 @@ def evaluate_platform_supports_fp8():
return SM90OrLater or torch.cuda.get_device_capability() == (8, 9)
return False
def evaluate_platform_supports_fp8_grouped_gemm():
if torch.cuda.is_available():
if torch.version.hip:
if "USE_FBGEMM_GENAI" not in torch.__config__.show():
return False
archs = ['gfx942']
for arch in archs:
if arch in torch.cuda.get_device_properties(0).gcnArchName:
return True
else:
return SM90OrLater and not SM100OrLater
return False
PLATFORM_SUPPORTS_FP8: bool = LazyVal(lambda: evaluate_platform_supports_fp8())
PLATFORM_SUPPORTS_FP8_GROUPED_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_fp8_grouped_gemm())
PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: TEST_CUDA and SM100OrLater)
if TEST_NUMBA: