[ROCm] Add scaled_mm v2 support. (#165528)

Add mx fp4 support in Blas.cpp.
Updated the scale_kernel_dispatch array and ScaledGemmImplementation enum to include MXFP4 support.
Modify the tests under test_scaled_matmul_cuda accordingly.

PYTORCH_TEST_WITH_ROCM=1 python test/test_scaled_matmul_cuda.py -v -k test_blockwise
115 test passed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165528
Approved by: https://github.com/jeffdaily
This commit is contained in:
Jagadish Krishnamoorthy
2025-10-16 18:36:37 +00:00
committed by PyTorch MergeBot
parent 86fd4fc23e
commit 7669ac9402
2 changed files with 123 additions and 11 deletions

View File

@ -152,15 +152,34 @@ def infer_scale_swizzle(mat, scale):
):
return ScalingType.BlockWise1x16, SwizzleType.SWIZZLE_32_4_4
# MX
# MXFP4 w/o swizzle
if (
scale.numel()
== round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4)
or scale.numel()
== round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4)
scale.numel() == 2 * math.ceil(mat.shape[0] // 32) * mat.shape[1]
or scale.numel() == 2 * math.ceil(mat.shape[1] // 32) * mat.shape[0]
and mat.dtype == torch.float4_e2m1fn_x2
and scale.dtype == torch.float8_e8m0fnu
):
return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
if not torch.version.hip:
# MXFP8 w/ swizzle
if (
scale.numel()
== round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4)
or scale.numel()
== round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4)
and scale.dtype == torch.float8_e8m0fnu
):
return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4
else:
# MXFP8 w/o swizzle
if (
scale.numel() == math.ceil(mat.shape[0] // 32) * mat.shape[1]
or scale.numel() == math.ceil(mat.shape[1] // 32) * mat.shape[0]
and scale.dtype == torch.float8_e8m0fnu
):
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
return None, None
@ -1489,7 +1508,7 @@ class TestFP8Matmul(TestCase):
assert sqnr.item() > approx_match_sqnr_target
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM or IS_WINDOWS, mx_skip_msg)
@parametrize("recipe", ["mxfp8", "nvfp4"])
@parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"])
def test_blockwise_mxfp8_nvfp4_error_messages(self, device, recipe) -> None:
M, K, N = (1024, 512, 2048)
BLOCK_SIZE_K = 16 if recipe == "nvfp4" else 32
@ -1503,7 +1522,7 @@ class TestFP8Matmul(TestCase):
if recipe == "mxfp8":
x_lowp = x.to(e4m3_type)
y_lowp = y.to(e4m3_type).t()
else: # nvfp4
else: # nvfp4 #mxfp4
x_lowp = _bfloat16_to_float4_e2m1fn_x2(x.bfloat16())
y_lowp = _bfloat16_to_float4_e2m1fn_x2(y.bfloat16()).t()
@ -1517,7 +1536,10 @@ class TestFP8Matmul(TestCase):
if recipe == "nvfp4"
else ScalingType.BlockWise1x32
)
swizzle = SwizzleType.SWIZZLE_32_4_4
if torch.version.hip:
swizzle = SwizzleType.NO_SWIZZLE
else:
swizzle = SwizzleType.SWIZZLE_32_4_4
# Test wrong scale tensor size for scale_a with correct dtype
with self.assertRaisesRegex(