mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support DeepSeek-style blockwise scaling scaled-mm for fp8 on Hopper+ (#158037)
cuBLAS added support for them in CUDA 12.9. It's rather easy to call into them, the hardest thing is allowing the lhs and rhs operands to have different scaling types, as that changes the whole callstack. The scaling format is still detected from the sizes of the scale tensors. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158037 Approved by: https://github.com/eqy, https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
d76323d417
commit
39ac189808
@ -785,7 +785,7 @@ def amax_to_scale(
|
||||
if float8_dtype == e4m3_type:
|
||||
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
|
||||
elif float8_dtype == e5m2_type:
|
||||
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
|
||||
res = E5M2_MAX_POS / torch.clamp(amax, min=EPS)
|
||||
else:
|
||||
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
|
||||
|
||||
@ -806,6 +806,20 @@ def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None):
|
||||
|
||||
return amax_to_scale(amax, float8_dtype, x.dtype)
|
||||
|
||||
def tensor_to_scale_block(
|
||||
x: torch.Tensor,
|
||||
float8_dtype: torch.dtype,
|
||||
block_outer: int,
|
||||
block_inner: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer))
|
||||
amax = x.abs().amax(dim=[1, 3], keepdim=True).float()
|
||||
scale = torch.finfo(float8_dtype).max / amax
|
||||
x = x.mul(scale).to(float8_dtype)
|
||||
x = x.flatten(2, 3).flatten(0, 1)
|
||||
scale = scale.flatten(2, 3).flatten(0, 1)
|
||||
return x, scale
|
||||
|
||||
def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
|
||||
# naive implementation: dq -> op -> q
|
||||
x_fp32 = x.to(torch.float) / x_scale
|
||||
@ -814,6 +828,17 @@ def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
|
||||
|
||||
return out_fp32.to(out_dtype)
|
||||
|
||||
def mm_float8_emulated_block(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
|
||||
x = x.unflatten(1, (x_scale.shape[1], -1)).unflatten(0, (x_scale.shape[0], -1))
|
||||
y = y.unflatten(1, (y_scale.shape[1], -1)).unflatten(0, (y_scale.shape[0], -1))
|
||||
x_fp32 = x.to(torch.float) / x_scale[:, None, :, None]
|
||||
y_fp32 = y.to(torch.float) / y_scale[:, None, :, None]
|
||||
x_fp32 = x_fp32.flatten(2, 3).flatten(0, 1)
|
||||
y_fp32 = y_fp32.flatten(2, 3).flatten(0, 1)
|
||||
out_fp32 = torch.mm(x_fp32, y_fp32)
|
||||
|
||||
return out_fp32.to(out_dtype)
|
||||
|
||||
def addmm_float8_unwrapped(
|
||||
a_data: torch.Tensor,
|
||||
a_scale: torch.Tensor,
|
||||
@ -1237,11 +1262,7 @@ class TestFP8Matmul(TestCase):
|
||||
y_fp8 = y.to(e4m3_type).t()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape(
|
||||
"For RowWise scaling, scale_a should be (1024, 1) and scale_b "
|
||||
"should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)"
|
||||
),
|
||||
RuntimeError, re.escape("Invalid scaling configuration")
|
||||
):
|
||||
torch._scaled_mm(
|
||||
x_fp8,
|
||||
@ -1252,11 +1273,7 @@ class TestFP8Matmul(TestCase):
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape(
|
||||
" For RowWise scaling, scale_a should be (1024, 1) and scale_b "
|
||||
"should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)"
|
||||
),
|
||||
RuntimeError, re.escape("Invalid scaling configuration")
|
||||
):
|
||||
torch._scaled_mm(
|
||||
x_fp8,
|
||||
@ -1266,22 +1283,18 @@ class TestFP8Matmul(TestCase):
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape("For non-TensorWise scaling, scale tensors must be 2-dimensional"),
|
||||
RuntimeError, re.escape("Invalid scaling configuration")
|
||||
):
|
||||
torch._scaled_mm(
|
||||
x_fp8,
|
||||
y_fp8,
|
||||
scale_a=torch.ones((M), device="cuda"),
|
||||
scale_b=torch.ones((N, N), device="cuda"),
|
||||
scale_b=torch.ones((N, N, 1), device="cuda"),
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape(
|
||||
"Both scale_a and scale_b must be contiguous for RowWise scaling."
|
||||
),
|
||||
RuntimeError, re.escape("Invalid scaling configuration")
|
||||
):
|
||||
torch._scaled_mm(
|
||||
x_fp8,
|
||||
@ -1346,6 +1359,58 @@ class TestFP8Matmul(TestCase):
|
||||
|
||||
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@unittest.skipIf(not SM90OrLater, "cuBLAS blockwise scaling requires sm90+")
|
||||
@unittest.skipIf(
|
||||
_get_torch_cuda_version() < (12, 9),
|
||||
"cuBLAS blockwise scaling added in CUDA 12.9",
|
||||
)
|
||||
@parametrize("output_dtype", [torch.bfloat16, torch.float32])
|
||||
@parametrize("lhs_block,rhs_block", [(1, 1), (128, 1), (1, 128)])
|
||||
def test_scaled_mm_vs_emulated_block_wise(self, output_dtype, lhs_block, rhs_block):
|
||||
torch.manual_seed(42)
|
||||
|
||||
x = torch.randn(256, 512, device="cuda", dtype=output_dtype).pow(3)
|
||||
y = torch.randn(768, 512, device="cuda", dtype=output_dtype).pow(3)
|
||||
|
||||
x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128)
|
||||
y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128)
|
||||
|
||||
# 1x128 blocks need scales to be outer-dim-major
|
||||
if lhs_block == 1:
|
||||
x_scales = x_scales.t().contiguous().t()
|
||||
if rhs_block == 1:
|
||||
y_scales = y_scales.t().contiguous().t()
|
||||
|
||||
# Calculate actual F8 mm
|
||||
out_scaled_mm = mm_float8(
|
||||
x_fp8, y_fp8.t(), a_scale=x_scales, b_scale=y_scales.t(), output_dtype=output_dtype
|
||||
)
|
||||
|
||||
# Calculate emulated F8 mm
|
||||
out_emulated = mm_float8_emulated_block(
|
||||
x_fp8, x_scales, y_fp8.t(), y_scales.t(), output_dtype
|
||||
)
|
||||
|
||||
cosine_sim = torch.nn.functional.cosine_similarity(
|
||||
out_scaled_mm.flatten().float(), out_emulated.flatten().float(), dim=0
|
||||
)
|
||||
self.assertGreaterEqual(float(cosine_sim), 0.999)
|
||||
|
||||
if output_dtype in {torch.bfloat16, torch.float16}:
|
||||
atol, rtol = 6e-1, 7e-2
|
||||
else:
|
||||
atol, rtol = 7e-1, 2e-3
|
||||
|
||||
self.assertEqual(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
||||
|
||||
# One last check against the full-precision reference, to ensure we
|
||||
# didn't mess up the scaling itself and made the test trivial.
|
||||
cosine_sim = torch.nn.functional.cosine_similarity(
|
||||
out_scaled_mm.flatten().float(), (x @ y.t()).flatten().float(), dim=0
|
||||
)
|
||||
self.assertGreaterEqual(float(cosine_sim), 0.999)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@parametrize("which_dim_zero", [0, 1, 2])
|
||||
@parametrize("use_torch_compile", [False, True])
|
||||
|
Reference in New Issue
Block a user