[cuBLASLt][FP8] cuBLASLt appears to support float8 rowwise-scaling on H100 (#161305)

Following #157905 I think the macro around
```
  TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
```
was never updated and this would cause `float8` tests to fail. Also it appears the `Lt` accepts two inputs with `e4m3` and `e5m2` dtypes simultaneously, so removing that check here as well...

CC @lw

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161305
Approved by: https://github.com/Skylion007, https://github.com/drisspg, https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
Eddie Yan
2025-09-05 16:55:07 +00:00
committed by PyTorch MergeBot
parent b2c7b9ad2d
commit c2a3024617
3 changed files with 27 additions and 12 deletions

View File

@ -1937,11 +1937,11 @@ void scaled_gemm(
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb));
cublasLtMatmulDescAttributes_t matmulDescA = CUBLASLT_MATMUL_DESC_A_SCALE_POINTER;
cublasLtMatmulDescAttributes_t matmulDescB = CUBLASLT_MATMUL_DESC_B_SCALE_POINTER;
#if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
// hipblaslt supported row-wise before cublas, and did so their own way (via
// the SCALE_POINTERSs), but then migrated to match how cublas does it (via
// the SCALE_MODEs). Here we check for this early custom mode.
bool use_rowwise = (mat1_scaling_type == ScalingType::RowWise && mat2_scaling_type == ScalingType::RowWise);
#if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
if (use_rowwise) {
matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT;
matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT;
@ -1956,8 +1956,12 @@ void scaled_gemm(
}
#endif
}
#else
// rowwise isn't supported using cublaslt or older hipblaslt
#elif (CUDA_VERSION < 12090) && !defined(USE_ROCM)
// hipblaslt supported row-wise before cublas, and did so their own way (via
// the SCALE_POINTERSs), but then migrated to match how cublas does it (via
// the SCALE_MODEs). Here we check for this early custom mode.
bool use_rowwise = (mat1_scaling_type == ScalingType::RowWise && mat2_scaling_type == ScalingType::RowWise);
// rowwise isn't supported using older cublaslt or older hipblaslt
TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
#endif // if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
computeDesc.setAttribute(matmulDescA, mat1_scale_ptr);

View File

@ -465,7 +465,10 @@ class TestFP8Lowering(TestCase):
# autotuning for the compiled case, the results can be different because of
# the way blocks of results are accumulated (float addition not associative), so
# setting a small absolute tolerance in these tests
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
if dtype == torch.bfloat16:
self.assertEqual(y_eager, y_compiled, rtol=5e-2, atol=0.07)
else:
self.assertEqual(y_eager, y_compiled, rtol=1e-2, atol=0.05)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@unittest.skipIf(
@ -611,7 +614,7 @@ class TestFP8Lowering(TestCase):
)
self.assertEqual(y_eager.dtype, dtype)
self.assertEqual(y_compiled.dtype, dtype)
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
torch.testing.assert_close(y_eager, y_compiled, rtol=5e-2, atol=0.07)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@unittest.skipIf(
@ -744,7 +747,7 @@ class TestFP8Lowering(TestCase):
)
self.assertEqual(y_eager.dtype, dtype)
self.assertEqual(y_compiled.dtype, dtype)
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.07)
torch.testing.assert_close(y_eager, y_compiled, rtol=5e-2, atol=0.07)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("M", (1, 3, 33, 257, 1024))

View File

@ -1315,18 +1315,26 @@ class TestFP8Matmul(TestCase):
out_dtype=torch.bfloat16,
)
# Note re.compile is used, not re.escape. This is to accommodate fn vs fnuz type message.
with self.assertRaisesRegex(
RuntimeError,
r"Expected b\.dtype\(\) == at::kFloat8_e4m3fnu?z? to be true, but got false\.",
):
torch._scaled_mm(
def e5m2():
out = torch._scaled_mm(
x_fp8,
y_fp8.to(e5m2_type),
scale_a=torch.ones((M, 1), device="cuda"),
scale_b=torch.ones((1, N), device="cuda"),
out_dtype=torch.bfloat16,
)
return out
if torch.cuda.get_device_capability() == (9, 0) and torch.version.cuda and torch.version.cuda >= "12.9":
out = e5m2()
self.assertEqual(out, torch.ones_like(out) * 128.)
else:
# Note re.compile is used, not re.escape. This is to accommodate fn vs fnuz type message.
with self.assertRaisesRegex(
RuntimeError,
r"Expected b\.dtype\(\) == at::kFloat8_e4m3fnu?z? to be true, but got false\.",
):
e5m2()
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific")