mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[CUDA][float8][TF32] Disable tf32 for vs. emulated rowwise comparison (#162387)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162387 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
f03d635dc6
commit
5eb35d2ab8
@ -32,6 +32,7 @@ from torch.testing._internal.common_cuda import (
|
||||
PLATFORM_SUPPORTS_MX_GEMM,
|
||||
PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM,
|
||||
IS_SM90,
|
||||
with_tf32_off,
|
||||
)
|
||||
from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
@ -1510,6 +1511,7 @@ class TestFP8Matmul(TestCase):
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific")
|
||||
@parametrize("base_dtype", [torch.bfloat16, torch.float32])
|
||||
@with_tf32_off
|
||||
def test_scaled_mm_vs_emulated_row_wise(self, base_dtype):
|
||||
# Fp32 out_dtype is only supported by cuBLAS, which however only started
|
||||
# shipping row-wise kernels in CUDA 12.9, and only for sm90+.
|
||||
|
Reference in New Issue
Block a user