mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[c10d] allow nonblocking wrap of ncclCommInitRankConfig (#118256)
resolve #117749 Summary: Updated the PR with the following intentions: 1. identify eagerMode init (as opposed to lazy init), in which case we will create NCCL comms without guarantees that they are fully initialized if NONBLOCKING mode is also enabled. 2. Python users can do their other works (e.g., model init) between invoking init_process_group and their first collective call. 3. c10D would guarantee/wait for communicators to be initialized before issuing the first collective call. 4. For NCCL collective calls, the contract between python users and c10d is not changed much from blocking calls (C10d would wait the NCCL call to be ncclSuccess, or timeout, whichever happens first). Pull Request resolved: https://github.com/pytorch/pytorch/pull/118256 Approved by: https://github.com/kwen2501
This commit is contained in:
committed by
PyTorch MergeBot
parent
e632d0c0dc
commit
c7af626a26
@ -1354,6 +1354,52 @@ class ProcessGroupNCCLTest(MultiProcessTestCase):
|
||||
self.assertEqual(backend.comm_split_count(), 1)
|
||||
self.assertEqual(tensor, original_tensor)
|
||||
|
||||
@requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit")
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_non_blocking_init(self):
|
||||
# Test creating a pg using nonblocking mode but not eagerly
|
||||
os.environ["TORCH_NCCL_USE_COMM_NONBLOCKING"] = "1"
|
||||
os.environ["TORCH_NCCL_NONBLOCKING_TIMEOUT"] = "100"
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
device = self.rank_to_GPU[self.rank][0]
|
||||
pg = self._create_process_group_nccl(store, self.opts())
|
||||
backend = pg._get_backend(torch.device(device))
|
||||
self.assertEqual(backend.comm_split_count(), 0)
|
||||
reduce_tensor = torch.rand(10, 10, device=device)
|
||||
# Run an allreduce, which should trigger a comm init for pg
|
||||
pg.allreduce(reduce_tensor).wait()
|
||||
new_pg = c10d.new_group()
|
||||
# even after pg's collective call, new pg's comm is not initialized until its own collectcive calls
|
||||
self.assertEqual(backend.comm_split_count(), 0)
|
||||
broadcast_tensor = torch.tensor([self.rank]).cuda(device)
|
||||
new_pg.broadcast(broadcast_tensor, 0).wait()
|
||||
self.assertEqual(backend.comm_split_count(), 1)
|
||||
|
||||
|
||||
@requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit")
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_non_blocking_with_eager_init(self):
|
||||
# Test creating a pg eagerly with nonblocking mode when
|
||||
# we've passed a specific device_id to init_process_group.
|
||||
os.environ["TORCH_NCCL_USE_COMM_NONBLOCKING"] = "1"
|
||||
os.environ["TORCH_NCCL_NONBLOCKING_TIMEOUT"] = "100"
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
device = torch.device(f'cuda:{self.rank}')
|
||||
# bound device to triger eager init mode
|
||||
pg = self._create_process_group_nccl(store, self.opts(), device_id=device)
|
||||
backend = pg._get_backend(torch.device(device))
|
||||
self.assertEqual(backend.comm_split_count(), 0)
|
||||
reduce_tensor = torch.rand(10, 10, device=device)
|
||||
# Run an allreduce, comm should have already started initilizaing,
|
||||
# but allreduce is issued to CUDA STREAM only after the initialization is a success
|
||||
pg.allreduce(reduce_tensor).wait()
|
||||
new_pg = c10d.new_group()
|
||||
# even after pg's collective call, new pg's comm is not initialized until its own collectcive calls
|
||||
self.assertEqual(backend.comm_split_count(), 0)
|
||||
broadcast_tensor = torch.tensor([self.rank]).cuda(device)
|
||||
new_pg.broadcast(broadcast_tensor, 0).wait()
|
||||
self.assertEqual(backend.comm_split_count(), 1)
|
||||
|
||||
|
||||
class DistributedDataParallelTest(
|
||||
test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
|
||||
|
@ -9,6 +9,10 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include <mutex>
|
||||
|
||||
namespace {
|
||||
constexpr int64_t kCommInitBusyWaitMillis = 10;
|
||||
} // namespace
|
||||
|
||||
namespace c10d {
|
||||
|
||||
ncclComm_t NCCLComm::getNcclComm() {
|
||||
@ -26,9 +30,39 @@ ncclComm_t NCCLComm::getNcclComm() {
|
||||
". ",
|
||||
commFailureMsg));
|
||||
}
|
||||
// only wait for initialization if nonblocking mode is enabled
|
||||
if (!initialized_ && nccl_use_nonblocking()) {
|
||||
waitUntilInitialized(nccl_nonblocking_timeout());
|
||||
}
|
||||
|
||||
return ncclComm_;
|
||||
}
|
||||
|
||||
void NCCLComm::waitUntilInitialized(int timeoutSecs) {
|
||||
auto startTimepoint = std::chrono::steady_clock::now();
|
||||
while (!initialized_) {
|
||||
if (ncclComm_) {
|
||||
ncclResult_t result;
|
||||
ncclCommGetAsyncError(ncclComm_, &result);
|
||||
if (result == ncclSuccess) {
|
||||
LOG(INFO) << "Rank " << rank_ << ": NCCL communicator is initialized.";
|
||||
initialized_ = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
auto currentTimepoint = std::chrono::steady_clock::now();
|
||||
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(
|
||||
currentTimepoint - startTimepoint)
|
||||
.count();
|
||||
if (timeElapsed > timeoutSecs) {
|
||||
std::string err = "NCCL timeout in communicator initialization.";
|
||||
TORCH_CHECK_WITH(DistBackendError, false, err);
|
||||
}
|
||||
std::this_thread::sleep_for(
|
||||
std::chrono::milliseconds(kCommInitBusyWaitMillis));
|
||||
}
|
||||
}
|
||||
|
||||
std::string getNcclVersion() {
|
||||
static c10::once_flag ncclGetVersionFlag;
|
||||
static std::string versionString;
|
||||
|
@ -7,6 +7,7 @@
|
||||
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/util/Exception.h>
|
||||
@ -86,6 +87,18 @@
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// Macro to throw on a non-successful NCCL return value for NONBLOCKING calls.
|
||||
#define C10D_NCCL_CHECK_NONBLOCKING(cmd, failureReason) \
|
||||
do { \
|
||||
ncclResult_t result = cmd; \
|
||||
if (result != ncclSuccess && result != ncclInProgress) { \
|
||||
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
|
||||
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \
|
||||
"\n" + getNcclErrorDetailStr(result, failureReason); \
|
||||
TORCH_CHECK_WITH(DistBackendError, false, err); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// Macro to throw on a non-successful NCCL return value, non-blocking.
|
||||
#define C10D_NCCL_CHECK_TIMEOUT(cmd, comm, failureReason) \
|
||||
ncclResult_t result = cmd; \
|
||||
@ -209,7 +222,8 @@ class NCCLComm {
|
||||
: ncclComm_(ncclComm),
|
||||
aborted_(false),
|
||||
ncclAsyncErr_(ncclSuccess),
|
||||
commFailureReason_(c10::nullopt) {}
|
||||
commFailureReason_(c10::nullopt),
|
||||
initialized_(false) {}
|
||||
|
||||
NCCLComm() : NCCLComm(nullptr) {}
|
||||
|
||||
@ -239,6 +253,7 @@ class NCCLComm {
|
||||
c10::nullopt);
|
||||
comm->ncclId_ = commId;
|
||||
comm->rank_ = rank;
|
||||
comm->initialized_ = true;
|
||||
return comm;
|
||||
}
|
||||
|
||||
@ -249,21 +264,26 @@ class NCCLComm {
|
||||
ncclUniqueId commId,
|
||||
ncclConfig_t& config) {
|
||||
auto comm = std::make_shared<NCCLComm>();
|
||||
bool isInitialized = false;
|
||||
if (nccl_use_nonblocking()) {
|
||||
config.blocking = 0;
|
||||
C10D_NCCL_CHECK_TIMEOUT(
|
||||
LOG(INFO) << "Rank " << rank
|
||||
<< ": creating NCCL communicator in nonblocking mode";
|
||||
C10D_NCCL_CHECK_NONBLOCKING(
|
||||
ncclCommInitRankConfig(
|
||||
&(comm->ncclComm_), numRanks, commId, rank, &config),
|
||||
comm->ncclComm_,
|
||||
c10::nullopt);
|
||||
} else {
|
||||
C10D_NCCL_CHECK(
|
||||
ncclCommInitRankConfig(
|
||||
&(comm->ncclComm_), numRanks, commId, rank, &config),
|
||||
c10::nullopt);
|
||||
// under blocking mode, comm is initialized after NCCL CHECK
|
||||
isInitialized = true;
|
||||
}
|
||||
comm->ncclId_ = commId;
|
||||
comm->rank_ = rank;
|
||||
comm->initialized_ = isInitialized;
|
||||
return comm;
|
||||
}
|
||||
#endif
|
||||
@ -280,6 +300,7 @@ class NCCLComm {
|
||||
source->ncclComm_, color_id, rank, &(comm->ncclComm_), &config),
|
||||
c10::nullopt);
|
||||
++source->ncclCommSplitCounter_;
|
||||
comm->rank_ = rank;
|
||||
return comm;
|
||||
}
|
||||
#endif
|
||||
@ -303,6 +324,7 @@ class NCCLComm {
|
||||
std::swap(ncclComm_, other.ncclComm_);
|
||||
std::swap(aborted_, other.aborted_);
|
||||
std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
|
||||
std::swap(initialized_, other.initialized_);
|
||||
}
|
||||
|
||||
ncclComm_t getNcclComm();
|
||||
@ -446,6 +468,8 @@ class NCCLComm {
|
||||
friend class ProcessGroupNCCL;
|
||||
|
||||
protected:
|
||||
// a helper function to wait until the communicator is initialized;
|
||||
void waitUntilInitialized(int timeoutSecs);
|
||||
ncclComm_t ncclComm_;
|
||||
// Unique nccl_id for this communicator.
|
||||
ncclUniqueId ncclId_;
|
||||
@ -458,6 +482,7 @@ class NCCLComm {
|
||||
// Optional reason for communicator failure, provided by ProcessGroupNCCL for
|
||||
// better error messaging.
|
||||
c10::optional<std::string> commFailureReason_;
|
||||
bool initialized_{false};
|
||||
#ifdef NCCL_HAS_COMM_REGISTER
|
||||
// Stores handlers for tensors registered by NCCL
|
||||
std::unordered_map<void*, void*> registeredSegmentHandles_;
|
||||
|
@ -1709,7 +1709,15 @@ std::exception_ptr ProcessGroupNCCL::checkForNCCLErrorsInternal(
|
||||
*commFailureReason)));
|
||||
}
|
||||
ncclResult_t ncclAsyncErr = ncclComm->checkForNcclError();
|
||||
// When nonblocking mode is enabled by TORCH_NCCL_USE_COMM_NONBLOCKING,
|
||||
// ncclInProgress could be returned when there are pending NCCL calls.
|
||||
// In this case, no exception should be thrown
|
||||
#ifdef NCCL_HAS_COMM_NONBLOCKING
|
||||
// ncclInProgress is defined only if NCCL_HAS_COMM_NONBLOCKING is defined
|
||||
if (ncclAsyncErr != ncclSuccess && ncclAsyncErr != ncclInProgress) {
|
||||
#else
|
||||
if (ncclAsyncErr != ncclSuccess) {
|
||||
#endif
|
||||
return std::make_exception_ptr(C10_BUILD_ERROR(
|
||||
DistBackendError,
|
||||
"NCCL error: " + ncclGetErrorWithVersion(ncclAsyncErr) + "\n" +
|
||||
@ -1975,10 +1983,12 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
|
||||
#ifndef NCCL_HAS_COMM_NONBLOCKING
|
||||
C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
|
||||
#else
|
||||
if (!nccl_use_nonblocking()) {
|
||||
C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
|
||||
if (nccl_use_nonblocking()) {
|
||||
// If we use nonblocking mode, allow communicators to be
|
||||
// uninitialized/ncclInProgress until the first communication
|
||||
C10D_NCCL_CHECK_NONBLOCKING(ncclGroupEnd(), c10::nullopt);
|
||||
} else {
|
||||
C10D_NCCL_CHECK_TIMEOUT_GROUPEND(ncclGroupEnd(), ncclComms, c10::nullopt);
|
||||
C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
Reference in New Issue
Block a user