mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[CUTLASS][FP8] Skip scaled_mm rowwise test on sm89 (#133612)
Rowwise implementation currently uses sm90-specific features incl. TMA CC @drisspg Pull Request resolved: https://github.com/pytorch/pytorch/pull/133612 Approved by: https://github.com/Skylion007
This commit is contained in:
@ -16,6 +16,7 @@ from torch.quantization._quantized_conversions import (
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_cuda import (
|
||||
SM53OrLater,
|
||||
SM90OrLater,
|
||||
_get_torch_cuda_version,
|
||||
PLATFORM_SUPPORTS_FP8
|
||||
)
|
||||
@ -664,6 +665,7 @@ class TestFP8MatmulCuda(TestCase):
|
||||
)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@unittest.skipIf(not SM90OrLater, "rowwise implementation is currently sm90 specific")
|
||||
@skipIfRocm()
|
||||
@parametrize("base_dtype", [torch.bfloat16])
|
||||
def test_scaled_mm_vs_emulated_row_wise(self, base_dtype):
|
||||
|
||||
Reference in New Issue
Block a user