mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[c10d] PGNCCL refactor part 2: Simplify ProcessGroupNCCL into single-device style (#119421)"
This reverts commit f3e7d809936d9f1bf63102e8afe241e13ed8766a. Reverted https://github.com/pytorch/pytorch/pull/119421 on behalf of https://github.com/DanilBaibak due to Broken trunk ([comment](https://github.com/pytorch/pytorch/pull/119421#issuecomment-1938169747))
This commit is contained in:
@ -20,18 +20,20 @@ constexpr int kNcclErrorHandlingVersion = 2400;
|
||||
class WorkNCCLSimulateErrors : public c10d::ProcessGroupNCCL::WorkNCCL {
|
||||
public:
|
||||
WorkNCCLSimulateErrors(
|
||||
at::Device& device,
|
||||
const std::vector<at::Device>& devices,
|
||||
bool simulate_error,
|
||||
int rank,
|
||||
c10d::OpType opType,
|
||||
uint64_t seq)
|
||||
: WorkNCCL(device, rank, opType, seq), simulateError_(simulate_error) {}
|
||||
: WorkNCCL(devices, rank, opType, seq), simulateError_(simulate_error) {}
|
||||
|
||||
std::exception_ptr checkForNCCLErrors() override {
|
||||
std::exception_ptr checkForNCCLErrors(
|
||||
const std::vector<std::shared_ptr<c10d::NCCLComm>>& ncclComms)
|
||||
const override {
|
||||
if (simulateError_) {
|
||||
return std::make_exception_ptr(std::runtime_error("Error"));
|
||||
}
|
||||
return c10d::ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors();
|
||||
return c10d::ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors(ncclComms);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -48,11 +50,11 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL {
|
||||
: ProcessGroupNCCL(store, rank, size, opts), simulateError_(false) {}
|
||||
|
||||
std::exception_ptr checkForNCCLErrors(
|
||||
std::shared_ptr<c10d::NCCLComm>& ncclComm) override {
|
||||
const std::vector<std::shared_ptr<c10d::NCCLComm>>& ncclComms) override {
|
||||
if (simulateError_) {
|
||||
return std::make_exception_ptr(std::runtime_error("Error"));
|
||||
}
|
||||
return c10d::ProcessGroupNCCL::checkForNCCLErrors(ncclComm);
|
||||
return c10d::ProcessGroupNCCL::checkForNCCLErrors(ncclComms);
|
||||
}
|
||||
|
||||
std::chrono::duration<int64_t, std::milli> getWatchdogSleepInterval() {
|
||||
@ -61,14 +63,14 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL {
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
|
||||
at::Device& device,
|
||||
std::vector<at::Device> devices,
|
||||
int rank,
|
||||
c10d::OpType opType,
|
||||
const char* profilingTitle,
|
||||
const std::vector<at::Tensor>& inputs = {},
|
||||
const std::vector<at::Tensor>& outputs = {}) override {
|
||||
return c10::make_intrusive<WorkNCCLSimulateErrors>(
|
||||
device, simulateError_, rank, opType, seq_);
|
||||
devices, simulateError_, rank, opType, seq_);
|
||||
}
|
||||
|
||||
size_t getNCCLCommCacheSize() {
|
||||
@ -90,12 +92,12 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL {
|
||||
class WorkNCCLTimedoutErrors : public c10d::ProcessGroupNCCL::WorkNCCL {
|
||||
public:
|
||||
WorkNCCLTimedoutErrors(
|
||||
at::Device& device,
|
||||
const std::vector<at::Device>& devices,
|
||||
bool set_timedout_error,
|
||||
int rank,
|
||||
c10d::OpType opType,
|
||||
uint64_t seq)
|
||||
: WorkNCCL(device, rank, opType, seq),
|
||||
: WorkNCCL(devices, rank, opType, seq),
|
||||
setTimedoutError_(set_timedout_error) {}
|
||||
|
||||
private:
|
||||
@ -122,14 +124,14 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors {
|
||||
setTimedoutError_(false) {}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
|
||||
at::Device& device,
|
||||
std::vector<at::Device> devices,
|
||||
int rank,
|
||||
c10d::OpType opType,
|
||||
const char* profilingTitle,
|
||||
const std::vector<at::Tensor>& inputs = {},
|
||||
const std::vector<at::Tensor>& outputs = {}) override {
|
||||
return c10::make_intrusive<WorkNCCLTimedoutErrors>(
|
||||
device, setTimedoutError_, rank, opType, seq_);
|
||||
devices, setTimedoutError_, rank, opType, seq_);
|
||||
}
|
||||
|
||||
void setTimedoutError() {
|
||||
|
@ -2947,10 +2947,6 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||
def blocking_wait_error_msg(self):
|
||||
return "timeout"
|
||||
|
||||
@property
|
||||
def remote_error_msg(self):
|
||||
return "remote process exit"
|
||||
|
||||
def _run_all_reduce(self, pg):
|
||||
pg.allreduce(torch.rand(10).cuda(self.rank))
|
||||
|
||||
@ -2999,9 +2995,8 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||
process_group.allreduce(torch.rand(10).cuda(self.rank))
|
||||
if self.rank == 0:
|
||||
work = process_group.allreduce(torch.rand(10).cuda(self.rank))
|
||||
with self.assertRaisesRegex(dist.DistBackendError, self.remote_error_msg):
|
||||
# Previously this should timeout; but with newer NCCL version,
|
||||
# it seems NCCL would detect that the peer rank has exited
|
||||
with self.assertRaisesRegex(dist.DistBackendError, self.blocking_wait_error_msg):
|
||||
# Operation would time out in blocking mode.
|
||||
work.wait(timeout=timedelta(seconds=self.op_timeout_sec))
|
||||
# Run some GPU operations to make sure cuda has not gotten stuck.
|
||||
# It was observed cuda could get stuck if NCCL communicators were
|
||||
@ -3069,9 +3064,8 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||
)
|
||||
process_group.barrier().wait()
|
||||
if self.rank == 0:
|
||||
with self.assertRaisesRegex(dist.DistBackendError, self.remote_error_msg):
|
||||
# Previously this should timeout; but with newer NCCL version,
|
||||
# it seems NCCL would detect that the peer rank has exited
|
||||
with self.assertRaisesRegex(dist.DistBackendError, self.blocking_wait_error_msg):
|
||||
# This should timeout
|
||||
process_group.barrier().wait(timeout=timedelta(seconds=self.op_timeout_sec))
|
||||
|
||||
def _run_invalid_nccl_blocking_wait_env(self, val):
|
||||
|
@ -415,18 +415,20 @@ AutoNcclGroup::AutoNcclGroup() {
|
||||
(c10::cuda::getFreeMutex())->lock();
|
||||
#endif
|
||||
comm_nonblocking_ = false;
|
||||
comm_ = nullptr;
|
||||
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
|
||||
detail::NCCL_CHECK(ncclGroupStart());
|
||||
#endif
|
||||
}
|
||||
|
||||
AutoNcclGroup::AutoNcclGroup(ncclComm_t comm, bool comm_nonblocking) {
|
||||
AutoNcclGroup::AutoNcclGroup(
|
||||
std::vector<ncclComm_t>& comms,
|
||||
bool comm_nonblocking) {
|
||||
#if defined(NCCL_MAJOR) && (NCCL_MAJOR < 2)
|
||||
// nccl < 2.0 cannot be called concurrently with cudaFree
|
||||
(c10::cuda::getFreeMutex())->lock();
|
||||
#endif
|
||||
comm_ = comm;
|
||||
// TODO(eqy): can we make comms_ reference?
|
||||
comms_ = comms;
|
||||
comm_nonblocking_ = comm_nonblocking;
|
||||
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
|
||||
detail::NCCL_CHECK(ncclGroupStart());
|
||||
@ -435,10 +437,10 @@ AutoNcclGroup::AutoNcclGroup(ncclComm_t comm, bool comm_nonblocking) {
|
||||
|
||||
AutoNcclGroup::~AutoNcclGroup() noexcept(false) {
|
||||
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
|
||||
if (comm_nonblocking_ && comm_ != nullptr) {
|
||||
detail::NCCL_CHECK_TIMEOUT(ncclGroupEnd(), comm_);
|
||||
} else {
|
||||
if (!comm_nonblocking_) {
|
||||
detail::NCCL_CHECK(ncclGroupEnd());
|
||||
} else {
|
||||
detail::NCCL_CHECK_TIMEOUT(ncclGroupEnd(), comms_);
|
||||
}
|
||||
#endif
|
||||
#if defined(NCCL_MAJOR) && (NCCL_MAJOR < 2)
|
||||
|
@ -76,9 +76,9 @@ enum class ncclDataType {
|
||||
// manages group and lock lifetimes.
|
||||
struct AutoNcclGroup {
|
||||
AutoNcclGroup();
|
||||
AutoNcclGroup(ncclComm_t comm, bool comm_nonblocking);
|
||||
AutoNcclGroup(std::vector<ncclComm_t>& comms, bool comm_nonblocking);
|
||||
~AutoNcclGroup() noexcept(false);
|
||||
ncclComm_t comm_;
|
||||
std::vector<ncclComm_t> comms_;
|
||||
bool comm_nonblocking_;
|
||||
};
|
||||
|
||||
|
@ -126,10 +126,11 @@
|
||||
TORCH_CHECK_WITH(DistBackendError, false, err); \
|
||||
}
|
||||
|
||||
#define C10D_NCCL_CHECK_TIMEOUT_GROUPEND(cmd, comm, failureReason) \
|
||||
#define C10D_NCCL_CHECK_TIMEOUT_GROUPEND(cmd, comms_, failureReason) \
|
||||
ncclResult_t state = cmd; \
|
||||
auto startTimepoint = std::chrono::steady_clock::now(); \
|
||||
if (state == ncclInProgress) { \
|
||||
for (const auto i : c10::irange(comms_.size())) { \
|
||||
do { \
|
||||
if (nccl_nonblocking_timeout() > 0) { \
|
||||
auto currentTimepoint = std::chrono::steady_clock::now(); \
|
||||
@ -144,8 +145,12 @@
|
||||
TORCH_CHECK_WITH(DistBackendError, false, err); \
|
||||
} \
|
||||
} \
|
||||
ncclCommGetAsyncError(comm->getNcclComm(), &state); \
|
||||
ncclCommGetAsyncError(comms_[i]->getNcclComm(), &state); \
|
||||
} while (state == ncclInProgress); \
|
||||
if (state != ncclSuccess) { \
|
||||
break; /* fall through to failed case */ \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
if (state != ncclSuccess) { \
|
||||
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -177,7 +177,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
// Constructor takes a list of CUDA devices
|
||||
WorkNCCL(
|
||||
at::Device& device,
|
||||
const std::vector<at::Device>& devices,
|
||||
int rank,
|
||||
OpType opType,
|
||||
uint64_t seq,
|
||||
@ -214,7 +214,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
void synchronize() override;
|
||||
|
||||
// Synchronize streams by blocking each on the NCCL stream
|
||||
void synchronizeStream();
|
||||
void synchronizeStreams();
|
||||
|
||||
// Helper function to handle exception (throw if needed).
|
||||
void handleException(ErrorHandlingMode asyncErrorHandling);
|
||||
@ -245,20 +245,22 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
protected:
|
||||
// The cached list of CUDA devices to operate on
|
||||
at::Device device_;
|
||||
std::vector<at::Device> devices_;
|
||||
|
||||
// The start CUDA event of NCCL operator tracking this work item. These
|
||||
// start CUDA events are needed by desync debugging if enabled.
|
||||
std::shared_ptr<at::cuda::CUDAEvent> ncclStartEvent_;
|
||||
// The start CUDA events of NCCL operator tracking this work item on
|
||||
// multiple CUDA devices. These start CUDA events are needed by desync
|
||||
// debugging if enabled.
|
||||
std::shared_ptr<std::vector<at::cuda::CUDAEvent>> ncclStartEvents_;
|
||||
|
||||
// The end CUDA event of NCCL operator tracking this work item.
|
||||
std::shared_ptr<at::cuda::CUDAEvent> ncclEndEvent_;
|
||||
// The end CUDA events of NCCL operator tracking this work item on
|
||||
// multiple CUDA devices.
|
||||
std::shared_ptr<std::vector<at::cuda::CUDAEvent>> ncclEndEvents_;
|
||||
|
||||
// The NCCL communicator used for this work item.
|
||||
std::shared_ptr<NCCLComm> ncclComm_;
|
||||
// The NCCL communicators used for this work item.
|
||||
std::vector<std::shared_ptr<NCCLComm>> ncclComms_;
|
||||
|
||||
// Tensors used for barrier op
|
||||
at::Tensor barrierTensor_;
|
||||
std::vector<at::Tensor> barrierTensors_;
|
||||
|
||||
// Clone of blockingWait_ from ProcessGroupNCCL.
|
||||
bool blockingWait_ = false;
|
||||
@ -286,7 +288,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
// Wrapper method for the static checkForNCCLErrors which can be overridden
|
||||
// for tests.
|
||||
virtual std::exception_ptr checkForNCCLErrors();
|
||||
virtual std::exception_ptr checkForNCCLErrors(
|
||||
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) const;
|
||||
|
||||
friend std::ostream& operator<<(
|
||||
std::ostream& output,
|
||||
@ -318,7 +321,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// TORCH_NCCL_AVOID_RECORD_STREAMS implementation helper.
|
||||
// Stores references to participating non-output tensors (ie inputs,
|
||||
// flattened intermediates).
|
||||
// We'll clear this list in synchronizeStream, just after user-facing
|
||||
// We'll clear this list in synchronizeStreams, just after user-facing
|
||||
// stream(s) are synced with the nccl work stream(s).
|
||||
// By keeping these refs (as well as outputs_) alive until after the
|
||||
// collective's work rejoins the user-facing streams, we achieve
|
||||
@ -413,16 +416,13 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
c10::intrusive_ptr<Work> endCoalescing() override;
|
||||
|
||||
// For specifying a composite optype, such as ALLGATHER and REDUCE_SCATTER
|
||||
c10::intrusive_ptr<Work> endCoalescing(OpType optype);
|
||||
|
||||
c10::intrusive_ptr<Work> broadcast(
|
||||
std::vector<at::Tensor>& tensors,
|
||||
const BroadcastOptions& opts = BroadcastOptions()) override;
|
||||
|
||||
c10::intrusive_ptr<Work> _broadcast_oop(
|
||||
at::Tensor& outputTensors,
|
||||
at::Tensor& inputTensors,
|
||||
std::vector<at::Tensor>& outputTensors,
|
||||
std::vector<at::Tensor>& inputTensors,
|
||||
const BroadcastOptions& opts = BroadcastOptions());
|
||||
|
||||
c10::intrusive_ptr<Work> allreduce_sparse(
|
||||
@ -443,8 +443,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
const ReduceOptions& opts = ReduceOptions()) override;
|
||||
|
||||
c10::intrusive_ptr<Work> _reduce_oop(
|
||||
at::Tensor& outputTensors,
|
||||
at::Tensor& inputTensors,
|
||||
std::vector<at::Tensor>& outputTensors,
|
||||
std::vector<at::Tensor>& inputTensors,
|
||||
const ReduceOptions& opts = ReduceOptions());
|
||||
|
||||
c10::intrusive_ptr<Work> allgather(
|
||||
@ -511,8 +511,9 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
void groupEnd();
|
||||
|
||||
void groupEndNonblocking(std::shared_ptr<NCCLComm> comm);
|
||||
void groupEndNonblocking(std::vector<std::shared_ptr<NCCLComm>> comms);
|
||||
|
||||
// Unsupported Ops
|
||||
c10::intrusive_ptr<Work> gather(
|
||||
std::vector<std::vector<at::Tensor>>& outputTensors,
|
||||
std::vector<at::Tensor>& inputTensors,
|
||||
@ -523,7 +524,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
std::vector<std::vector<at::Tensor>>& inputTensors,
|
||||
const ScatterOptions& opts = ScatterOptions()) override;
|
||||
|
||||
// Unsupported Ops
|
||||
c10::intrusive_ptr<Work> recvAnysource(
|
||||
std::vector<at::Tensor>& tensors,
|
||||
int tag) override;
|
||||
@ -549,7 +549,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
// Helper function for iteratively aborting communicators in the provided map
|
||||
void abortCommsFromMap(
|
||||
std::unordered_map<std::string, std::shared_ptr<NCCLComm>>& ncclCommsMap,
|
||||
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>&
|
||||
ncclCommsMap,
|
||||
c10::optional<std::string> abortReason);
|
||||
|
||||
c10::intrusive_ptr<intra_node_comm::IntraNodeComm> initIntraNodeComm();
|
||||
@ -574,19 +575,19 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
// Helper that either looks up the cached NCCL communicators or creates
|
||||
// a new set of NCCL communicators as a cache entry
|
||||
std::shared_ptr<NCCLComm> getNCCLComm(
|
||||
const std::string& deviceKey,
|
||||
at::Device& device,
|
||||
std::vector<std::shared_ptr<NCCLComm>>& getNCCLComm(
|
||||
const std::string& devicesKey,
|
||||
const std::vector<at::Device>& devices,
|
||||
OpType opType,
|
||||
int p2pRank = 0,
|
||||
bool isSendRecvSelf = false);
|
||||
|
||||
// Wrapper method which can be overridden for tests.
|
||||
virtual std::exception_ptr checkForNCCLErrors(
|
||||
std::shared_ptr<NCCLComm>& ncclComm);
|
||||
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms);
|
||||
|
||||
virtual c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
|
||||
at::Device& device,
|
||||
std::vector<at::Device> devices,
|
||||
int rank,
|
||||
OpType opType,
|
||||
const char* profilingTitle = nullptr,
|
||||
@ -605,8 +606,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// void {pre,post}(std::vector<at::cuda::CUDAStream&>);
|
||||
template <typename Fn>
|
||||
c10::intrusive_ptr<Work> collective(
|
||||
at::Tensor& input,
|
||||
at::Tensor& output,
|
||||
std::vector<at::Tensor>& input,
|
||||
std::vector<at::Tensor>& output,
|
||||
Fn fn,
|
||||
OpType opType,
|
||||
const char* profilingTitle = nullptr,
|
||||
@ -614,20 +615,11 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
template <typename Fn, typename PreProcess, typename PostProcess>
|
||||
c10::intrusive_ptr<Work> collective(
|
||||
at::Tensor& input,
|
||||
at::Tensor& output,
|
||||
Fn fn,
|
||||
PreProcess pre,
|
||||
PostProcess post,
|
||||
OpType opType,
|
||||
const char* profilingTitle = nullptr,
|
||||
bool avoidRecordStreams = false);
|
||||
|
||||
template <typename Fn>
|
||||
c10::intrusive_ptr<Work> collectiveCoalesced(
|
||||
std::vector<at::Tensor>& input,
|
||||
std::vector<at::Tensor>& output,
|
||||
Fn fn,
|
||||
PreProcess pre,
|
||||
PostProcess post,
|
||||
OpType opType,
|
||||
const char* profilingTitle = nullptr,
|
||||
bool avoidRecordStreams = false);
|
||||
@ -637,15 +629,14 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// communication primitives.
|
||||
template <typename Fn>
|
||||
c10::intrusive_ptr<Work> pointToPoint(
|
||||
at::Tensor& tensor,
|
||||
std::vector<at::Tensor>& tensor,
|
||||
Fn fn,
|
||||
int peer,
|
||||
OpType opType,
|
||||
const char* profilingTitle = nullptr);
|
||||
|
||||
template <typename Fn, typename PreProcess, typename PostProcess>
|
||||
c10::intrusive_ptr<Work> pointToPoint(
|
||||
at::Tensor& tensor,
|
||||
std::vector<at::Tensor>& tensor,
|
||||
Fn fn,
|
||||
int peer,
|
||||
OpType opType,
|
||||
@ -654,13 +645,13 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
const char* profilingTitle);
|
||||
|
||||
c10::intrusive_ptr<Work> allreduce_impl(
|
||||
at::Tensor& tensor,
|
||||
std::vector<at::Tensor>& tensors,
|
||||
const AllreduceOptions& opts = AllreduceOptions());
|
||||
|
||||
// Checks for NCCL errors on each of the communicators and returns an
|
||||
// appropriate exception_ptr (nullptr if no errors).
|
||||
static std::exception_ptr checkForNCCLErrorsInternal(
|
||||
std::shared_ptr<NCCLComm>& ncclComm);
|
||||
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms);
|
||||
|
||||
// Function that runs as part of a separate thread and checks for errors on
|
||||
// NCCL communicators. We need a separate thread to check for NCCL errors
|
||||
@ -803,14 +794,16 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// communication, the key will be "1:2" on both processes. Note: this is for
|
||||
// the scenario where there is only 1 GPU per process. When it comes to
|
||||
// multiple GPUs per process, this part may need to redesigned.
|
||||
std::unordered_map<std::string, std::shared_ptr<NCCLComm>> devNCCLCommMap_;
|
||||
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>
|
||||
devNCCLCommMap_;
|
||||
|
||||
// The NCCL communicators currently in process of being initialized.
|
||||
std::unordered_map<std::string, std::shared_ptr<NCCLComm>>
|
||||
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>
|
||||
inInitializationCommMap_;
|
||||
|
||||
// Map from ncclUniqueId to appropriate communicator.
|
||||
std::unordered_map<std::string, std::shared_ptr<NCCLComm>> ncclIdToCommMap_;
|
||||
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>
|
||||
ncclIdToCommMap_;
|
||||
|
||||
// Mutex to guard maps like devNCCLCommMap_ and ncclIdToCommMap_.
|
||||
std::mutex mutex_;
|
||||
@ -888,10 +881,11 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
void workEnqueue(c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>);
|
||||
|
||||
// The CUDA streams used by NCCL kernels
|
||||
std::unordered_map<std::string, at::cuda::CUDAStream> ncclStreams_;
|
||||
std::unordered_map<std::string, std::vector<at::cuda::CUDAStream>>
|
||||
ncclStreams_;
|
||||
|
||||
// The CUDA events used to sync NCCL streams
|
||||
std::unordered_map<std::string, at::cuda::CUDAEvent> ncclEvents_;
|
||||
std::unordered_map<std::string, std::vector<at::cuda::CUDAEvent>> ncclEvents_;
|
||||
|
||||
// Device Indexes used for all collectives in this group
|
||||
std::set<int> usedDeviceIdxs_;
|
||||
@ -900,10 +894,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
int coalescing_state_ = 0;
|
||||
|
||||
// Stores device indexes for all collectives run inside a coalescing block
|
||||
std::vector<at::Device> coalescedDevices_;
|
||||
std::vector<std::vector<at::Device>> coalescedDevices_;
|
||||
|
||||
// Stores communicators for all collectives run inside a coalescing block
|
||||
std::vector<std::shared_ptr<NCCLComm>> coalescedComms_;
|
||||
std::vector<std::vector<std::shared_ptr<NCCLComm>>> coalescedComms_;
|
||||
|
||||
// map from the key: "group name + pg counter (ID)" to the
|
||||
// unique NCCL ID count. This needs to be group and pg specific
|
||||
|
@ -269,13 +269,19 @@ inline std::string retrieveDesyncReport(
|
||||
#ifdef USE_C10D_NCCL
|
||||
|
||||
/* Helper used by work::getDuration() and nccl flight recorder */
|
||||
float getDurationFromEvent(
|
||||
at::cuda::CUDAEvent& ncclStartEvent,
|
||||
at::cuda::CUDAEvent& ncclEndEvent) {
|
||||
float getDurationFromFirstEvent(
|
||||
const std::vector<at::cuda::CUDAEvent>& ncclStartEvents,
|
||||
const std::vector<at::cuda::CUDAEvent>& ncclEndEvents) {
|
||||
TORCH_CHECK(
|
||||
ncclEndEvent.query(),
|
||||
ncclStartEvents.size() == 1,
|
||||
"getDuration only works for single device per ProcessGroup, but found multiple start events.");
|
||||
TORCH_CHECK(
|
||||
ncclEndEvents.size() == 1,
|
||||
"getDuration only works for single device per ProcessGroup, but found multiple end events.");
|
||||
TORCH_CHECK(
|
||||
ncclEndEvents[0].query(),
|
||||
"getDuration can only be called after work is succeeded.")
|
||||
return ncclStartEvent.elapsed_time(ncclEndEvent);
|
||||
return ncclStartEvents[0].elapsed_time(ncclEndEvents[0]);
|
||||
}
|
||||
|
||||
DebugInfoWriter::~DebugInfoWriter() = default;
|
||||
@ -379,7 +385,7 @@ struct NCCLTraceBuffer {
|
||||
capture_cpp_stack_ = getCvarBool({"TORCH_NCCL_TRACE_CPP_STACK"}, false);
|
||||
enabled_ = max_entries_ > 0;
|
||||
}
|
||||
using Event = at::cuda::CUDAEvent;
|
||||
using EventList = std::vector<at::cuda::CUDAEvent>;
|
||||
struct Entry {
|
||||
size_t id_; // incremented id in the trace buffer
|
||||
// used to figure out where in the circular entries
|
||||
@ -393,7 +399,7 @@ struct NCCLTraceBuffer {
|
||||
// we borrow pointers to start_ and end_ so we can query the state
|
||||
// on reporting. However, once the event is completed, the call
|
||||
// to `complete` will clear these.
|
||||
Event *start_, *end_;
|
||||
EventList *start_, *end_;
|
||||
|
||||
// timestamp when the entry was created, likely close to the time the work
|
||||
// was 'enqueued'- not necessarily started
|
||||
@ -433,8 +439,8 @@ struct NCCLTraceBuffer {
|
||||
const char* profiling_name,
|
||||
const std::vector<at::Tensor>& inputs,
|
||||
const std::vector<at::Tensor>& outputs,
|
||||
Event* start,
|
||||
Event* end) {
|
||||
EventList* start,
|
||||
EventList* end) {
|
||||
if (!enabled_) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
@ -477,13 +483,25 @@ struct NCCLTraceBuffer {
|
||||
|
||||
void update_state(Entry& r) {
|
||||
if (r.start_ != nullptr) {
|
||||
bool started = r.start_->query();
|
||||
bool started = true;
|
||||
for (auto& ev : *r.start_) {
|
||||
if (!ev.query()) {
|
||||
started = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (started && !r.time_discovered_started_) {
|
||||
r.time_discovered_started_ = c10::getTime();
|
||||
}
|
||||
}
|
||||
if (r.end_ != nullptr) {
|
||||
bool completed = r.end_->query();
|
||||
bool completed = true;
|
||||
for (auto& ev : *r.end_) {
|
||||
if (!ev.query()) {
|
||||
completed = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (completed && !r.time_discovered_completed_) {
|
||||
r.time_discovered_completed_ = c10::getTime();
|
||||
}
|
||||
@ -522,8 +540,8 @@ struct NCCLTraceBuffer {
|
||||
}
|
||||
|
||||
bool can_compute_duration = false;
|
||||
Event* startEvent = nullptr;
|
||||
Event* endEvent = nullptr;
|
||||
EventList* startEvents = nullptr;
|
||||
EventList* endEvents = nullptr;
|
||||
c10::optional<float> duration = c10::nullopt;
|
||||
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
@ -535,8 +553,8 @@ struct NCCLTraceBuffer {
|
||||
if (compute_duration) {
|
||||
can_compute_duration = entry.time_discovered_completed_.has_value() &&
|
||||
entry.start_ && entry.end_;
|
||||
startEvent = entry.start_;
|
||||
endEvent = entry.end_;
|
||||
startEvents = entry.start_;
|
||||
endEvents = entry.end_;
|
||||
}
|
||||
}
|
||||
|
||||
@ -545,7 +563,7 @@ struct NCCLTraceBuffer {
|
||||
// cudaEventDuration() can hang, and we need to acquire the lock before we
|
||||
// can dump(), which we never want to block.
|
||||
guard.unlock();
|
||||
duration = getDurationFromEvent(*startEvent, *endEvent);
|
||||
duration = getDurationFromFirstEvent(*startEvents, *endEvents);
|
||||
guard.lock();
|
||||
|
||||
// Refresh the entry ref, see if it has been overwritten
|
||||
|
Reference in New Issue
Block a user