mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[c10d] Support optional backend if device_id provided (#140963)
Citing @malfet's [comment](https://github.com/pytorch/pytorch/pull/136343#pullrequestreview-2318792396) in https://github.com/pytorch/pytorch/pull/136343 > It would be great, if users do not have to modify their programs for every new backend, but rather use with torch.device('xpu'): and keep rest of the code unchanged. This PR makes the backend specification ("nccl", "gloo") optional when user provides a `devce_id` to `init_process_group` (the acceptance of `device_id` has been previously supported for the purpose of eager init). New user experience: ``` device = torch.device(device_type, rank % device_count) dist.init_process_group(device_id=device) ``` The line of `device = torch.device(...)` is anyway needed because user would use it for tensor creation etc. Pull Request resolved: https://github.com/pytorch/pytorch/pull/140963 Approved by: https://github.com/wconstab
This commit is contained in:
@ -213,6 +213,48 @@ class ProcessGroupNCCLNoGPUTest(TestCase):
|
||||
c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
|
||||
|
||||
|
||||
class ProcessGroupNCCLInitTest(MultiProcessTestCase):
|
||||
device_type = "cuda"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
try:
|
||||
os.remove(self.file_name)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
dm = torch.get_device_module(self.device_type)
|
||||
return dm.device_count()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return torch.device(self.device_type, self.rank % self.world_size)
|
||||
|
||||
# A helper with the must-needed init args for test infra.
|
||||
# kwargs can be filled in by individual init tests.
|
||||
def _init_process_group(self, **kwargs):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
c10d.init_process_group(
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
store=store,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_init_wo_backend_str(self):
|
||||
self._init_process_group(device_id=self.device)
|
||||
x = torch.empty(1, device=self.device)
|
||||
c10d.all_reduce(x)
|
||||
|
||||
|
||||
class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||
def _create_process_group_nccl(self, store, opts, device_id=None):
|
||||
# create nccl processgroup with opts
|
||||
|
@ -251,6 +251,7 @@ class Backend(str):
|
||||
|
||||
backend_list = [UNDEFINED, GLOO, NCCL, UCC, MPI]
|
||||
|
||||
# 3rd-party devices can register the default backend support here
|
||||
default_device_backend_map: Dict[str, str] = {
|
||||
"cpu": GLOO,
|
||||
"cuda": NCCL,
|
||||
@ -1600,10 +1601,22 @@ def init_process_group(
|
||||
elif init_method is None:
|
||||
init_method = "env://"
|
||||
|
||||
if backend:
|
||||
# If user did not provide a backend string but provided a device id, e.g.
|
||||
# >>> init_process_group(device_id=device)
|
||||
# we try to figure out the backend name based on the device type.
|
||||
if backend is None and device_id is not None:
|
||||
# Note: 3rd-party devices can register default backend through the
|
||||
# default map below.
|
||||
backend = Backend.default_device_backend_map.get(device_id.type)
|
||||
|
||||
# If we still cannot figure it out, e.g.
|
||||
# >>> init_process_group()
|
||||
# we set it to `undefined` and rely on lazy init.
|
||||
if backend is None:
|
||||
backend = "undefined"
|
||||
|
||||
# Convert string into `Backend` type
|
||||
backend = Backend(backend)
|
||||
else:
|
||||
backend = Backend("undefined")
|
||||
|
||||
if timeout is None:
|
||||
timeout = _get_default_timeout(backend)
|
||||
|
Reference in New Issue
Block a user