mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
[ROCm] Add scaled_mm v2 support. (#165528)
Add mx fp4 support in Blas.cpp. Updated the scale_kernel_dispatch array and ScaledGemmImplementation enum to include MXFP4 support. Modify the tests under test_scaled_matmul_cuda accordingly. PYTORCH_TEST_WITH_ROCM=1 python test/test_scaled_matmul_cuda.py -v -k test_blockwise 115 test passed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165528 Approved by: https://github.com/jeffdaily
This commit is contained in:
committed by
PyTorch MergeBot
parent
86fd4fc23e
commit
7669ac9402
@ -152,15 +152,34 @@ def infer_scale_swizzle(mat, scale):
|
||||
):
|
||||
return ScalingType.BlockWise1x16, SwizzleType.SWIZZLE_32_4_4
|
||||
|
||||
# MX
|
||||
# MXFP4 w/o swizzle
|
||||
if (
|
||||
scale.numel()
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4)
|
||||
or scale.numel()
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4)
|
||||
scale.numel() == 2 * math.ceil(mat.shape[0] // 32) * mat.shape[1]
|
||||
or scale.numel() == 2 * math.ceil(mat.shape[1] // 32) * mat.shape[0]
|
||||
and mat.dtype == torch.float4_e2m1fn_x2
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4
|
||||
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
|
||||
|
||||
if not torch.version.hip:
|
||||
# MXFP8 w/ swizzle
|
||||
if (
|
||||
scale.numel()
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4)
|
||||
or scale.numel()
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4)
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4
|
||||
|
||||
else:
|
||||
# MXFP8 w/o swizzle
|
||||
if (
|
||||
scale.numel() == math.ceil(mat.shape[0] // 32) * mat.shape[1]
|
||||
or scale.numel() == math.ceil(mat.shape[1] // 32) * mat.shape[0]
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
|
||||
|
||||
return None, None
|
||||
|
||||
@ -1489,7 +1508,7 @@ class TestFP8Matmul(TestCase):
|
||||
assert sqnr.item() > approx_match_sqnr_target
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM or IS_WINDOWS, mx_skip_msg)
|
||||
@parametrize("recipe", ["mxfp8", "nvfp4"])
|
||||
@parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"])
|
||||
def test_blockwise_mxfp8_nvfp4_error_messages(self, device, recipe) -> None:
|
||||
M, K, N = (1024, 512, 2048)
|
||||
BLOCK_SIZE_K = 16 if recipe == "nvfp4" else 32
|
||||
@ -1503,7 +1522,7 @@ class TestFP8Matmul(TestCase):
|
||||
if recipe == "mxfp8":
|
||||
x_lowp = x.to(e4m3_type)
|
||||
y_lowp = y.to(e4m3_type).t()
|
||||
else: # nvfp4
|
||||
else: # nvfp4 #mxfp4
|
||||
x_lowp = _bfloat16_to_float4_e2m1fn_x2(x.bfloat16())
|
||||
y_lowp = _bfloat16_to_float4_e2m1fn_x2(y.bfloat16()).t()
|
||||
|
||||
@ -1517,7 +1536,10 @@ class TestFP8Matmul(TestCase):
|
||||
if recipe == "nvfp4"
|
||||
else ScalingType.BlockWise1x32
|
||||
)
|
||||
swizzle = SwizzleType.SWIZZLE_32_4_4
|
||||
if torch.version.hip:
|
||||
swizzle = SwizzleType.NO_SWIZZLE
|
||||
else:
|
||||
swizzle = SwizzleType.SWIZZLE_32_4_4
|
||||
|
||||
# Test wrong scale tensor size for scale_a with correct dtype
|
||||
with self.assertRaisesRegex(
|
||||
|
Reference in New Issue
Block a user