[PG/nccl] Simplify uniqueHash management (#156790)

Summary:

ncclUniqueID is only relevant when a comm is created using ncclCommCreate or ncclCommCreateConfig.  If a comm is created with ncclCommSplit, this field is unset, causing its usage to create unexpected behavior.

This patch creates a unique hash key for each comm, irrespective of how the comm is created.

Test Plan:

CI

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156790
Approved by: https://github.com/fduwjj, https://github.com/kwen2501
This commit is contained in:
Pavan Balaji
2025-06-25 20:06:03 +00:00
committed by PyTorch MergeBot
parent 070aa59e49
commit e99a2a2dba
3 changed files with 31 additions and 21 deletions

View File

@ -39,8 +39,24 @@ NCCLComm::NCCLComm(NCCLComm&& other) {
std::swap(deviceIndex_, other.deviceIndex_);
}
ncclUniqueId NCCLComm::getNcclId() {
return ncclId_;
void NCCLComm::setUniqueHash(ncclUniqueId ncclId) {
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(&ncclId);
fmt::memory_buffer buf;
buf.reserve(NCCL_UNIQUE_ID_BYTES * 2); // 2 hex chars per byte
for (int i = 0; i < NCCL_UNIQUE_ID_BYTES; ++i) {
fmt::format_to(
std::back_inserter(buf), "{:02x}", static_cast<int>(bytes[i]));
}
this->uniqueHash_ = fmt::to_string(buf);
}
void NCCLComm::setUniqueHash(std::string hash) {
this->uniqueHash_ = std::move(hash);
}
std::string NCCLComm::getUniqueHash() {
return uniqueHash_;
}
std::shared_ptr<NCCLComm> NCCLComm::create(
@ -53,7 +69,7 @@ std::shared_ptr<NCCLComm> NCCLComm::create(
C10D_NCCL_CHECK(
ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank),
std::nullopt);
comm->ncclId_ = commId;
comm->setUniqueHash(commId);
comm->rank_ = rank;
comm->deviceIndex_ = deviceIndex;
comm->initialized_ = true;
@ -78,7 +94,7 @@ std::shared_ptr<NCCLComm> NCCLComm::create(
ncclCommInitRankConfig(
&(comm->ncclComm_), numRanks, commId, rank, &config),
std::nullopt);
comm->ncclId_ = commId;
comm->setUniqueHash(commId);
comm->rank_ = rank;
comm->deviceIndex_ = deviceIndex;
// Under blocking mode, comm is initialized immediately after NCCL init
@ -112,7 +128,7 @@ std::shared_ptr<NCCLComm> NCCLComm::create_scalable(
// Only the first ncclUniqueId will be used to create the
// communicator hash id, which is used to identify the communicator
// in the log file and in the replay tool.
comm->ncclId_ = commIds[0];
comm->setUniqueHash(commIds[0]);
comm->rank_ = rank;
comm->deviceIndex_ = deviceIndex;
comm->initialized_ = !comm->nonBlocking_;
@ -237,6 +253,9 @@ std::shared_ptr<NCCLComm> NCCLComm::split(
// Child comm should be on the same device as parent comm
comm->deviceIndex_ = source->deviceIndex_;
comm->nonBlocking_ = config.blocking == 0;
comm->setUniqueHash(
source->getUniqueHash() + ":" +
std::to_string(source->ncclCommSplitCounter_));
LOG(INFO) << "Rank " << source->rank_ << ": created child comm "
<< comm->repr() << " with color_id " << color_id;
return comm;

View File

@ -259,6 +259,10 @@ class NCCLComm {
~NCCLComm() noexcept;
void setUniqueHash(ncclUniqueId ncclId);
void setUniqueHash(std::string hash);
std::string getUniqueHash();
static std::shared_ptr<NCCLComm> create(
int numRanks,
int rank,
@ -295,7 +299,6 @@ class NCCLComm {
std::unordered_map<std::string, std::string> ncclCommDump();
#endif
ncclUniqueId getNcclId();
at::DeviceIndex getDeviceIndex();
// Must not be copyable
@ -355,8 +358,8 @@ class NCCLComm {
friend class ProcessGroupNCCL;
protected:
// Unique nccl_id for this communicator.
ncclUniqueId ncclId_{};
// Unique hash for this communicator.
std::string uniqueHash_;
bool aborted_{false};
uint64_t ncclCommSplitCounter_{0};
ncclResult_t ncclAsyncErr_{ncclSuccess};

View File

@ -217,17 +217,6 @@ void syncStream(
ncclEvent.block(ncclStream);
}
// Given a ncclUniqueId, convert it to a string representation that can be put
// in the store.
std::string buildNcclUniqueIdStr(const ncclUniqueId& ncclID) {
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(&ncclID);
std::ostringstream oss;
for (const auto i : c10::irange(NCCL_UNIQUE_ID_BYTES)) {
oss << std::hex << static_cast<int>(bytes[i]);
}
return oss.str();
}
std::string getNcclAbortedCommStoreKey(const std::string& ncclIdStr) {
return std::string(kNCCLAbortedCommStoreKey) + ":" + ncclIdStr;
}
@ -382,8 +371,7 @@ static std::
}
}
for (auto& ncclComm : allNCCLComms) {
std::string ncclUniqueIDStr = buildNcclUniqueIdStr(ncclComm->getNcclId());
ncclDumpMap[ncclUniqueIDStr] = ncclComm->ncclCommDump();
ncclDumpMap[ncclComm->getUniqueHash()] = ncclComm->ncclCommDump();
}
return ncclDumpMap;
#else