Files
pytorch/test/distributed/test_backends.py
ankurneog f497a0039c API to retrieve default distributed backend from device (#140536)
# Motivation
The distributed APIs rely on backend names for creation of process group.
To abstract out references of these names from PG creation, an API is added to get default distributed backend for  device.
The device code would need to register its device and backend  via  ```torch.distributed.Backend.register_backend```  or  update the map ``` torch.distributed.Backend.default_device_backend_map["device"] = "distributed_backend" ```  prior to using the API.

An example of use is added in the test file ( which can be used to check abstracted APIs)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140536
Approved by: https://github.com/kwen2501
2024-11-22 11:01:53 +00:00

52 lines
1.5 KiB
Python

# Owner(s): ["oncall: distributed"]
import os
import torch.distributed as dist
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import run_tests, TestCase
"""
common backend API tests
"""
class TestMiscCollectiveUtils(TestCase):
def test_device_to_backend_mapping(self, device) -> None:
"""
Test device to backend mapping
"""
if "cuda" in device:
assert dist.get_default_backend_for_device(device) == "nccl"
elif "cpu" in device:
assert dist.get_default_backend_for_device(device) == "gloo"
elif "hpu" in device:
assert dist.get_default_backend_for_device(device) == "hccl"
else:
with self.assertRaises(ValueError):
dist.get_default_backend_for_device(device)
def test_create_pg(self, device) -> None:
"""
Test create process group
"""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
backend = dist.get_default_backend_for_device(device)
dist.init_process_group(
backend=backend, rank=0, world_size=1, init_method="env://"
)
pg = dist.distributed_c10d._get_default_group()
backend_pg = pg._get_backend_name()
assert backend_pg == backend
dist.destroy_process_group()
devices = ["cpu", "cuda", "hpu"]
instantiate_device_type_tests(TestMiscCollectiveUtils, globals(), only_for=devices)
if __name__ == "__main__":
run_tests()