Skip symmetric memory tests calling _scaled_mm on CCC < 8.9 (#164251)

This avoids them failing on e.g. A100 GPUs with
> RuntimeError: torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164251
Approved by: https://github.com/Skylion007, https://github.com/kwen2501
This commit is contained in:
Alexander Grund
2025-10-01 03:26:17 +00:00
committed by PyTorch MergeBot
parent fa90090735
commit 8bb71c07c4

View File

@ -4,7 +4,7 @@ import itertools
import os
import random
from contextlib import nullcontext
from unittest import skip, skipIf
from unittest import skip, skipIf, skipUnless
import torch
import torch.distributed as dist
@ -25,6 +25,7 @@ from torch.distributed._symmetric_memory import (
from torch.testing._internal.common_cuda import (
_get_torch_cuda_version,
SM100OrLater,
SM89OrLater,
SM90OrLater,
xfailIfSM100OrLater,
)
@ -430,6 +431,7 @@ class AsyncTPTest(MultiProcContinuousTest):
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
)
@skip_if_lt_x_gpu(2)
@skipUnless(SM89OrLater, "Requires compute capability >= 8.9")
@parametrize("gather_dim", [0, 1])
@parametrize(
"scale_mode", ["tensor-wise", "row-wise-replicated", "row-wise-sharded"]
@ -545,6 +547,7 @@ class AsyncTPTest(MultiProcContinuousTest):
@skip_if_rocm_multiprocess # AsyncTP support changed _fused_scaled_matmul_reduce_scatter_fallback API, need more changes
@skip_if_lt_x_gpu(2)
@skipUnless(SM89OrLater, "Requires compute capability >= 8.9")
@parametrize("scatter_dim", [0, 1])
@parametrize("rowwise", [True, False])
def test_fused_scaled_matmul_reduce_scatter(