Support DeepSeek-style blockwise scaling scaled-mm for fp8 on Hopper+ (#158037)

cuBLAS added support for them in CUDA 12.9. It's rather easy to call into them, the hardest thing is allowing the lhs and rhs operands to have different scaling types, as that changes the whole callstack.

The scaling format is still detected from the sizes of the scale tensors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158037
Approved by: https://github.com/eqy, https://github.com/drisspg
This commit is contained in:
Luca Wehrstedt
2025-07-17 08:19:55 +00:00
committed by PyTorch MergeBot
parent d76323d417
commit 39ac189808
8 changed files with 359 additions and 200 deletions

View File

@ -785,7 +785,7 @@ def amax_to_scale(
if float8_dtype == e4m3_type:
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
elif float8_dtype == e5m2_type:
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
res = E5M2_MAX_POS / torch.clamp(amax, min=EPS)
else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
@ -806,6 +806,20 @@ def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None):
return amax_to_scale(amax, float8_dtype, x.dtype)
def tensor_to_scale_block(
x: torch.Tensor,
float8_dtype: torch.dtype,
block_outer: int,
block_inner: int,
) -> tuple[torch.Tensor, torch.Tensor]:
x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer))
amax = x.abs().amax(dim=[1, 3], keepdim=True).float()
scale = torch.finfo(float8_dtype).max / amax
x = x.mul(scale).to(float8_dtype)
x = x.flatten(2, 3).flatten(0, 1)
scale = scale.flatten(2, 3).flatten(0, 1)
return x, scale
def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
# naive implementation: dq -> op -> q
x_fp32 = x.to(torch.float) / x_scale
@ -814,6 +828,17 @@ def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
return out_fp32.to(out_dtype)
def mm_float8_emulated_block(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
x = x.unflatten(1, (x_scale.shape[1], -1)).unflatten(0, (x_scale.shape[0], -1))
y = y.unflatten(1, (y_scale.shape[1], -1)).unflatten(0, (y_scale.shape[0], -1))
x_fp32 = x.to(torch.float) / x_scale[:, None, :, None]
y_fp32 = y.to(torch.float) / y_scale[:, None, :, None]
x_fp32 = x_fp32.flatten(2, 3).flatten(0, 1)
y_fp32 = y_fp32.flatten(2, 3).flatten(0, 1)
out_fp32 = torch.mm(x_fp32, y_fp32)
return out_fp32.to(out_dtype)
def addmm_float8_unwrapped(
a_data: torch.Tensor,
a_scale: torch.Tensor,
@ -1237,11 +1262,7 @@ class TestFP8Matmul(TestCase):
y_fp8 = y.to(e4m3_type).t()
with self.assertRaisesRegex(
RuntimeError,
re.escape(
"For RowWise scaling, scale_a should be (1024, 1) and scale_b "
"should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)"
),
RuntimeError, re.escape("Invalid scaling configuration")
):
torch._scaled_mm(
x_fp8,
@ -1252,11 +1273,7 @@ class TestFP8Matmul(TestCase):
)
with self.assertRaisesRegex(
RuntimeError,
re.escape(
" For RowWise scaling, scale_a should be (1024, 1) and scale_b "
"should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)"
),
RuntimeError, re.escape("Invalid scaling configuration")
):
torch._scaled_mm(
x_fp8,
@ -1266,22 +1283,18 @@ class TestFP8Matmul(TestCase):
out_dtype=torch.bfloat16,
)
with self.assertRaisesRegex(
RuntimeError,
re.escape("For non-TensorWise scaling, scale tensors must be 2-dimensional"),
RuntimeError, re.escape("Invalid scaling configuration")
):
torch._scaled_mm(
x_fp8,
y_fp8,
scale_a=torch.ones((M), device="cuda"),
scale_b=torch.ones((N, N), device="cuda"),
scale_b=torch.ones((N, N, 1), device="cuda"),
out_dtype=torch.bfloat16,
)
with self.assertRaisesRegex(
RuntimeError,
re.escape(
"Both scale_a and scale_b must be contiguous for RowWise scaling."
),
RuntimeError, re.escape("Invalid scaling configuration")
):
torch._scaled_mm(
x_fp8,
@ -1346,6 +1359,58 @@ class TestFP8Matmul(TestCase):
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@unittest.skipIf(not SM90OrLater, "cuBLAS blockwise scaling requires sm90+")
@unittest.skipIf(
_get_torch_cuda_version() < (12, 9),
"cuBLAS blockwise scaling added in CUDA 12.9",
)
@parametrize("output_dtype", [torch.bfloat16, torch.float32])
@parametrize("lhs_block,rhs_block", [(1, 1), (128, 1), (1, 128)])
def test_scaled_mm_vs_emulated_block_wise(self, output_dtype, lhs_block, rhs_block):
torch.manual_seed(42)
x = torch.randn(256, 512, device="cuda", dtype=output_dtype).pow(3)
y = torch.randn(768, 512, device="cuda", dtype=output_dtype).pow(3)
x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128)
y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128)
# 1x128 blocks need scales to be outer-dim-major
if lhs_block == 1:
x_scales = x_scales.t().contiguous().t()
if rhs_block == 1:
y_scales = y_scales.t().contiguous().t()
# Calculate actual F8 mm
out_scaled_mm = mm_float8(
x_fp8, y_fp8.t(), a_scale=x_scales, b_scale=y_scales.t(), output_dtype=output_dtype
)
# Calculate emulated F8 mm
out_emulated = mm_float8_emulated_block(
x_fp8, x_scales, y_fp8.t(), y_scales.t(), output_dtype
)
cosine_sim = torch.nn.functional.cosine_similarity(
out_scaled_mm.flatten().float(), out_emulated.flatten().float(), dim=0
)
self.assertGreaterEqual(float(cosine_sim), 0.999)
if output_dtype in {torch.bfloat16, torch.float16}:
atol, rtol = 6e-1, 7e-2
else:
atol, rtol = 7e-1, 2e-3
self.assertEqual(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
# One last check against the full-precision reference, to ensure we
# didn't mess up the scaling itself and made the test trivial.
cosine_sim = torch.nn.functional.cosine_similarity(
out_scaled_mm.flatten().float(), (x @ y.t()).flatten().float(), dim=0
)
self.assertGreaterEqual(float(cosine_sim), 0.999)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("which_dim_zero", [0, 1, 2])
@parametrize("use_torch_compile", [False, True])