Enables barrier to support the specified device (#99589)

Enables barrier to support the specified device, e.g cuda/custom device. There is some discussion here: https://github.com/pytorch/pytorch/issues/97938#issue-1646833919

Today, there are two limitations of barrier:
One is that barrier does not support custom  #device:
fbdb86c174/torch/csrc/distributed/c10d/ProcessGroup.hpp (L512-L522)

The second is that there is a special valid for nccl when device_id is not None, which is an assumption for cuda and nccl bindings, and also hinders custom device.
789070986c/torch/distributed/distributed_c10d.py (L3504-L3508)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99589
Approved by: https://github.com/kwen2501
This commit is contained in:
shaoyf42
2023-05-17 05:25:59 +00:00
committed by PyTorch MergeBot
parent 6261aa5c8d
commit 97180aca5e
6 changed files with 52 additions and 41 deletions

View File

@ -2355,16 +2355,6 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
return skip_but_pass_in_sandcastle("Test requires world_size of at least 4")
self._test_sequence_num_incremented_subgroup("gloo")
@requires_gloo()
def test_gloo_barrier_device_ids(self):
store = c10d.FileStore(self.file_name, self.world_size)
c10d.init_process_group(
backend="gloo", rank=self.rank, world_size=self.world_size, store=store
)
with self.assertRaisesRegex(RuntimeError, "device_ids not supported"):
c10d.barrier(device_ids=[self.rank])
@skip_if_lt_x_gpu(2)
@requires_gloo()
def test_gloo_warn_not_in_group(self):

View File

@ -528,7 +528,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const BarrierOptions& opts = BarrierOptions()) {
static at::Tensor tensor;
// TODO: if nccl was specified then use it
if (backendType_ == c10d::ProcessGroup::BackendType::NCCL) {
auto device = opts.device;
if (device.has_value()) {
// set device tensor from argument
tensor = at::empty(
{1},
at::TensorOptions().device(device.value()).dtype(at::kByte));
} else if (backendType_ == c10d::ProcessGroup::BackendType::NCCL) {
// set cuda tensor
tensor = at::empty(
{1},

View File

@ -156,6 +156,7 @@ struct AllToAllOptions {
struct BarrierOptions {
std::vector<int64_t> device_ids;
std::chrono::milliseconds timeout = kUnsetTimeout;
c10::optional<at::Device> device;
};
struct DistributedBackendOptions {

View File

@ -832,7 +832,8 @@ This class does not support ``__members__`` property.)");
py::class_<::c10d::BarrierOptions>(module, "BarrierOptions")
.def(py::init<>())
.def_readwrite("device_ids", &::c10d::BarrierOptions::device_ids)
.def_readwrite("timeout", &::c10d::BarrierOptions::timeout);
.def_readwrite("timeout", &::c10d::BarrierOptions::timeout)
.def_readwrite("device", &::c10d::BarrierOptions::device);
py::class_<::c10d::AllToAllOptions>(module, "AllToAllOptions")
.def(py::init<>())

View File

@ -453,7 +453,7 @@ class _World:
def __init__(self):
self._default_pg = None
self._pg_coalesce_state: Dict[ProcessGroup, List[Union[_CollOp, P2POp]]] = {}
self._pg_object_coll_device: Dict[ProcessGroup, torch.device] = {}
self._pg_default_device: Dict[ProcessGroup, torch.device] = {}
@property
def default_pg(self):
@ -540,8 +540,8 @@ class _World:
return self._pg_coalesce_state
@property
def pg_object_coll_device(self) -> Dict[ProcessGroup, torch.device]:
return self._pg_object_coll_device
def pg_default_device(self) -> Dict[ProcessGroup, torch.device]:
return self._pg_default_device
_world = _World()
"""Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it"""
@ -573,11 +573,28 @@ _default_pg_init_method = None
STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key"
def _get_object_coll_device(group: Optional[ProcessGroup] = None):
def _get_pg_default_device(group: Optional[ProcessGroup] = None):
"""
Returns the device to use with ``group`` for control flow usage (object collectives, barrier).
There are selection rules:
1. If user specifies exactly one backend in ``init_process_group`` call:
use that backend
2. Else if user specifies multiple "device:backend" pairs in init_process_group:
If "cpu" is among those pairs, use "cpu" (because the object is in cpu memory);
Otherwise, use the first backend (sort of a random pick).
Args:
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
Returns:
torch.device: The device to use with ``group``.
"""
group = group or _get_default_group()
if group in _world.pg_object_coll_device:
if group in _world.pg_default_device:
# Previously searched and cached; just return
return _world.pg_object_coll_device[group]
return _world.pg_default_device[group]
if not isinstance(group, ProcessGroup):
# Provide backward compatibility to cases where `group` passed in is
@ -589,8 +606,8 @@ def _get_object_coll_device(group: Optional[ProcessGroup] = None):
"of PyTorch Distributed instead."
)
# Most users create Gloo with private API for object collectives
_world.pg_object_coll_device[group] = torch.device("cpu")
return _world.pg_object_coll_device[group]
_world.pg_default_device[group] = torch.device("cpu")
return _world.pg_default_device[group]
"""
``group._device_types`` is a property pybind that returns the devices
@ -601,26 +618,26 @@ def _get_object_coll_device(group: Optional[ProcessGroup] = None):
if len(devices) == 1:
# User fixed exactly one backend in `init_process_group`
_world.pg_object_coll_device[group] = devices[0]
_world.pg_default_device[group] = devices[0]
elif len(devices) == 0:
# No backend has been registered with this PG (maybe because no
# collective has been run?) We pick cpu as the default and hopefully
# this would lazily init Gloo or other available cpu backend.
_world.pg_object_coll_device[group] = torch.device("cpu")
_world.pg_default_device[group] = torch.device("cpu")
elif torch.device("cpu") in devices:
# There are multiple backends in this PG and cpu is among them.
# cpu is preferred as the object is in cpu memory. No need for device
# copy.
_world.pg_object_coll_device[group] = torch.device("cpu")
_world.pg_default_device[group] = torch.device("cpu")
else:
# No cpu in the backend list. Randomly pick the first backend
_world.pg_object_coll_device[group] = devices[0]
_world.pg_default_device[group] = devices[0]
logger.info(
f"Using device {_world.pg_object_coll_device[group]} for object " # noqa: G004
f"Using device {_world.pg_default_device[group]} for object " # noqa: G004
"collectives."
)
return _world.pg_object_coll_device[group]
return _world.pg_default_device[group]
# Environment variable to control whether we do a barrier after process group
@ -1363,7 +1380,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
_world.pg_to_tag.clear()
_world.tags_to_pg.clear()
_world.pg_coalesce_state.clear()
_world.pg_object_coll_device.clear()
_world.pg_default_device.clear()
# when process group doesn't have an explicit name (only WORLD (default)
# process group can have an explicit name), we use global _world.group_count
@ -1379,8 +1396,8 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
del _world.pg_names[pg]
del _world.pg_group_ranks[pg]
del _world.pg_backend_config[pg]
if pg in _world.pg_object_coll_device:
del _world.pg_object_coll_device[pg]
if pg in _world.pg_default_device:
del _world.pg_default_device[pg]
if pg in _world.pg_coalesce_state.keys():
warnings.warn(
"Some coalesced collectives haven't been launched when "
@ -2339,7 +2356,7 @@ def all_gather_object(object_list, obj, group=None):
_warn_not_in_group("all_gather_object")
return
current_device = _get_object_coll_device(group)
current_device = _get_pg_default_device(group)
input_tensor, local_size = _object_to_tensor(obj, current_device)
# Gather all local sizes. This is so that we can find the max size, and index
@ -2440,7 +2457,7 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None):
# Ensure object_gather_list is specified appropriately.
my_rank = get_rank()
_validate_output_list_for_rank(my_rank, dst, object_gather_list)
current_device = _get_object_coll_device(group)
current_device = _get_pg_default_device(group)
input_tensor, local_size = _object_to_tensor(obj, current_device)
# Gather all local sizes. This is so that we can find the max size, and index
@ -2554,7 +2571,7 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
# ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the
# case it is not ``None`` we move the size and object tensors to be
# broadcasted to this device.
current_device = device or _get_object_coll_device(group)
current_device = device or _get_pg_default_device(group)
my_rank = get_rank()
# Serialize object_list elements to tensors on src rank.
if my_rank == src:
@ -2659,7 +2676,7 @@ def scatter_object_list(
)
my_rank = get_rank()
pg_device = _get_object_coll_device(group)
pg_device = _get_pg_default_device(group)
if my_rank == src:
tensor_list, tensor_sizes = zip(
*[_object_to_tensor(obj, pg_device) for obj in scatter_object_input_list]
@ -3623,7 +3640,6 @@ def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None):
the default process group will be used.
async_op (bool, optional): Whether this op should be an async op
device_ids ([int], optional): List of device/GPU ids.
Valid only for NCCL backend.
Returns:
Async work handle, if async_op is set to True.
@ -3634,11 +3650,8 @@ def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None):
return
opts = BarrierOptions()
opts.device = _get_pg_default_device(group)
if device_ids is not None:
if get_backend(group) != Backend.NCCL:
raise RuntimeError(
f"Function argument device_ids not supported for the selected backend {get_backend(group)}"
)
if isinstance(device_ids, list):
opts.device_ids = device_ids
else:

View File

@ -381,7 +381,7 @@ class WorldData:
tags_to_pg: Dict[str, List[dist.ProcessGroup]]
pg_to_tag: Dict[dist.ProcessGroup, str]
pg_coalesce_state: Dict[dist.ProcessGroup, List[Union[_CollOp, P2POp]]]
pg_object_coll_device: Dict[dist.ProcessGroup, torch.device]
pg_default_device: Dict[dist.ProcessGroup, torch.device]
class ThreadLocalWorld:
@ -437,8 +437,8 @@ class ThreadLocalWorld:
return self._get_world().pg_coalesce_state
@property
def pg_object_coll_device(self) -> Dict[dist.ProcessGroup, torch.device]:
return self._get_world().pg_object_coll_device
def pg_default_device(self) -> Dict[dist.ProcessGroup, torch.device]:
return self._get_world().pg_default_device
_old_pg_world = None