API to retrieve default distributed backend from device (#140536)

# Motivation
The distributed APIs rely on backend names for creation of process group.
To abstract out references of these names from PG creation, an API is added to get default distributed backend for  device.
The device code would need to register its device and backend  via  ```torch.distributed.Backend.register_backend```  or  update the map ``` torch.distributed.Backend.default_device_backend_map["device"] = "distributed_backend" ```  prior to using the API.

An example of use is added in the test file ( which can be used to check abstracted APIs)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140536
Approved by: https://github.com/kwen2501
This commit is contained in:
ankurneog
2024-11-22 11:01:50 +00:00
committed by PyTorch MergeBot
parent 7d89a8d385
commit f497a0039c
2 changed files with 75 additions and 0 deletions

View File

@ -0,0 +1,51 @@
# Owner(s): ["oncall: distributed"]
import os
import torch.distributed as dist
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import run_tests, TestCase
"""
common backend API tests
"""
class TestMiscCollectiveUtils(TestCase):
def test_device_to_backend_mapping(self, device) -> None:
"""
Test device to backend mapping
"""
if "cuda" in device:
assert dist.get_default_backend_for_device(device) == "nccl"
elif "cpu" in device:
assert dist.get_default_backend_for_device(device) == "gloo"
elif "hpu" in device:
assert dist.get_default_backend_for_device(device) == "hccl"
else:
with self.assertRaises(ValueError):
dist.get_default_backend_for_device(device)
def test_create_pg(self, device) -> None:
"""
Test create process group
"""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
backend = dist.get_default_backend_for_device(device)
dist.init_process_group(
backend=backend, rank=0, world_size=1, init_method="env://"
)
pg = dist.distributed_c10d._get_default_group()
backend_pg = pg._get_backend_name()
assert backend_pg == backend
dist.destroy_process_group()
devices = ["cpu", "cuda", "hpu"]
instantiate_device_type_tests(TestMiscCollectiveUtils, globals(), only_for=devices)
if __name__ == "__main__":
run_tests()

View File

@ -76,6 +76,7 @@ __all__ = [
"gather_object",
"get_backend_config",
"get_backend",
"get_default_backend_for_device",
"get_rank",
"get_world_size",
"get_pg_count",
@ -1345,6 +1346,29 @@ def get_backend(group: Optional[ProcessGroup] = None) -> Backend:
return Backend(not_none(pg_store)[0])
def get_default_backend_for_device(device: Union[str, torch.device]) -> str:
"""
Return the default backend for the given device.
Args:
Union[str, torch.device]: The device to get the default backend for.
Returns:
The default backend for the given device as a lower case string.
"""
if isinstance(device, torch.device):
device_str = device.type
else:
device_str = device.split(":")[0]
backend = Backend.default_device_backend_map.get(device_str)
if backend is None:
raise ValueError(f"Default backend not registered for device : {device}")
return backend
def _get_process_group_uid(pg: ProcessGroup) -> int:
backend = None
try: