[ROCm] Fix mx fp8 and fp4 code after scaling refactor changes. (#163127)

PR #151360 added mx fp8 and fp4 support on ROCm.
1. However, on recent upstream, scaling function in Blas.cpp along with test_matmul_cuda changes triggered failures.
This patch corrects is_blockwise_1x32_scaling function code.

2. Fixes the m, n, k dimensions for ROCm mx case.

3.  Modify FP4E2M1FN_LARGEST_POW2 (largest power of 2 representable in `torch.float4_e2m1fn_x2`) to 2.
This resulted in higher SQNR value for mx fp4 test.

Testing result on gfx950 w/ ROCm7.0

PYTORCH_TEST_WITH_ROCM=1 python test/test_matmul_cuda.py -k test_blockwise -v Ran 452 tests in 22.698s
OK passed 111
This is same as before. (when PR 151360 was merged)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163127
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
Jagadish Krishnamoorthy
2025-09-19 12:29:52 +00:00
committed by PyTorch MergeBot
parent bee362c381
commit 264e7f68a0
3 changed files with 33 additions and 15 deletions

View File

@ -926,7 +926,7 @@ def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# largest power of 2 representable in `torch.float8_e4m3fn`
F8E4M3_LARGEST_POW2 = 8
# largest power of 2 representable in `torch.float4_e2m1fn_x2`
FP4E2M1FN_LARGEST_POW2 = 1.0
FP4E2M1FN_LARGEST_POW2 = 2.0
# max value of `torch.float8_e4m3fn` (448)
F8E4M3_MAX_VAL = torch.finfo(torch.float8_e4m3fn).max
# exponent bias of `torch.float8_e8m0fnu`
@ -1746,8 +1746,12 @@ class TestFP8Matmul(TestCase):
device = "cuda"
M, K, N = mkn
if (recipe == "nvfp4" or recipe == "mxfp4") and K % 32 != 0:
raise unittest.SkipTest("K must be divisible by 32 for nvfp4/mxfp4 cublas gemm, skipping")
if recipe == "nvfp4" and K % 32 != 0:
raise unittest.SkipTest("K must be divisible by 32 for nvfp4 cublas gemm, skipping")
if torch.version.hip:
if not (M % 16 == 0 and K % 128 == 0 and N % 16 == 0):
raise unittest.SkipTest("M and N must be multiples of 16 and K must be multiple of 128 on ROCm, skipping")
fp4_scaling_dtype = torch.float8_e8m0fnu if torch.version.hip else torch.float8_e4m3fn
BLOCK_SIZE = 32 if torch.version.hip else (16 if recipe == "nvfp4" else 32)
@ -1912,9 +1916,12 @@ class TestFP8Matmul(TestCase):
B = (B_ref.reshape(-1, BLOCK_SIZE) / B_scale.reshape(N * ceil_div(K, BLOCK_SIZE), 1).float()).reshape(N, K)
B = B.clamp(min=min_val, max=max_val).to(torch.float8_e4m3fn)
else: # nvfp4 # mxfp4
scale_func = data_to_mx_scale if recipe == "mxfp4" else data_to_nvfp4_scale
A_scale = scale_func(*([A_ref, BLOCK_SIZE] + recipe if recipe == "mxfp4" else [A_ref, BLOCK_SIZE]))
B_scale = scale_func(*([B_ref, BLOCK_SIZE] + recipe if recipe == "mxfp4" else [B_ref, BLOCK_SIZE]))
if recipe == "mxfp4":
A_scale = data_to_mx_scale(A_ref, BLOCK_SIZE, recipe)
B_scale = data_to_mx_scale(B_ref, BLOCK_SIZE, recipe)
else:
A_scale = data_to_nvfp4_scale(A_ref, BLOCK_SIZE)
B_scale = data_to_nvfp4_scale(B_ref, BLOCK_SIZE)
max_val = FP4_MAX_VAL
min_val = -1 * max_val
@ -1925,7 +1932,7 @@ class TestFP8Matmul(TestCase):
B = B.clamp(min=min_val, max=max_val)
B = _bfloat16_to_float4_e2m1fn_x2(B)
approx_match_sqnr_target = 12.0 if torch.version.hip else 15.8
approx_match_sqnr_target = 15 if torch.version.hip else 15.8
C_ref = A_ref @ B_ref.t()