mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
# Motivation The distributed APIs rely on backend names for creation of process group. To abstract out references of these names from PG creation, an API is added to get default distributed backend for device. The device code would need to register its device and backend via ```torch.distributed.Backend.register_backend``` or update the map ``` torch.distributed.Backend.default_device_backend_map["device"] = "distributed_backend" ``` prior to using the API. An example of use is added in the test file ( which can be used to check abstracted APIs) Pull Request resolved: https://github.com/pytorch/pytorch/pull/140536 Approved by: https://github.com/kwen2501
52 lines
1.5 KiB
Python
52 lines
1.5 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import os
|
|
|
|
import torch.distributed as dist
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
|
|
|
|
"""
|
|
common backend API tests
|
|
"""
|
|
|
|
|
|
class TestMiscCollectiveUtils(TestCase):
|
|
def test_device_to_backend_mapping(self, device) -> None:
|
|
"""
|
|
Test device to backend mapping
|
|
"""
|
|
if "cuda" in device:
|
|
assert dist.get_default_backend_for_device(device) == "nccl"
|
|
elif "cpu" in device:
|
|
assert dist.get_default_backend_for_device(device) == "gloo"
|
|
elif "hpu" in device:
|
|
assert dist.get_default_backend_for_device(device) == "hccl"
|
|
else:
|
|
with self.assertRaises(ValueError):
|
|
dist.get_default_backend_for_device(device)
|
|
|
|
def test_create_pg(self, device) -> None:
|
|
"""
|
|
Test create process group
|
|
"""
|
|
os.environ["MASTER_ADDR"] = "localhost"
|
|
os.environ["MASTER_PORT"] = "29500"
|
|
|
|
backend = dist.get_default_backend_for_device(device)
|
|
dist.init_process_group(
|
|
backend=backend, rank=0, world_size=1, init_method="env://"
|
|
)
|
|
pg = dist.distributed_c10d._get_default_group()
|
|
backend_pg = pg._get_backend_name()
|
|
assert backend_pg == backend
|
|
dist.destroy_process_group()
|
|
|
|
|
|
devices = ["cpu", "cuda", "hpu"]
|
|
instantiate_device_type_tests(TestMiscCollectiveUtils, globals(), only_for=devices)
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|