diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 5337f47ac045..09c376b7f258 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -3516,17 +3516,6 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase): c10d.barrier(device_ids=[self.rank]) - @requires_nccl() - @skip_if_lt_x_gpu(2) - def test_nccl_barrier_device_ids_function_argument(self): - store = c10d.FileStore(self.file_name, self.world_size) - c10d.init_process_group( - backend="nccl", rank=self.rank, world_size=self.world_size, store=store - ) - - with self.assertRaisesRegex(TypeError, "Invalid function argument"): - c10d.barrier(device_ids=self.rank) - @requires_nccl() @skip_if_lt_x_gpu(2) def test_unwaited(self) -> None: diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 5db84f50b5a0..79c3241d8e16 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -4730,7 +4730,7 @@ def barrier( group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. async_op (bool, optional): Whether this op should be an async op - device_ids ([int], optional): List of device/GPU ids. + device_ids ([int], optional): List of device/GPU ids. Only one id is expected. Returns: Async work handle, if async_op is set to True. @@ -4738,22 +4738,35 @@ def barrier( .. note:: `ProcessGroupNCCL` now blocks the cpu thread till the completion of the barrier collective. """ + group = group or _get_default_group() + if _rank_not_in_group(group): _warn_not_in_group("barrier") return opts = BarrierOptions() - opts.device = torch.device(_get_object_coll_device(group)) opts.asyncOp = async_op - if device_ids is not None: - if isinstance(device_ids, list): - opts.device_ids = device_ids - else: - raise TypeError( - "Invalid function argument: device_ids type should be List[int]" - ) + # Detect the accelerator on the machine. If no accelerator is available, it + # returns CPU. + device = torch._C._get_accelerator() + if isinstance(device_ids, list): + opts.device_ids = device_ids + # use only the first device id + opts.device = torch.device(device.type, device_ids[0]) + elif getattr(group, "bound_device_id", None) is not None: + # Use device id from `init_process_group(device_id=...)` + opts.device = group.bound_device_id # type: ignore[assignment] + elif device.type == "cpu" or _get_object_coll_device(group) == "cpu": + opts.device = torch.device("cpu") + else: + # Use the current device set by the user. If user did not set any, this + # may use default device 0, causing issues like hang or all processes + # creating context on device 0. + opts.device = device + warnings.warn( # warn only once + "No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. " + ) - group = group or _get_default_group() work = group.barrier(opts=opts) if async_op: