mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[cuBLASLt][FP8] cuBLASLt
appears to support float8 rowwise-scaling on H100 (#161305)
Following #157905 I think the macro around ``` TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt"); ``` was never updated and this would cause `float8` tests to fail. Also it appears the `Lt` accepts two inputs with `e4m3` and `e5m2` dtypes simultaneously, so removing that check here as well... CC @lw Pull Request resolved: https://github.com/pytorch/pytorch/pull/161305 Approved by: https://github.com/Skylion007, https://github.com/drisspg, https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
b2c7b9ad2d
commit
c2a3024617
@ -1937,11 +1937,11 @@ void scaled_gemm(
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb));
|
||||
cublasLtMatmulDescAttributes_t matmulDescA = CUBLASLT_MATMUL_DESC_A_SCALE_POINTER;
|
||||
cublasLtMatmulDescAttributes_t matmulDescB = CUBLASLT_MATMUL_DESC_B_SCALE_POINTER;
|
||||
#if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
|
||||
// hipblaslt supported row-wise before cublas, and did so their own way (via
|
||||
// the SCALE_POINTERSs), but then migrated to match how cublas does it (via
|
||||
// the SCALE_MODEs). Here we check for this early custom mode.
|
||||
bool use_rowwise = (mat1_scaling_type == ScalingType::RowWise && mat2_scaling_type == ScalingType::RowWise);
|
||||
#if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
|
||||
if (use_rowwise) {
|
||||
matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT;
|
||||
matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT;
|
||||
@ -1956,8 +1956,12 @@ void scaled_gemm(
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#else
|
||||
// rowwise isn't supported using cublaslt or older hipblaslt
|
||||
#elif (CUDA_VERSION < 12090) && !defined(USE_ROCM)
|
||||
// hipblaslt supported row-wise before cublas, and did so their own way (via
|
||||
// the SCALE_POINTERSs), but then migrated to match how cublas does it (via
|
||||
// the SCALE_MODEs). Here we check for this early custom mode.
|
||||
bool use_rowwise = (mat1_scaling_type == ScalingType::RowWise && mat2_scaling_type == ScalingType::RowWise);
|
||||
// rowwise isn't supported using older cublaslt or older hipblaslt
|
||||
TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
|
||||
#endif // if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
|
||||
computeDesc.setAttribute(matmulDescA, mat1_scale_ptr);
|
||||
|
@ -465,7 +465,10 @@ class TestFP8Lowering(TestCase):
|
||||
# autotuning for the compiled case, the results can be different because of
|
||||
# the way blocks of results are accumulated (float addition not associative), so
|
||||
# setting a small absolute tolerance in these tests
|
||||
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
|
||||
if dtype == torch.bfloat16:
|
||||
self.assertEqual(y_eager, y_compiled, rtol=5e-2, atol=0.07)
|
||||
else:
|
||||
self.assertEqual(y_eager, y_compiled, rtol=1e-2, atol=0.05)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@unittest.skipIf(
|
||||
@ -611,7 +614,7 @@ class TestFP8Lowering(TestCase):
|
||||
)
|
||||
self.assertEqual(y_eager.dtype, dtype)
|
||||
self.assertEqual(y_compiled.dtype, dtype)
|
||||
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
|
||||
torch.testing.assert_close(y_eager, y_compiled, rtol=5e-2, atol=0.07)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@unittest.skipIf(
|
||||
@ -744,7 +747,7 @@ class TestFP8Lowering(TestCase):
|
||||
)
|
||||
self.assertEqual(y_eager.dtype, dtype)
|
||||
self.assertEqual(y_compiled.dtype, dtype)
|
||||
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.07)
|
||||
torch.testing.assert_close(y_eager, y_compiled, rtol=5e-2, atol=0.07)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@parametrize("M", (1, 3, 33, 257, 1024))
|
||||
|
@ -1315,18 +1315,26 @@ class TestFP8Matmul(TestCase):
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Note re.compile is used, not re.escape. This is to accommodate fn vs fnuz type message.
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"Expected b\.dtype\(\) == at::kFloat8_e4m3fnu?z? to be true, but got false\.",
|
||||
):
|
||||
torch._scaled_mm(
|
||||
def e5m2():
|
||||
out = torch._scaled_mm(
|
||||
x_fp8,
|
||||
y_fp8.to(e5m2_type),
|
||||
scale_a=torch.ones((M, 1), device="cuda"),
|
||||
scale_b=torch.ones((1, N), device="cuda"),
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
return out
|
||||
|
||||
if torch.cuda.get_device_capability() == (9, 0) and torch.version.cuda and torch.version.cuda >= "12.9":
|
||||
out = e5m2()
|
||||
self.assertEqual(out, torch.ones_like(out) * 128.)
|
||||
else:
|
||||
# Note re.compile is used, not re.escape. This is to accommodate fn vs fnuz type message.
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"Expected b\.dtype\(\) == at::kFloat8_e4m3fnu?z? to be true, but got false\.",
|
||||
):
|
||||
e5m2()
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific")
|
||||
|
Reference in New Issue
Block a user