mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[PG/nccl] Simplify uniqueHash management (#156790)
Summary: ncclUniqueID is only relevant when a comm is created using ncclCommCreate or ncclCommCreateConfig. If a comm is created with ncclCommSplit, this field is unset, causing its usage to create unexpected behavior. This patch creates a unique hash key for each comm, irrespective of how the comm is created. Test Plan: CI Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/156790 Approved by: https://github.com/fduwjj, https://github.com/kwen2501
This commit is contained in:
committed by
PyTorch MergeBot
parent
070aa59e49
commit
e99a2a2dba
@ -39,8 +39,24 @@ NCCLComm::NCCLComm(NCCLComm&& other) {
|
||||
std::swap(deviceIndex_, other.deviceIndex_);
|
||||
}
|
||||
|
||||
ncclUniqueId NCCLComm::getNcclId() {
|
||||
return ncclId_;
|
||||
void NCCLComm::setUniqueHash(ncclUniqueId ncclId) {
|
||||
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(&ncclId);
|
||||
|
||||
fmt::memory_buffer buf;
|
||||
buf.reserve(NCCL_UNIQUE_ID_BYTES * 2); // 2 hex chars per byte
|
||||
for (int i = 0; i < NCCL_UNIQUE_ID_BYTES; ++i) {
|
||||
fmt::format_to(
|
||||
std::back_inserter(buf), "{:02x}", static_cast<int>(bytes[i]));
|
||||
}
|
||||
this->uniqueHash_ = fmt::to_string(buf);
|
||||
}
|
||||
|
||||
void NCCLComm::setUniqueHash(std::string hash) {
|
||||
this->uniqueHash_ = std::move(hash);
|
||||
}
|
||||
|
||||
std::string NCCLComm::getUniqueHash() {
|
||||
return uniqueHash_;
|
||||
}
|
||||
|
||||
std::shared_ptr<NCCLComm> NCCLComm::create(
|
||||
@ -53,7 +69,7 @@ std::shared_ptr<NCCLComm> NCCLComm::create(
|
||||
C10D_NCCL_CHECK(
|
||||
ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank),
|
||||
std::nullopt);
|
||||
comm->ncclId_ = commId;
|
||||
comm->setUniqueHash(commId);
|
||||
comm->rank_ = rank;
|
||||
comm->deviceIndex_ = deviceIndex;
|
||||
comm->initialized_ = true;
|
||||
@ -78,7 +94,7 @@ std::shared_ptr<NCCLComm> NCCLComm::create(
|
||||
ncclCommInitRankConfig(
|
||||
&(comm->ncclComm_), numRanks, commId, rank, &config),
|
||||
std::nullopt);
|
||||
comm->ncclId_ = commId;
|
||||
comm->setUniqueHash(commId);
|
||||
comm->rank_ = rank;
|
||||
comm->deviceIndex_ = deviceIndex;
|
||||
// Under blocking mode, comm is initialized immediately after NCCL init
|
||||
@ -112,7 +128,7 @@ std::shared_ptr<NCCLComm> NCCLComm::create_scalable(
|
||||
// Only the first ncclUniqueId will be used to create the
|
||||
// communicator hash id, which is used to identify the communicator
|
||||
// in the log file and in the replay tool.
|
||||
comm->ncclId_ = commIds[0];
|
||||
comm->setUniqueHash(commIds[0]);
|
||||
comm->rank_ = rank;
|
||||
comm->deviceIndex_ = deviceIndex;
|
||||
comm->initialized_ = !comm->nonBlocking_;
|
||||
@ -237,6 +253,9 @@ std::shared_ptr<NCCLComm> NCCLComm::split(
|
||||
// Child comm should be on the same device as parent comm
|
||||
comm->deviceIndex_ = source->deviceIndex_;
|
||||
comm->nonBlocking_ = config.blocking == 0;
|
||||
comm->setUniqueHash(
|
||||
source->getUniqueHash() + ":" +
|
||||
std::to_string(source->ncclCommSplitCounter_));
|
||||
LOG(INFO) << "Rank " << source->rank_ << ": created child comm "
|
||||
<< comm->repr() << " with color_id " << color_id;
|
||||
return comm;
|
||||
|
@ -259,6 +259,10 @@ class NCCLComm {
|
||||
|
||||
~NCCLComm() noexcept;
|
||||
|
||||
void setUniqueHash(ncclUniqueId ncclId);
|
||||
void setUniqueHash(std::string hash);
|
||||
std::string getUniqueHash();
|
||||
|
||||
static std::shared_ptr<NCCLComm> create(
|
||||
int numRanks,
|
||||
int rank,
|
||||
@ -295,7 +299,6 @@ class NCCLComm {
|
||||
std::unordered_map<std::string, std::string> ncclCommDump();
|
||||
#endif
|
||||
|
||||
ncclUniqueId getNcclId();
|
||||
at::DeviceIndex getDeviceIndex();
|
||||
|
||||
// Must not be copyable
|
||||
@ -355,8 +358,8 @@ class NCCLComm {
|
||||
friend class ProcessGroupNCCL;
|
||||
|
||||
protected:
|
||||
// Unique nccl_id for this communicator.
|
||||
ncclUniqueId ncclId_{};
|
||||
// Unique hash for this communicator.
|
||||
std::string uniqueHash_;
|
||||
bool aborted_{false};
|
||||
uint64_t ncclCommSplitCounter_{0};
|
||||
ncclResult_t ncclAsyncErr_{ncclSuccess};
|
||||
|
@ -217,17 +217,6 @@ void syncStream(
|
||||
ncclEvent.block(ncclStream);
|
||||
}
|
||||
|
||||
// Given a ncclUniqueId, convert it to a string representation that can be put
|
||||
// in the store.
|
||||
std::string buildNcclUniqueIdStr(const ncclUniqueId& ncclID) {
|
||||
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(&ncclID);
|
||||
std::ostringstream oss;
|
||||
for (const auto i : c10::irange(NCCL_UNIQUE_ID_BYTES)) {
|
||||
oss << std::hex << static_cast<int>(bytes[i]);
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
std::string getNcclAbortedCommStoreKey(const std::string& ncclIdStr) {
|
||||
return std::string(kNCCLAbortedCommStoreKey) + ":" + ncclIdStr;
|
||||
}
|
||||
@ -382,8 +371,7 @@ static std::
|
||||
}
|
||||
}
|
||||
for (auto& ncclComm : allNCCLComms) {
|
||||
std::string ncclUniqueIDStr = buildNcclUniqueIdStr(ncclComm->getNcclId());
|
||||
ncclDumpMap[ncclUniqueIDStr] = ncclComm->ncclCommDump();
|
||||
ncclDumpMap[ncclComm->getUniqueHash()] = ncclComm->ncclCommDump();
|
||||
}
|
||||
return ncclDumpMap;
|
||||
#else
|
||||
|
Reference in New Issue
Block a user