[c10d] Extend NCCL communicator splitting to more use cases (#114916)

Previously we could only use `ncclCommSplit` when we knew all backends were connected on all shards (due to the need to perform a NOCOLOR split), which in practice meant we could only use it for subgroups that were copies of the entire world.

This change allows for specifying a bound device id to `init_process_group` which tells the pg and its backends that the specified device, and the specified device only, will be associated with this rank.

This guarantee lets us do an early connect (which we could not previously do due to how ProcessGroupNCCL infers devices based on tensors and not the rank number).  And by doing the early connect, we have the guarantee ranks are connected and can perform nocolor splits when needed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114916
Approved by: https://github.com/kwen2501
This commit is contained in:
Chip Turner
2023-12-07 15:13:01 +00:00
committed by PyTorch MergeBot
parent a6736ac851
commit 78b945484b
7 changed files with 232 additions and 36 deletions

View File

@ -210,14 +210,15 @@ class ProcessGroupNCCLNoGPUTest(TestCase):
class ProcessGroupNCCLTest(MultiProcessTestCase):
def _create_process_group_nccl(self, store, opts):
def _create_process_group_nccl(self, store, opts, device_id=None):
# create nccl processgroup with opts
c10d.init_process_group(
"nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
pg_options=opts)
pg_options=opts,
device_id=device_id)
pg = c10d.distributed_c10d._get_default_group()
return pg
@ -1286,6 +1287,8 @@ class ProcessGroupNCCLTest(MultiProcessTestCase):
@requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit")
def test_comm_split_optimization(self):
# Test the optimization of new groups that contain all world
# ranks use the "transparent" `ncclCommSplit` optimization.
store = c10d.FileStore(self.file_name, self.world_size)
pg = self._create_process_group_nccl(store, self.opts())
@ -1305,6 +1308,31 @@ class ProcessGroupNCCLTest(MultiProcessTestCase):
ng.broadcast(tensor, 0)
self.assertEqual(backend.comm_split_count(), 1)
@requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit")
def test_comm_split_subgroup(self):
# Test `ncclCommSplit` for smaller subgroups of the world when
# we've passed a specific device_id to init_process_group.
store = c10d.FileStore(self.file_name, self.world_size)
device = torch.device(f'cuda:{self.rank}')
pg = self._create_process_group_nccl(store, self.opts(), device_id=device)
backend = pg._get_backend(torch.device(device))
tensor = torch.full((1,), self.rank).cuda(device)
original_tensor = tensor.clone()
ng = c10d.new_group([0])
# rank 0 hasn't split yet, but rank 1 did for the
# nocolor... so split count matches rank count coincidentally
# in each of the proceses this test spawned!
self.assertEqual(backend.comm_split_count(), self.rank)
if self.rank == 0:
dist.broadcast(tensor, 0, group=ng)
# now everyone has split because rank 0 has performed a comm
self.assertEqual(backend.comm_split_count(), 1)
self.assertEqual(tensor, original_tensor)
class DistributedDataParallelTest(
test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
):

View File

@ -57,6 +57,10 @@ class TORCH_API Backend : public torch::CustomClassHolder {
return reinterpret_cast<std::intptr_t>(this);
}
virtual bool supportsSplitting() const {
return false;
}
virtual void startCoalescing() {
TORCH_CHECK(
false,
@ -365,6 +369,25 @@ class TORCH_API Backend : public torch::CustomClassHolder {
return pg_name_;
}
// See similar functions in ProcessGroup.hpp for context.
c10::optional<at::Device> getBoundDeviceId() const {
return bound_device_id_;
}
// Perform an eager connect to the specified device if the backend supports
// it.
virtual void eagerConnectSingleDevice(at::Device device) {
// no-op in the default case; this is an optimization some
// backends may perform
}
void setBoundDeviceId(c10::optional<at::Device> device) {
if (device) {
TORCH_CHECK(device->has_index(), "setBoundDeviceId must have an index");
}
bound_device_id_ = device;
}
protected:
// Implementations of this interface need to call this to setup
// appropriate logging etc.
@ -378,6 +401,8 @@ class TORCH_API Backend : public torch::CustomClassHolder {
std::string pg_name_;
std::function<void(std::shared_ptr<WorkInfo>)> onCompletionHook_;
c10::optional<at::Device> bound_device_id_;
};
} // namespace c10d

View File

@ -631,11 +631,15 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
backendTypeToBackend_.end()) {
auto existingBackend = backendTypeToBackend_.at(backendType);
deviceTypeToBackend_[deviceType] = existingBackend;
TORCH_CHECK(
existingBackend->getBoundDeviceId() ==
(*backend)->getBoundDeviceId());
} else {
// check if backend has value
if (backend.has_value()) {
deviceTypeToBackend_[deviceType] = backend.value();
backendTypeToBackend_[backendType] = backend.value();
(*backend)->setBoundDeviceId(bound_device_id_);
}
}
}
@ -694,6 +698,22 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
void release_resources() override;
// ProcessGroups optionally can be "bound" to a specific device.
// Currently this is only for nccl and allows for some opt-in
// optimizations such as automatic use of ncclCommSplit. The device
// is specified in `init_process_group` and eventually makes it
// here and then down into the actual backend instances.
c10::optional<at::Device> getBoundDeviceId() const {
return bound_device_id_;
}
void setBoundDeviceId(c10::optional<at::Device> device) {
if (device) {
TORCH_CHECK(device->has_index(), "setBoundDeviceId must have an index");
}
bound_device_id_ = device;
}
protected:
// Implementations of this interface need to call this to setup
// appropriate logging etc.
@ -716,6 +736,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
deviceTypeToBackend_;
std::unordered_map<BackendType, c10::intrusive_ptr<Backend>>
backendTypeToBackend_;
c10::optional<at::Device> bound_device_id_;
};
} // namespace c10d

View File

@ -220,14 +220,6 @@ std::vector<at::Device> getDeviceList(const std::vector<at::Tensor>& tensors) {
return res;
}
// Return CUDA device with ordinal given by input rank.
at::Device getDeviceForRank(int rank) {
TORCH_CHECK_WITH(ValueError, rank >= 0, "Invalid rank ", rank);
auto numGPUs = at::cuda::getNumGPUs();
int16_t deviceIdx = static_cast<int16_t>(rank % numGPUs);
return at::Device(at::DeviceType::CUDA, deviceIdx);
}
// [Sync Streams] Helper that lets the input ncclStreams to wait for the current
// stream. NCCL communications run on ncclStreams, but input tensors are
// allocated on different streams (i.e., current streams). Communications on
@ -349,6 +341,23 @@ std::string dump_nccl_trace() {
return NCCLTraceBuffer::get()->dump();
}
// Return CUDA device with ordinal given by input rank. If we aren't
// bound to a specific device, there is no strict guarantee that this
// heuristic is the correct assignment of ranks to GPUs that Python
// layers use, but in practice it tends to be. Fortunately we don't
// rely on this for correctness of any tensor operations, just for
// ancillary uses like health checks and barriers.
at::Device ProcessGroupNCCL::guessDeviceForRank() const {
TORCH_CHECK_WITH(ValueError, rank_ >= 0, "Invalid rank ", rank_);
if (getBoundDeviceId()) {
return *getBoundDeviceId();
} else {
auto numGPUs = at::cuda::getNumGPUs();
int16_t deviceIdx = static_cast<int16_t>(rank_ % numGPUs);
return at::Device(at::DeviceType::CUDA, deviceIdx);
}
}
const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 1000;
constexpr int64_t kSynchronizeBusyWaitMillis = 10;
thread_local uint64_t ProcessGroupNCCL::ncclActiveGroupCounter_ = 0;
@ -785,6 +794,8 @@ ProcessGroupNCCL::ProcessGroupNCCL(
<< ", TIMEOUT(ms): " << options_->timeout.count()
<< ", USE_HIGH_PRIORITY_STREAM: "
<< options_->is_high_priority_stream
<< ", SPLIT_FROM: " << options_->split_from
<< ", SPLIT_COLOR: " << options_->split_color
<< ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug
#ifdef NCCL_HAS_COMM_REGISTER
<< ", TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: "
@ -822,6 +833,32 @@ ProcessGroupNCCL::ProcessGroupNCCL(
}
}
void ProcessGroupNCCL::eagerConnectSingleDevice(at::Device device) {
std::vector<at::Device> rankDevices = {device};
const auto key = getKeyFromDevices(rankDevices);
LOG(INFO) << "Eagerly connecting nccl backend with device " << device;
getNCCLComm(key, rankDevices, OpType::ALLREDUCE);
}
void ProcessGroupNCCL::performNocolorSplit(at::Device device) {
// If our backend doesn't support splitting, this is a no-op for
// ranks not in the new subgroup (and ranks that would be in it will
// just use a new communicator rather than split).
#ifdef NCCL_HAS_COMM_SPLIT
std::vector<at::Device> rankDevices = {device};
const auto key = getKeyFromDevices(rankDevices);
LOG(INFO) << "Performing nocolor split on backend device " << device
<< ", key " << key << ", i am " << this;
auto comm = getNCCLComm(key, rankDevices, OpType::ALLREDUCE);
TORCH_CHECK_WITH(
DistBackendError,
comm.size() == 1,
"exactly one communicator found for device ",
device);
NCCLComm::split(comm[0].get(), NCCL_SPLIT_NOCOLOR, rank_, options_->config);
#endif
}
void ProcessGroupNCCL::runHealthCheck() {
// Run health check in a separate thread and wait on CV to handle timeouts,
// since majority of getNCCLComm failures are hangs.
@ -836,7 +873,8 @@ void ProcessGroupNCCL::runHealthCheck() {
HealthCheckData healthCheckData;
auto t = std::thread([&healthCheckData, this]() {
try {
std::vector<at::Device> rankDevice = {getDeviceForRank(rank_)};
std::vector<at::Device> rankDevice = {guessDeviceForRank()};
const auto key = getKeyFromDevices(rankDevice);
// OpType does not matter, only need to set to not go through send/recv
// path.
@ -1615,6 +1653,17 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
"Not able to create/get the NCCL Communicator since "
"the GPU devices are not known");
}
if (bound_device_id_) {
for (const auto& device : devices) {
if (*bound_device_id_ != device) {
LOG(ERROR) << "Tensor found on device " << device
<< " but backend constrained to " << *bound_device_id_;
C10_THROW_ERROR(
DistBackendError,
"Attempt to perform collective on tensor not on device passed to init_process_group");
}
}
}
for (auto& device : devices) {
usedDeviceIdxs_.insert(device.index());
@ -3336,7 +3385,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::barrier(const BarrierOptions& opts) {
" to perform barrier as devices used by this process are currently unknown. ",
"This can potentially cause a hang if this rank to GPU mapping is incorrect.",
"Specify device_ids in barrier() to force use of a particular device.");
devices.emplace_back(getDeviceForRank(rank_));
devices.emplace_back(guessDeviceForRank());
} else {
for (auto usedDeviceIdx : usedDeviceIdxs_) {
devices.emplace_back(at::DeviceType::CUDA, usedDeviceIdx);

View File

@ -398,6 +398,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
return std::string(NCCL_BACKEND_NAME);
}
bool supportsSplitting() const override {
return true;
}
void startCoalescing() override;
c10::intrusive_ptr<Work> endCoalescing() override;
@ -542,6 +546,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
void shutdown();
void eagerConnectSingleDevice(at::Device device) override;
void performNocolorSplit(at::Device device);
protected:
// Helper that broadcasts nccl unique ID to all ranks through the store
void broadcastUniqueNCCLID(
@ -644,6 +652,15 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// object might get destroyed before the WorkNCCL object.
void ncclCommWatchdog();
// Return the CUDA device most likely associated with this backend.
// If we aren't bound to a specific device, there is no strict
// guarantee that this heuristic is the correct assignment of ranks
// to GPUs that Python layers use, but in practice it tends to be.
// Fortunately we don't rely on this for correctness of any tensor
// operations, just for ancillary uses like health checks and
// barriers.
at::Device guessDeviceForRank() const;
// Performs a health check by initializing dummy NCCL communicators and then
// destroying them. This will help indicate and signal any NCCL-related issues
// prior to the first collective. The actual initialization and subsequent

View File

@ -1807,6 +1807,10 @@ Arguments:
"group_name",
&::c10d::ProcessGroup::getGroupName,
"(Gets this process group name. It's cluster unique)")
.def_property(
"bound_device_id",
&::c10d::ProcessGroup::getBoundDeviceId,
&::c10d::ProcessGroup::setBoundDeviceId)
.def("boxed", [](c10::intrusive_ptr<::c10d::ProcessGroup> self) {
return torch::jit::toPyObject(c10::IValue(std::move(self)));
})
@ -1871,6 +1875,10 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
.def("rank", &::c10d::Backend::getRank)
.def("size", &::c10d::Backend::getSize)
.def("name", &::c10d::Backend::getBackendName)
.def_property_readonly(
"supports_splitting",
&::c10d::Backend::supportsSplitting,
"(test whether the backend supports splitting)")
.def(
"broadcast",
&::c10d::Backend::broadcast,
@ -2140,6 +2148,9 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
py::arg("timeout") = ::c10d::kUnsetTimeout,
py::arg("wait_all_ranks") = false,
py::call_guard<py::gil_scoped_release>())
.def(
"eager_connect_single_device",
&::c10d::Backend::eagerConnectSingleDevice)
.def(
"_get_backend_name",
&::c10d::Backend::getBackendName,
@ -2300,7 +2311,14 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
py::arg("timeout_mil_sec"),
py::call_guard<py::gil_scoped_release>())
.def_property_readonly(
"options", &::c10d::ProcessGroupNCCL::getOptions);
"options", &::c10d::ProcessGroupNCCL::getOptions)
.def_property(
"bound_device_id",
&::c10d::ProcessGroupNCCL::getBoundDeviceId,
&::c10d::ProcessGroupNCCL::setBoundDeviceId)
.def(
"perform_nocolor_split",
&::c10d::ProcessGroupNCCL::performNocolorSplit);
#ifdef NCCL_HAS_COMM_CTA_CGA
py::class_<ncclConfig_t>(

View File

@ -1056,6 +1056,7 @@ def init_process_group(
store: Optional[Store] = None,
group_name: str = "",
pg_options: Optional[Any] = None,
device_id: Optional[torch.device] = None,
):
"""
Initialize the default distributed process group.
@ -1189,7 +1190,8 @@ def init_process_group(
store,
group_name,
pg_options=pg_options,
timeout=timeout
timeout=timeout,
device_id=device_id,
)
_update_default_pg(default_pg)
@ -1218,6 +1220,26 @@ def init_process_group(
# default devices and messes up NCCL internal state.
_store_based_barrier(rank, store, group_name, world_size, timeout)
def _get_split_source(pg):
split_from = None
if pg.bound_device_id:
split_from = pg._get_backend(pg.bound_device_id)
elif pg is _world.default_pg:
try:
split_from = pg._get_backend(torch.device("cuda"))
except RuntimeError:
# no cuda device associated with this backend
pass
if not split_from or not split_from.supports_splitting:
return None
# If necessary, find a backend to split from by peeling process
# group wrappers from our potentially wrapped process group.
while isinstance(split_from, _ProcessGroupWrapper):
split_from = split_from.wrapped_pg
return split_from
def _new_process_group_helper(
group_size,
@ -1228,7 +1250,8 @@ def _new_process_group_helper(
group_name,
pg_options=None,
timeout=None,
pg_tag=None
pg_tag=None,
device_id=None,
):
"""
Create a new distributed process group.
@ -1247,6 +1270,10 @@ def _new_process_group_helper(
"created, please use a different group name"
)
if device_id is not None and (device_id.index is None or device_id.type != 'cuda'):
raise ValueError("init_process_group device_id parameter must be a cuda device with an "
"id, e.g. cuda:0, not just cuda or cpu")
# Note: _new_process_group_helper is only called from init_process_group, which always provides a timeout value
_check_valid_timeout(timeout)
@ -1260,17 +1287,41 @@ def _new_process_group_helper(
# The list of group ranks is empty if we're creating the default group.
is_default_group = len(global_ranks_in_group) == 0
# nccl and potentially other backends allow creation of
# communicators based on pre-existing ones, which can save
# initialization time. Due to lazy initialization of
# communicators in some backends, we have to be careful and only
# split when we *know* the backends already are connected _on all
# ranks_. We can only know this if the group we are making is the
# entire world or if we have bound a device id to the world (which
# causes early connection initialization).
if (is_initialized() and
(len(global_ranks_in_group) == _world.default_pg.size() or _world.default_pg.bound_device_id)):
split_from = _get_split_source(_world.default_pg)
else:
split_from = None
# If this is a subgroup (which means group_ranks is specified),
# we check if the current process is a member of the new group.
if not is_default_group:
global_rank = _get_default_group().rank()
if global_rank not in global_ranks_in_group:
# If we are using `ncclCommSplit` (or similar split from
# other APIs) to create the communicator, we will need to
# call `ncclCommSplit` on *all* ranks in this new group's
# parent group, even those not in the new group. This is
# a requirement of the NCCL API as otherwise we would get
# out of sync.
if split_from:
split_from.perform_nocolor_split(_world.default_pg.bound_device_id)
return GroupMember.NON_GROUP_MEMBER, None
prefix_store = PrefixStore(f"{group_name}/", store)
base_pg_options = ProcessGroup.Options(backend=str(backend))
base_pg_options._timeout = timeout
pg: ProcessGroup = ProcessGroup(prefix_store, group_rank, group_size, base_pg_options)
if device_id:
pg.bound_device_id = device_id
backend_config = BackendConfig(backend)
for device, backend_str in backend_config.get_device_backend_map().items():
# Use the group name as prefix in the default store, such that
@ -1315,24 +1366,6 @@ def _new_process_group_helper(
pg_options.is_high_priority_stream = False
pg_options._timeout = timeout
# If our new group includes all ranks, we can reduce
# overhead by splitting the communicator (`nccCommSplit`).
# TODO: support this in the general case by calling
# `nccCommSplit` with `NCCL_SPLIT_NOCOLOR` for the ranks
# not in the communicator.
split_from = None
if (
is_initialized()
and _world.default_pg._get_backend_name() == Backend.NCCL
and len(global_ranks_in_group) == _world.default_pg.size()
):
# If possible, find a backend to split from by peeling
# process group wrappers from the world's default pg.
split_from = _world.default_pg._get_backend(_get_pg_default_device())
while isinstance(split_from, _ProcessGroupWrapper):
split_from = split_from.wrapped_pg
if split_from:
pg_options.split_from = split_from
pg_options.split_color = _process_group_color(global_ranks_in_group)
@ -1411,6 +1444,10 @@ def _new_process_group_helper(
pg._register_backend(torch.device(device), backend_type, backend_class)
if device_id and pg._get_backend(device_id).supports_splitting:
eager_backend = pg._get_backend(device_id)
eager_backend.eager_connect_single_device(device_id)
# update global state
assert group_name is not None
_world.pg_map[pg] = (backend, prefix_store)