Revert "[PGNCCL] Launch kernel on current stream & remove record_stream entirely (#148590)"

This reverts commit ef6296e7f20d744a0cfed81cab573d60204e7626.

Reverted https://github.com/pytorch/pytorch/pull/148590 on behalf of https://github.com/izaitsevfb due to reverted internally, see D71292427 ([comment](https://github.com/pytorch/pytorch/pull/148590#issuecomment-2731114626))
This commit is contained in:
PyTorch MergeBot
2025-03-17 22:43:15 +00:00
parent a16ada41b9
commit afa1eda901
11 changed files with 362 additions and 411 deletions

View File

@ -382,6 +382,9 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// Clone of blockingWait_ from ProcessGroupNCCL.
bool blockingWait_{false};
// Clone of avoidRecordStreams_ from ProcessGroupNCCL.
bool avoidRecordStreams_{false};
// Clone of opTimeout_ from ProcessGroupNCCL.
std::chrono::milliseconds opTimeout_{};
@ -428,13 +431,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// exception_ptr.
bool finishedGPUExecutionInternal() const;
// Stash tensors so that CachingAllocator cannot recycle them prematurely.
// Used in case of async ops.
void stashTensors(std::vector<at::Tensor>& tensors);
// Unstage the stashed tensors so that CachingAllocator can recycle them
void unstashTensors();
// Reference to the store so that we can write aborted communicators
// to the store.
c10::intrusive_ptr<Store> store_;
@ -454,9 +450,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// For in-place collectives, some refs stashed here may alias outputs_,
// but that doesn't do any harm.
std::shared_ptr<std::vector<at::Tensor>> stashed_for_allocator_safety_;
// Need a mutex to protect stashed_for_allocator_safety_ because it can be
// accessed from both main thread and watchdog thread.
std::mutex stashMutex_;
// The future returned by getFuture.
c10::intrusive_ptr<at::ivalue::Future> future_;
@ -885,8 +878,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
at::Tensor& output,
Fn fn,
OpType opType,
bool asyncOp,
const char* profilingTitle = nullptr,
bool avoidRecordStreams = false,
bool nanCheck = true);
template <typename Fn, typename PreProcess, typename PostProcess>
@ -897,8 +890,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
PreProcess pre,
PostProcess post,
OpType opType,
bool asyncOp,
const char* profilingTitle = nullptr,
bool avoidRecordStreams = false,
bool nanCheck = true);
template <typename Fn, typename PreProcess, typename PostProcess>
@ -909,8 +902,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
PreProcess pre,
PostProcess post,
OpType opType,
bool asyncOp,
const char* profilingTitle = nullptr,
bool avoidRecordStreams = false,
bool nanCheck = true);
template <typename Fn>
@ -919,8 +912,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
std::vector<at::Tensor>& output,
Fn fn,
OpType opType,
bool asyncOp,
const char* profilingTitle = nullptr);
const char* profilingTitle = nullptr,
bool avoidRecordStreams = false);
// Helper that encapsulates work shared across point-to-point communication
// primitives. It is the same structure as the helper used for collective
@ -1229,9 +1222,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// Stores communicators for all collectives run inside a coalescing block
std::shared_ptr<NCCLComm> coalescedComm_ = nullptr;
// Whether the coalesced calls are sync or async.
bool coalescedAsync_;
// Whether or not wait() and synchronize() are blocking operations that wait
// for the operation to complete.
bool blockingWait_ = false;