mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-16 15:34:57 +08:00
refactor
This commit is contained in:
@ -604,10 +604,8 @@ class TestDecomp(TestCase):
|
||||
# Check both "torch._scaled_mm" and "_scaled_mm" as op.name could be either
|
||||
if op.name in ("torch._scaled_mm", "_scaled_mm"):
|
||||
if torch.version.cuda is not None:
|
||||
cuda_version = tuple(int(x) for x in torch.version.cuda.split("."))
|
||||
if cuda_version >= (13, 0):
|
||||
if _get_torch_cuda_version() >= (13, 0):
|
||||
self.skipTest("xfail on CUDA 13.0+ until nullptr issue is fixed")
|
||||
|
||||
self.do_cross_ref(device, dtype, op, run_all=True)
|
||||
|
||||
def test_uniform(self, device):
|
||||
|
||||
Reference in New Issue
Block a user