mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
API to retrieve default distributed backend from device (#140536)
# Motivation The distributed APIs rely on backend names for creation of process group. To abstract out references of these names from PG creation, an API is added to get default distributed backend for device. The device code would need to register its device and backend via ```torch.distributed.Backend.register_backend``` or update the map ``` torch.distributed.Backend.default_device_backend_map["device"] = "distributed_backend" ``` prior to using the API. An example of use is added in the test file ( which can be used to check abstracted APIs) Pull Request resolved: https://github.com/pytorch/pytorch/pull/140536 Approved by: https://github.com/kwen2501
This commit is contained in:
committed by
PyTorch MergeBot
parent
7d89a8d385
commit
f497a0039c
51
test/distributed/test_backends.py
Normal file
51
test/distributed/test_backends.py
Normal file
@ -0,0 +1,51 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import os
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
"""
|
||||
common backend API tests
|
||||
"""
|
||||
|
||||
|
||||
class TestMiscCollectiveUtils(TestCase):
|
||||
def test_device_to_backend_mapping(self, device) -> None:
|
||||
"""
|
||||
Test device to backend mapping
|
||||
"""
|
||||
if "cuda" in device:
|
||||
assert dist.get_default_backend_for_device(device) == "nccl"
|
||||
elif "cpu" in device:
|
||||
assert dist.get_default_backend_for_device(device) == "gloo"
|
||||
elif "hpu" in device:
|
||||
assert dist.get_default_backend_for_device(device) == "hccl"
|
||||
else:
|
||||
with self.assertRaises(ValueError):
|
||||
dist.get_default_backend_for_device(device)
|
||||
|
||||
def test_create_pg(self, device) -> None:
|
||||
"""
|
||||
Test create process group
|
||||
"""
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "29500"
|
||||
|
||||
backend = dist.get_default_backend_for_device(device)
|
||||
dist.init_process_group(
|
||||
backend=backend, rank=0, world_size=1, init_method="env://"
|
||||
)
|
||||
pg = dist.distributed_c10d._get_default_group()
|
||||
backend_pg = pg._get_backend_name()
|
||||
assert backend_pg == backend
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
devices = ["cpu", "cuda", "hpu"]
|
||||
instantiate_device_type_tests(TestMiscCollectiveUtils, globals(), only_for=devices)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -76,6 +76,7 @@ __all__ = [
|
||||
"gather_object",
|
||||
"get_backend_config",
|
||||
"get_backend",
|
||||
"get_default_backend_for_device",
|
||||
"get_rank",
|
||||
"get_world_size",
|
||||
"get_pg_count",
|
||||
@ -1345,6 +1346,29 @@ def get_backend(group: Optional[ProcessGroup] = None) -> Backend:
|
||||
return Backend(not_none(pg_store)[0])
|
||||
|
||||
|
||||
def get_default_backend_for_device(device: Union[str, torch.device]) -> str:
|
||||
"""
|
||||
Return the default backend for the given device.
|
||||
|
||||
Args:
|
||||
Union[str, torch.device]: The device to get the default backend for.
|
||||
|
||||
Returns:
|
||||
The default backend for the given device as a lower case string.
|
||||
|
||||
"""
|
||||
if isinstance(device, torch.device):
|
||||
device_str = device.type
|
||||
else:
|
||||
device_str = device.split(":")[0]
|
||||
|
||||
backend = Backend.default_device_backend_map.get(device_str)
|
||||
if backend is None:
|
||||
raise ValueError(f"Default backend not registered for device : {device}")
|
||||
|
||||
return backend
|
||||
|
||||
|
||||
def _get_process_group_uid(pg: ProcessGroup) -> int:
|
||||
backend = None
|
||||
try:
|
||||
|
Reference in New Issue
Block a user