This commit is contained in:
Ting Lu
2025-11-14 02:34:29 -08:00
parent 383bae5707
commit f12cbe11db

View File

@ -604,10 +604,8 @@ class TestDecomp(TestCase):
# Check both "torch._scaled_mm" and "_scaled_mm" as op.name could be either
if op.name in ("torch._scaled_mm", "_scaled_mm"):
if torch.version.cuda is not None:
cuda_version = tuple(int(x) for x in torch.version.cuda.split("."))
if cuda_version >= (13, 0):
if _get_torch_cuda_version() >= (13, 0):
self.skipTest("xfail on CUDA 13.0+ until nullptr issue is fixed")
self.do_cross_ref(device, dtype, op, run_all=True)
def test_uniform(self, device):