diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 0dea610bf822..70235665cb6f 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -16,6 +16,7 @@ from torch.quantization._quantized_conversions import ( from torch.testing import make_tensor from torch.testing._internal.common_cuda import ( SM53OrLater, + SM90OrLater, _get_torch_cuda_version, PLATFORM_SUPPORTS_FP8 ) @@ -664,6 +665,7 @@ class TestFP8MatmulCuda(TestCase): ) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) + @unittest.skipIf(not SM90OrLater, "rowwise implementation is currently sm90 specific") @skipIfRocm() @parametrize("base_dtype", [torch.bfloat16]) def test_scaled_mm_vs_emulated_row_wise(self, base_dtype):