mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update test_c10d_object_collectives.py with DistributedTestBase class (#145056)
# MOTIVATION To generalize distributed test cases for non-CUDA devices, we are leveraging the DistributedTestBase class introduced in [PR #138216](https://github.com/pytorch/pytorch/pull/138216). This new class is derived from MultiProcessTestCase and abstracts the creation/deletion of process groups and other functionality for specific devices. In this PR, we extend the scope of these tests to support HPUs. # CHANGES Replaced MultiProcessTestCase with the DistributedTestBase class. Extended test functionality to include support for HPUs. Utilized instantiate_device_type_tests with targeted attributes to generate device-specific test instances. Applied the skipIfHPU decorator to skip tests that are not yet compatible with HPU devices. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145056 Approved by: https://github.com/kwen2501, https://github.com/guangyey
This commit is contained in:
committed by
PyTorch MergeBot
parent
a9598337b7
commit
936df4571b
@ -1,6 +1,5 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
from functools import partial, wraps
|
||||
|
||||
@ -12,8 +11,15 @@ if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
from torch.testing._internal.common_distributed import MultiProcessTestCase, TEST_SKIPS
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_distributed import DistributedTestBase, TEST_SKIPS
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skipIfHpu,
|
||||
TEST_CUDA,
|
||||
TEST_HPU,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
)
|
||||
|
||||
|
||||
if TEST_WITH_DEV_DBG_ASAN:
|
||||
@ -23,7 +29,16 @@ if TEST_WITH_DEV_DBG_ASAN:
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
|
||||
if TEST_HPU:
|
||||
DEVICE = "hpu"
|
||||
elif TEST_CUDA:
|
||||
DEVICE = "cuda"
|
||||
else:
|
||||
DEVICE = "cpu"
|
||||
|
||||
device_module = torch.get_device_module(DEVICE)
|
||||
device_count = device_module.device_count()
|
||||
BACKEND = dist.get_default_backend_for_device(DEVICE)
|
||||
|
||||
|
||||
def with_comms(func=None):
|
||||
@ -34,59 +49,22 @@ def with_comms(func=None):
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size:
|
||||
if DEVICE != "cpu" and device_count < self.world_size:
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
|
||||
self.dist_init()
|
||||
func(self)
|
||||
self.destroy_comms()
|
||||
|
||||
kwargs["device"] = DEVICE
|
||||
self.pg = self.create_pg(device=DEVICE)
|
||||
try:
|
||||
return func(self, *args, **kwargs)
|
||||
finally:
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class TestObjectCollectives(MultiProcessTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
os.environ["WORLD_SIZE"] = str(self.world_size)
|
||||
os.environ["BACKEND"] = BACKEND
|
||||
self._spawn_processes()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return (
|
||||
torch.device("cuda", self.rank % torch.cuda.device_count())
|
||||
if BACKEND == dist.Backend.NCCL
|
||||
else torch.device("cpu")
|
||||
)
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
if BACKEND == dist.Backend.NCCL:
|
||||
return torch.cuda.device_count()
|
||||
return super().world_size
|
||||
|
||||
@property
|
||||
def process_group(self):
|
||||
return dist.group.WORLD
|
||||
|
||||
def destroy_comms(self):
|
||||
# Wait for all ranks to reach here before starting shutdown.
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
def dist_init(self):
|
||||
dist.init_process_group(
|
||||
backend=BACKEND,
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
init_method=f"file://{self.file_name}",
|
||||
)
|
||||
|
||||
# set device for nccl pg for collectives
|
||||
if BACKEND == "nccl":
|
||||
torch.cuda.set_device(self.rank)
|
||||
|
||||
class TestObjectCollectives(DistributedTestBase):
|
||||
@with_comms()
|
||||
def test_all_gather_object(self):
|
||||
def test_all_gather_object(self, device):
|
||||
output = [None] * dist.get_world_size()
|
||||
dist.all_gather_object(object_list=output, obj=self.rank)
|
||||
|
||||
@ -94,7 +72,7 @@ class TestObjectCollectives(MultiProcessTestCase):
|
||||
self.assertEqual(i, v, f"rank: {self.rank}")
|
||||
|
||||
@with_comms()
|
||||
def test_gather_object(self):
|
||||
def test_gather_object(self, device):
|
||||
output = [None] * dist.get_world_size() if self.rank == 0 else None
|
||||
dist.gather_object(obj=self.rank, object_gather_list=output)
|
||||
|
||||
@ -102,8 +80,9 @@ class TestObjectCollectives(MultiProcessTestCase):
|
||||
for i, v in enumerate(output):
|
||||
self.assertEqual(i, v, f"rank: {self.rank}")
|
||||
|
||||
@skipIfHpu
|
||||
@with_comms()
|
||||
def test_send_recv_object_list(self):
|
||||
def test_send_recv_object_list(self, device):
|
||||
val = 99 if self.rank == 0 else None
|
||||
object_list = [val] * dist.get_world_size()
|
||||
if self.rank == 0:
|
||||
@ -117,7 +96,7 @@ class TestObjectCollectives(MultiProcessTestCase):
|
||||
self.assertEqual(None, object_list[0])
|
||||
|
||||
@with_comms()
|
||||
def test_broadcast_object_list(self):
|
||||
def test_broadcast_object_list(self, device):
|
||||
val = 99 if self.rank == 0 else None
|
||||
object_list = [val] * dist.get_world_size()
|
||||
# TODO test with broadcast_object_list's device argument
|
||||
@ -126,7 +105,7 @@ class TestObjectCollectives(MultiProcessTestCase):
|
||||
self.assertEqual(99, object_list[0])
|
||||
|
||||
@with_comms()
|
||||
def test_scatter_object_list(self):
|
||||
def test_scatter_object_list(self, device):
|
||||
input_list = list(range(dist.get_world_size())) if self.rank == 0 else None
|
||||
output_list = [None]
|
||||
dist.scatter_object_list(
|
||||
@ -144,30 +123,34 @@ class TestObjectCollectives(MultiProcessTestCase):
|
||||
my_pg = dist.new_group(ranks, use_local_synchronization=True)
|
||||
return rank, ranks, my_pg
|
||||
|
||||
@skipIfHpu
|
||||
@with_comms()
|
||||
def test_subpg_scatter_object(self):
|
||||
def test_subpg_scatter_object(self, device):
|
||||
rank, ranks, my_pg = self.setup_sub_pg()
|
||||
out_list = [None]
|
||||
dist.scatter_object_list(out_list, ranks, src=ranks[0], group=my_pg)
|
||||
self.assertEqual(rank, out_list[0])
|
||||
|
||||
@skipIfHpu
|
||||
@with_comms()
|
||||
def test_subpg_all_gather_object(self):
|
||||
def test_subpg_all_gather_object(self, device):
|
||||
rank, ranks, my_pg = self.setup_sub_pg()
|
||||
out_list = [None] * len(ranks)
|
||||
dist.all_gather_object(out_list, rank, group=my_pg)
|
||||
self.assertEqual(ranks, out_list)
|
||||
|
||||
@skipIfHpu
|
||||
@with_comms()
|
||||
def test_subpg_gather_object(self):
|
||||
def test_subpg_gather_object(self, device):
|
||||
rank, ranks, my_pg = self.setup_sub_pg()
|
||||
out_list = [None] * len(ranks) if rank == ranks[0] else None
|
||||
dist.gather_object(rank, out_list, dst=ranks[0], group=my_pg)
|
||||
if rank == ranks[0]:
|
||||
self.assertEqual(ranks, out_list)
|
||||
|
||||
@skipIfHpu
|
||||
@with_comms()
|
||||
def test_subpg_broadcast_object(self):
|
||||
def test_subpg_broadcast_object(self, device):
|
||||
rank, ranks, my_pg = self.setup_sub_pg()
|
||||
out_list = [None]
|
||||
if rank == ranks[0]:
|
||||
@ -176,5 +159,7 @@ class TestObjectCollectives(MultiProcessTestCase):
|
||||
self.assertEqual(ranks[0], out_list[0])
|
||||
|
||||
|
||||
devices = ("cpu", "cuda", "hpu")
|
||||
instantiate_device_type_tests(TestObjectCollectives, globals(), only_for=devices)
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user