Compare commits

...

1 Commits

Author SHA1 Message Date
8f52dccd56 [FP8] Additional fp8 rowwise-scaling testing
Summary:

During working on a nasty numerics bug involving rowwise scaled mm
calls, I wrote a bunch of testing to convince myself that everything was
ok. That testing is useful to OSS, so here it is.

Note: For main rowwise numerics testing, don't use fp16 as an output
type - it overflows too easily, and pretty easily just infs-out the
tests.

Test Plan:

```
pytest -sv -k "test_float8_rowwise_scaling_numerics" test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>

[ghstack-poisoned]
2025-11-17 09:07:00 -08:00

View File

@ -420,7 +420,10 @@ def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
y: The tensor to compare to the original tensor.
"""
Ps = torch.norm(x)
Pn = torch.norm(x - y)
# If Ps == Pn, then we get inf as result - prevent this
eps = 1e-8
Pn = torch.norm(x - y + eps)
return 20 * torch.log10(Ps / Pn)
@ -1122,6 +1125,145 @@ class TestFP8Matmul(TestCase):
out_fp8_s = scaled_mm_wrap(x, y, scale_a=scale_a, scale_b=scale_b, out_dtype=e4m3_type, use_fast_accum=True)
self.assertEqual(out_fp8, out_fp8_s)
@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific")
@parametrize("M,K,N", [
# Small test case
(32, 128, 64),
# Medium case, medium-K
(8192, 2048, 8192),
# Medium case, larger-K
(8192, 16384, 4096),
# Large test case from a debug effort
(323162, 512, 2048),
])
# Note(slayton58): fp16 doesn't have the range for output type, really want it to be bf16/fp32
@parametrize("base_dtype", [torch.bfloat16, ])
@parametrize("use_fast_accum", [True, False])
@parametrize("bias_dtype", [None, "same", torch.float32], name_fn=lambda t: f"{t}")
@with_tf32_off
def test_float8_rowwise_scaling_numerics(
self,
device,
M: int,
K: int,
N: int,
base_dtype: torch.dtype,
use_fast_accum: bool,
bias_dtype: torch.dtype | None
) -> None:
# Fp32 out_dtype is only supported by cuBLAS, which however only started
# shipping row-wise kernels in CUDA 12.9, and only for sm90+.
if base_dtype is torch.float32:
if torch.version.hip:
raise unittest.SkipTest("hipblaslt rowwise _scaled_mm only supports BFloat16")
if _get_torch_cuda_version() < (12, 9):
raise unittest.SkipTest("Need CUDA 12.9+ for row-wise fp8 w/ cuBLAS")
if torch.cuda.get_device_capability() < (9, 0):
raise unittest.SkipTest("Need sm90+ for row-wise fp8 w/ cuBLAS")
if base_dtype is torch.float16:
if torch.version.hip:
raise unittest.SkipTest("hipblaslt rowwise _scaled_mm only supports BFloat16")
if torch.cuda.get_device_capability() < (9, 0):
raise unittest.SkipTest("Need sm90+ for row-wise fp8 w/ cuBLAS")
x = torch.randn((M, K), device=device).to(base_dtype)
y = torch.randn((N, K), device=device).to(base_dtype)
bias = None
if bias_dtype is not None:
bias_dtype_ = y.dtype if bias_dtype == "same" else torch.float32
bias = torch.randn((N,), device=device, dtype=bias_dtype_)
e4m3_type = torch.float8_e4m3fn
x_scales = tensor_to_scale(x, torch.float8_e4m3fn, dim=1).float()
y_scales = tensor_to_scale(y, torch.float8_e4m3fn, dim=1).float()
x_fp8 = to_fp8_saturated(x * x_scales, e4m3_type)
y_fp8 = to_fp8_saturated(y * y_scales, e4m3_type)
x_scales = x_scales.reciprocal()
y_scales = y_scales.reciprocal()
x_dq = x_fp8.to(torch.float32) * x_scales
y_dq = y_fp8.to(torch.float32) * y_scales
# Actual fp8 rowwise call
out_fp8 = scaled_mm_wrap(
x_fp8,
y_fp8.t(),
scale_a=x_scales,
scale_b=y_scales.t().contiguous(),
out_dtype=base_dtype,
use_fast_accum=use_fast_accum,
bias=bias.to(base_dtype) if bias is not None else bias,
)
# full high-precision gemm
out_hp = x @ y.t()
if bias is not None:
out_hp += bias.float()
# dequant gemm - high-precision gemm w/low-precision inputs
# Isolates gemm-error from quantization-error
out_dq = x_dq @ y_dq.t()
if bias is not None:
out_dq += bias.float()
out_dq = out_dq.to(base_dtype)
# Emulate rowwise scaling w/an unscaled fp8 x fp8 -> bf16 gemm, then apply
# rowwise scale afterwards
# Isolates rowwise-kernel-specific problems as this runs through a different kernel to rowwise
def emulate_fp8_rowwise(x_fp8, y_fp8, x_scales, y_scales, bias=None):
scale_tmp = torch.ones((1, ), device=x_fp8.device, dtype=torch.float32)
out_tmp = scaled_mm_wrap(
x_fp8,
y_fp8.t(),
scale_a=scale_tmp,
scale_b=scale_tmp,
out_dtype=torch.float32,
)
total_scale = torch.outer(x_scales.squeeze(), y_scales.squeeze())
out = out_tmp * total_scale
if bias is not None:
out += bias
return out.to(base_dtype)
out_emulated = emulate_fp8_rowwise(x_fp8, y_fp8, x_scales, y_scales, bias)
# Compute SNR across different combinations (vs. fp8 result)
# Note: result is data-dependent, so pick a somewhat arbitrary value for dq/emulated
# - fp8 vs. dq and fp8 vs. emulated should be high
# - fp8 vs. hp should be lower (due to capturing quant error as well)
snr_limit_hp = 25.
snr_limit_dq_emulated = 60
def check_snr(Ps, Pq, limit):
snr = compute_error(Ps, Pq)
self.assertEqual(True, torch.isfinite(snr).any().item())
self.assertGreaterEqual(snr.item(), limit)
check_snr(out_fp8, out_dq, snr_limit_dq_emulated)
check_snr(out_fp8, out_emulated, snr_limit_dq_emulated)
check_snr(out_fp8, out_hp, snr_limit_hp)
# Quick check of values (so we're not just looking at sqnr)
# - Check fp8 vs. QDQ
if base_dtype in {torch.bfloat16, torch.float16}:
atol, rtol = 7e-2, 7e-2
else:
atol, rtol = 2e-3, 2e-3
self.assertEqual(out_fp8, out_dq, atol=atol, rtol=rtol)
@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific")
@ -1240,6 +1382,10 @@ class TestFP8Matmul(TestCase):
):
e5m2()
# Note: this is a basic test, on a single, small problem size.
# Mostly replaced by test_float8_rowwise_scaling_numerics, but this *does* test
# fp16 output, which isn't tested elsewhere, because it's hard to make it not
# overflow and give nonsense test results :)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific")
@parametrize("base_dtype", [torch.bfloat16, torch.float16, torch.float32])