add device generalisation support for distributed tests (#152471)

### MOTIVATION
To generalize Distributed test cases for non-CUDA devices

### CHANGES

- test/distributed/optim/test_zero_redundancy_optimizer.py
- test/distributed/test_c10d_logger.py
- test/distributed/test_compute_comm_reordering.py

Replaced hard coded device names with get_devtype from torch.testing._internal.common_fsdp.
DistributedTestBase is used instead of MultiProcessTestCase, to make use of helper functions.

- torch/testing/_internal/common_distributed.py

extended common utility functions

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152471
Approved by: https://github.com/d4l3k
This commit is contained in:
Hari Krishna Sai Kodali
2025-06-20 07:35:42 +00:00
committed by PyTorch MergeBot
parent 0aed855b2b
commit e1f28fe17b
5 changed files with 243 additions and 224 deletions

View File

@ -39,6 +39,7 @@ from torch.testing._internal.common_utils import (
retry_on_connect_failures,
skip_but_pass_in_sandcastle,
skip_but_pass_in_sandcastle_if,
TEST_CUDA,
TEST_HPU,
TEST_WITH_ROCM,
TEST_WITH_TSAN,
@ -55,6 +56,10 @@ from torch.testing._internal.distributed.multi_threaded_pg import (
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
ACCELERATOR_DIST_BACKENDS = ["nccl", "xccl", "hccl"]
DDP_RANK_DEVICES = ["cuda", "xpu"]
HAS_ACCELERATOR = TEST_CUDA or TEST_HPU or TEST_XPU
class TestSkip(NamedTuple):
exit_code: int
@ -109,21 +114,25 @@ class DistTestCases:
backend_feature["xpu"] = {"xccl"}
def requires_ddp_rank(device):
return device in DDP_RANK_DEVICES
def skip_if_no_gpu(func):
"""Skips if the world size exceeds the number of GPUs, ensuring that if the
test is run, each rank has its own GPU via ``torch.cuda.device(rank)``."""
@wraps(func)
def wrapper(*args, **kwargs):
if not torch.cuda.is_available():
if not (TEST_CUDA or TEST_HPU or TEST_XPU):
sys.exit(TEST_SKIPS["no_cuda"].exit_code)
world_size = int(os.environ["WORLD_SIZE"])
if torch.cuda.device_count() < world_size:
if TEST_CUDA and torch.cuda.device_count() < world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
if TEST_HPU and torch.hpu.device_count < world_size:
if TEST_HPU and torch.hpu.device_count() < world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
if TEST_XPU and torch.xpu.device_count() < world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
if TEST_XPU and torch.xpu.device_count < world_size:
sys.exit(TEST_SKIPS[f"multi-xpu-{world_size}"].exit_code)
return func(*args, **kwargs)
@ -189,7 +198,13 @@ def import_transformers_or_skip():
def at_least_x_gpu(x):
return torch.cuda.is_available() and torch.cuda.device_count() >= x
if TEST_CUDA and torch.cuda.device_count() >= x:
return True
if TEST_HPU and torch.hpu.device_count() >= x:
return True
if TEST_XPU and torch.xpu.device_count() >= x:
return True
return False
def skip_if_lt_x_gpu(x):
@ -355,6 +370,35 @@ def requires_mpi():
)
def requires_accelerator_dist_backend(backends=None):
"""
Decorator to skip tests if no accelerator communication backend (NCCL, XCCL, HCCL) is available.
Args:
backends (Optional[List[str]]): Specific accelerator backends to check (e.g., ["nccl", "xccl", "hccl"]).
If None, checks all supported accelerator backends (NCCL, XCCL, HCCL).
Returns:
callable: A decorator that skips the test if no specified accelerator backend is available.
"""
if backends is None:
backends = ACCELERATOR_DIST_BACKENDS
backend_available = any(
{
"nccl": c10d.is_nccl_available,
"xccl": c10d.is_xccl_available,
"hccl": lambda: TEST_HPU,
}.get(backend, lambda: False)()
for backend in backends
)
return skip_but_pass_in_sandcastle_if(
not backend_available,
f"No accelerator communication backend available among {backends}",
)
def requires_multicast_support():
has_multicast_support = (
torch.cuda.is_available()
@ -968,9 +1012,14 @@ class MultiProcessTestCase(TestCase):
class DistributedTestBase(MultiProcessTestCase):
def setUp(self):
super().setUp()
os.environ["WORLD_SIZE"] = str(self.world_size)
self._spawn_processes()
def tearDown(self):
try:
torch.distributed.destroy_process_group()
except AssertionError:
pass
try:
os.remove(self.file_name)
except OSError:
@ -986,12 +1035,14 @@ class DistributedTestBase(MultiProcessTestCase):
else:
return "gloo"
def create_pg(self, device):
def create_pg(self, device, world_size=None):
if world_size is None:
world_size = self.world_size
num_visible_devices = torch.get_device_module(device).device_count()
store = torch.distributed.FileStore(self.file_name, num_visible_devices)
torch.distributed.init_process_group(
backend=self.backend(device),
world_size=self.world_size,
world_size=world_size,
rank=self.rank,
store=store,
)
@ -1404,7 +1455,9 @@ class SaveForwardInputsModel(nn.Module):
@contextmanager
def _dynamo_dist_per_rank_init(rank, world_size, init_pg=True, fake_pg=False):
def _dynamo_dist_per_rank_init(
rank, world_size, backend="nccl", init_pg=True, fake_pg=False
):
# To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase,
# Just manually implement the most important part of the dynamo behavior to reset/clear.
if not fake_pg:
@ -1421,7 +1474,7 @@ def _dynamo_dist_per_rank_init(rank, world_size, init_pg=True, fake_pg=False):
store=store,
)
else:
c10d.init_process_group("nccl", rank=rank, world_size=world_size)
c10d.init_process_group(backend=backend, rank=rank, world_size=world_size)
torch._dynamo.reset()
torch._dynamo.utils.counters.clear()
try:
@ -1465,7 +1518,7 @@ class DynamoDistributedSingleProcTestCase(torch._dynamo.test_case.TestCase):
super().tearDownClass()
class DynamoDistributedMultiProcTestCase(MultiProcessTestCase):
class DynamoDistributedMultiProcTestCase(DistributedTestBase):
"""
Use this for tests that actually run on multiple GPUs.
@ -1476,20 +1529,9 @@ class DynamoDistributedMultiProcTestCase(MultiProcessTestCase):
sparingly for integration tests.
"""
def setUp(self):
super().setUp()
self._spawn_processes()
def tearDown(self):
super().tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
@property
def world_size(self) -> int:
return torch.cuda.device_count()
return torch.accelerator.device_count()
@classmethod
def _run(