From 264e7f68a095b5b2f10e7ca66a3bbb16f6e56e28 Mon Sep 17 00:00:00 2001 From: Jagadish Krishnamoorthy Date: Fri, 19 Sep 2025 12:29:52 +0000 Subject: [PATCH] [ROCm] Fix mx fp8 and fp4 code after scaling refactor changes. (#163127) PR #151360 added mx fp8 and fp4 support on ROCm. 1. However, on recent upstream, scaling function in Blas.cpp along with test_matmul_cuda changes triggered failures. This patch corrects is_blockwise_1x32_scaling function code. 2. Fixes the m, n, k dimensions for ROCm mx case. 3. Modify FP4E2M1FN_LARGEST_POW2 (largest power of 2 representable in `torch.float4_e2m1fn_x2`) to 2. This resulted in higher SQNR value for mx fp4 test. Testing result on gfx950 w/ ROCm7.0 PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v Ran 452 tests in 22.698s OK passed 111 This is same as before. (when PR 151360 was merged) Pull Request resolved: https://github.com/pytorch/pytorch/pull/163127 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily --- aten/src/ATen/cuda/CUDABlas.cpp | 4 ++-- aten/src/ATen/native/cuda/Blas.cpp | 23 +++++++++++++++++------ test/test_matmul_cuda.py | 21 ++++++++++++++------- 3 files changed, 33 insertions(+), 15 deletions(-) diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index a81d34df4d64..c6d307481b74 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -1954,8 +1954,8 @@ void scaled_gemm( #if ROCM_VERSION >= 70000 if (at::detail::getCUDAHooks().isGPUArch({"gfx950"})) { // TODO: add constraints based on hipblaslt internals - TORCH_CHECK((m % 32 == 0) && (n % 32 == 0) && (k % 32 == 0), - "Matrix dimensions must be multiples of 32 for MX format. " + TORCH_CHECK((m % 16 == 0) && (n % 16 == 0) && (k % 128 == 0), + "M, N must be multiples of 16 and K should be multiple of 128 for MX format. " "Got m=", m, ", n=", n, ", k=", k); } #endif diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 1dab8c19c700..6dea2fe647b2 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -1138,9 +1138,14 @@ bool is_blockwise_1x16_scaling(const at::Tensor& t, const at::Tensor& scale) { bool is_blockwise_1x32_scaling(const at::Tensor& t, const at::Tensor& scale) { // TODO: We might want to enforce some structure on the shapes of the scale // tensors - return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == at::kFloat8_e8m0fnu - && scale.numel() == round_up(t.size(0), 128) * round_up(ceil_div(t.size(1), 32), 4) - && scale.is_contiguous()); + bool is_fp8_path = (isFloat8Type(t.scalar_type()) && scale.scalar_type() == at::kFloat8_e8m0fnu + && scale.numel() == round_up(t.size(0), 128) * round_up(ceil_div(t.size(1), 32), 4)); + bool is_packed_fp4_path = false; +#ifdef USE_ROCM + is_packed_fp4_path = (t.scalar_type() == ScalarType::Float4_e2m1fn_x2 && scale.scalar_type() == at::kFloat8_e8m0fnu + && scale.numel() == round_up(t.size(0), 128) * round_up(ceil_div(t.size(1) * 2, 32), 4)); +#endif + return (is_fp8_path || is_packed_fp4_path) && scale.is_contiguous(); } bool is_blockwise_1x128_scaling(const at::Tensor& t, const at::Tensor& scale) { @@ -1381,9 +1386,15 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, TORCH_CHECK(at::detail::getCUDAHooks().isGPUArch({"gfx950"}), "Block-wise scaling for Float8_e8m0fnu is only supported on gfx950"); - TORCH_CHECK(mat1.size(0) % 32 == 0 && mat1.size(1) % 32 == 0 && - mat2.size(0) % 32 == 0 && mat2.size(1) % 32 == 0, - "Matrix dimensions must be multiples of 32 for block-wise scaling"); + int packed_factor = 1; + if (mat1.scalar_type() == ScalarType::Float4_e2m1fn_x2) { + // For float4 data type, each byte stores two 4-bit floating-point values, + // effectively packing two elements into one byte. + packed_factor = 2; + } + TORCH_CHECK(mat1.size(0) % 16 == 0 && (mat1.size(1) * packed_factor) % 128 == 0 && + mat2.size(1) % 16 == 0, + "M, N must be multiples of 16 and K must be multiple of 128 for block-wise scaling"); TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16 || out.scalar_type() == ScalarType::Half, diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index ea73ccfd5b37..96f252834076 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -926,7 +926,7 @@ def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # largest power of 2 representable in `torch.float8_e4m3fn` F8E4M3_LARGEST_POW2 = 8 # largest power of 2 representable in `torch.float4_e2m1fn_x2` -FP4E2M1FN_LARGEST_POW2 = 1.0 +FP4E2M1FN_LARGEST_POW2 = 2.0 # max value of `torch.float8_e4m3fn` (448) F8E4M3_MAX_VAL = torch.finfo(torch.float8_e4m3fn).max # exponent bias of `torch.float8_e8m0fnu` @@ -1746,8 +1746,12 @@ class TestFP8Matmul(TestCase): device = "cuda" M, K, N = mkn - if (recipe == "nvfp4" or recipe == "mxfp4") and K % 32 != 0: - raise unittest.SkipTest("K must be divisible by 32 for nvfp4/mxfp4 cublas gemm, skipping") + if recipe == "nvfp4" and K % 32 != 0: + raise unittest.SkipTest("K must be divisible by 32 for nvfp4 cublas gemm, skipping") + + if torch.version.hip: + if not (M % 16 == 0 and K % 128 == 0 and N % 16 == 0): + raise unittest.SkipTest("M and N must be multiples of 16 and K must be multiple of 128 on ROCm, skipping") fp4_scaling_dtype = torch.float8_e8m0fnu if torch.version.hip else torch.float8_e4m3fn BLOCK_SIZE = 32 if torch.version.hip else (16 if recipe == "nvfp4" else 32) @@ -1912,9 +1916,12 @@ class TestFP8Matmul(TestCase): B = (B_ref.reshape(-1, BLOCK_SIZE) / B_scale.reshape(N * ceil_div(K, BLOCK_SIZE), 1).float()).reshape(N, K) B = B.clamp(min=min_val, max=max_val).to(torch.float8_e4m3fn) else: # nvfp4 # mxfp4 - scale_func = data_to_mx_scale if recipe == "mxfp4" else data_to_nvfp4_scale - A_scale = scale_func(*([A_ref, BLOCK_SIZE] + recipe if recipe == "mxfp4" else [A_ref, BLOCK_SIZE])) - B_scale = scale_func(*([B_ref, BLOCK_SIZE] + recipe if recipe == "mxfp4" else [B_ref, BLOCK_SIZE])) + if recipe == "mxfp4": + A_scale = data_to_mx_scale(A_ref, BLOCK_SIZE, recipe) + B_scale = data_to_mx_scale(B_ref, BLOCK_SIZE, recipe) + else: + A_scale = data_to_nvfp4_scale(A_ref, BLOCK_SIZE) + B_scale = data_to_nvfp4_scale(B_ref, BLOCK_SIZE) max_val = FP4_MAX_VAL min_val = -1 * max_val @@ -1925,7 +1932,7 @@ class TestFP8Matmul(TestCase): B = B.clamp(min=min_val, max=max_val) B = _bfloat16_to_float4_e2m1fn_x2(B) - approx_match_sqnr_target = 12.0 if torch.version.hip else 15.8 + approx_match_sqnr_target = 15 if torch.version.hip else 15.8 C_ref = A_ref @ B_ref.t()