From cca33d50b91db6291deabe2cc3963ca8179b849b Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 9 Dec 2024 10:56:33 -0800 Subject: [PATCH] [PGNCCL] Use long/short wait for different non-blocking calls (#142291) In nonblocking mode, we always check if the NCCL communicator is ready between issuing commands to it. Today this is done by the `waitReady()` function. Unfortunately, the `waitReady()` function is burned with `C10D_NCCL_CHECK_TIMEOUT_SLEEP` which would sleep for an interval between two consecutive checks. While this is nice when waiting for comm init or finalize, it degrades performance of collective calls (which would almost certainly return success immediately.) This PR adds a `bool longInterval` argument to `waitReady` and let call site determine whether long wait is likely; if not, `waitReady` would use `sched_yield()` to more eagerly check for readiness. Thanks @eqy for reporting an issue that small collectives has perf impact in nonblocking mode. Pull Request resolved: https://github.com/pytorch/pytorch/pull/142291 Approved by: https://github.com/eqy, https://github.com/fduwjj --- torch/csrc/distributed/c10d/NCCLUtils.cpp | 16 +++++++++++++--- torch/csrc/distributed/c10d/NCCLUtils.hpp | 8 +++++++- torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 3 ++- 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index 7630b7e553aa..f93b78dbc458 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -25,7 +25,9 @@ ncclComm_t NCCLComm::getNcclComm() { } // In non-blocking mode, ensure comm is ready. if (nonBlocking_) { - waitReady(); + // Wait with long interval if communicator is being initialized. + bool longInterval = !initialized_; + waitReady(longInterval); // ncclComm_ should be initialized by now } if (!initialized_) { @@ -38,12 +40,20 @@ ncclComm_t NCCLComm::getNcclComm() { return ncclComm_; } -void NCCLComm::waitReady() { +// Wait for the communicator to be ready. This is a blocking function. +// Arguments: +// longInterval: if true, wait with sleep of an interval; otherwise, wait +// with `sched_yield` which is faster (but acquires CPU more frequently). +void NCCLComm::waitReady(bool longInterval) { LockType lock(mutex_); if (aborted_) return; // If timeout is reached, throw an exception. - C10D_NCCL_CHECK_TIMEOUT_SLEEP(ncclInProgress, ncclComm_, std::nullopt); + if (longInterval) { + C10D_NCCL_CHECK_TIMEOUT_SLEEP(ncclInProgress, ncclComm_, std::nullopt); + } else { + C10D_NCCL_CHECK_TIMEOUT(ncclInProgress, ncclComm_, std::nullopt); + } } // TODO: why do we have `!defined(FBCODE_CAFFE2)` here? diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index 3bfd97cd4cfb..77690ba4a7be 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -329,7 +329,13 @@ class NCCLComm { // Wait for the communicator to be ready. This is a blocking function. // Useful in nonblocking mode: NCCL requires the communicator to be ready // before issuing a second command. - void waitReady(); + // Arguments: + // longInterval: if true, wait with sleep of an interval; otherwise, wait + // with `sched_yield` which is faster (but acquires CPU more frequently). + // Use `longInterval=true` when waiting for initialization or finalize to + // complete. Use `longInterval=false` when waiting collective call to return + // ncclSuccess. + void waitReady(bool longInterval); std::optional getNcclCommFailureReason() const { LockType lock(mutex_); diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 9edc7d49d532..f6e7a1f0f7b3 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1450,7 +1450,8 @@ void ProcessGroupNCCL::shutdown() { // timeout is reach, this will throw an exception. for (auto& it : devNCCLCommMap_) { auto& ncclComm = it.second; - ncclComm->waitReady(); + // Use long interval to avoid acquiring CPU too frequently + ncclComm->waitReady(true); } // Tell watchdog to (1) flush its queue and (2) do not use comm objects // anymore because I am going to destroy them now