mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Skip symmetric memory tests calling _scaled_mm
on CCC < 8.9 (#164251)
This avoids them failing on e.g. A100 GPUs with > RuntimeError: torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+ Pull Request resolved: https://github.com/pytorch/pytorch/pull/164251 Approved by: https://github.com/Skylion007, https://github.com/kwen2501
This commit is contained in:
committed by
PyTorch MergeBot
parent
fa90090735
commit
8bb71c07c4
@ -4,7 +4,7 @@ import itertools
|
||||
import os
|
||||
import random
|
||||
from contextlib import nullcontext
|
||||
from unittest import skip, skipIf
|
||||
from unittest import skip, skipIf, skipUnless
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -25,6 +25,7 @@ from torch.distributed._symmetric_memory import (
|
||||
from torch.testing._internal.common_cuda import (
|
||||
_get_torch_cuda_version,
|
||||
SM100OrLater,
|
||||
SM89OrLater,
|
||||
SM90OrLater,
|
||||
xfailIfSM100OrLater,
|
||||
)
|
||||
@ -430,6 +431,7 @@ class AsyncTPTest(MultiProcContinuousTest):
|
||||
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
|
||||
)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skipUnless(SM89OrLater, "Requires compute capability >= 8.9")
|
||||
@parametrize("gather_dim", [0, 1])
|
||||
@parametrize(
|
||||
"scale_mode", ["tensor-wise", "row-wise-replicated", "row-wise-sharded"]
|
||||
@ -545,6 +547,7 @@ class AsyncTPTest(MultiProcContinuousTest):
|
||||
|
||||
@skip_if_rocm_multiprocess # AsyncTP support changed _fused_scaled_matmul_reduce_scatter_fallback API, need more changes
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skipUnless(SM89OrLater, "Requires compute capability >= 8.9")
|
||||
@parametrize("scatter_dim", [0, 1])
|
||||
@parametrize("rowwise", [True, False])
|
||||
def test_fused_scaled_matmul_reduce_scatter(
|
||||
|
Reference in New Issue
Block a user