mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
test_matmul_cuda: Refine MX test skipping (#161009)
Replace return unittest.skip with raise unittest.SkipTest to ensure that the test suite correctly reports skipped tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161009 Approved by: https://github.com/jeffdaily
This commit is contained in:
committed by
PyTorch MergeBot
parent
a3a82e3da8
commit
543896fcf3
@ -1565,12 +1565,12 @@ class TestFP8Matmul(TestCase):
|
||||
@parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"])
|
||||
def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
|
||||
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
|
||||
return unittest.skip("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")
|
||||
raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")
|
||||
|
||||
device = "cuda"
|
||||
M, K, N = mkn
|
||||
if (recipe == "nvfp4" or recipe == "mxfp4") and K % 32 != 0:
|
||||
return unittest.skip("K must be divisible by 32 for nvfp4/mxfp4 cublas gemm, skipping")
|
||||
raise unittest.SkipTest("K must be divisible by 32 for nvfp4/mxfp4 cublas gemm, skipping")
|
||||
|
||||
fp4_scaling_dtype = torch.float8_e8m0fnu if torch.version.hip else torch.float8_e4m3fn
|
||||
BLOCK_SIZE = 32 if torch.version.hip else (16 if recipe == "nvfp4" else 32)
|
||||
@ -1718,7 +1718,7 @@ class TestFP8Matmul(TestCase):
|
||||
|
||||
elif test_case_name == "data_random_scales_from_data":
|
||||
if not K % BLOCK_SIZE == 0:
|
||||
return unittest.skip(f"this test is only defined for K a multiple of {BLOCK_SIZE}, skipping")
|
||||
raise unittest.SkipTest(f"this test is only defined for K a multiple of {BLOCK_SIZE}, skipping")
|
||||
require_exact_match = False
|
||||
# random data, scales from data
|
||||
A_ref = torch.randn((M, K), device=device, dtype=torch.bfloat16) * 1000
|
||||
|
Reference in New Issue
Block a user