mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Reland] Launch kernel on current stream & remove record_stream
entirely (#150398)
Relanding #148590 due to merge conflict. This PR has multiple changes to `ProcessGroupNCCL` (which unfortunately are related): 1. When async_op=False, we directly launch the collective on "current" stream, instead of a trampoline stream and join back. - Resolves #147729 - Resolves #146881 - Also saves two event syncs (which have overhead in case of HIP) and one pybind when we call `work.wait()` in distributed_c10d.py on behalf of user. 2. Entirely remove `record_stream` and use CPU-side stashing for managing tensor lifetime against recycling. - Resolves #147168 3. Remove tensor life management when async_op=False; only use it when async_op=True. 4. To guard against user not calling `work.wait()`, we ask watchdog to unstash tensors after detecting completion of collectives, to prevent us from holding reference to tensors forever. This is a safety net, rather than a service guarantee, see discussion [here](https://github.com/pytorch/pytorch/issues/147168#issuecomment-2660142460). 5. Profile in async_op=False mode would look different -- collective kernels would show up in the same line and compute kernels. Joint work with @cenzhaometa who wants to remove the event sync overhead. Squashed contents: * [ptd][nccl] use current-stream as nccl-stream under async=False mode (#147820) PTD current workflow: - PTD creates its own dedicated `ncclStream` for comm operation - it will first add a dependency on current-stream (typically the compute stream) to ensure tensors are ready before invoking collective such stream synchronization become expensive in Inference world (cpu overhead: 70us vs GPU kernel time: 160us). This diff: - async=False [default], will use current-stream as nccl-stream and avoid the stream-sync overhead - async=True, will retain existing logic: create new nccl-stream, let it wait on current-stream to ensure tensors are ready - pass down async from c10d down to NCCL-PG this helps shave off 50% CPU overhead **(70us -> 35us)**, which reduce total CPU/GPU from **230us to 195us by 15%** * [PGNCCL] Make avoid-record-stream default * [c10d] Add asyncOp argument to Ops * Change python side wait * Pass asyncOp at ProcessGroup level * Watchdog unstashing tensors as a safety net * Stash tensors for reduce_scatter_v and all_gather_v Pull Request approved: https://github.com/pytorch/pytorch/pull/149753 * [c10d] Move unstashing from watchdog to main thread Pull Request approved: https://github.com/pytorch/pytorch/pull/150079 * [PGNCCL][BE] Merge mutex into TensorShelf for encapsulation Pull Request approved: https://github.com/pytorch/pytorch/pull/150130 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150398 Approved by: https://github.com/atalman
This commit is contained in:
@ -224,6 +224,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
const c10::intrusive_ptr<::c10d::ReduceOp>&,
|
||||
const std::optional<at::Tensor>& sparse_indices,
|
||||
bool,
|
||||
int64_t)>();
|
||||
|
||||
auto work = std::get<1>(op.call(
|
||||
@ -231,6 +232,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
c10::make_intrusive<ReduceOp>(opts.reduceOp),
|
||||
opts.sparseIndices,
|
||||
opts.asyncOp,
|
||||
opts.timeout.count()));
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
@ -250,12 +252,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
at::TensorList,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
const c10::intrusive_ptr<::c10d::ReduceOp>&,
|
||||
bool,
|
||||
int64_t)>();
|
||||
|
||||
auto work = op.call(
|
||||
tensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
c10::make_intrusive<ReduceOp>(opts.reduceOp),
|
||||
opts.asyncOp,
|
||||
opts.timeout.count());
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
@ -277,6 +281,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const c10::intrusive_ptr<::c10d::ReduceOp>&,
|
||||
int64_t,
|
||||
int64_t,
|
||||
bool,
|
||||
int64_t)>();
|
||||
auto work = op.call(
|
||||
tensors,
|
||||
@ -284,6 +289,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
c10::make_intrusive<ReduceOp>(opts.reduceOp),
|
||||
opts.rootRank,
|
||||
opts.rootTensor,
|
||||
opts.asyncOp,
|
||||
opts.timeout.count());
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
@ -306,12 +312,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const std::vector<std::vector<at::Tensor>>&,
|
||||
at::TensorList,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
bool,
|
||||
int64_t)>();
|
||||
|
||||
auto work = std::get<1>(op.call(
|
||||
outputTensors,
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
opts.asyncOp,
|
||||
opts.timeout.count()));
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
@ -363,18 +371,19 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
std::vector<std::vector<at::Tensor>>& outputTensorLists,
|
||||
std::vector<at::Tensor>& inputTensors,
|
||||
const AllgatherOptions& opts = AllgatherOptions()) {
|
||||
static auto op =
|
||||
c10::Dispatcher::singleton()
|
||||
.findSchemaOrThrow("c10d::allgather_coalesced_", "")
|
||||
.typed<c10::intrusive_ptr<Work>(
|
||||
const std::vector<std::vector<at::Tensor>>&,
|
||||
const at::TensorList&,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
|
||||
static auto op = c10::Dispatcher::singleton()
|
||||
.findSchemaOrThrow("c10d::allgather_coalesced_", "")
|
||||
.typed<c10::intrusive_ptr<Work>(
|
||||
const std::vector<std::vector<at::Tensor>>&,
|
||||
const at::TensorList&,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
bool)>();
|
||||
|
||||
auto work = op.call(
|
||||
outputTensorLists,
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
opts.asyncOp);
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
for (const auto& tensor_list : outputTensorLists) {
|
||||
@ -399,12 +408,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
.typed<c10::intrusive_ptr<Work>(
|
||||
const at::TensorList,
|
||||
const at::TensorList,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
bool)>();
|
||||
|
||||
auto work = op.call(
|
||||
outputTensors,
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
opts.asyncOp);
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
for (const auto& tensor : outputTensors) {
|
||||
@ -425,12 +436,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const at::TensorList&,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
int64_t,
|
||||
bool,
|
||||
int64_t)>();
|
||||
auto work = op.call(
|
||||
outputTensors,
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
opts.rootRank,
|
||||
opts.asyncOp,
|
||||
opts.timeout.count());
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
@ -487,12 +500,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const std::vector<std::vector<at::Tensor>>&,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
const c10::intrusive_ptr<::c10d::ReduceOp>&,
|
||||
bool,
|
||||
int64_t)>();
|
||||
auto work = std::get<1>(op.call(
|
||||
outputTensors,
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
|
||||
opts.asyncOp,
|
||||
opts.timeout.count()));
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
@ -546,6 +561,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const at::TensorList,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
const c10::intrusive_ptr<::c10d::ReduceOp>&,
|
||||
bool,
|
||||
int64_t)>();
|
||||
|
||||
auto work = op.call(
|
||||
@ -553,6 +569,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
|
||||
opts.asyncOp,
|
||||
opts.timeout.count());
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
@ -577,6 +594,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
std::vector<int64_t>,
|
||||
std::vector<int64_t>,
|
||||
bool,
|
||||
int64_t)>();
|
||||
auto work = op.call(
|
||||
outputBuffer,
|
||||
@ -584,6 +602,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
outputSplitSizes,
|
||||
inputSplitSizes,
|
||||
opts.asyncOp,
|
||||
opts.timeout.count());
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
@ -604,11 +623,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const at::TensorList&,
|
||||
const at::TensorList&,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
bool,
|
||||
int64_t)>();
|
||||
auto work = std::get<1>(op.call(
|
||||
outputTensors,
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
opts.asyncOp,
|
||||
opts.timeout.count()));
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
@ -778,12 +799,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
at::Tensor,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
const std::vector<int64_t>&,
|
||||
bool,
|
||||
int64_t)>();
|
||||
|
||||
auto work = op.call(
|
||||
tensor,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
opts.device_ids,
|
||||
opts.asyncOp,
|
||||
opts.timeout.count());
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
c10d::register_work(tensor, work);
|
||||
|
Reference in New Issue
Block a user