mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Support DeepSeek-style blockwise scaling scaled-mm for fp8 on Hopper+ (#158037)"
This reverts commit bc65253369933160a2da3fc786d027a572faf6b7. Reverted https://github.com/pytorch/pytorch/pull/158037 on behalf of https://github.com/lw due to OSX failures are real ([comment](https://github.com/pytorch/pytorch/pull/158037#issuecomment-3079042171))
This commit is contained in:
@ -785,7 +785,7 @@ def amax_to_scale(
|
||||
if float8_dtype == e4m3_type:
|
||||
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
|
||||
elif float8_dtype == e5m2_type:
|
||||
res = E5M2_MAX_POS / torch.clamp(amax, min=EPS)
|
||||
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
|
||||
else:
|
||||
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
|
||||
|
||||
@ -806,20 +806,6 @@ def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None):
|
||||
|
||||
return amax_to_scale(amax, float8_dtype, x.dtype)
|
||||
|
||||
def tensor_to_scale_block(
|
||||
x: torch.Tensor,
|
||||
float8_dtype: torch.dtype,
|
||||
block_outer: int,
|
||||
block_inner: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer))
|
||||
amax = x.abs().amax(dim=[1, 3], keepdim=True).float()
|
||||
scale = torch.finfo(float8_dtype).max / amax
|
||||
x = x.mul(scale).to(float8_dtype)
|
||||
x = x.flatten(2, 3).flatten(0, 1)
|
||||
scale = scale.flatten(2, 3).flatten(0, 1)
|
||||
return x, scale
|
||||
|
||||
def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
|
||||
# naive implementation: dq -> op -> q
|
||||
x_fp32 = x.to(torch.float) / x_scale
|
||||
@ -828,17 +814,6 @@ def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
|
||||
|
||||
return out_fp32.to(out_dtype)
|
||||
|
||||
def mm_float8_emulated_block(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
|
||||
x = x.unflatten(1, (x_scale.shape[1], -1)).unflatten(0, (x_scale.shape[0], -1))
|
||||
y = y.unflatten(1, (y_scale.shape[1], -1)).unflatten(0, (y_scale.shape[0], -1))
|
||||
x_fp32 = x.to(torch.float) / x_scale[:, None, :, None]
|
||||
y_fp32 = y.to(torch.float) / y_scale[:, None, :, None]
|
||||
x_fp32 = x_fp32.flatten(2, 3).flatten(0, 1)
|
||||
y_fp32 = y_fp32.flatten(2, 3).flatten(0, 1)
|
||||
out_fp32 = torch.mm(x_fp32, y_fp32)
|
||||
|
||||
return out_fp32.to(out_dtype)
|
||||
|
||||
def addmm_float8_unwrapped(
|
||||
a_data: torch.Tensor,
|
||||
a_scale: torch.Tensor,
|
||||
@ -1262,7 +1237,11 @@ class TestFP8Matmul(TestCase):
|
||||
y_fp8 = y.to(e4m3_type).t()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, re.escape("Invalid scaling configuration")
|
||||
RuntimeError,
|
||||
re.escape(
|
||||
"For RowWise scaling, scale_a should be (1024, 1) and scale_b "
|
||||
"should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)"
|
||||
),
|
||||
):
|
||||
torch._scaled_mm(
|
||||
x_fp8,
|
||||
@ -1273,7 +1252,11 @@ class TestFP8Matmul(TestCase):
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, re.escape("Invalid scaling configuration")
|
||||
RuntimeError,
|
||||
re.escape(
|
||||
" For RowWise scaling, scale_a should be (1024, 1) and scale_b "
|
||||
"should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)"
|
||||
),
|
||||
):
|
||||
torch._scaled_mm(
|
||||
x_fp8,
|
||||
@ -1283,18 +1266,22 @@ class TestFP8Matmul(TestCase):
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, re.escape("Invalid scaling configuration")
|
||||
RuntimeError,
|
||||
re.escape("For non-TensorWise scaling, scale tensors must be 2-dimensional"),
|
||||
):
|
||||
torch._scaled_mm(
|
||||
x_fp8,
|
||||
y_fp8,
|
||||
scale_a=torch.ones((M), device="cuda"),
|
||||
scale_b=torch.ones((N, N, 1), device="cuda"),
|
||||
scale_b=torch.ones((N, N), device="cuda"),
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, re.escape("Invalid scaling configuration")
|
||||
RuntimeError,
|
||||
re.escape(
|
||||
"Both scale_a and scale_b must be contiguous for RowWise scaling."
|
||||
),
|
||||
):
|
||||
torch._scaled_mm(
|
||||
x_fp8,
|
||||
@ -1359,58 +1346,6 @@ class TestFP8Matmul(TestCase):
|
||||
|
||||
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@unittest.skipIf(not SM90OrLater, "cuBLAS blockwise scaling requires sm90+")
|
||||
@unittest.skipIf(
|
||||
_get_torch_cuda_version() < (12, 9),
|
||||
"cuBLAS blockwise scaling added in CUDA 12.9",
|
||||
)
|
||||
@parametrize("output_dtype", [torch.bfloat16, torch.float32])
|
||||
@parametrize("lhs_block,rhs_block", [(1, 1), (128, 1), (1, 128)])
|
||||
def test_scaled_mm_vs_emulated_block_wise(self, output_dtype, lhs_block, rhs_block):
|
||||
torch.manual_seed(42)
|
||||
|
||||
x = torch.randn(256, 512, device="cuda", dtype=output_dtype).pow(3)
|
||||
y = torch.randn(768, 512, device="cuda", dtype=output_dtype).pow(3)
|
||||
|
||||
x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128)
|
||||
y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128)
|
||||
|
||||
# 1x128 blocks need scales to be outer-dim-major
|
||||
if lhs_block == 1:
|
||||
x_scales = x_scales.t().contiguous().t()
|
||||
if rhs_block == 1:
|
||||
y_scales = y_scales.t().contiguous().t()
|
||||
|
||||
# Calculate actual F8 mm
|
||||
out_scaled_mm = mm_float8(
|
||||
x_fp8, y_fp8.t(), a_scale=x_scales, b_scale=y_scales.t(), output_dtype=output_dtype
|
||||
)
|
||||
|
||||
# Calculate emulated F8 mm
|
||||
out_emulated = mm_float8_emulated_block(
|
||||
x_fp8, x_scales, y_fp8.t(), y_scales.t(), output_dtype
|
||||
)
|
||||
|
||||
cosine_sim = torch.nn.functional.cosine_similarity(
|
||||
out_scaled_mm.flatten().float(), out_emulated.flatten().float(), dim=0
|
||||
)
|
||||
self.assertGreaterEqual(float(cosine_sim), 0.999)
|
||||
|
||||
if output_dtype in {torch.bfloat16, torch.float16}:
|
||||
atol, rtol = 6e-1, 7e-2
|
||||
else:
|
||||
atol, rtol = 7e-1, 2e-3
|
||||
|
||||
self.assertEqual(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
||||
|
||||
# One last check against the full-precision reference, to ensure we
|
||||
# didn't mess up the scaling itself and made the test trivial.
|
||||
cosine_sim = torch.nn.functional.cosine_similarity(
|
||||
out_scaled_mm.flatten().float(), (x @ y.t()).flatten().float(), dim=0
|
||||
)
|
||||
self.assertGreaterEqual(float(cosine_sim), 0.999)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@parametrize("which_dim_zero", [0, 1, 2])
|
||||
@parametrize("use_torch_compile", [False, True])
|
||||
|
Reference in New Issue
Block a user