mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
# MOTIVATION To generalize distributed test cases for non-CUDA devices, we are leveraging the DistributedTestBase class introduced in [PR #138216](https://github.com/pytorch/pytorch/pull/138216). This new class is derived from MultiProcessTestCase and abstracts the creation/deletion of process groups and other functionality for specific devices. In this PR, we extend the scope of these tests to support HPUs. # CHANGES Replaced MultiProcessTestCase with the DistributedTestBase class. Extended test functionality to include support for HPUs. Utilized instantiate_device_type_tests with targeted attributes to generate device-specific test instances. Applied the skipIfHPU decorator to skip tests that are not yet compatible with HPU devices. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145056 Approved by: https://github.com/kwen2501, https://github.com/guangyey
166 lines
4.9 KiB
Python
166 lines
4.9 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import sys
|
|
from functools import partial, wraps
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
|
|
if not dist.is_available():
|
|
print("Distributed not available, skipping tests", file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
|
from torch.testing._internal.common_distributed import DistributedTestBase, TEST_SKIPS
|
|
from torch.testing._internal.common_utils import (
|
|
run_tests,
|
|
skipIfHpu,
|
|
TEST_CUDA,
|
|
TEST_HPU,
|
|
TEST_WITH_DEV_DBG_ASAN,
|
|
)
|
|
|
|
|
|
if TEST_WITH_DEV_DBG_ASAN:
|
|
print(
|
|
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
|
file=sys.stderr,
|
|
)
|
|
sys.exit(0)
|
|
|
|
if TEST_HPU:
|
|
DEVICE = "hpu"
|
|
elif TEST_CUDA:
|
|
DEVICE = "cuda"
|
|
else:
|
|
DEVICE = "cpu"
|
|
|
|
device_module = torch.get_device_module(DEVICE)
|
|
device_count = device_module.device_count()
|
|
BACKEND = dist.get_default_backend_for_device(DEVICE)
|
|
|
|
|
|
def with_comms(func=None):
|
|
if func is None:
|
|
return partial(
|
|
with_comms,
|
|
)
|
|
|
|
@wraps(func)
|
|
def wrapper(self, *args, **kwargs):
|
|
if DEVICE != "cpu" and device_count < self.world_size:
|
|
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
|
|
|
|
kwargs["device"] = DEVICE
|
|
self.pg = self.create_pg(device=DEVICE)
|
|
try:
|
|
return func(self, *args, **kwargs)
|
|
finally:
|
|
torch.distributed.destroy_process_group()
|
|
|
|
return wrapper
|
|
|
|
|
|
class TestObjectCollectives(DistributedTestBase):
|
|
@with_comms()
|
|
def test_all_gather_object(self, device):
|
|
output = [None] * dist.get_world_size()
|
|
dist.all_gather_object(object_list=output, obj=self.rank)
|
|
|
|
for i, v in enumerate(output):
|
|
self.assertEqual(i, v, f"rank: {self.rank}")
|
|
|
|
@with_comms()
|
|
def test_gather_object(self, device):
|
|
output = [None] * dist.get_world_size() if self.rank == 0 else None
|
|
dist.gather_object(obj=self.rank, object_gather_list=output)
|
|
|
|
if self.rank == 0:
|
|
for i, v in enumerate(output):
|
|
self.assertEqual(i, v, f"rank: {self.rank}")
|
|
|
|
@skipIfHpu
|
|
@with_comms()
|
|
def test_send_recv_object_list(self, device):
|
|
val = 99 if self.rank == 0 else None
|
|
object_list = [val] * dist.get_world_size()
|
|
if self.rank == 0:
|
|
dist.send_object_list(object_list, 1)
|
|
if self.rank == 1:
|
|
dist.recv_object_list(object_list, 0)
|
|
|
|
if self.rank < 2:
|
|
self.assertEqual(99, object_list[0])
|
|
else:
|
|
self.assertEqual(None, object_list[0])
|
|
|
|
@with_comms()
|
|
def test_broadcast_object_list(self, device):
|
|
val = 99 if self.rank == 0 else None
|
|
object_list = [val] * dist.get_world_size()
|
|
# TODO test with broadcast_object_list's device argument
|
|
dist.broadcast_object_list(object_list=object_list)
|
|
|
|
self.assertEqual(99, object_list[0])
|
|
|
|
@with_comms()
|
|
def test_scatter_object_list(self, device):
|
|
input_list = list(range(dist.get_world_size())) if self.rank == 0 else None
|
|
output_list = [None]
|
|
dist.scatter_object_list(
|
|
scatter_object_output_list=output_list, scatter_object_input_list=input_list
|
|
)
|
|
|
|
self.assertEqual(self.rank, output_list[0])
|
|
|
|
# Test Object Collectives With Sub Pg
|
|
|
|
def setup_sub_pg(self):
|
|
rank = dist.get_rank()
|
|
base_rank = rank - (rank % 2)
|
|
ranks = [base_rank, base_rank + 1]
|
|
my_pg = dist.new_group(ranks, use_local_synchronization=True)
|
|
return rank, ranks, my_pg
|
|
|
|
@skipIfHpu
|
|
@with_comms()
|
|
def test_subpg_scatter_object(self, device):
|
|
rank, ranks, my_pg = self.setup_sub_pg()
|
|
out_list = [None]
|
|
dist.scatter_object_list(out_list, ranks, src=ranks[0], group=my_pg)
|
|
self.assertEqual(rank, out_list[0])
|
|
|
|
@skipIfHpu
|
|
@with_comms()
|
|
def test_subpg_all_gather_object(self, device):
|
|
rank, ranks, my_pg = self.setup_sub_pg()
|
|
out_list = [None] * len(ranks)
|
|
dist.all_gather_object(out_list, rank, group=my_pg)
|
|
self.assertEqual(ranks, out_list)
|
|
|
|
@skipIfHpu
|
|
@with_comms()
|
|
def test_subpg_gather_object(self, device):
|
|
rank, ranks, my_pg = self.setup_sub_pg()
|
|
out_list = [None] * len(ranks) if rank == ranks[0] else None
|
|
dist.gather_object(rank, out_list, dst=ranks[0], group=my_pg)
|
|
if rank == ranks[0]:
|
|
self.assertEqual(ranks, out_list)
|
|
|
|
@skipIfHpu
|
|
@with_comms()
|
|
def test_subpg_broadcast_object(self, device):
|
|
rank, ranks, my_pg = self.setup_sub_pg()
|
|
out_list = [None]
|
|
if rank == ranks[0]:
|
|
out_list[0] = rank
|
|
dist.broadcast_object_list(out_list, src=ranks[0], group=my_pg)
|
|
self.assertEqual(ranks[0], out_list[0])
|
|
|
|
|
|
devices = ("cpu", "cuda", "hpu")
|
|
instantiate_device_type_tests(TestObjectCollectives, globals(), only_for=devices)
|
|
if __name__ == "__main__":
|
|
run_tests()
|