mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[FP8][cuBLAS][H100] only test fp32 outputs for rowwise _scaled_mm
on H100 (#162022)
only cuBLAS supports float32 output and cuBLAS only supports rowwise for SM 9.0 Intended to land after #161305 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162022 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
264e7f68a0
commit
f8f230a801
@ -1530,22 +1530,34 @@ class TestFP8Matmul(TestCase):
|
||||
x_fp8 = to_fp8_saturated(x * x_scales, e4m3_type)
|
||||
y_fp8 = to_fp8_saturated(y * y_scales, e4m3_type)
|
||||
|
||||
# Calculate actual F8 mm
|
||||
out_scaled_mm = mm_float8(
|
||||
x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype
|
||||
)
|
||||
def test():
|
||||
# Calculate actual F8 mm
|
||||
out_scaled_mm = mm_float8(
|
||||
x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype
|
||||
)
|
||||
|
||||
# Calculate emulated F8 mm
|
||||
out_emulated = mm_float8_emulated(
|
||||
x_fp8, x_scales, y_fp8, y_scales, output_dtype
|
||||
)
|
||||
# Calculate emulated F8 mm
|
||||
out_emulated = mm_float8_emulated(
|
||||
x_fp8, x_scales, y_fp8, y_scales, output_dtype
|
||||
)
|
||||
|
||||
if base_dtype in {torch.bfloat16, torch.float16}:
|
||||
atol, rtol = 7e-2, 7e-2
|
||||
if base_dtype in {torch.bfloat16, torch.float16}:
|
||||
atol, rtol = 7e-2, 7e-2
|
||||
else:
|
||||
atol, rtol = 2e-3, 2e-3
|
||||
|
||||
self.assertEqual(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
||||
|
||||
# only cuBLAS supports rowwise with fp32 output and cuBLAS only supports
|
||||
# rowwise on SM 9.0
|
||||
if torch.cuda.get_device_capability != (9, 0) and output_dtype == torch.float:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Only bf16 high precision output types are supported for row-wise scaling."
|
||||
):
|
||||
test()
|
||||
else:
|
||||
atol, rtol = 2e-3, 2e-3
|
||||
|
||||
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
||||
test()
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@unittest.skipIf(not IS_SM90, "cuBLAS blockwise scaling requires sm90+")
|
||||
|
Reference in New Issue
Block a user