[FP8][cuBLAS][H100] only test fp32 outputs for rowwise _scaled_mm on H100 (#162022)

only cuBLAS supports float32 output and cuBLAS only supports rowwise for SM 9.0

Intended to land after #161305

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162022
Approved by: https://github.com/ngimel
This commit is contained in:
Eddie Yan
2025-09-19 15:18:13 +00:00
committed by PyTorch MergeBot
parent 264e7f68a0
commit f8f230a801

View File

@ -1530,22 +1530,34 @@ class TestFP8Matmul(TestCase):
x_fp8 = to_fp8_saturated(x * x_scales, e4m3_type)
y_fp8 = to_fp8_saturated(y * y_scales, e4m3_type)
# Calculate actual F8 mm
out_scaled_mm = mm_float8(
x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype
)
def test():
# Calculate actual F8 mm
out_scaled_mm = mm_float8(
x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype
)
# Calculate emulated F8 mm
out_emulated = mm_float8_emulated(
x_fp8, x_scales, y_fp8, y_scales, output_dtype
)
# Calculate emulated F8 mm
out_emulated = mm_float8_emulated(
x_fp8, x_scales, y_fp8, y_scales, output_dtype
)
if base_dtype in {torch.bfloat16, torch.float16}:
atol, rtol = 7e-2, 7e-2
if base_dtype in {torch.bfloat16, torch.float16}:
atol, rtol = 7e-2, 7e-2
else:
atol, rtol = 2e-3, 2e-3
self.assertEqual(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
# only cuBLAS supports rowwise with fp32 output and cuBLAS only supports
# rowwise on SM 9.0
if torch.cuda.get_device_capability != (9, 0) and output_dtype == torch.float:
with self.assertRaisesRegex(
RuntimeError,
"Only bf16 high precision output types are supported for row-wise scaling."
):
test()
else:
atol, rtol = 2e-3, 2e-3
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
test()
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@unittest.skipIf(not IS_SM90, "cuBLAS blockwise scaling requires sm90+")