test_matmul_cuda: Refine MX test skipping (#161009)

Replace return unittest.skip with raise unittest.SkipTest to ensure that the test suite correctly reports skipped tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161009
Approved by: https://github.com/jeffdaily
This commit is contained in:
Jagadish Krishnamoorthy
2025-08-20 00:47:42 +00:00
committed by PyTorch MergeBot
parent a3a82e3da8
commit 543896fcf3

View File

@ -1565,12 +1565,12 @@ class TestFP8Matmul(TestCase):
@parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"])
def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
return unittest.skip("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")
raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")
device = "cuda"
M, K, N = mkn
if (recipe == "nvfp4" or recipe == "mxfp4") and K % 32 != 0:
return unittest.skip("K must be divisible by 32 for nvfp4/mxfp4 cublas gemm, skipping")
raise unittest.SkipTest("K must be divisible by 32 for nvfp4/mxfp4 cublas gemm, skipping")
fp4_scaling_dtype = torch.float8_e8m0fnu if torch.version.hip else torch.float8_e4m3fn
BLOCK_SIZE = 32 if torch.version.hip else (16 if recipe == "nvfp4" else 32)
@ -1718,7 +1718,7 @@ class TestFP8Matmul(TestCase):
elif test_case_name == "data_random_scales_from_data":
if not K % BLOCK_SIZE == 0:
return unittest.skip(f"this test is only defined for K a multiple of {BLOCK_SIZE}, skipping")
raise unittest.SkipTest(f"this test is only defined for K a multiple of {BLOCK_SIZE}, skipping")
require_exact_match = False
# random data, scales from data
A_ref = torch.randn((M, K), device=device, dtype=torch.bfloat16) * 1000