[c10d] allow nonblocking wrap of ncclCommInitRankConfig (#118256)

resolve #117749

Summary:
Updated the PR with the following intentions:

1. identify eagerMode init (as opposed to lazy init), in which case we will create NCCL comms without guarantees that they are fully initialized if NONBLOCKING mode is also enabled.
2. Python users can do their other works (e.g., model init) between invoking init_process_group and their first collective call.
3. c10D would guarantee/wait for communicators to be initialized before issuing the first collective call.
4. For NCCL collective calls, the contract between python users and c10d is not changed much from blocking calls (C10d would wait the NCCL call to be ncclSuccess, or timeout, whichever happens first).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118256
Approved by: https://github.com/kwen2501
This commit is contained in:
Shuqiang Zhang
2024-01-29 14:33:10 -08:00
committed by PyTorch MergeBot
parent e632d0c0dc
commit c7af626a26
4 changed files with 121 additions and 6 deletions

View File

@ -1354,6 +1354,52 @@ class ProcessGroupNCCLTest(MultiProcessTestCase):
self.assertEqual(backend.comm_split_count(), 1)
self.assertEqual(tensor, original_tensor)
@requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit")
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_non_blocking_init(self):
# Test creating a pg using nonblocking mode but not eagerly
os.environ["TORCH_NCCL_USE_COMM_NONBLOCKING"] = "1"
os.environ["TORCH_NCCL_NONBLOCKING_TIMEOUT"] = "100"
store = c10d.FileStore(self.file_name, self.world_size)
device = self.rank_to_GPU[self.rank][0]
pg = self._create_process_group_nccl(store, self.opts())
backend = pg._get_backend(torch.device(device))
self.assertEqual(backend.comm_split_count(), 0)
reduce_tensor = torch.rand(10, 10, device=device)
# Run an allreduce, which should trigger a comm init for pg
pg.allreduce(reduce_tensor).wait()
new_pg = c10d.new_group()
# even after pg's collective call, new pg's comm is not initialized until its own collectcive calls
self.assertEqual(backend.comm_split_count(), 0)
broadcast_tensor = torch.tensor([self.rank]).cuda(device)
new_pg.broadcast(broadcast_tensor, 0).wait()
self.assertEqual(backend.comm_split_count(), 1)
@requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit")
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_non_blocking_with_eager_init(self):
# Test creating a pg eagerly with nonblocking mode when
# we've passed a specific device_id to init_process_group.
os.environ["TORCH_NCCL_USE_COMM_NONBLOCKING"] = "1"
os.environ["TORCH_NCCL_NONBLOCKING_TIMEOUT"] = "100"
store = c10d.FileStore(self.file_name, self.world_size)
device = torch.device(f'cuda:{self.rank}')
# bound device to triger eager init mode
pg = self._create_process_group_nccl(store, self.opts(), device_id=device)
backend = pg._get_backend(torch.device(device))
self.assertEqual(backend.comm_split_count(), 0)
reduce_tensor = torch.rand(10, 10, device=device)
# Run an allreduce, comm should have already started initilizaing,
# but allreduce is issued to CUDA STREAM only after the initialization is a success
pg.allreduce(reduce_tensor).wait()
new_pg = c10d.new_group()
# even after pg's collective call, new pg's comm is not initialized until its own collectcive calls
self.assertEqual(backend.comm_split_count(), 0)
broadcast_tensor = torch.tensor([self.rank]).cuda(device)
new_pg.broadcast(broadcast_tensor, 0).wait()
self.assertEqual(backend.comm_split_count(), 1)
class DistributedDataParallelTest(
test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase

View File

@ -9,6 +9,10 @@
#include <cuda_runtime.h>
#include <mutex>
namespace {
constexpr int64_t kCommInitBusyWaitMillis = 10;
} // namespace
namespace c10d {
ncclComm_t NCCLComm::getNcclComm() {
@ -26,9 +30,39 @@ ncclComm_t NCCLComm::getNcclComm() {
". ",
commFailureMsg));
}
// only wait for initialization if nonblocking mode is enabled
if (!initialized_ && nccl_use_nonblocking()) {
waitUntilInitialized(nccl_nonblocking_timeout());
}
return ncclComm_;
}
void NCCLComm::waitUntilInitialized(int timeoutSecs) {
auto startTimepoint = std::chrono::steady_clock::now();
while (!initialized_) {
if (ncclComm_) {
ncclResult_t result;
ncclCommGetAsyncError(ncclComm_, &result);
if (result == ncclSuccess) {
LOG(INFO) << "Rank " << rank_ << ": NCCL communicator is initialized.";
initialized_ = true;
break;
}
}
auto currentTimepoint = std::chrono::steady_clock::now();
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(
currentTimepoint - startTimepoint)
.count();
if (timeElapsed > timeoutSecs) {
std::string err = "NCCL timeout in communicator initialization.";
TORCH_CHECK_WITH(DistBackendError, false, err);
}
std::this_thread::sleep_for(
std::chrono::milliseconds(kCommInitBusyWaitMillis));
}
}
std::string getNcclVersion() {
static c10::once_flag ncclGetVersionFlag;
static std::string versionString;

View File

@ -7,6 +7,7 @@
#include <memory>
#include <mutex>
#include <thread>
#include <ATen/ATen.h>
#include <c10/util/Exception.h>
@ -86,6 +87,18 @@
} \
} while (0)
// Macro to throw on a non-successful NCCL return value for NONBLOCKING calls.
#define C10D_NCCL_CHECK_NONBLOCKING(cmd, failureReason) \
do { \
ncclResult_t result = cmd; \
if (result != ncclSuccess && result != ncclInProgress) { \
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \
"\n" + getNcclErrorDetailStr(result, failureReason); \
TORCH_CHECK_WITH(DistBackendError, false, err); \
} \
} while (0)
// Macro to throw on a non-successful NCCL return value, non-blocking.
#define C10D_NCCL_CHECK_TIMEOUT(cmd, comm, failureReason) \
ncclResult_t result = cmd; \
@ -209,7 +222,8 @@ class NCCLComm {
: ncclComm_(ncclComm),
aborted_(false),
ncclAsyncErr_(ncclSuccess),
commFailureReason_(c10::nullopt) {}
commFailureReason_(c10::nullopt),
initialized_(false) {}
NCCLComm() : NCCLComm(nullptr) {}
@ -239,6 +253,7 @@ class NCCLComm {
c10::nullopt);
comm->ncclId_ = commId;
comm->rank_ = rank;
comm->initialized_ = true;
return comm;
}
@ -249,21 +264,26 @@ class NCCLComm {
ncclUniqueId commId,
ncclConfig_t& config) {
auto comm = std::make_shared<NCCLComm>();
bool isInitialized = false;
if (nccl_use_nonblocking()) {
config.blocking = 0;
C10D_NCCL_CHECK_TIMEOUT(
LOG(INFO) << "Rank " << rank
<< ": creating NCCL communicator in nonblocking mode";
C10D_NCCL_CHECK_NONBLOCKING(
ncclCommInitRankConfig(
&(comm->ncclComm_), numRanks, commId, rank, &config),
comm->ncclComm_,
c10::nullopt);
} else {
C10D_NCCL_CHECK(
ncclCommInitRankConfig(
&(comm->ncclComm_), numRanks, commId, rank, &config),
c10::nullopt);
// under blocking mode, comm is initialized after NCCL CHECK
isInitialized = true;
}
comm->ncclId_ = commId;
comm->rank_ = rank;
comm->initialized_ = isInitialized;
return comm;
}
#endif
@ -280,6 +300,7 @@ class NCCLComm {
source->ncclComm_, color_id, rank, &(comm->ncclComm_), &config),
c10::nullopt);
++source->ncclCommSplitCounter_;
comm->rank_ = rank;
return comm;
}
#endif
@ -303,6 +324,7 @@ class NCCLComm {
std::swap(ncclComm_, other.ncclComm_);
std::swap(aborted_, other.aborted_);
std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
std::swap(initialized_, other.initialized_);
}
ncclComm_t getNcclComm();
@ -446,6 +468,8 @@ class NCCLComm {
friend class ProcessGroupNCCL;
protected:
// a helper function to wait until the communicator is initialized;
void waitUntilInitialized(int timeoutSecs);
ncclComm_t ncclComm_;
// Unique nccl_id for this communicator.
ncclUniqueId ncclId_;
@ -458,6 +482,7 @@ class NCCLComm {
// Optional reason for communicator failure, provided by ProcessGroupNCCL for
// better error messaging.
c10::optional<std::string> commFailureReason_;
bool initialized_{false};
#ifdef NCCL_HAS_COMM_REGISTER
// Stores handlers for tensors registered by NCCL
std::unordered_map<void*, void*> registeredSegmentHandles_;

View File

@ -1709,7 +1709,15 @@ std::exception_ptr ProcessGroupNCCL::checkForNCCLErrorsInternal(
*commFailureReason)));
}
ncclResult_t ncclAsyncErr = ncclComm->checkForNcclError();
// When nonblocking mode is enabled by TORCH_NCCL_USE_COMM_NONBLOCKING,
// ncclInProgress could be returned when there are pending NCCL calls.
// In this case, no exception should be thrown
#ifdef NCCL_HAS_COMM_NONBLOCKING
// ncclInProgress is defined only if NCCL_HAS_COMM_NONBLOCKING is defined
if (ncclAsyncErr != ncclSuccess && ncclAsyncErr != ncclInProgress) {
#else
if (ncclAsyncErr != ncclSuccess) {
#endif
return std::make_exception_ptr(C10_BUILD_ERROR(
DistBackendError,
"NCCL error: " + ncclGetErrorWithVersion(ncclAsyncErr) + "\n" +
@ -1975,10 +1983,12 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
#ifndef NCCL_HAS_COMM_NONBLOCKING
C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
#else
if (!nccl_use_nonblocking()) {
C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
if (nccl_use_nonblocking()) {
// If we use nonblocking mode, allow communicators to be
// uninitialized/ncclInProgress until the first communication
C10D_NCCL_CHECK_NONBLOCKING(ncclGroupEnd(), c10::nullopt);
} else {
C10D_NCCL_CHECK_TIMEOUT_GROUPEND(ncclGroupEnd(), ncclComms, c10::nullopt);
C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
}
#endif