[Reland] Launch kernel on current stream & remove record_stream entirely (#150398)

Relanding #148590 due to merge conflict.

This PR has multiple changes to `ProcessGroupNCCL` (which unfortunately are related):
1. When async_op=False, we directly launch the collective on "current" stream, instead of a trampoline stream and join back.
- Resolves #147729
- Resolves #146881
- Also saves two event syncs (which have overhead in case of HIP) and one pybind when we call `work.wait()` in distributed_c10d.py on behalf of user.
2. Entirely remove `record_stream` and use CPU-side stashing for managing tensor lifetime against recycling.
- Resolves #147168
3. Remove tensor life management when async_op=False; only use it when async_op=True.
4. To guard against user not calling `work.wait()`, we ask watchdog to unstash tensors after detecting completion of collectives, to prevent us from holding reference to tensors forever. This is a safety net, rather than a service guarantee, see discussion [here](https://github.com/pytorch/pytorch/issues/147168#issuecomment-2660142460).
5. Profile in async_op=False mode would look different -- collective kernels would show up in the same line and compute kernels.

Joint work with @cenzhaometa who wants to remove the event sync overhead.

Squashed contents:

* [ptd][nccl] use current-stream as nccl-stream under async=False mode (#147820)
PTD current workflow:
- PTD creates its own dedicated `ncclStream` for comm operation
- it will first add a dependency on current-stream (typically the compute stream) to ensure tensors are ready before invoking collective
such stream synchronization become expensive in Inference world (cpu overhead: 70us vs GPU kernel time: 160us).
This diff:
- async=False [default], will use current-stream as nccl-stream and avoid the stream-sync overhead
- async=True, will retain existing logic: create new nccl-stream, let it wait on current-stream to ensure tensors are ready
- pass down async from c10d down to NCCL-PG
this helps shave off 50% CPU overhead **(70us -> 35us)**, which reduce total CPU/GPU from **230us to 195us by 15%**

* [PGNCCL] Make avoid-record-stream default

* [c10d] Add asyncOp argument to Ops

* Change python side wait

* Pass asyncOp at ProcessGroup level

* Watchdog unstashing tensors as a safety net

* Stash tensors for reduce_scatter_v and all_gather_v
Pull Request approved: https://github.com/pytorch/pytorch/pull/149753

* [c10d] Move unstashing from watchdog to main thread
Pull Request approved: https://github.com/pytorch/pytorch/pull/150079

* [PGNCCL][BE] Merge mutex into TensorShelf for encapsulation
Pull Request approved: https://github.com/pytorch/pytorch/pull/150130

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150398
Approved by: https://github.com/atalman
This commit is contained in:
Ke Wen
2025-03-31 23:58:44 -07:00
committed by PyTorch MergeBot
parent 7382654ebc
commit 35c45a4a31
12 changed files with 521 additions and 363 deletions

View File

@ -224,6 +224,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
const std::optional<at::Tensor>& sparse_indices,
bool,
int64_t)>();
auto work = std::get<1>(op.call(
@ -231,6 +232,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.sparseIndices,
opts.asyncOp,
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -250,12 +252,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
bool,
int64_t)>();
auto work = op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.asyncOp,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -277,6 +281,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ReduceOp>&,
int64_t,
int64_t,
bool,
int64_t)>();
auto work = op.call(
tensors,
@ -284,6 +289,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.rootRank,
opts.rootTensor,
opts.asyncOp,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -306,12 +312,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const std::vector<std::vector<at::Tensor>>&,
at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
bool,
int64_t)>();
auto work = std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.asyncOp,
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -363,18 +371,19 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
std::vector<std::vector<at::Tensor>>& outputTensorLists,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::allgather_coalesced_", "")
.typed<c10::intrusive_ptr<Work>(
const std::vector<std::vector<at::Tensor>>&,
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::allgather_coalesced_", "")
.typed<c10::intrusive_ptr<Work>(
const std::vector<std::vector<at::Tensor>>&,
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
bool)>();
auto work = op.call(
outputTensorLists,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.asyncOp);
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor_list : outputTensorLists) {
@ -399,12 +408,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
.typed<c10::intrusive_ptr<Work>(
const at::TensorList,
const at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
bool)>();
auto work = op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.asyncOp);
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : outputTensors) {
@ -425,12 +436,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t,
bool,
int64_t)>();
auto work = op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.rootRank,
opts.asyncOp,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -487,12 +500,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const std::vector<std::vector<at::Tensor>>&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
bool,
int64_t)>();
auto work = std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
opts.asyncOp,
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -546,6 +561,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
bool,
int64_t)>();
auto work = op.call(
@ -553,6 +569,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
opts.asyncOp,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -577,6 +594,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
std::vector<int64_t>,
std::vector<int64_t>,
bool,
int64_t)>();
auto work = op.call(
outputBuffer,
@ -584,6 +602,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
outputSplitSizes,
inputSplitSizes,
opts.asyncOp,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -604,11 +623,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList&,
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
bool,
int64_t)>();
auto work = std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.asyncOp,
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -778,12 +799,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
at::Tensor,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const std::vector<int64_t>&,
bool,
int64_t)>();
auto work = op.call(
tensor,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.device_ids,
opts.asyncOp,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
c10d::register_work(tensor, work);