AMD/ROCm OCP Micro-scaling Format (mx-fp8/mx-fp4) Support (#151360)

- This pull request introduces support for the [OCP Micro-scaling (MX) format](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf), with a focus on compatibility with AMD **ROCm 7.0** and the **gfx950** architecture.

  This PR also establishes the foundation for enabling MX-FPX features in [TorchAO](https://github.com/pytorch/ao/issues/2229) on the AMD platform.

- Validation (**ROCm 7.0** + **gfx950** required):

  `111 relevant tests passing.`

  > PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v

  Co-author: @jagadish-amd —  Thank you for the efforts leading validation on gfx950 with ROCm 7.0.

-----------------------------------

This pull request introduces support for new scalar types and scaling methods, particularly for ROCm 7.0 and gfx950, and refines testing for these features. Key changes include adding constraints for matrix dimensions, enabling block-wise scaling, and updating tests to accommodate new data types.

### Support for new scalar types and scaling methods:
* [`aten/src/ATen/cuda/CUDABlas.cpp`](diffhunk://#diff-74fcb26047c1df4024105d36ce22a36b77cf8cc93c28631d743e639b3d6066aeR1876-R1885): Added constraints for matrix dimensions when using `Float8_e8m0fnu` with block-wise scaling, ensuring dimensions are multiples of 32. Updated compatibility checks to support ROCm 7.0 for `Float8_e8m0fnu` and `Float8_e4m3fn`. [[1]](diffhunk://#diff-74fcb26047c1df4024105d36ce22a36b77cf8cc93c28631d743e639b3d6066aeR1876-R1885) [[2]](diffhunk://#diff-74fcb26047c1df4024105d36ce22a36b77cf8cc93c28631d743e639b3d6066aeL1913-R1934)

* [`aten/src/ATen/native/cuda/Blas.cpp`](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abR1276-R1290): Introduced block-wise scaling for `Float8_e8m0fnu`, with checks for ROCm 7.0 and GPU architecture `gfx950`. Added validation for supported scalar types and matrix dimensions. [[1]](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abR1276-R1290) [[2]](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abR1349-R1364)

### Updates to scalar type mappings:
* [`aten/src/ATen/cuda/CUDADataType.h`](diffhunk://#diff-9188bb13b1a49f459141f5f9b875593d1c5ce2beb5ad711fdbaf5bc7089ec015L93-R93): Extended scalar type mappings to support `Float4_e2m1fn_x2` for ROCm 7.0.

* [`aten/src/ATen/cuda/tunable/GemmHipblaslt.h`](diffhunk://#diff-bfa1a3b5d4bef1892bf50338775f3b0fd8cd31fc1868148f3968b98aefb68e3fR88-R96): Added a constexpr mapping for `Float4_e2m1fn_x2` based on ROCm version.

### Enhancements to testing(@jagadish-amd):
* [`test/test_matmul_cuda.py`](diffhunk://#diff-3f31c52b48cfddf8f4617d809f7695b2e4a1c78656f8c4b5143a4b45d01fcf23R765-R766): Updated tests to include new scalar types (`Float4_e2m1fn_x2`) and recipes (`mxfp4`). Added logic to handle different scaling recipes and validate compatibility with ROCm and CUDA versions. [[1]](diffhunk://#diff-3f31c52b48cfddf8f4617d809f7695b2e4a1c78656f8c4b5143a4b45d01fcf23R765-R766) [[2]](diffhunk://#diff-3f31c52b48cfddf8f4617d809f7695b2e4a1c78656f8c4b5143a4b45d01fcf23L1331-R1356) F592e669L1353R1472)

These changes improve compatibility with newer hardware and software versions, enhance functionality for matrix operations, and ensure robust testing for the added features.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151360
Approved by: https://github.com/drisspg, https://github.com/malfet
This commit is contained in:
Peter Y. Yeh
2025-08-18 16:43:09 +00:00
committed by PyTorch MergeBot
parent f2be3dc8da
commit e389a08dcd
8 changed files with 156 additions and 63 deletions

View File

@ -1847,8 +1847,12 @@ int get_scale_mode(ScalingType scaling_type, ScalarType scale_dtype, bool use_fa
switch (scaling_type) {
case ScalingType::BlockWise1x32:
TORCH_CHECK(scale_dtype == kFloat8_e8m0fnu);
#if CUDA_VERSION >= 12080
#if CUDA_VERSION >= 12080 || (defined(USE_ROCM) && ROCM_VERSION >= 70000)
#ifdef USE_ROCM
return HIPBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
#else
return CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
#endif // USE_ROCM
#else
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales of 1x32 blocks is only supported for CUDA 12.8 and above");
#endif // if CUDA_VERSION >= 12080
@ -1946,12 +1950,26 @@ void scaled_gemm(
// hipblaslt supported row-wise before cublas, and did so their own way (via
// the SCALE_POINTERSs), but then migrated to match how cublas does it (via
// the SCALE_MODEs). Here we check for this early custom mode.
bool use_rowwise = (mat1_scaling_type == ScalingType::RowWise && mat2_scaling_type == ScalingType::RowWise);
#if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
if (mat1_scaling_type == ScalingType::RowWise && mat2_scaling_type == ScalingType::RowWise) {
if (use_rowwise) {
matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT;
matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT;
}
#endif // if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
else if (mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) {
#if ROCM_VERSION >= 70000
if (at::detail::getCUDAHooks().isGPUArch({"gfx950"})) {
// TODO: add constraints based on hipblaslt internals
TORCH_CHECK((m % 32 == 0) && (n % 32 == 0) && (k % 32 == 0),
"Matrix dimensions must be multiples of 32 for MX format. "
"Got m=", m, ", n=", n, ", k=", k);
}
#endif
}
#else
// rowwise isn't supported using cublaslt or older hipblaslt
TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
#endif // if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
computeDesc.setAttribute(matmulDescA, mat1_scale_ptr);
computeDesc.setAttribute(matmulDescB, mat2_scale_ptr);
if (result_scale_ptr != nullptr) {
@ -1990,15 +2008,16 @@ void scaled_gemm(
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_EPILOGUE_BIAS);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
}
// The SCALE_MODE attrs only exist in cuBLAS 12.8+ or in recent hipblaslt,
// but we must invoke get_scale_mode anyways to trigger the version checks.
[[maybe_unused]] int a_scale_mode = get_scale_mode(mat1_scaling_type, mat1_scale_dtype, use_fast_accum);
[[maybe_unused]] int b_scale_mode = get_scale_mode(mat2_scaling_type, mat2_scale_dtype, use_fast_accum);
#if CUDA_VERSION >= 12080 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC))
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, a_scale_mode);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, b_scale_mode);
#endif
// For other data types, use the get_scale_mode function based on scaling type
// The SCALE_MODE attrs only exist in cuBLAS 12.8+/ROCm 7.0 or in recent hipblaslt,
// but we must invoke get_scale_mode anyways to trigger the version checks.
// Note that AMD/ROCm follows OCP Spec 1.0, which is different from NVIDIA's implementation. See get_scale_mode() for details.
[[maybe_unused]] int a_scale_mode = get_scale_mode(mat1_scaling_type, mat1_scale_dtype, use_fast_accum);
[[maybe_unused]] int b_scale_mode = get_scale_mode(mat2_scaling_type, mat2_scale_dtype, use_fast_accum);
#if CUDA_VERSION >= 12080 || (defined(USE_ROCM) && ROCM_VERSION >= 70000 && defined(HIPBLASLT_OUTER_VEC))
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, a_scale_mode);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, b_scale_mode);
#endif // if CUDA_VERSION >= 12080 || (defined(USE_ROCM) && ROCM_VERSION >= 70000 && defined(HIPBLASLT_OUTER_VEC))
CuBlasLtMatmulPreference preference;
auto ltworkspace = CublasLtWorkspace();

View File

@ -90,7 +90,7 @@ inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type)
case c10::ScalarType::Float8_e5m2fnuz:
return HIP_R_8F_E5M2_FNUZ;
#endif
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12080)
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12080) || (defined(USE_ROCM) && ROCM_VERSION >= 70000)
case c10::ScalarType::Float4_e2m1fn_x2:
return CUDA_R_4F_E2M1;
#endif

View File

@ -85,6 +85,15 @@ constexpr hipDataType HipDataTypeFor<c10::Float8_e8m0fnu>() {
return static_cast<hipDataType>(500);
}
template <>
constexpr hipDataType HipDataTypeFor<c10::Float4_e2m1fn_x2>() {
#if ROCM_VERSION >= 70000
return HIP_R_4F_E2M1;
#else
return static_cast<hipDataType>(33);
#endif
}
template <typename T>
int GetBatchFromParams(const GemmParams<T>* params) {
return 1;

View File

@ -1283,15 +1283,35 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
if (use_fast_accum) {
TORCH_CHECK(mat1.scalar_type() != ScalarType::Float4_e2m1fn_x2 && mat2.scalar_type() != ScalarType::Float4_e2m1fn_x2, "`use_fast_accum` is not supported when `mat1` or `mat2` tensors have the `Float4_e2m1fn_x2` dtype.");
}
#ifdef USE_ROCM
if (mat1.scalar_type() == ScalarType::Float4_e2m1fn_x2 || mat2.scalar_type() == ScalarType::Float4_e2m1fn_x2) {
TORCH_CHECK(ROCM_VERSION >= 70000, "Float4_e2m1fn_x2 is only supported for ROCm 7.0 and above");
}
if (mat1.scalar_type() == ScalarType::Float8_e5m2 || mat2.scalar_type() == ScalarType::Float8_e5m2) {
TORCH_CHECK(ROCM_VERSION >= 60500, "Float8_e5m2 is only supported for ROCm 6.5 and above");
}
if (mat1.scalar_type() == ScalarType::Float8_e4m3fn || mat2.scalar_type() == ScalarType::Float8_e4m3fn) {
TORCH_CHECK(ROCM_VERSION >= 60500, "Float8_e4m3fn is only supported for ROCm 6.5 and above");
}
#endif
if (bias) {
TORCH_CHECK(out.scalar_type() != kFloat, "Bias is not supported when out_dtype is set to Float32");
TORCH_CHECK(bias->scalar_type() == ScalarType::BFloat16 || bias->scalar_type() == ScalarType::Half,
"Bias must be either Half or BFloat16, but got ", bias->scalar_type());
TORCH_CHECK((out.scalar_type() != kFloat && out.scalar_type() != ScalarType::BFloat16) ||
bias->scalar_type() == ScalarType::BFloat16,
"Bias must be BFloat16 to compute ", out.scalar_type(), " output, but got ", bias->scalar_type());
TORCH_CHECK(out.scalar_type() != ScalarType::Half || bias->scalar_type() == ScalarType::Half,
"Bias must be Float16 to compute ", out.scalar_type(), " output, but got ", bias->scalar_type());
TORCH_CHECK(out.scalar_type() != kFloat,
"Bias is not supported when out_dtype is set to Float32");
TORCH_CHECK(bias->scalar_type() == ScalarType::BFloat16 ||
bias->scalar_type() == ScalarType::Half,
"Bias must be BFloat16 or Half, but got ", bias->scalar_type());
TORCH_CHECK((out.scalar_type() != kFloat &&
out.scalar_type() != ScalarType::BFloat16) ||
bias->scalar_type() == ScalarType::BFloat16,
"Bias must be BFloat16 to compute ", out.scalar_type(),
" output, but got ", bias->scalar_type());
TORCH_CHECK(out.scalar_type() != ScalarType::Half ||
bias->scalar_type() == ScalarType::Half,
"Bias must be Float16 to compute ", out.scalar_type(),
" output, but got ", bias->scalar_type());
}
{
auto bias_ = bias.value_or(Tensor());
@ -1353,6 +1373,22 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16,
"hipblaslt rowwise _scaled_mm only supports BFloat16 output but got ", out.scalar_type());
}
else if (scaling_choice_a == ScalingType::BlockWise1x32 && scaling_choice_b == ScalingType::BlockWise1x32) {
#if ROCM_VERSION >= 70000
TORCH_CHECK(at::detail::getCUDAHooks().isGPUArch({"gfx950"}),
"Block-wise scaling for Float8_e8m0fnu is only supported on gfx950");
TORCH_CHECK(mat1.size(0) % 32 == 0 && mat1.size(1) % 32 == 0 &&
mat2.size(0) % 32 == 0 && mat2.size(1) % 32 == 0,
"Matrix dimensions must be multiples of 32 for block-wise scaling");
TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16 ||
out.scalar_type() == ScalarType::Half,
"Block-wise scaling only supports BFloat16 or Half output types");
#else
TORCH_CHECK(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
#endif
}
#endif
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result, scaling_choice_a, scaling_choice_b);
@ -1430,12 +1466,14 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
params.k = args.k;
params.a = args.mata->data_ptr();
params.a_scale_ptr = args.scale_mata_ptr;
params.a_scale_dtype = args.scale_mata_dtype.value();
params.lda = args.lda;
params.a_dtype = args.mata->scalar_type();
params.a_scale_dtype = args.scale_mata_dtype.value();
params.a_scaling_type = args.scaling_mata_type.value();
params.b = args.matb->data_ptr();
params.b_scale_ptr = args.scale_matb_ptr;
params.b_scale_dtype = args.scale_matb_dtype.value();
params.ldb = args.ldb;
params.b_dtype = args.matb->scalar_type();
params.b_scale_dtype = args.scale_matb_dtype.value();

View File

@ -66,6 +66,7 @@ from torch.testing._internal.common_utils import (
parametrize,
serialTest,
skipIfHpu,
skipIfRocm,
skipIfWindows,
TEST_WITH_ROCM,
)
@ -7405,6 +7406,7 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
out = f_compiled(x, s0, s1, s2)
self.assertEqual(out_ref, out)
@skipIfRocm
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "requires gpu with fp8 support")
@requires_cuda
def test_partitioner_saves_weights_for_bw(self):

View File

@ -918,6 +918,8 @@ def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# largest power of 2 representable in `torch.float8_e4m3fn`
F8E4M3_LARGEST_POW2 = 8
# largest power of 2 representable in `torch.float4_e2m1fn_x2`
FP4E2M1FN_LARGEST_POW2 = 1.0
# max value of `torch.float8_e4m3fn` (448)
F8E4M3_MAX_VAL = torch.finfo(torch.float8_e4m3fn).max
# exponent bias of `torch.float8_e8m0fnu`
@ -926,14 +928,20 @@ F8E8M0_EXP_BIAS = 127
FP4_EBITS, FP4_MBITS = 2, 1
FP4_MAX_VAL = 6.0
def data_to_mx_scale(x, block_size):
def data_to_mx_scale(x, block_size, recipe):
# simple implementation of https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
# section 6.3, not all edge cases (such as NaN) are handled/tested
if recipe == "mxfp8":
largest_pow2 = F8E4M3_LARGEST_POW2
elif recipe == "mxfp4":
largest_pow2 = FP4E2M1FN_LARGEST_POW2
else:
raise ValueError(f"data_to_mx_scale(): Unsupported mx recipe: {recipe}")
orig_shape = x.shape
x = x.reshape(-1, block_size)
max_abs = torch.amax(torch.abs(x), 1)
largest_p2_lt_max_abs = torch.floor(torch.log2(max_abs))
scale_e8m0_unbiased = largest_p2_lt_max_abs - F8E4M3_LARGEST_POW2
scale_e8m0_unbiased = largest_p2_lt_max_abs - largest_pow2
scale_e8m0_unbiased = torch.clamp(scale_e8m0_unbiased, -1 * F8E8M0_EXP_BIAS, F8E8M0_EXP_BIAS)
scale_e8m0_biased = scale_e8m0_unbiased + F8E8M0_EXP_BIAS
scale_e8m0_biased = scale_e8m0_biased.to(torch.uint8)
@ -1415,6 +1423,7 @@ class TestFP8Matmul(TestCase):
self.assertGreaterEqual(float(cosine_sim), 0.999)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@unittest.skipIf(torch.version.hip is not None, "Float8_e4m3fn not supported on current ROCm CI setup (MI325X)")
@parametrize("which_dim_zero", [0, 1, 2])
@parametrize("use_torch_compile", [False, True])
def test_zero_dim_tensorwise(self, which_dim_zero, use_torch_compile) -> None:
@ -1553,23 +1562,24 @@ class TestFP8Matmul(TestCase):
(127, 96, 1024),
(1025, 128, 96)
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
@parametrize("recipe", ["mxfp8", "nvfp4"])
def test_blockwise_mxfp8_nvfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
if recipe == "nvfp4" and fast_accum:
return unittest.skip("fast_accum not supported in nvfp4 cublas gemm, skipping")
@parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"])
def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
return unittest.skip("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")
device = "cuda"
M, K, N = mkn
if recipe == "nvfp4" and K % 32 != 0:
return unittest.skip("K must be divisible by 32 for nvfp4 cublas gemm, skipping")
if (recipe == "nvfp4" or recipe == "mxfp4") and K % 32 != 0:
return unittest.skip("K must be divisible by 32 for nvfp4/mxfp4 cublas gemm, skipping")
BLOCK_SIZE = 16 if recipe == "nvfp4" else 32
fp4_scaling_dtype = torch.float8_e8m0fnu if torch.version.hip else torch.float8_e4m3fn
BLOCK_SIZE = 32 if torch.version.hip else (16 if recipe == "nvfp4" else 32)
require_exact_match = True
approx_match_sqnr_target = 22.0
if test_case_name == "a_eye_b_eye":
if not ((M == K) and (M == N)):
return unittest.skip("this test is only defined for M == K == N, skipping")
raise unittest.SkipTest("this test is only defined for M == K == N, skipping")
A_ref = torch.eye(M, device=device, dtype=torch.bfloat16)
B_ref = torch.eye(M, device=device, dtype=torch.bfloat16)
@ -1578,11 +1588,11 @@ class TestFP8Matmul(TestCase):
B = B_ref.to(torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
else: # nvfp4
else: # nvfp4 # mxfp4
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
elif test_case_name == "a_ones_b_ones":
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
@ -1593,11 +1603,11 @@ class TestFP8Matmul(TestCase):
B = B_ref.to(torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
else: # nvfp4
else: # nvfp4 # mxfp4
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
elif test_case_name == "a_ones_modified_b_ones":
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
@ -1609,11 +1619,11 @@ class TestFP8Matmul(TestCase):
B = B_ref.to(torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
else: # nvfp4
else: # nvfp4 # mxfp4
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
elif test_case_name == "a_ones_b_ones_modified":
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
@ -1625,11 +1635,11 @@ class TestFP8Matmul(TestCase):
B = B_ref.to(torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
else: # nvfp4
else: # nvfp4 # mxfp4
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
elif test_case_name == "a_scale_modified_b_ones":
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
@ -1643,11 +1653,11 @@ class TestFP8Matmul(TestCase):
A_ref[1][0:BLOCK_SIZE] = 4
A[1][0:BLOCK_SIZE] = 2
A_scale[1][0] = 2
else: # nvfp4
else: # nvfp4 # mxfp4
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
A_ref[1][0:BLOCK_SIZE] = 4
A.view(torch.uint8)[1][0:(BLOCK_SIZE // 2)] = 0b01000100
A_scale[1][0] = 2
@ -1664,11 +1674,11 @@ class TestFP8Matmul(TestCase):
B_ref[1][0:BLOCK_SIZE] = 4
B[1][0:BLOCK_SIZE] = 2
B_scale[1][0] = 2
else: # nvfp4
else: # nvfp4 # mxfp4
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
B_ref[1][0:BLOCK_SIZE] = 4
B.view(torch.uint8)[1][0:(BLOCK_SIZE // 2)] = 0b01000100
B_scale[1][0] = 2
@ -1688,7 +1698,7 @@ class TestFP8Matmul(TestCase):
B = B_ref.to(torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
else: # nvfp4
else: # nvfp4 # mxfp4
# scales all-ones, element data random while being exactly representable in float4_e2m1fn_x2
# generate integers in [0, 16] and cast to bfloat16
A_ref = _floatx_unpacked_to_f32(
@ -1703,8 +1713,8 @@ class TestFP8Matmul(TestCase):
).bfloat16()
A = _bfloat16_to_float4_e2m1fn_x2(A_ref)
B = _bfloat16_to_float4_e2m1fn_x2(B_ref)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn)
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
elif test_case_name == "data_random_scales_from_data":
if not K % BLOCK_SIZE == 0:
@ -1716,17 +1726,18 @@ class TestFP8Matmul(TestCase):
if recipe == "mxfp8":
# Calculate scales based on the inputs
A_scale = data_to_mx_scale(A_ref, BLOCK_SIZE)
B_scale = data_to_mx_scale(B_ref, BLOCK_SIZE)
A_scale = data_to_mx_scale(A_ref, BLOCK_SIZE, recipe)
B_scale = data_to_mx_scale(B_ref, BLOCK_SIZE, recipe)
max_val = F8E4M3_MAX_VAL
min_val = -1 * max_val
A = (A_ref.reshape(-1, BLOCK_SIZE) / A_scale.reshape(M * ceil_div(K, BLOCK_SIZE), 1).float()).reshape(M, K)
A = A.clamp(min=min_val, max=max_val).to(torch.float8_e4m3fn)
B = (B_ref.reshape(-1, BLOCK_SIZE) / B_scale.reshape(N * ceil_div(K, BLOCK_SIZE), 1).float()).reshape(N, K)
B = B.clamp(min=min_val, max=max_val).to(torch.float8_e4m3fn)
else: # nvfp4
A_scale = data_to_nvfp4_scale(A_ref, BLOCK_SIZE)
B_scale = data_to_nvfp4_scale(B_ref, BLOCK_SIZE)
else: # nvfp4 # mxfp4
scale_func = data_to_mx_scale if recipe == "mxfp4" else data_to_nvfp4_scale
A_scale = scale_func(A_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None)
B_scale = scale_func(B_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None)
max_val = FP4_MAX_VAL
min_val = -1 * max_val
@ -1737,13 +1748,14 @@ class TestFP8Matmul(TestCase):
B = B.clamp(min=min_val, max=max_val)
B = _bfloat16_to_float4_e2m1fn_x2(B)
approx_match_sqnr_target = 15.8
approx_match_sqnr_target = 12.0 if torch.version.hip else 15.8
C_ref = A_ref @ B_ref.t()
# convert to swizzled format
A_scale = to_blocked(A_scale)
B_scale = to_blocked(B_scale)
if not torch.version.hip:
A_scale = to_blocked(A_scale)
B_scale = to_blocked(B_scale)
C = torch._scaled_mm(
A,

View File

@ -120,12 +120,20 @@ def evaluate_platform_supports_fp8_grouped_gemm():
return SM90OrLater and not SM100OrLater
return False
def evaluate_platform_supports_mx_gemm():
if torch.cuda.is_available():
if torch.version.hip:
ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split('.')[:2])
if ROCM_VERSION >= (7, 0):
return 'gfx950' in torch.cuda.get_device_properties(0).gcnArchName
else:
return SM100OrLater
return False
PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_mx_gemm())
PLATFORM_SUPPORTS_FP8: bool = LazyVal(lambda: evaluate_platform_supports_fp8())
PLATFORM_SUPPORTS_FP8_GROUPED_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_fp8_grouped_gemm())
PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: TEST_CUDA and SM100OrLater)
if TEST_NUMBA:
try:
import numba.cuda

View File

@ -3999,6 +3999,7 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict(
("CUDA_C_64U", ("HIP_C_64U", CONV_TYPE, API_RUNTIME)),
("CUDA_R_8F_E4M3", ("HIP_R_8F_E4M3", CONV_TYPE, API_RUNTIME)),
("CUDA_R_8F_E5M2", ("HIP_R_8F_E5M2", CONV_TYPE, API_RUNTIME)),
("CUDA_R_4F_E2M1", ("HIP_R_4F_E2M1", CONV_TYPE, API_RUNTIME)),
(
"MAJOR_VERSION",
("hipLibraryMajorVersion", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED),
@ -7693,6 +7694,10 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict(
("CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F", ("HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F", CONV_MATH_FUNC, API_BLAS)),
("CUBLASLT_MATMUL_DESC_AMAX_D_POINTER", ("HIPBLASLT_MATMUL_DESC_AMAX_D_POINTER", CONV_MATH_FUNC, API_BLAS)),
("CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", ("HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", CONV_MATH_FUNC, API_BLAS)),
("CUBLASLT_MATMUL_DESC_A_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_A_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)),
("CUBLASLT_MATMUL_DESC_B_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_B_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)),
("CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0", ("HIPBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0", CONV_MATH_FUNC, API_BLAS)),
("CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3", ("HIPBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3", CONV_MATH_FUNC, API_BLAS)),
("cublasLtMatrixLayout_t", ("hipblasLtMatrixLayout_t", CONV_MATH_FUNC, API_BLAS)),
("cublasLtMatrixLayoutOpaque_t", ("hipblasLtMatrixLayoutOpaque_t", CONV_MATH_FUNC, API_BLAS)),
("cublasLtMatrixLayoutAttribute_t", ("hipblasLtMatrixLayoutAttribute_t", CONV_MATH_FUNC, API_BLAS)),