From f497a0039c0d94a530f141731bcbb65910c9bcfa Mon Sep 17 00:00:00 2001 From: ankurneog Date: Fri, 22 Nov 2024 11:01:50 +0000 Subject: [PATCH] API to retrieve default distributed backend from device (#140536) # Motivation The distributed APIs rely on backend names for creation of process group. To abstract out references of these names from PG creation, an API is added to get default distributed backend for device. The device code would need to register its device and backend via ```torch.distributed.Backend.register_backend``` or update the map ``` torch.distributed.Backend.default_device_backend_map["device"] = "distributed_backend" ``` prior to using the API. An example of use is added in the test file ( which can be used to check abstracted APIs) Pull Request resolved: https://github.com/pytorch/pytorch/pull/140536 Approved by: https://github.com/kwen2501 --- test/distributed/test_backends.py | 51 +++++++++++++++++++++++++++ torch/distributed/distributed_c10d.py | 24 +++++++++++++ 2 files changed, 75 insertions(+) create mode 100644 test/distributed/test_backends.py diff --git a/test/distributed/test_backends.py b/test/distributed/test_backends.py new file mode 100644 index 000000000000..baf78bb62db1 --- /dev/null +++ b/test/distributed/test_backends.py @@ -0,0 +1,51 @@ +# Owner(s): ["oncall: distributed"] + +import os + +import torch.distributed as dist +from torch.testing._internal.common_device_type import instantiate_device_type_tests +from torch.testing._internal.common_utils import run_tests, TestCase + + +""" +common backend API tests +""" + + +class TestMiscCollectiveUtils(TestCase): + def test_device_to_backend_mapping(self, device) -> None: + """ + Test device to backend mapping + """ + if "cuda" in device: + assert dist.get_default_backend_for_device(device) == "nccl" + elif "cpu" in device: + assert dist.get_default_backend_for_device(device) == "gloo" + elif "hpu" in device: + assert dist.get_default_backend_for_device(device) == "hccl" + else: + with self.assertRaises(ValueError): + dist.get_default_backend_for_device(device) + + def test_create_pg(self, device) -> None: + """ + Test create process group + """ + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29500" + + backend = dist.get_default_backend_for_device(device) + dist.init_process_group( + backend=backend, rank=0, world_size=1, init_method="env://" + ) + pg = dist.distributed_c10d._get_default_group() + backend_pg = pg._get_backend_name() + assert backend_pg == backend + dist.destroy_process_group() + + +devices = ["cpu", "cuda", "hpu"] +instantiate_device_type_tests(TestMiscCollectiveUtils, globals(), only_for=devices) + +if __name__ == "__main__": + run_tests() diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 0afe8cbc34d7..aed7af4b4e6b 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -76,6 +76,7 @@ __all__ = [ "gather_object", "get_backend_config", "get_backend", + "get_default_backend_for_device", "get_rank", "get_world_size", "get_pg_count", @@ -1345,6 +1346,29 @@ def get_backend(group: Optional[ProcessGroup] = None) -> Backend: return Backend(not_none(pg_store)[0]) +def get_default_backend_for_device(device: Union[str, torch.device]) -> str: + """ + Return the default backend for the given device. + + Args: + Union[str, torch.device]: The device to get the default backend for. + + Returns: + The default backend for the given device as a lower case string. + + """ + if isinstance(device, torch.device): + device_str = device.type + else: + device_str = device.split(":")[0] + + backend = Backend.default_device_backend_map.get(device_str) + if backend is None: + raise ValueError(f"Default backend not registered for device : {device}") + + return backend + + def _get_process_group_uid(pg: ProcessGroup) -> int: backend = None try: