diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index e907565753a3..4f04c6a27699 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -213,6 +213,48 @@ class ProcessGroupNCCLNoGPUTest(TestCase): c10d.ProcessGroupNCCL(store, self.rank, self.world_size) +class ProcessGroupNCCLInitTest(MultiProcessTestCase): + device_type = "cuda" + + def setUp(self): + super().setUp() + self._spawn_processes() + + def tearDown(self): + super().tearDown() + try: + os.remove(self.file_name) + except OSError: + pass + + @property + def world_size(self): + dm = torch.get_device_module(self.device_type) + return dm.device_count() + + @property + def device(self): + return torch.device(self.device_type, self.rank % self.world_size) + + # A helper with the must-needed init args for test infra. + # kwargs can be filled in by individual init tests. + def _init_process_group(self, **kwargs): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + rank=self.rank, + world_size=self.world_size, + store=store, + **kwargs, + ) + + @requires_nccl() + @skip_if_lt_x_gpu(1) + def test_init_wo_backend_str(self): + self._init_process_group(device_id=self.device) + x = torch.empty(1, device=self.device) + c10d.all_reduce(x) + + class ProcessGroupNCCLGroupTest(MultiProcessTestCase): def _create_process_group_nccl(self, store, opts, device_id=None): # create nccl processgroup with opts diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 3e5c91984535..15becc083230 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -251,6 +251,7 @@ class Backend(str): backend_list = [UNDEFINED, GLOO, NCCL, UCC, MPI] + # 3rd-party devices can register the default backend support here default_device_backend_map: Dict[str, str] = { "cpu": GLOO, "cuda": NCCL, @@ -1600,10 +1601,22 @@ def init_process_group( elif init_method is None: init_method = "env://" - if backend: - backend = Backend(backend) - else: - backend = Backend("undefined") + # If user did not provide a backend string but provided a device id, e.g. + # >>> init_process_group(device_id=device) + # we try to figure out the backend name based on the device type. + if backend is None and device_id is not None: + # Note: 3rd-party devices can register default backend through the + # default map below. + backend = Backend.default_device_backend_map.get(device_id.type) + + # If we still cannot figure it out, e.g. + # >>> init_process_group() + # we set it to `undefined` and rely on lazy init. + if backend is None: + backend = "undefined" + + # Convert string into `Backend` type + backend = Backend(backend) if timeout is None: timeout = _get_default_timeout(backend)