mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enables barrier to support the specified device (#99589)
Enables barrier to support the specified device, e.g cuda/custom device. There is some discussion here: https://github.com/pytorch/pytorch/issues/97938#issue-1646833919 Today, there are two limitations of barrier: One is that barrier does not support custom #device:fbdb86c174/torch/csrc/distributed/c10d/ProcessGroup.hpp (L512-L522)
The second is that there is a special valid for nccl when device_id is not None, which is an assumption for cuda and nccl bindings, and also hinders custom device.789070986c/torch/distributed/distributed_c10d.py (L3504-L3508)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99589 Approved by: https://github.com/kwen2501
This commit is contained in:
committed by
PyTorch MergeBot
parent
6261aa5c8d
commit
97180aca5e
@ -2355,16 +2355,6 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
|
||||
return skip_but_pass_in_sandcastle("Test requires world_size of at least 4")
|
||||
self._test_sequence_num_incremented_subgroup("gloo")
|
||||
|
||||
@requires_gloo()
|
||||
def test_gloo_barrier_device_ids(self):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
c10d.init_process_group(
|
||||
backend="gloo", rank=self.rank, world_size=self.world_size, store=store
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "device_ids not supported"):
|
||||
c10d.barrier(device_ids=[self.rank])
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_gloo()
|
||||
def test_gloo_warn_not_in_group(self):
|
||||
|
@ -528,7 +528,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const BarrierOptions& opts = BarrierOptions()) {
|
||||
static at::Tensor tensor;
|
||||
// TODO: if nccl was specified then use it
|
||||
if (backendType_ == c10d::ProcessGroup::BackendType::NCCL) {
|
||||
auto device = opts.device;
|
||||
if (device.has_value()) {
|
||||
// set device tensor from argument
|
||||
tensor = at::empty(
|
||||
{1},
|
||||
at::TensorOptions().device(device.value()).dtype(at::kByte));
|
||||
} else if (backendType_ == c10d::ProcessGroup::BackendType::NCCL) {
|
||||
// set cuda tensor
|
||||
tensor = at::empty(
|
||||
{1},
|
||||
|
@ -156,6 +156,7 @@ struct AllToAllOptions {
|
||||
struct BarrierOptions {
|
||||
std::vector<int64_t> device_ids;
|
||||
std::chrono::milliseconds timeout = kUnsetTimeout;
|
||||
c10::optional<at::Device> device;
|
||||
};
|
||||
|
||||
struct DistributedBackendOptions {
|
||||
|
@ -832,7 +832,8 @@ This class does not support ``__members__`` property.)");
|
||||
py::class_<::c10d::BarrierOptions>(module, "BarrierOptions")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("device_ids", &::c10d::BarrierOptions::device_ids)
|
||||
.def_readwrite("timeout", &::c10d::BarrierOptions::timeout);
|
||||
.def_readwrite("timeout", &::c10d::BarrierOptions::timeout)
|
||||
.def_readwrite("device", &::c10d::BarrierOptions::device);
|
||||
|
||||
py::class_<::c10d::AllToAllOptions>(module, "AllToAllOptions")
|
||||
.def(py::init<>())
|
||||
|
@ -453,7 +453,7 @@ class _World:
|
||||
def __init__(self):
|
||||
self._default_pg = None
|
||||
self._pg_coalesce_state: Dict[ProcessGroup, List[Union[_CollOp, P2POp]]] = {}
|
||||
self._pg_object_coll_device: Dict[ProcessGroup, torch.device] = {}
|
||||
self._pg_default_device: Dict[ProcessGroup, torch.device] = {}
|
||||
|
||||
@property
|
||||
def default_pg(self):
|
||||
@ -540,8 +540,8 @@ class _World:
|
||||
return self._pg_coalesce_state
|
||||
|
||||
@property
|
||||
def pg_object_coll_device(self) -> Dict[ProcessGroup, torch.device]:
|
||||
return self._pg_object_coll_device
|
||||
def pg_default_device(self) -> Dict[ProcessGroup, torch.device]:
|
||||
return self._pg_default_device
|
||||
|
||||
_world = _World()
|
||||
"""Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it"""
|
||||
@ -573,11 +573,28 @@ _default_pg_init_method = None
|
||||
STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key"
|
||||
|
||||
|
||||
def _get_object_coll_device(group: Optional[ProcessGroup] = None):
|
||||
def _get_pg_default_device(group: Optional[ProcessGroup] = None):
|
||||
"""
|
||||
Returns the device to use with ``group`` for control flow usage (object collectives, barrier).
|
||||
There are selection rules:
|
||||
1. If user specifies exactly one backend in ``init_process_group`` call:
|
||||
use that backend
|
||||
2. Else if user specifies multiple "device:backend" pairs in init_process_group:
|
||||
If "cpu" is among those pairs, use "cpu" (because the object is in cpu memory);
|
||||
Otherwise, use the first backend (sort of a random pick).
|
||||
|
||||
Args:
|
||||
group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used.
|
||||
|
||||
Returns:
|
||||
torch.device: The device to use with ``group``.
|
||||
|
||||
"""
|
||||
group = group or _get_default_group()
|
||||
if group in _world.pg_object_coll_device:
|
||||
if group in _world.pg_default_device:
|
||||
# Previously searched and cached; just return
|
||||
return _world.pg_object_coll_device[group]
|
||||
return _world.pg_default_device[group]
|
||||
|
||||
if not isinstance(group, ProcessGroup):
|
||||
# Provide backward compatibility to cases where `group` passed in is
|
||||
@ -589,8 +606,8 @@ def _get_object_coll_device(group: Optional[ProcessGroup] = None):
|
||||
"of PyTorch Distributed instead."
|
||||
)
|
||||
# Most users create Gloo with private API for object collectives
|
||||
_world.pg_object_coll_device[group] = torch.device("cpu")
|
||||
return _world.pg_object_coll_device[group]
|
||||
_world.pg_default_device[group] = torch.device("cpu")
|
||||
return _world.pg_default_device[group]
|
||||
|
||||
"""
|
||||
``group._device_types`` is a property pybind that returns the devices
|
||||
@ -601,26 +618,26 @@ def _get_object_coll_device(group: Optional[ProcessGroup] = None):
|
||||
|
||||
if len(devices) == 1:
|
||||
# User fixed exactly one backend in `init_process_group`
|
||||
_world.pg_object_coll_device[group] = devices[0]
|
||||
_world.pg_default_device[group] = devices[0]
|
||||
elif len(devices) == 0:
|
||||
# No backend has been registered with this PG (maybe because no
|
||||
# collective has been run?) We pick cpu as the default and hopefully
|
||||
# this would lazily init Gloo or other available cpu backend.
|
||||
_world.pg_object_coll_device[group] = torch.device("cpu")
|
||||
_world.pg_default_device[group] = torch.device("cpu")
|
||||
elif torch.device("cpu") in devices:
|
||||
# There are multiple backends in this PG and cpu is among them.
|
||||
# cpu is preferred as the object is in cpu memory. No need for device
|
||||
# copy.
|
||||
_world.pg_object_coll_device[group] = torch.device("cpu")
|
||||
_world.pg_default_device[group] = torch.device("cpu")
|
||||
else:
|
||||
# No cpu in the backend list. Randomly pick the first backend
|
||||
_world.pg_object_coll_device[group] = devices[0]
|
||||
_world.pg_default_device[group] = devices[0]
|
||||
|
||||
logger.info(
|
||||
f"Using device {_world.pg_object_coll_device[group]} for object " # noqa: G004
|
||||
f"Using device {_world.pg_default_device[group]} for object " # noqa: G004
|
||||
"collectives."
|
||||
)
|
||||
return _world.pg_object_coll_device[group]
|
||||
return _world.pg_default_device[group]
|
||||
|
||||
|
||||
# Environment variable to control whether we do a barrier after process group
|
||||
@ -1363,7 +1380,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
|
||||
_world.pg_to_tag.clear()
|
||||
_world.tags_to_pg.clear()
|
||||
_world.pg_coalesce_state.clear()
|
||||
_world.pg_object_coll_device.clear()
|
||||
_world.pg_default_device.clear()
|
||||
|
||||
# when process group doesn't have an explicit name (only WORLD (default)
|
||||
# process group can have an explicit name), we use global _world.group_count
|
||||
@ -1379,8 +1396,8 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
|
||||
del _world.pg_names[pg]
|
||||
del _world.pg_group_ranks[pg]
|
||||
del _world.pg_backend_config[pg]
|
||||
if pg in _world.pg_object_coll_device:
|
||||
del _world.pg_object_coll_device[pg]
|
||||
if pg in _world.pg_default_device:
|
||||
del _world.pg_default_device[pg]
|
||||
if pg in _world.pg_coalesce_state.keys():
|
||||
warnings.warn(
|
||||
"Some coalesced collectives haven't been launched when "
|
||||
@ -2339,7 +2356,7 @@ def all_gather_object(object_list, obj, group=None):
|
||||
_warn_not_in_group("all_gather_object")
|
||||
return
|
||||
|
||||
current_device = _get_object_coll_device(group)
|
||||
current_device = _get_pg_default_device(group)
|
||||
input_tensor, local_size = _object_to_tensor(obj, current_device)
|
||||
|
||||
# Gather all local sizes. This is so that we can find the max size, and index
|
||||
@ -2440,7 +2457,7 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None):
|
||||
# Ensure object_gather_list is specified appropriately.
|
||||
my_rank = get_rank()
|
||||
_validate_output_list_for_rank(my_rank, dst, object_gather_list)
|
||||
current_device = _get_object_coll_device(group)
|
||||
current_device = _get_pg_default_device(group)
|
||||
input_tensor, local_size = _object_to_tensor(obj, current_device)
|
||||
|
||||
# Gather all local sizes. This is so that we can find the max size, and index
|
||||
@ -2554,7 +2571,7 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
|
||||
# ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the
|
||||
# case it is not ``None`` we move the size and object tensors to be
|
||||
# broadcasted to this device.
|
||||
current_device = device or _get_object_coll_device(group)
|
||||
current_device = device or _get_pg_default_device(group)
|
||||
my_rank = get_rank()
|
||||
# Serialize object_list elements to tensors on src rank.
|
||||
if my_rank == src:
|
||||
@ -2659,7 +2676,7 @@ def scatter_object_list(
|
||||
)
|
||||
|
||||
my_rank = get_rank()
|
||||
pg_device = _get_object_coll_device(group)
|
||||
pg_device = _get_pg_default_device(group)
|
||||
if my_rank == src:
|
||||
tensor_list, tensor_sizes = zip(
|
||||
*[_object_to_tensor(obj, pg_device) for obj in scatter_object_input_list]
|
||||
@ -3623,7 +3640,6 @@ def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None):
|
||||
the default process group will be used.
|
||||
async_op (bool, optional): Whether this op should be an async op
|
||||
device_ids ([int], optional): List of device/GPU ids.
|
||||
Valid only for NCCL backend.
|
||||
|
||||
Returns:
|
||||
Async work handle, if async_op is set to True.
|
||||
@ -3634,11 +3650,8 @@ def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None):
|
||||
return
|
||||
|
||||
opts = BarrierOptions()
|
||||
opts.device = _get_pg_default_device(group)
|
||||
if device_ids is not None:
|
||||
if get_backend(group) != Backend.NCCL:
|
||||
raise RuntimeError(
|
||||
f"Function argument device_ids not supported for the selected backend {get_backend(group)}"
|
||||
)
|
||||
if isinstance(device_ids, list):
|
||||
opts.device_ids = device_ids
|
||||
else:
|
||||
|
@ -381,7 +381,7 @@ class WorldData:
|
||||
tags_to_pg: Dict[str, List[dist.ProcessGroup]]
|
||||
pg_to_tag: Dict[dist.ProcessGroup, str]
|
||||
pg_coalesce_state: Dict[dist.ProcessGroup, List[Union[_CollOp, P2POp]]]
|
||||
pg_object_coll_device: Dict[dist.ProcessGroup, torch.device]
|
||||
pg_default_device: Dict[dist.ProcessGroup, torch.device]
|
||||
|
||||
|
||||
class ThreadLocalWorld:
|
||||
@ -437,8 +437,8 @@ class ThreadLocalWorld:
|
||||
return self._get_world().pg_coalesce_state
|
||||
|
||||
@property
|
||||
def pg_object_coll_device(self) -> Dict[dist.ProcessGroup, torch.device]:
|
||||
return self._get_world().pg_object_coll_device
|
||||
def pg_default_device(self) -> Dict[dist.ProcessGroup, torch.device]:
|
||||
return self._get_world().pg_default_device
|
||||
|
||||
|
||||
_old_pg_world = None
|
||||
|
Reference in New Issue
Block a user