Update test_c10d_object_collectives.py with DistributedTestBase class (#145056)

# MOTIVATION
To generalize distributed test cases for non-CUDA devices, we are leveraging the DistributedTestBase class introduced in [PR #138216](https://github.com/pytorch/pytorch/pull/138216). This new class is derived from MultiProcessTestCase and abstracts the creation/deletion of process groups and other functionality for specific devices. In this PR, we extend the scope of these tests to support HPUs.

# CHANGES

Replaced MultiProcessTestCase with the DistributedTestBase class.
Extended test functionality to include support for HPUs.
Utilized instantiate_device_type_tests with targeted attributes to generate device-specific test instances.
Applied the skipIfHPU decorator to skip tests that are not yet compatible with HPU devices.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145056
Approved by: https://github.com/kwen2501, https://github.com/guangyey
This commit is contained in:
amathewc
2025-02-13 03:57:59 +00:00
committed by PyTorch MergeBot
parent a9598337b7
commit 936df4571b

View File

@ -1,6 +1,5 @@
# Owner(s): ["oncall: distributed"]
import os
import sys
from functools import partial, wraps
@ -12,8 +11,15 @@ if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
from torch.testing._internal.common_distributed import MultiProcessTestCase, TEST_SKIPS
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_distributed import DistributedTestBase, TEST_SKIPS
from torch.testing._internal.common_utils import (
run_tests,
skipIfHpu,
TEST_CUDA,
TEST_HPU,
TEST_WITH_DEV_DBG_ASAN,
)
if TEST_WITH_DEV_DBG_ASAN:
@ -23,7 +29,16 @@ if TEST_WITH_DEV_DBG_ASAN:
)
sys.exit(0)
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
if TEST_HPU:
DEVICE = "hpu"
elif TEST_CUDA:
DEVICE = "cuda"
else:
DEVICE = "cpu"
device_module = torch.get_device_module(DEVICE)
device_count = device_module.device_count()
BACKEND = dist.get_default_backend_for_device(DEVICE)
def with_comms(func=None):
@ -34,59 +49,22 @@ def with_comms(func=None):
@wraps(func)
def wrapper(self, *args, **kwargs):
if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size:
if DEVICE != "cpu" and device_count < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
self.dist_init()
func(self)
self.destroy_comms()
kwargs["device"] = DEVICE
self.pg = self.create_pg(device=DEVICE)
try:
return func(self, *args, **kwargs)
finally:
torch.distributed.destroy_process_group()
return wrapper
class TestObjectCollectives(MultiProcessTestCase):
def setUp(self):
super().setUp()
os.environ["WORLD_SIZE"] = str(self.world_size)
os.environ["BACKEND"] = BACKEND
self._spawn_processes()
@property
def device(self):
return (
torch.device("cuda", self.rank % torch.cuda.device_count())
if BACKEND == dist.Backend.NCCL
else torch.device("cpu")
)
@property
def world_size(self):
if BACKEND == dist.Backend.NCCL:
return torch.cuda.device_count()
return super().world_size
@property
def process_group(self):
return dist.group.WORLD
def destroy_comms(self):
# Wait for all ranks to reach here before starting shutdown.
dist.barrier()
dist.destroy_process_group()
def dist_init(self):
dist.init_process_group(
backend=BACKEND,
world_size=self.world_size,
rank=self.rank,
init_method=f"file://{self.file_name}",
)
# set device for nccl pg for collectives
if BACKEND == "nccl":
torch.cuda.set_device(self.rank)
class TestObjectCollectives(DistributedTestBase):
@with_comms()
def test_all_gather_object(self):
def test_all_gather_object(self, device):
output = [None] * dist.get_world_size()
dist.all_gather_object(object_list=output, obj=self.rank)
@ -94,7 +72,7 @@ class TestObjectCollectives(MultiProcessTestCase):
self.assertEqual(i, v, f"rank: {self.rank}")
@with_comms()
def test_gather_object(self):
def test_gather_object(self, device):
output = [None] * dist.get_world_size() if self.rank == 0 else None
dist.gather_object(obj=self.rank, object_gather_list=output)
@ -102,8 +80,9 @@ class TestObjectCollectives(MultiProcessTestCase):
for i, v in enumerate(output):
self.assertEqual(i, v, f"rank: {self.rank}")
@skipIfHpu
@with_comms()
def test_send_recv_object_list(self):
def test_send_recv_object_list(self, device):
val = 99 if self.rank == 0 else None
object_list = [val] * dist.get_world_size()
if self.rank == 0:
@ -117,7 +96,7 @@ class TestObjectCollectives(MultiProcessTestCase):
self.assertEqual(None, object_list[0])
@with_comms()
def test_broadcast_object_list(self):
def test_broadcast_object_list(self, device):
val = 99 if self.rank == 0 else None
object_list = [val] * dist.get_world_size()
# TODO test with broadcast_object_list's device argument
@ -126,7 +105,7 @@ class TestObjectCollectives(MultiProcessTestCase):
self.assertEqual(99, object_list[0])
@with_comms()
def test_scatter_object_list(self):
def test_scatter_object_list(self, device):
input_list = list(range(dist.get_world_size())) if self.rank == 0 else None
output_list = [None]
dist.scatter_object_list(
@ -144,30 +123,34 @@ class TestObjectCollectives(MultiProcessTestCase):
my_pg = dist.new_group(ranks, use_local_synchronization=True)
return rank, ranks, my_pg
@skipIfHpu
@with_comms()
def test_subpg_scatter_object(self):
def test_subpg_scatter_object(self, device):
rank, ranks, my_pg = self.setup_sub_pg()
out_list = [None]
dist.scatter_object_list(out_list, ranks, src=ranks[0], group=my_pg)
self.assertEqual(rank, out_list[0])
@skipIfHpu
@with_comms()
def test_subpg_all_gather_object(self):
def test_subpg_all_gather_object(self, device):
rank, ranks, my_pg = self.setup_sub_pg()
out_list = [None] * len(ranks)
dist.all_gather_object(out_list, rank, group=my_pg)
self.assertEqual(ranks, out_list)
@skipIfHpu
@with_comms()
def test_subpg_gather_object(self):
def test_subpg_gather_object(self, device):
rank, ranks, my_pg = self.setup_sub_pg()
out_list = [None] * len(ranks) if rank == ranks[0] else None
dist.gather_object(rank, out_list, dst=ranks[0], group=my_pg)
if rank == ranks[0]:
self.assertEqual(ranks, out_list)
@skipIfHpu
@with_comms()
def test_subpg_broadcast_object(self):
def test_subpg_broadcast_object(self, device):
rank, ranks, my_pg = self.setup_sub_pg()
out_list = [None]
if rank == ranks[0]:
@ -176,5 +159,7 @@ class TestObjectCollectives(MultiProcessTestCase):
self.assertEqual(ranks[0], out_list[0])
devices = ("cpu", "cuda", "hpu")
instantiate_device_type_tests(TestObjectCollectives, globals(), only_for=devices)
if __name__ == "__main__":
run_tests()