[Reland] Launch kernel on current stream & remove record_stream entirely (#150398)

Relanding #148590 due to merge conflict.

This PR has multiple changes to `ProcessGroupNCCL` (which unfortunately are related):
1. When async_op=False, we directly launch the collective on "current" stream, instead of a trampoline stream and join back.
- Resolves #147729
- Resolves #146881
- Also saves two event syncs (which have overhead in case of HIP) and one pybind when we call `work.wait()` in distributed_c10d.py on behalf of user.
2. Entirely remove `record_stream` and use CPU-side stashing for managing tensor lifetime against recycling.
- Resolves #147168
3. Remove tensor life management when async_op=False; only use it when async_op=True.
4. To guard against user not calling `work.wait()`, we ask watchdog to unstash tensors after detecting completion of collectives, to prevent us from holding reference to tensors forever. This is a safety net, rather than a service guarantee, see discussion [here](https://github.com/pytorch/pytorch/issues/147168#issuecomment-2660142460).
5. Profile in async_op=False mode would look different -- collective kernels would show up in the same line and compute kernels.

Joint work with @cenzhaometa who wants to remove the event sync overhead.

Squashed contents:

* [ptd][nccl] use current-stream as nccl-stream under async=False mode (#147820)
PTD current workflow:
- PTD creates its own dedicated `ncclStream` for comm operation
- it will first add a dependency on current-stream (typically the compute stream) to ensure tensors are ready before invoking collective
such stream synchronization become expensive in Inference world (cpu overhead: 70us vs GPU kernel time: 160us).
This diff:
- async=False [default], will use current-stream as nccl-stream and avoid the stream-sync overhead
- async=True, will retain existing logic: create new nccl-stream, let it wait on current-stream to ensure tensors are ready
- pass down async from c10d down to NCCL-PG
this helps shave off 50% CPU overhead **(70us -> 35us)**, which reduce total CPU/GPU from **230us to 195us by 15%**

* [PGNCCL] Make avoid-record-stream default

* [c10d] Add asyncOp argument to Ops

* Change python side wait

* Pass asyncOp at ProcessGroup level

* Watchdog unstashing tensors as a safety net

* Stash tensors for reduce_scatter_v and all_gather_v
Pull Request approved: https://github.com/pytorch/pytorch/pull/149753

* [c10d] Move unstashing from watchdog to main thread
Pull Request approved: https://github.com/pytorch/pytorch/pull/150079

* [PGNCCL][BE] Merge mutex into TensorShelf for encapsulation
Pull Request approved: https://github.com/pytorch/pytorch/pull/150130

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150398
Approved by: https://github.com/atalman
This commit is contained in:
Ke Wen
2025-03-31 23:58:44 -07:00
committed by PyTorch MergeBot
parent 7382654ebc
commit 35c45a4a31
12 changed files with 521 additions and 363 deletions

View File

@ -363,6 +363,9 @@ class TestDebugInfoWriter : public c10d::DebugInfoWriter {
};
TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
// Note (kwen2501) 03/07/2025
// TODO: re-enable
GTEST_SKIP() << "Skipping test as the trace write seems unstable.";
int heartBeatIntervalInSec = 2;
std::string timeInterval = std::to_string(heartBeatIntervalInSec);
ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "0", 1) == 0);

View File

@ -733,6 +733,32 @@ class ProcessGroupNCCLOpTest(MultiProcContinousTest):
# fails the check because the dtype is different
reduce_scatter_base(output_t, tensor)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_reduce_scatter_v(self):
device = torch.device("cuda", self.rank_to_GPU[self.rank][0])
# A list of tensors with different sizes
input_list = [torch.ones(i, device=device) for i in range(self.world_size)]
# The i-th output should have size i
output = torch.zeros(self.rank, device=device)
work = c10d.reduce_scatter(output, input_list, group=self.pg, async_op=True)
expected = torch.ones(self.rank, device=device) * self.world_size
work.wait()
self.assertEqual(expected, output)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_all_gather_v(self):
device = torch.device("cuda", self.rank_to_GPU[self.rank][0])
# A list of tensors with different sizes
output_list = [torch.zeros(i, device=device) for i in range(self.world_size)]
# The i-th input has size i, filled with value i
input = torch.ones(self.rank, device=device) * self.rank
work = c10d.all_gather(output_list, input, group=self.pg, async_op=True)
expected = [torch.ones(i, device=device) * i for i in range(self.world_size)]
work.wait()
self.assertEqual(expected, output_list)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_reduce_scatter_ops(self):

View File

@ -126,6 +126,9 @@ ALLOW_LIST = [
("aten::reduce_scatter_tensor", datetime.date(9999, 1, 30)),
("aten::all_gather_into_tensor", datetime.date(9999, 1, 30)),
("aten::all_reduce", datetime.date(9999, 1, 30)),
# These ops are defined in torch/csrc/distributed/c10d/Ops.cpp
# TODO: add back restriction when c10d ops can be exported
("c10d::.*", datetime.date(9999, 1, 1)),
]
ALLOW_LIST_COMPILED = [

View File

@ -2,7 +2,7 @@
# mypy: disable-error-code="type-arg"
from datetime import timedelta
from enum import Enum
from typing import Any, overload
from typing import Any, Optional, overload
import torch
from torch import Tensor
@ -139,6 +139,8 @@ class BroadcastOptions:
class AllreduceOptions:
reduceOp: ReduceOp
timeout: timedelta
asyncOp: bool
sparseIndices: Optional[Tensor]
class AllreduceCoalescedOptions(AllreduceOptions): ...
@ -147,6 +149,7 @@ class ReduceOptions:
rootRank: int
rootTensor: int
timeout: timedelta
asyncOp: bool
class AllgatherOptions:
timeout: timedelta
@ -155,6 +158,7 @@ class AllgatherOptions:
class GatherOptions:
rootRank: int
timeout: timedelta
asyncOp: bool
class ScatterOptions:
rootRank: int
@ -170,9 +174,11 @@ class BarrierOptions:
device_ids: list[int]
device: torch.device
timeout: timedelta
asyncOp: bool
class AllToAllOptions:
timeout: timedelta
asyncOp: bool
class Store:
def set(self, key: str, value: str): ...

View File

@ -17,37 +17,37 @@ TORCH_LIBRARY(c10d, m) {
.def("wait", [](const c10::intrusive_ptr<Work>& self) { self->wait(); });
m.class_<ReduceOp>("ReduceOp").def(torch::init<>());
m.def(
"broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, bool asyncOp, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
"broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
m.def(
"allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
"allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
m.def(
"allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> __torch__.torch.classes.c10d.Work");
"allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
m.def(
"allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int timeout) -> (Tensor[][], __torch__.torch.classes.c10d.Work)");
"allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, int timeout=-1) -> (Tensor[][], __torch__.torch.classes.c10d.Work)");
m.def(
"_allgather_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, bool asyncOp, int timeout) -> (Tensor, __torch__.torch.classes.c10d.Work)");
"_allgather_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, int timeout=-1) -> (Tensor, __torch__.torch.classes.c10d.Work)");
m.def(
"allgather_coalesced_(Tensor[][] output_lists, Tensor[] input_list, __torch__.torch.classes.c10d.ProcessGroup process_group) -> __torch__.torch.classes.c10d.Work");
"allgather_coalesced_(Tensor[][] output_lists, Tensor[] input_list, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True) -> __torch__.torch.classes.c10d.Work");
m.def(
"allgather_into_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group) -> __torch__.torch.classes.c10d.Work");
"allgather_into_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True) -> __torch__.torch.classes.c10d.Work");
m.def(
"reduce_scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
"reduce_scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
m.def(
"_reduce_scatter_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool asyncOp, int timeout) -> (Tensor, __torch__.torch.classes.c10d.Work)");
"_reduce_scatter_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> (Tensor, __torch__.torch.classes.c10d.Work)");
m.def(
"reduce_scatter_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> __torch__.torch.classes.c10d.Work");
"reduce_scatter_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
m.def(
"reduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int root_rank, int root_tensor, int timeout) -> __torch__.torch.classes.c10d.Work");
"reduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int root_rank, int root_tensor, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
m.def(
"gather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int timeout) -> __torch__.torch.classes.c10d.Work");
"gather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
m.def(
"scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool asyncOp, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
"scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
m.def(
"alltoall_(Tensor[] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
"alltoall_(Tensor[] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
m.def(
"alltoall_base_(Tensor output, Tensor input, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] output_split_sizes, int[] input_split_sizes, int timeout) -> __torch__.torch.classes.c10d.Work");
"alltoall_base_(Tensor output, Tensor input, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] output_split_sizes, int[] input_split_sizes, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
m.def(
"barrier(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, int timeout) -> __torch__.torch.classes.c10d.Work");
"barrier(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
m.def(
"monitored_barrier_(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, int timeout, bool wait_all_ranks) -> ()");
m.def(
@ -118,6 +118,7 @@ IMPL_RECV_ANY_SOURCE(PrivateUse1)
const c10::intrusive_ptr<ReduceOp>& reduce_op, \
int64_t root_rank, \
int64_t root_tensor, \
bool asyncOp, \
int64_t timeout) { \
auto tensor_vec = tensors.vec(); \
return process_group->getBackend(c10::DeviceType::DEV) \
@ -127,7 +128,8 @@ IMPL_RECV_ANY_SOURCE(PrivateUse1)
*reduce_op.get(), \
root_rank, \
root_tensor, \
std::chrono::milliseconds(timeout)}); \
std::chrono::milliseconds(timeout), \
asyncOp}); \
}
IMPL_REDUCE(CPU)
@ -169,12 +171,13 @@ IMPL_BROADCAST(PrivateUse1)
const c10::intrusive_ptr<ProcessGroup>& process_group, \
const c10::intrusive_ptr<ReduceOp>& reduce_op, \
const std::optional<at::Tensor>& sparse_indices, \
bool asyncOp, \
int64_t timeout) { \
auto tensor_vec = tensors.vec(); \
auto work = process_group->getBackend(c10::DeviceType::DEV) -> allreduce( \
tensor_vec, \
AllreduceOptions{ \
*reduce_op.get(), std::chrono::milliseconds(timeout)}); \
*reduce_op.get(), std::chrono::milliseconds(timeout), asyncOp}); \
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( \
std::move(tensor_vec), work); \
}
@ -188,11 +191,13 @@ IMPL_ALLREDUCE(PrivateUse1)
at::TensorList tensors, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \
const c10::intrusive_ptr<ReduceOp>& reduce_op, \
bool asyncOp, \
int64_t timeout) { \
auto tensor_vec = tensors.vec(); \
AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; \
opts.reduceOp = *reduce_op.get(); \
opts.timeout = std::chrono::milliseconds(timeout); \
opts.asyncOp = asyncOp; \
return process_group->getBackend(c10::DeviceType::DEV) \
->allreduce_coalesced(tensor_vec, opts); \
}
@ -209,12 +214,13 @@ IMPL_ALLREDUCE_COALESCED(PrivateUse1)
const std::vector<std::vector<at::Tensor>>& output_tensors, \
at::TensorList input_tensors, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \
bool asyncOp, \
int64_t timeout) { \
auto input_tensors_vec = input_tensors.vec(); \
auto work = process_group->getBackend(c10::DeviceType::DEV) -> allgather( \
const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors), \
input_tensors_vec, \
AllgatherOptions{std::chrono::milliseconds(timeout)}); \
AllgatherOptions{std::chrono::milliseconds(timeout), asyncOp}); \
return std:: \
tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>>( \
output_tensors, work); \
@ -249,12 +255,16 @@ IMPL__ALLGATHER_BASE(PrivateUse1)
c10::intrusive_ptr<Work> allgather_coalesced_##DEV( \
const std::vector<std::vector<at::Tensor>>& output_lists, \
const at::TensorList& input_list, \
const c10::intrusive_ptr<ProcessGroup>& process_group) { \
const c10::intrusive_ptr<ProcessGroup>& process_group, \
bool asyncOp) { \
auto input_list_vec = input_list.vec(); \
auto opts = AllgatherOptions{}; \
opts.asyncOp = asyncOp; \
return process_group->getBackend(c10::DeviceType::DEV) \
->allgather_coalesced( \
const_cast<std::vector<std::vector<at::Tensor>>&>(output_lists), \
input_list_vec); \
input_list_vec, \
opts); \
}
IMPL_ALLGATHER_COALESCED(CPU)
@ -265,11 +275,14 @@ IMPL_ALLGATHER_COALESCED(PrivateUse1)
c10::intrusive_ptr<c10d::Work> allgather_into_tensor_coalesced_##DEV( \
at::TensorList outputs, \
at::TensorList inputs, \
const c10::intrusive_ptr<ProcessGroup>& process_group) { \
const c10::intrusive_ptr<ProcessGroup>& process_group, \
bool asyncOp) { \
auto output_vec = outputs.vec(); \
auto input_vec = inputs.vec(); \
auto opts = AllgatherOptions{}; \
opts.asyncOp = asyncOp; \
return process_group->getBackend(c10::DeviceType::DEV) \
->allgather_into_tensor_coalesced(output_vec, input_vec); \
->allgather_into_tensor_coalesced(output_vec, input_vec, opts); \
}
IMPL_ALLGATHER_INTO_TENSOR_COALESCED(CPU)
@ -283,6 +296,7 @@ IMPL_ALLGATHER_INTO_TENSOR_COALESCED(PrivateUse1)
const std::vector<std::vector<at::Tensor>>& input_tensors, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \
const c10::intrusive_ptr<ReduceOp>& reduce_op, \
bool asyncOp, \
int64_t timeout) { \
auto output_tensors_vec = output_tensors.vec(); \
auto work = \
@ -290,7 +304,9 @@ IMPL_ALLGATHER_INTO_TENSOR_COALESCED(PrivateUse1)
output_tensors_vec, \
const_cast<std::vector<std::vector<at::Tensor>>&>(input_tensors), \
ReduceScatterOptions{ \
*reduce_op.get(), std::chrono::milliseconds(timeout)}); \
*reduce_op.get(), \
std::chrono::milliseconds(timeout), \
asyncOp}); \
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( \
output_tensors_vec, work); \
}
@ -329,6 +345,7 @@ IMPL__REDUCE_SCATTER_BASE(PrivateUse1)
at::TensorList inputs, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \
const c10::intrusive_ptr<ReduceOp>& reduce_op, \
bool asyncOp, \
int64_t timeout) { \
auto output_vec = outputs.vec(); \
auto input_vec = inputs.vec(); \
@ -337,7 +354,9 @@ IMPL__REDUCE_SCATTER_BASE(PrivateUse1)
output_vec, \
input_vec, \
ReduceScatterOptions{ \
*reduce_op.get(), std::chrono::milliseconds(timeout)}); \
*reduce_op.get(), \
std::chrono::milliseconds(timeout), \
asyncOp}); \
}
IMPL_REDUCE_SCATTER_TENSOR_COALESCED(CPU)
@ -350,13 +369,15 @@ IMPL_REDUCE_SCATTER_TENSOR_COALESCED(PrivateUse1)
const at::TensorList& input_tensors, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \
int64_t root_rank, \
bool asyncOp, \
int64_t timeout) { \
auto input_tensors_vec = input_tensors.vec(); \
return process_group->getBackend(c10::DeviceType::DEV) \
->gather( \
const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors), \
input_tensors_vec, \
GatherOptions{root_rank, std::chrono::milliseconds(timeout)}); \
GatherOptions{ \
root_rank, std::chrono::milliseconds(timeout), asyncOp}); \
}
IMPL_GATHER(CPU)
@ -391,13 +412,14 @@ IMPL_SCATTER(PrivateUse1)
const at::TensorList& output_tensors, \
const at::TensorList& input_tensors, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \
bool asyncOp, \
int64_t timeout) { \
auto output_tensors_vec = output_tensors.vec(); \
auto input_tensors_vec = input_tensors.vec(); \
auto work = process_group->getBackend(c10::DeviceType::DEV) -> alltoall( \
output_tensors_vec, \
input_tensors_vec, \
AllToAllOptions{std::chrono::milliseconds(timeout)}); \
AllToAllOptions{std::chrono::milliseconds(timeout), asyncOp}); \
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( \
std::move(output_tensors_vec), work); \
}
@ -406,21 +428,22 @@ IMPL_ALLTOALL(CPU)
IMPL_ALLTOALL(CUDA)
IMPL_ALLTOALL(PrivateUse1)
#define IMPL_ALLTOALL_BASE(DEV) \
c10::intrusive_ptr<Work> alltoall_base_##DEV( \
at::Tensor& output, \
at::Tensor& input, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \
std::vector<int64_t> output_split_sizes, \
std::vector<int64_t> input_split_sizes, \
int64_t timeout) { \
return process_group->getBackend(c10::DeviceType::DEV) \
->alltoall_base( \
output, \
input, \
output_split_sizes, \
input_split_sizes, \
AllToAllOptions{std::chrono::milliseconds(timeout)}); \
#define IMPL_ALLTOALL_BASE(DEV) \
c10::intrusive_ptr<Work> alltoall_base_##DEV( \
at::Tensor& output, \
at::Tensor& input, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \
std::vector<int64_t> output_split_sizes, \
std::vector<int64_t> input_split_sizes, \
bool asyncOp, \
int64_t timeout) { \
return process_group->getBackend(c10::DeviceType::DEV) \
->alltoall_base( \
output, \
input, \
output_split_sizes, \
input_split_sizes, \
AllToAllOptions{std::chrono::milliseconds(timeout), asyncOp}); \
}
IMPL_ALLTOALL_BASE(CPU)
@ -428,15 +451,18 @@ IMPL_ALLTOALL_BASE(CUDA)
IMPL_ALLTOALL_BASE(PrivateUse1)
// NOLINTBEGIN(performance-unnecessary-value-param)
#define IMPL_BARRIER(DEV) \
c10::intrusive_ptr<Work> barrier##DEV( \
at::Tensor /* unused */, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \
const std::vector<int64_t>& device_ids, \
int64_t timeout) { \
return process_group->getBackend(c10::DeviceType::DEV) \
->barrier( \
BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}); \
#define IMPL_BARRIER(DEV) \
c10::intrusive_ptr<Work> barrier##DEV( \
at::Tensor /* unused */, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \
const std::vector<int64_t>& device_ids, \
bool asyncOp, \
int64_t timeout) { \
auto opts = BarrierOptions{}; \
opts.device_ids = device_ids; \
opts.timeout = std::chrono::milliseconds(timeout); \
opts.asyncOp = asyncOp; \
return process_group->getBackend(c10::DeviceType::DEV)->barrier(opts); \
}
IMPL_BARRIER(CPU)
@ -464,6 +490,7 @@ allreduce_sparse_cuda_(
const c10::intrusive_ptr<ProcessGroup>& process_group,
const c10::intrusive_ptr<ReduceOp>& reduce_op,
const std::optional<at::Tensor>& sparse_indices,
bool asyncOp,
int64_t timeout) {
auto tensor_vec = tensors.vec();
auto work = process_group->getBackend(c10::DeviceType::CUDA)
@ -472,6 +499,7 @@ allreduce_sparse_cuda_(
AllreduceOptions{
*reduce_op,
std::chrono::milliseconds(timeout),
asyncOp,
sparse_indices});
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(

View File

@ -224,6 +224,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
const std::optional<at::Tensor>& sparse_indices,
bool,
int64_t)>();
auto work = std::get<1>(op.call(
@ -231,6 +232,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.sparseIndices,
opts.asyncOp,
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -250,12 +252,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
bool,
int64_t)>();
auto work = op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.asyncOp,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -277,6 +281,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ReduceOp>&,
int64_t,
int64_t,
bool,
int64_t)>();
auto work = op.call(
tensors,
@ -284,6 +289,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.rootRank,
opts.rootTensor,
opts.asyncOp,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -306,12 +312,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const std::vector<std::vector<at::Tensor>>&,
at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
bool,
int64_t)>();
auto work = std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.asyncOp,
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -363,18 +371,19 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
std::vector<std::vector<at::Tensor>>& outputTensorLists,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::allgather_coalesced_", "")
.typed<c10::intrusive_ptr<Work>(
const std::vector<std::vector<at::Tensor>>&,
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::allgather_coalesced_", "")
.typed<c10::intrusive_ptr<Work>(
const std::vector<std::vector<at::Tensor>>&,
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
bool)>();
auto work = op.call(
outputTensorLists,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.asyncOp);
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor_list : outputTensorLists) {
@ -399,12 +408,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
.typed<c10::intrusive_ptr<Work>(
const at::TensorList,
const at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
bool)>();
auto work = op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.asyncOp);
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : outputTensors) {
@ -425,12 +436,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t,
bool,
int64_t)>();
auto work = op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.rootRank,
opts.asyncOp,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -487,12 +500,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const std::vector<std::vector<at::Tensor>>&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
bool,
int64_t)>();
auto work = std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
opts.asyncOp,
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -546,6 +561,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
bool,
int64_t)>();
auto work = op.call(
@ -553,6 +569,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
opts.asyncOp,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -577,6 +594,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
std::vector<int64_t>,
std::vector<int64_t>,
bool,
int64_t)>();
auto work = op.call(
outputBuffer,
@ -584,6 +602,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
outputSplitSizes,
inputSplitSizes,
opts.asyncOp,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -604,11 +623,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList&,
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
bool,
int64_t)>();
auto work = std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.asyncOp,
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -778,12 +799,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
at::Tensor,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const std::vector<int64_t>&,
bool,
int64_t)>();
auto work = op.call(
tensor,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.device_ids,
opts.asyncOp,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
c10d::register_work(tensor, work);

View File

@ -440,6 +440,36 @@ std::ostream& operator<<(
return output << workInfo;
}
/* Implementation of TensorShelf class */
void TensorShelf::stash(std::vector<at::Tensor>& tensors) {
std::lock_guard<std::mutex> lock(mutex_);
tVector_.insert(tVector_.end(), tensors.begin(), tensors.end());
}
void TensorShelf::stash(TensorShelf& other) {
std::vector<at::Tensor>& otherVec = other.get();
this->stash(otherVec);
}
void TensorShelf::unstash() {
this->clear();
}
bool TensorShelf::empty() {
std::lock_guard<std::mutex> lock(mutex_);
return tVector_.empty();
}
void TensorShelf::clear() {
std::lock_guard<std::mutex> lock(mutex_);
tVector_.clear();
}
std::vector<at::Tensor>& TensorShelf::get() {
return tVector_;
}
ProcessGroupNCCL::WorkNCCL::WorkNCCL(
std::string pgUID,
std::string pgDesc,
@ -482,6 +512,8 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(
}
futureWorkResult_ =
c10::make_intrusive<at::ivalue::Future>(c10::AnyEnumType::get());
// other functions expect an initialized ptr
stashed_for_allocator_safety_ = std::make_shared<TensorShelf>();
}
ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w)
@ -503,6 +535,11 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w)
numelIn_(w.numelIn_),
numelOut_(w.numelOut_),
store_(w.store_),
// Note: the `work` returned to user and the `work` enqueued to watchdog
// share the pointer to the tensor stash. At least one of them should
// clean the tensor stash, the earlier the better, i.e. user calling
// `work.wait` than watchdog detecting work completion.
stashed_for_allocator_safety_(w.stashed_for_allocator_safety_),
futureWorkResult_(w.futureWorkResult_),
timingEnabled_(w.timingEnabled_),
trace_id_(w.trace_id_),
@ -700,10 +737,9 @@ void ProcessGroupNCCL::WorkNCCL::synchronizeStream() {
auto currentStream = at::cuda::getCurrentCUDAStream(device_.index());
// Block the current stream on the NCCL stream
ncclEndEvent_->block(currentStream);
if (avoidRecordStreams_) {
stashed_for_allocator_safety_->clear();
}
// Unstage the stashed tensors so that CachingAllocator can recycle them
// THIS MUST HAPPEN AFTER THE BLOCKING CALL ABOVE
stashed_for_allocator_safety_->unstash();
}
// Same as calling synchronize() when blockingWait_ is false
@ -919,7 +955,10 @@ ProcessGroupNCCL::ProcessGroupNCCL(
enableTiming_.store(
getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug_);
#endif // ENABLE_NCCL_ERROR_CHECKING
avoidRecordStreams_ = getCvarBool(TORCH_NCCL_AVOID_RECORD_STREAMS, false);
if (getCvarBool(TORCH_NCCL_AVOID_RECORD_STREAMS, false)) {
TORCH_WARN_ONCE(
"TORCH_NCCL_AVOID_RECORD_STREAMS is the default now, this environment variable is thus deprecated.");
}
#ifdef NCCL_HAS_COMM_REGISTER
useTensorRegisterAllocatorHook_ =
getCvarBool(TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK, false);
@ -2309,6 +2348,23 @@ void ProcessGroupNCCL::watchdogHandler() {
// Clean up completed work
if (work.isCompleted()) {
// In case user didn't call `work.wait()` with async collectives,
// watchdog would unstage the stashed tensors when detecting completion
// of the collective, to prevent ProcessGroupNCCL from holding reference
// to those tensors forever.
// work.stashed_for_allocator_safety_->unstash();
// Update: it seems directly unstashing from watchdog thread would cause
// some rare problems. We thus move the unstashing to main thread,
// triggered by a next user call, see `workEnqueue`. But `work` is going
// to be destructed, so we transfer the work's shelf to a shelves
// structure owned by the PG.
if (!work.stashed_for_allocator_safety_->empty()) {
std::lock_guard<std::mutex> lock(shelvesMutex_);
// We are just pushing back a shared_ptr here, so the cost should be
// minimal
shelvesToUnstash_.push_back(work.stashed_for_allocator_safety_);
}
// Work status logging for desync debug
desyncDebugger_.logWorkEnd(work);
@ -3043,6 +3099,7 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
enableTiming_.load(),
cudaEventCacheEnabled_.load(),
dist_debug_level_);
if (record) {
bool isP2P = isP2POp(opType);
// Ideally record every work that we enqueue, rather than every work we
@ -3122,6 +3179,17 @@ void ProcessGroupNCCL::assignTimeoutToWork(
void ProcessGroupNCCL::workEnqueue(
const c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
// We clean up the TensorShelf's in case user hasn't called `work.wait()`.
// This has nothing to do with new work enqueue. We are just using a place
// that would be triggered by a next user call.
{
std::lock_guard<std::mutex> lock(shelvesMutex_);
for (auto& shelf : shelvesToUnstash_) {
shelf->unstash();
}
shelvesToUnstash_.clear();
}
// in blockingWait_ mode, we don't need watchdog thread, so no need to enqueue
// the work
if (!terminateProcessGroup_.load() && !blockingWait_) {
@ -3158,6 +3226,7 @@ void ProcessGroupNCCL::startCoalescing() {
coalescedDevice_.set_index(-1);
coalescedComm_ = nullptr;
coalescedTensors_.clear();
coalescing_state_ |= CoalActive;
groupStart();
}
@ -3200,10 +3269,12 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::endCoalescing(OpType optype) {
enqueue);
work->ncclComm_ = comm;
work->blockingWait_ = blockingWait_;
work->avoidRecordStreams_ = avoidRecordStreams_;
work->store_ = store_;
assignTimeoutToWork(work, options_);
// Hand over references to tensors during coalescing to work's stash
work->stashed_for_allocator_safety_->stash(coalescedTensors_);
// Record start before ncclGroupEnd
if (work->timingEnabled_) {
work->ncclStartEvent_->record(ncclStream);
@ -3219,19 +3290,17 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::endCoalescing(OpType optype) {
// TODO(eqy): is this still necessary if avoidRecordStreams_ is set?
work->ncclEndEvent_->record(ncclStream);
if (avoidRecordStreams_) {
// other functions expect an initialized ptr if avoidRecordStreams_ is set
work->stashed_for_allocator_safety_ =
std::make_shared<std::vector<at::Tensor>>();
}
if (enqueue) {
workEnqueue(work);
}
// Reset coalescing state
coalescing_state_ = 0;
coalescedComm_ = nullptr;
return work;
coalescedTensors_.clear();
// If in async mode, return work; otherwise, kernel is enqueued on current
// stream, no need to return work
return coalescedAsync_ ? work : nullptr;
}
c10::intrusive_ptr<Work> ProcessGroupNCCL::endCoalescing() {
@ -3264,11 +3333,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
PreProcess pre,
PostProcess post,
OpType opType,
bool asyncOp,
const char* profilingTitle,
bool avoidRecordStreams,
bool nanCheck) {
// Environment setting by the user may add onto collective call's option
avoidRecordStreams |= avoidRecordStreams_;
nanCheck &= enableNanCheck_;
auto device = getDevice(inputs[0]);
@ -3309,13 +3377,17 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
} else {
TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG);
}
coalescedAsync_ = asyncOp;
}
// Used many times below, so we stash the unordered_map lookup
auto ncclStream = ncclStreams_.at(key);
// First let NCCL streams wait for input tensors allocation streams
syncStream(device, ncclEvents_[key], ncclStream);
// in asyncOp=false [default] mode, we use currentStream as ncclStream
// otherwise, we use separate ncclStream and let it sync on currentStream
auto ncclStream = asyncOp ? ncclStreams_.at(key)
: at::cuda::getCurrentCUDAStream(device.index());
if (asyncOp) {
// First let NCCL streams wait for input tensors allocation streams
syncStream(device, ncclEvents_[key], ncclStream);
}
bool enqueue =
!coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None;
@ -3325,9 +3397,19 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
// Store references to outputs to be used by WorkNCCL::result and operator<<.
work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);
if (avoidRecordStreams) {
work->stashed_for_allocator_safety_ =
std::make_shared<std::vector<at::Tensor>>(inputs);
// If we are performing sync operations, i.e. equeuing kernel onto "current"
// stream, we don't need to do anything for tensor lifetime management.
// Otherwise, we need to stage the tensors will `work.wait()`.
if (asyncOp) {
// First select which shelf to stash onto: to `work` if single collective;
// to an inflight shelf if coalescing.
if (coalescing_state_) {
coalescedTensors_.stash(inputs);
coalescedTensors_.stash(outputs);
} else {
work->stashed_for_allocator_safety_->stash(inputs);
work->stashed_for_allocator_safety_->stash(outputs);
}
}
if (nanCheck) {
@ -3353,21 +3435,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
// operations where `inputs' and `outputs' are not the same.
//
// See [Sync Streams].
if (!avoidRecordStreams) {
for (const auto& input : inputs) {
if (!input.is_sparse()) {
c10::cuda::CUDACachingAllocator::recordStream(
input.storage().data_ptr(), ncclStream);
} else {
// for sparse input case record streams on both index and value
// tensors
c10::cuda::CUDACachingAllocator::recordStream(
input.values().storage().data_ptr(), ncclStream);
c10::cuda::CUDACachingAllocator::recordStream(
input.indices().storage().data_ptr(), ncclStream);
}
}
}
// Not all collectives have the same signature, e.g, all-reduce take in a Tensor
// as the input and output while all-to-all take in a vector of Tensors as input
@ -3419,7 +3486,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
// Set appropriate work parameters.
work->blockingWait_ = blockingWait_;
work->avoidRecordStreams_ = avoidRecordStreams;
work->store_ = store_;
assignTimeoutToWork(work, options_);
// Record size info for debug. We only record the size on the first device as
@ -3437,7 +3503,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
workEnqueue(work);
}
return work;
return asyncOp ? work : nullptr;
}
template <typename Fn>
@ -3446,11 +3512,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
std::vector<at::Tensor>& outputs,
Fn fn,
OpType opType,
const char* profilingTitle,
bool avoidRecordStreams) {
// Environment setting by the user may add onto collective call's option
avoidRecordStreams |= avoidRecordStreams_;
bool asyncOp,
const char* profilingTitle) {
// Currently, the API permits one scenario where inputs.size() and
// outputs.size() are > 0.
// 1. If the call was a _coalesced call, all inputs must be on the same
@ -3496,13 +3559,17 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
} else {
TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG);
}
coalescedAsync_ = asyncOp;
}
// Used many times below, so we stash the unordered_map lookup
auto ncclStream = ncclStreams_.at(key);
// First let NCCL streams wait for input tensors allocation streams
syncStream(device, ncclEvents_[key], ncclStream);
// in asyncOp=false [default] mode, we use currentStream as ncclStream
// otherwise, we use separate ncclStream and let it sync on currentStream
auto ncclStream = asyncOp ? ncclStreams_.at(key)
: at::cuda::getCurrentCUDAStream(device.index());
if (asyncOp) {
// First let NCCL streams wait for input tensors allocation streams
syncStream(device, ncclEvents_[key], ncclStream);
}
auto work = initWork(
device,
@ -3517,9 +3584,12 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
// Store references to outputs to be used by WorkNCCL::result and operator<<.
work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);
if (avoidRecordStreams) {
work->stashed_for_allocator_safety_ =
std::make_shared<std::vector<at::Tensor>>(inputs);
// If we are performing sync operations, i.e. equeuing kernel onto "current"
// stream, we don't need to do anything for tensor lifetime management.
// Otherwise, we need to stage the tensors will `work.wait()`.
if (asyncOp) {
work->stashed_for_allocator_safety_->stash(inputs);
work->stashed_for_allocator_safety_->stash(outputs);
}
// Start event should only be recorded before the ncclGroupStart() (which
@ -3545,27 +3615,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
{
torch::cuda::nccl::AutoNcclGroup nccl_group_guard(comm, useNonblocking());
for (const auto i : c10::irange(inputs.size())) {
// Both `inputs' and `outputs' are created on a worker stream and used in
// different ncclStreams. Hence, both must record the ncclStream to
// prevent being freed before the collective finishes.
//
// We only record `inputs' here, and leave recording `outputs' to `fn' for
// operations where `inputs' and `outputs' are not the same.
//
// See [Sync Streams].
if (!avoidRecordStreams) {
if (!inputs[i].is_sparse()) {
c10::cuda::CUDACachingAllocator::recordStream(
inputs[i].storage().data_ptr(), ncclStream);
} else {
// for sparse input case record streams on both index and value
// tensors
c10::cuda::CUDACachingAllocator::recordStream(
inputs[i].values().storage().data_ptr(), ncclStream);
c10::cuda::CUDACachingAllocator::recordStream(
inputs[i].indices().storage().data_ptr(), ncclStream);
}
}
#ifndef NCCL_HAS_COMM_NONBLOCKING
C10D_NCCL_CHECK(
fn(inputs[i], outputs[i], comm, ncclStream),
@ -3606,7 +3655,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
// Set appropriate work parameters.
work->blockingWait_ = blockingWait_;
work->avoidRecordStreams_ = avoidRecordStreams;
work->store_ = store_;
assignTimeoutToWork(work, options_);
// Record size info for debug. We only record the size on the first device as
@ -3637,7 +3685,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
// it, since interactions with it by usercode won't behave normally - they
// won't observe work completion, for instance. Will this lead to silent
// problems during capture?
return work;
return asyncOp ? work : nullptr;
}
template <typename Fn, typename PreProcess, typename PostProcess>
@ -3655,13 +3703,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
// to wait() on the returned handle, so ProcessGroupNCCL can't know
// when it's safe to release the input back to the allocator,
// and the present call has no way to know it's not an isend.
// Therefore, we warn and fall back to the typical recordStream logic:
if (avoidRecordStreams_) {
TORCH_WARN_ONCE(
"TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point "
"collectives.");
}
// Therefore, we warn and fall back to the typical recordStream logic.
// TODO( kwen2501 ): revisit this when we have a better solution.
auto device = getDevice(tensor);
at::cuda::OptionalCUDAGuard gpuGuard(device);
@ -3716,6 +3759,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
} else {
TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG);
}
// For now, P2P ops are always put on internal stream
coalescedAsync_ = true;
}
// Used many times below, so we stash the unordered_map lookup
@ -3887,8 +3932,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
PreProcess pre,
PostProcess post,
OpType opType,
bool asyncOp,
const char* profilingTitle,
bool avoidRecordStreams,
bool nanCheck) {
auto inputs = std::vector<at::Tensor>{input};
auto outputs = std::vector<at::Tensor>{output};
@ -3899,8 +3944,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
pre,
post,
opType,
asyncOp,
profilingTitle,
avoidRecordStreams,
nanCheck);
}
@ -3910,8 +3955,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
at::Tensor& output,
Fn fn,
OpType opType,
bool asyncOp,
const char* profilingTitle,
bool avoidRecordStreams,
bool nanCheck) {
auto inputs = std::vector<at::Tensor>{input};
auto outputs = std::vector<at::Tensor>{output};
@ -3924,8 +3969,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
[](at::cuda::CUDAStream&,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
opType,
asyncOp,
profilingTitle,
avoidRecordStreams,
nanCheck);
}
@ -3977,6 +4022,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_sparse(
auto recvIndices = indices[0] * colSize;
// prevent output and recvIndices from being freed
// TODO: not changing the lifetime management of outputs this time,
// revisit later
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
c10::cuda::CUDACachingAllocator::recordStream(
@ -4008,6 +4055,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_sparse(
}
},
OpType::_ALLREDUCE_SPARSE,
opts.asyncOp,
"nccl:all_reduce_sparse");
return work;
#else
@ -4042,6 +4090,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_impl(
stream.stream());
},
OpType::ALLREDUCE,
opts.asyncOp,
profilingTitle);
}
@ -4142,6 +4191,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_coalesced(
stream.stream());
},
OpType::COALESCED,
opts.asyncOp,
"nccl:allreduce_coalesced");
}
@ -4173,12 +4223,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::broadcast(
globalRankStride_, // globalRankStride_
this->getSize()); // worldSize
// avoidRecordStreams_ note: collective() will stash tensors.
bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp);
const auto root = opts.rootRank + opts.rootTensor;
bool nanCheck = (root == rank_);
// avoidRecordStreams_ note: collective() will stash tensors.
return collective(
tensor,
tensor,
@ -4195,8 +4243,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::broadcast(
stream.stream());
},
OpType::BROADCAST,
opts.asyncOp,
"nccl:broadcast",
avoidRecordStreams,
nanCheck);
}
@ -4235,8 +4283,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_broadcast_oop(
stream.stream());
},
OpType::BROADCAST,
opts.asyncOp,
"nccl:_broadcast_oop",
/*avoidRecordStreams=*/false,
nanCheck);
}
@ -4295,6 +4343,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce(
stream.stream());
},
OpType::REDUCE,
opts.asyncOp,
"nccl:reduce");
}
@ -4336,6 +4385,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_oop(
stream.stream());
},
OpType::REDUCE,
opts.asyncOp,
"nccl:_reduce_oop");
}
@ -4379,10 +4429,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
at::Tensor& output,
ncclComm_t comm,
at::cuda::CUDAStream& stream) {
if (!avoidRecordStreams_) {
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
}
// See [We actually don't need to stash anything here].
return ncclAllGather(
input.data_ptr(),
output.data_ptr(),
@ -4398,27 +4445,27 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
// - inputTensors is stashed onto work->stashed_for_allocator_safety_
// in collective().
// - outputFlattened is stashed onto work->outputs_ in collective().
// - User-facing outputTensors should be held by the user until after
// waiting on work_, or the call makes no sense.
// So all participating tensors are accounted for, and won't be
// released back to their allocation streams until after work_ is
// waited on.
},
[&](at::cuda::CUDAStream& ncclStream,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
// User-facing outputTensors should be held by the user until after
// waiting on work_, or the call makes no sense. We do a stashing here
// in case user doesn't hold the outputTensors in downstream code,
// which can cause an early recyle by the CachingAllocator, which can
// lead to segfault or data corruption.
if (opts.asyncOp) {
work->stashed_for_allocator_safety_->stash(outputTensors_);
}
// Copy the flattened output tensors to the outputs.
at::cuda::CUDAStreamGuard guard(ncclStream);
for (const auto j : c10::irange(outputTensors_.size())) {
// See [Sync Streams].
if (!avoidRecordStreams_) {
c10::cuda::CUDACachingAllocator::recordStream(
outputTensors_[j].storage().data_ptr(), ncclStream);
}
// See [We actually don't need to stash anything here].
outputTensors_[j].copy_(
outputFlattened[static_cast<int64_t>(j)], true);
}
},
OpType::ALLGATHER,
opts.asyncOp,
"nccl:all_gather");
} else {
const auto num_reduces = outputTensors_.size();
@ -4426,7 +4473,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
for (const int64_t i : c10::irange(static_cast<int64_t>(num_reduces))) {
auto& output = outputTensors_[i];
auto& input = (i == rank_) ? inputTensor : output;
auto broadcastOpts = BroadcastOptions{i, int64_t(0), opts.timeout};
auto broadcastOpts =
BroadcastOptions{i, int64_t(0), opts.timeout, opts.asyncOp};
_broadcast_oop(output, input, broadcastOpts);
}
auto work = endCoalescing(OpType::ALLGATHER);
@ -4482,6 +4530,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather_into_tensor_coalesced(
stream.stream());
},
OpType::COALESCED,
opts.asyncOp,
"nccl:all_gather_into_tensor_coalesced");
}
@ -4527,10 +4576,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
at::Tensor& output,
ncclComm_t comm,
at::cuda::CUDAStream& stream) {
if (!avoidRecordStreams_) {
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
}
const auto ncclDataType = getNcclDataType(input.scalar_type());
const auto ncclReduceOp =
getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm);
@ -4545,27 +4590,18 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
},
[&](at::cuda::CUDAStream& ncclStream,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
if (avoidRecordStreams_) {
// We only need to stash inputTensors.
// - inputFlattened is stashed onto
// work->stashed_for_allocator_safety_
// in collective().
// - User-facing outputTensors is stashed onto work->outputs_ in
// collective(),
// and should also be held by the user until after waiting on
// work_.
auto& v = work->stashed_for_allocator_safety_;
v->insert(v->end(), inputTensors_.begin(), inputTensors_.end());
// We only need to stash inputTensors.
// - inputFlattened is stashed onto
// work->stashed_for_allocator_safety_ in collective().
// - User-facing outputTensors is stashed onto work->outputs_ in
// collective(), and should also be held by the user until after
// waiting on work_.
if (opts.asyncOp) {
work->stashed_for_allocator_safety_->stash(inputTensors_);
}
// Copy the input tensors to the flattened inputs.
at::cuda::CUDAStreamGuard guard(ncclStream);
for (const auto j : c10::irange(inputTensors_.size())) {
// See [Sync Streams].
if (!avoidRecordStreams_) {
c10::cuda::CUDACachingAllocator::recordStream(
inputTensors_[j].storage().data_ptr(), ncclStream);
}
inputFlattened[static_cast<int64_t>(j)].copy_(
inputTensors_[j], true);
}
@ -4573,6 +4609,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
[&](at::cuda::CUDAStream&,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
OpType::REDUCE_SCATTER,
opts.asyncOp,
"nccl:reduce_scatter");
} else {
const auto num_reduces = inputTensors_.size();
@ -4584,7 +4621,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
opts.reduceOp,
static_cast<int64_t>(i),
static_cast<int64_t>(0),
opts.timeout};
opts.timeout,
opts.asyncOp};
_reduce_oop(output, input, reduceOpts);
}
auto work = endCoalescing(OpType::REDUCE_SCATTER);
@ -4638,7 +4676,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_scatter_base(
// stream so that the caching allocator can reuse memory pool for this stream
// in a clever way. This setting is added for libraries like FSDP which uses
// `reduce_scatter_tensor`.
bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp);
return collective(
inputTensor,
@ -4647,10 +4684,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_scatter_base(
at::Tensor& output,
ncclComm_t comm,
at::cuda::CUDAStream& stream) {
if (!avoidRecordStreams) {
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
}
auto ncclDataType = getNcclDataType(input.scalar_type());
auto ncclReduceOp =
getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm);
@ -4664,8 +4697,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_scatter_base(
stream.stream());
},
OpType::_REDUCE_SCATTER_BASE,
"nccl:_reduce_scatter_base",
avoidRecordStreams);
opts.asyncOp,
"nccl:_reduce_scatter_base");
}
c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter_tensor_coalesced(
@ -4702,10 +4735,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter_tensor_coalesced(
at::Tensor& output,
ncclComm_t comm,
at::cuda::CUDAStream& stream) {
if (!avoidRecordStreams_) {
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
}
auto ncclDataType = getNcclDataType(input.scalar_type());
auto ncclReduceOp =
getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm);
@ -4719,6 +4748,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter_tensor_coalesced(
stream.stream());
},
OpType::COALESCED,
opts.asyncOp,
"nccl:reduce_scatter_tensor_coalesced");
}
@ -4797,13 +4827,28 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::barrier(const BarrierOptions& opts) {
at::zeros({1}, at::TensorOptions().device(barDevice).dtype(at::kFloat));
// All reduce to achieve the barrier
auto work = allreduce_impl(barrierTensor, "nccl:all_reduce_barrier");
AllreduceOptions arOpts = AllreduceOptions();
arOpts.asyncOp = opts.asyncOp;
auto work = allreduce_impl(barrierTensor, "nccl:all_reduce_barrier", arOpts);
// Work will take over barrierTensors
auto ncclWork = dynamic_cast<ProcessGroupNCCL::WorkNCCL*>(work.get());
TORCH_CHECK(ncclWork);
ncclWork->isBarrierOp_ = true;
return work;
if (opts.asyncOp) {
// Work will take over barrierTensors
auto ncclWork = dynamic_cast<ProcessGroupNCCL::WorkNCCL*>(work.get());
// If user specified async, the work should not be nullptr
TORCH_CHECK(ncclWork);
// Put a marker here so that `work.wait()` issue by users does
// barrier-specific thing: CPU sync
ncclWork->isBarrierOp_ = true;
return work;
}
// Otherwise, we are in sync mode, we directly wait here.
// (It is a CPU wait for barrier)
auto currentStream = at::cuda::getCurrentCUDAStream(barDevIdx);
// CUDAStream wrapper will correctly use a DeviceGuard here
currentStream.synchronize();
// No work to return
return nullptr;
}
c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
@ -4811,7 +4856,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
at::Tensor& inputTensor,
std::vector<int64_t>& outputSplitSizes,
std::vector<int64_t>& inputSplitSizes,
const AllToAllOptions& /* unused */) {
const AllToAllOptions& opts) {
check_gpu_single_tensor(outputTensor);
check_gpu_single_tensor(inputTensor);
if (outputSplitSizes.empty() && inputSplitSizes.empty()) {
@ -4842,16 +4887,12 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
at::Tensor& output,
ncclComm_t comm,
at::cuda::CUDAStream& stream) {
// See [Sync Streams].
if (!avoidRecordStreams_) {
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
}
torch::cuda::nccl::all2all_single_equal_split(
input, output, this->getSize(), comm, stream);
return ncclSuccess;
},
OpType::ALLTOALL_BASE,
opts.asyncOp,
"nccl:all_to_all");
} else {
c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_);
@ -4893,10 +4934,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
c10d::computeLengthsAndOffsets(
outputSplitSizes, output, &recv_lengths, &recv_offsets);
// See [Sync Streams].
if (!avoidRecordStreams_) {
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
}
torch::cuda::nccl::all2all_single_unequal_split(
input.data_ptr(),
send_lengths.data(),
@ -4911,6 +4948,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
return ncclSuccess;
},
OpType::ALLTOALL_BASE,
opts.asyncOp,
"nccl:all_to_all");
}
}
@ -4918,7 +4956,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllToAllOptions& /* unused */) {
const AllToAllOptions& opts) {
int64_t input_total_numel = 0;
int64_t output_total_numel = 0;
@ -4963,18 +5001,11 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall(
return ncclSuccess;
},
[&](at::cuda::CUDAStream&,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
if (avoidRecordStreams_) {
// inputTensor0 and outputTensor0 are stashed redundantly by
// collective(), but that's ok.
auto& v = work->stashed_for_allocator_safety_;
v->insert(v->end(), inputTensors.begin(), inputTensors.end());
v->insert(v->end(), outputTensors.begin(), outputTensors.end());
}
},
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
[](at::cuda::CUDAStream&,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
OpType::ALLTOALL,
opts.asyncOp,
"nccl:all_to_all");
}
@ -5172,14 +5203,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::gather(
ncclComm_t comm,
at::cuda::CUDAStream& stream) {
const auto root = opts.rootRank;
if (getRank() == root) {
if (!avoidRecordStreams_) {
for (auto const& output : outputs) {
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
}
}
}
torch::cuda::nccl::gather(
inputTensor, outputs, comm, stream, static_cast<int32_t>(root));
return ncclSuccess;
@ -5189,6 +5212,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::gather(
[](at::cuda::CUDAStream&,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
OpType::GATHER,
opts.asyncOp,
"nccl:gather");
}
@ -5257,8 +5281,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::scatter(
// avoidRecordStreams_ note: collective() will stash outputTensors and
// inputs, which == inputTensors[0] on the root rank where it matters.
bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp);
const auto root = opts.rootRank;
bool nanCheck = (rank_ == root);
@ -5270,14 +5292,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::scatter(
at::Tensor& /* unused */,
ncclComm_t comm,
at::cuda::CUDAStream& stream) {
if (getRank() == root) {
if (!avoidRecordStreams) {
for (auto const& input : inputs) {
c10::cuda::CUDACachingAllocator::recordStream(
input.storage().data_ptr(), stream);
}
}
}
torch::cuda::nccl::scatter(
inputs, outputTensor, comm, stream, static_cast<int32_t>(root));
return ncclSuccess;
@ -5287,8 +5301,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::scatter(
[](at::cuda::CUDAStream&,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
OpType::SCATTER,
opts.asyncOp,
"nccl:scatter",
avoidRecordStreams,
nanCheck);
}
@ -5344,7 +5358,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_allgather_base(
// stream so that the caching allocator can reuse memory pool for this stream
// in a clever way. This setting is added for libraries like FSDP which uses
// `all_gather_into_tensor`.
bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp);
return collective(
input_tensor,
@ -5353,10 +5366,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_allgather_base(
at::Tensor& output,
ncclComm_t comm,
at::cuda::CUDAStream& stream) {
if (!avoidRecordStreams) {
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
}
return ncclAllGather(
input.data_ptr(),
output.data_ptr(),
@ -5366,8 +5375,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_allgather_base(
stream.stream());
},
OpType::_ALLGATHER_BASE,
"nccl:_all_gather_base",
avoidRecordStreams);
opts.asyncOp,
"nccl:_all_gather_base");
}
// Create a memory allocator for NCCL. This allocator is used to allocate memory

View File

@ -235,6 +235,34 @@ struct DumpPipe {
};
#endif
// A shelf for stashing tensors between op call and `work.wait()`.
// Used in case of async ops.
class TensorShelf {
public:
// Stash tensors so that CachingAllocator cannot recycle them prematurely.
void stash(std::vector<at::Tensor>& tensors);
// Stash tensors from another shelf.
void stash(TensorShelf& other);
// Unstage the stashed tensors so that CachingAllocator can recycle them.
// Same as `clear()`.
void unstash();
// Whether shelf is empty.
bool empty();
// Clear the shelf.
void clear();
protected:
// Get the inner tensor vector. Use with caution as it is not protected by
// mutex.
std::vector<at::Tensor>& get();
private:
std::vector<at::Tensor> tVector_;
// Need a mutex to protect `tVector_` because it can be potentially accessed
// from both main thread and watchdog thread.
std::mutex mutex_;
};
// ProcessGroupNCCL implements NCCL bindings for c10d.
//
// All functions of the class are expected to be called in the same order
@ -382,9 +410,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// Clone of blockingWait_ from ProcessGroupNCCL.
bool blockingWait_{false};
// Clone of avoidRecordStreams_ from ProcessGroupNCCL.
bool avoidRecordStreams_{false};
// Clone of opTimeout_ from ProcessGroupNCCL.
std::chrono::milliseconds opTimeout_{};
@ -448,7 +473,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// caching allocator safety without any recordStream calls.
// For in-place collectives, some refs stashed here may alias outputs_,
// but that doesn't do any harm.
std::shared_ptr<std::vector<at::Tensor>> stashed_for_allocator_safety_;
std::shared_ptr<TensorShelf> stashed_for_allocator_safety_;
// The future returned by getFuture.
c10::intrusive_ptr<at::ivalue::Future> future_;
@ -889,8 +914,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
at::Tensor& output,
Fn fn,
OpType opType,
bool asyncOp,
const char* profilingTitle = nullptr,
bool avoidRecordStreams = false,
bool nanCheck = true);
template <typename Fn, typename PreProcess, typename PostProcess>
@ -901,8 +926,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
PreProcess pre,
PostProcess post,
OpType opType,
bool asyncOp,
const char* profilingTitle = nullptr,
bool avoidRecordStreams = false,
bool nanCheck = true);
template <typename Fn, typename PreProcess, typename PostProcess>
@ -913,8 +938,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
PreProcess pre,
PostProcess post,
OpType opType,
bool asyncOp,
const char* profilingTitle = nullptr,
bool avoidRecordStreams = false,
bool nanCheck = true);
template <typename Fn>
@ -923,8 +948,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
std::vector<at::Tensor>& output,
Fn fn,
OpType opType,
const char* profilingTitle = nullptr,
bool avoidRecordStreams = false);
bool asyncOp,
const char* profilingTitle = nullptr);
// Helper that encapsulates work shared across point-to-point communication
// primitives. It is the same structure as the helper used for collective
@ -1233,6 +1258,22 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// Stores communicators for all collectives run inside a coalescing block
std::shared_ptr<NCCLComm> coalescedComm_ = nullptr;
// Whether the coalesced calls are sync or async.
bool coalescedAsync_;
// keeps track of input and output tensors when coalescing is in flight. Will
// hand over these tensors to WorkNCCL's stash when coalescing is ended.
TensorShelf coalescedTensors_;
// Some ops may have completed, but user still hasn't called `work.wait()`.
// When watchdog detects this, it transfers the TensorShelf from `work` to
// this `shelves` structure. Next time we execute ProcessGroupNCCL's methods
// on main thread, we clear the `shelves` in one shot. This is mainly because
// watchdog (a side thread) unstashing the shelf directly seems to cause some
// problem.
std::vector<std::shared_ptr<TensorShelf>> shelvesToUnstash_;
std::mutex shelvesMutex_;
// Whether or not wait() and synchronize() are blocking operations that wait
// for the operation to complete.
bool blockingWait_ = false;

View File

@ -122,6 +122,7 @@ struct BroadcastOptions {
struct AllreduceOptions {
ReduceOp reduceOp = ReduceOp::SUM;
std::chrono::milliseconds timeout = kUnsetTimeout;
bool asyncOp = true;
std::optional<at::Tensor> sparseIndices = std::nullopt;
};
@ -132,6 +133,7 @@ struct ReduceOptions {
int64_t rootRank = 0;
int64_t rootTensor = 0;
std::chrono::milliseconds timeout = kUnsetTimeout;
bool asyncOp = true;
};
struct AllgatherOptions {
@ -142,6 +144,7 @@ struct AllgatherOptions {
struct GatherOptions {
int64_t rootRank = 0;
std::chrono::milliseconds timeout = kUnsetTimeout;
bool asyncOp = true;
};
struct ScatterOptions {
@ -158,12 +161,14 @@ struct ReduceScatterOptions {
struct AllToAllOptions {
std::chrono::milliseconds timeout = kUnsetTimeout;
bool asyncOp = true;
};
struct BarrierOptions {
std::vector<int64_t> device_ids;
std::chrono::milliseconds timeout = kUnsetTimeout;
std::optional<at::Device> device;
bool asyncOp = true;
};
struct DistributedBackendOptions {

View File

@ -999,20 +999,23 @@ This class does not support ``__members__`` property.)");
py::class_<::c10d::AllreduceOptions>(module, "AllreduceOptions")
.def(py::init<>())
.def_readwrite("reduceOp", &::c10d::AllreduceOptions::reduceOp)
.def_readwrite("timeout", &::c10d::AllreduceOptions::timeout);
.def_readwrite("timeout", &::c10d::AllreduceOptions::timeout)
.def_readwrite("asyncOp", &::c10d::AllreduceOptions::asyncOp);
py::class_<::c10d::AllreduceCoalescedOptions>(
module, "AllreduceCoalescedOptions")
.def(py::init<>())
.def_readwrite("reduceOp", &::c10d::AllreduceCoalescedOptions::reduceOp)
.def_readwrite("timeout", &::c10d::AllreduceCoalescedOptions::timeout);
.def_readwrite("timeout", &::c10d::AllreduceCoalescedOptions::timeout)
.def_readwrite("asyncOp", &::c10d::AllreduceCoalescedOptions::asyncOp);
py::class_<::c10d::ReduceOptions>(module, "ReduceOptions")
.def(py::init<>())
.def_readwrite("reduceOp", &::c10d::ReduceOptions::reduceOp)
.def_readwrite("rootRank", &::c10d::ReduceOptions::rootRank)
.def_readwrite("rootTensor", &::c10d::ReduceOptions::rootTensor)
.def_readwrite("timeout", &::c10d::ReduceOptions::timeout);
.def_readwrite("timeout", &::c10d::ReduceOptions::timeout)
.def_readwrite("asyncOp", &::c10d::ReduceOptions::asyncOp);
py::class_<::c10d::AllgatherOptions>(module, "AllgatherOptions")
.def(py::init<>())
@ -1022,7 +1025,8 @@ This class does not support ``__members__`` property.)");
py::class_<::c10d::GatherOptions>(module, "GatherOptions")
.def(py::init<>())
.def_readwrite("rootRank", &::c10d::GatherOptions::rootRank)
.def_readwrite("timeout", &::c10d::GatherOptions::timeout);
.def_readwrite("timeout", &::c10d::GatherOptions::timeout)
.def_readwrite("asyncOp", &::c10d::GatherOptions::asyncOp);
py::class_<::c10d::ScatterOptions>(module, "ScatterOptions")
.def(py::init<>())
@ -1040,11 +1044,13 @@ This class does not support ``__members__`` property.)");
.def(py::init<>())
.def_readwrite("device_ids", &::c10d::BarrierOptions::device_ids)
.def_readwrite("timeout", &::c10d::BarrierOptions::timeout)
.def_readwrite("device", &::c10d::BarrierOptions::device);
.def_readwrite("device", &::c10d::BarrierOptions::device)
.def_readwrite("asyncOp", &::c10d::BarrierOptions::asyncOp);
py::class_<::c10d::AllToAllOptions>(module, "AllToAllOptions")
.def(py::init<>())
.def_readwrite("timeout", &::c10d::AllToAllOptions::timeout);
.def_readwrite("timeout", &::c10d::AllToAllOptions::timeout)
.def_readwrite("asyncOp", &::c10d::AllToAllOptions::asyncOp);
py::class_<::c10d::DistributedBackendOptions>(
module, "_DistributedBackendOptions")

View File

@ -2501,7 +2501,7 @@ class _CoalescingManager:
def __init__(self) -> None:
self.works: list[Work] = []
def append(self, work: Work):
def append(self, work: Optional[Work] = None):
if work:
self.works.append(work)
@ -2514,7 +2514,7 @@ class _CoalescingManager:
def _coalescing_manager(
group: Optional[ProcessGroup] = None,
device: Optional[torch.device] = None,
async_ops: Optional[bool] = False,
async_ops: bool = False,
):
"""
Context manager used to coalesce collectives or P2P operations when possible.
@ -2553,6 +2553,7 @@ def _coalescing_manager(
group._start_coalescing(device)
cm = _CoalescingManager()
yield cm
work = None
op_list = _world.pg_coalesce_state.pop(group)
if op_list:
# Collectives supporting "Fast Path" coalescing are captured.
@ -2566,6 +2567,7 @@ def _coalescing_manager(
tensors = [op.tensor for op in op_list]
all_reduce_opts = AllreduceCoalescedOptions()
all_reduce_opts.reduceOp = not_none(op_list[0].redop)
all_reduce_opts.asyncOp = async_ops
work = group.allreduce_coalesced(tensors, all_reduce_opts)
elif op0 == all_gather_into_tensor:
inputs = []
@ -2573,6 +2575,8 @@ def _coalescing_manager(
for op in op_list:
inputs.append(op.tensor)
outputs.append(not_none(op.dst_tensor))
all_gather_opts = AllgatherOptions()
all_gather_opts.asyncOp = async_ops
work = group.allgather_into_tensor_coalesced(outputs, inputs)
elif op0 == reduce_scatter_tensor:
inputs = []
@ -2582,6 +2586,7 @@ def _coalescing_manager(
outputs.append(not_none(op.dst_tensor))
reduce_opts = ReduceScatterOptions()
reduce_opts.reduceOp = not_none(op_list[0].redop)
reduce_opts.asyncOp = async_ops
work = group.reduce_scatter_tensor_coalesced(outputs, inputs, reduce_opts)
else:
raise AssertionError(
@ -2594,9 +2599,12 @@ def _coalescing_manager(
work = group._end_coalescing(device)
if async_ops:
cm.append(work) # type: ignore[possibly-undefined]
else:
work.wait() # type: ignore[possibly-undefined]
cm.append(work)
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait()
# Otherwise, the backend has sync'ed at CPP level
class _TimeEstimator:
@ -2772,8 +2780,11 @@ def broadcast(
work = group.broadcast([tensor], opts)
if async_op:
return work
else:
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger
@ -2853,6 +2864,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
opts = AllreduceOptions()
opts.reduceOp = op
opts.asyncOp = async_op
if group is None:
group = _get_default_group()
@ -2869,8 +2881,11 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
if async_op:
return work
else:
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger
@ -2929,13 +2944,17 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False):
opts = AllreduceCoalescedOptions()
opts.reduceOp = op
opts.asyncOp = async_op
group = group or _get_default_group()
work = group.allreduce_coalesced(tensors, opts)
if async_op:
return work.get_future()
else:
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger
@ -2980,11 +2999,15 @@ def reduce(
opts = ReduceOptions()
opts.reduceOp = op
opts.rootRank = group_dst
opts.asyncOp = async_op
work = group.reduce([tensor], opts)
if async_op:
return work
else:
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait()
# Otherwise, the backend has sync'ed at CPP level
def _object_to_tensor(obj, device, group):
@ -3783,12 +3806,17 @@ def all_gather(tensor_list, tensor, group=None, async_op=False):
tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor)
group = group or _get_default_group()
work = group.allgather([tensor_list], [tensor])
opts = AllgatherOptions()
opts.asyncOp = async_op
work = group.allgather([tensor_list], [tensor], opts)
if async_op:
return work
else:
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger
@ -3891,8 +3919,11 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal
if async_op:
return work
else:
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger
@ -4002,12 +4033,17 @@ def all_gather_coalesced(
]
group = group or _get_default_group()
work = group.allgather_coalesced(output_tensor_lists, input_tensor_list)
opts = AllgatherOptions()
opts.asyncOp = async_op
work = group.allgather_coalesced(output_tensor_lists, input_tensor_list, opts)
if async_op:
return work.get_future()
else:
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait()
# Otherwise, the backend has sync'ed at CPP level
def _validate_output_list_for_rank(my_rank, dst, gather_list):
@ -4093,12 +4129,16 @@ def gather(
opts = GatherOptions()
opts.rootRank = group_dst
opts.asyncOp = async_op
work = group.gather(output_tensors, input_tensors, opts)
if async_op:
return work
else:
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger
@ -4199,8 +4239,11 @@ def scatter(
if async_op:
return work
else:
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger
@ -4232,14 +4275,18 @@ def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=Fal
opts = ReduceScatterOptions()
opts.reduceOp = op
opts.asyncOp = async_op
group = group or _get_default_group()
work = group.reduce_scatter([output], [input_list], opts)
if async_op:
return work
else:
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger
@ -4336,8 +4383,11 @@ def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=F
if async_op:
return work
else:
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait()
# Otherwise, the backend has sync'ed at CPP level
@deprecated(
@ -4490,6 +4540,7 @@ def all_to_all_single(
return
opts = AllToAllOptions()
opts.asyncOp = async_op
_check_single_tensor(output, "output")
_check_single_tensor(input, "input")
_ensure_all_tensors_same_dtype(output, input)
@ -4509,8 +4560,11 @@ def all_to_all_single(
if async_op:
return work
else:
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger
@ -4611,6 +4665,7 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False
return
opts = AllToAllOptions()
opts.asyncOp = async_op
_check_tensor_list(output_tensor_list, "output_tensor_list")
_check_tensor_list(input_tensor_list, "input_tensor_list")
_ensure_all_tensors_same_dtype(output_tensor_list, input_tensor_list)
@ -4627,8 +4682,11 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False
if async_op:
return work
else:
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger
@ -4659,6 +4717,7 @@ def barrier(
opts = BarrierOptions()
opts.device = torch.device(_get_object_coll_device(group))
opts.asyncOp = async_op
if device_ids is not None:
if isinstance(device_ids, list):
opts.device_ids = device_ids
@ -4672,8 +4731,11 @@ def barrier(
if async_op:
return work
else:
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait()
# Otherwise, the backend has sync'ed at CPP level
def monitored_barrier(

View File

@ -96,7 +96,7 @@ try:
import torchvision
HAS_TORCHVISION = True
except ImportError:
except Exception: # Covering both ImportError and RuntimeError
HAS_TORCHVISION = False
if sys.platform == "win32":
@ -8310,50 +8310,14 @@ class DistributedTest:
def test_compute_bucket_assignment_by_size_sparse_error_with_logger(self):
self._test_compute_bucket_assignment_by_size(use_logger=True)
def _determine_expected_error_verify_model_across_rank(
self, group_to_use, diff_num_params=False
):
# When running with NCCL backend, we don't expect an error on rank 0,
# rather, it will be taken down by TORCH_NCCL_ASYNC_ERROR_HANDLING. When
# running with Gloo or with debug mode wrapper, we expect the error
# to be caught inline.
# All ranks report same error when there is a # of parameter
# mismatch since we use allgather in the impl.
if diff_num_params:
expected_err = "DDP expects same model across all ranks"
ctx = self.assertRaisesRegex(RuntimeError, expected_err)
return ctx, expected_err
is_detail_dbg_mode = dist.get_debug_level() == dist.DebugLevel.DETAIL
if self.rank == 0:
if (
dist.get_backend(group_to_use) == dist.Backend.NCCL
and not is_detail_dbg_mode
):
expected_err = "caught collective operation timeout"
ctx = self.assertRaisesRegex(RuntimeError, expected_err)
else:
expected_err = None
ctx = self.assertRaises(RuntimeError)
else:
expected_err = "appears not to match"
ctx = self.assertRaisesRegex(RuntimeError, expected_err)
return ctx, expected_err
def _test_verify_model_across_rank(self, use_logger):
group_gloo = dist.new_group(
timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
)
# Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test
# determinism.
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
group_to_use = dist.new_group(
backend=dist.get_backend(), timeout=timedelta(seconds=5)
)
torch.cuda.set_device(self.rank)
ctx, expected_err = self._determine_expected_error_verify_model_across_rank(
group_to_use
)
# Create a valid model. The constructor initializes the logger that we use later.
net = EmbeddingNetDifferentParams(0)
@ -8371,7 +8335,8 @@ class DistributedTest:
net.module.lin = nn.Linear(100 if self.rank == 0 else 10, 1)
# if we pass a logger we can verify that it was logged
with ctx:
caught = 0
try:
if use_logger:
_verify_param_shape_across_processes(
net.process_group, list(net.parameters()), net.logger
@ -8380,18 +8345,13 @@ class DistributedTest:
_verify_param_shape_across_processes(
net.process_group, list(net.parameters())
)
# Should only be run by rank 0, and blocking_wait catches and
# reports exception.
dist.barrier(group_to_use)
except Exception:
caught = 1
# We don't check when self.rank != 0 because the logger doesn't log
# the error "Caught collective operation" as that is not thrown in the reducer.
if use_logger and self.rank != 0:
verify_ddp_error_logged(net, expected_err)
# Perform gloo-based barrier to ensure one rank doesn't exit test
# early which causes failure with Barrier.sync.
dist.barrier(group_gloo)
# As long as there is one rank catching the exception
t = torch.Tensor([caught])
dist.all_reduce(t, group=group_gloo)
self.assertGreater(t, 0)
@require_backend_is_available(DistTestCases.backend_feature["gpu"])
@skip_but_pass_in_sandcastle_if(
@ -8409,20 +8369,19 @@ class DistributedTest:
def test_verify_model_across_rank_without_logger(self):
self._test_verify_model_across_rank(use_logger=False)
def _run_test_ddp_model_with_diff_params(self, ctx, net, ddp_group, group_gloo):
with ctx:
def _run_test_ddp_model_with_diff_params(self, net, ddp_group, group_gloo):
caught = 0
try:
net = torch.nn.parallel.DistributedDataParallel(
net.to(self.rank), device_ids=[self.rank], process_group=ddp_group
)
# Should only be run by rank 0, and blocking_wait catches and
# reports exception.
dist.barrier(ddp_group)
except Exception:
caught = 1
# can't use verify_ddp_error_logged here because net was never properly constructed
# Perform gloo-based barrier to ensure one rank doesn't exit test
# early which causes failure with Barrier.sync.
dist.barrier(group_gloo)
# As long as there is one rank catching the exception
t = torch.Tensor([caught])
dist.all_reduce(t, group=group_gloo)
self.assertGreater(t, 0)
@require_backend_is_available(DistTestCases.backend_feature["gpu"])
@skip_but_pass_in_sandcastle_if(
@ -8433,21 +8392,15 @@ class DistributedTest:
group_gloo = dist.new_group(
timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
)
# Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test
# determinism.
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
group_to_use = dist.new_group(
backend=dist.get_backend(), timeout=timedelta(seconds=10)
)
torch.cuda.set_device(self.rank)
ctx, _expected_err = self._determine_expected_error_verify_model_across_rank(
group_to_use
)
# Creates network with different sized embedding table on different
# ranks. This should throw an error during DDP init.
net = EmbeddingNetDifferentParams(self.rank)
self._run_test_ddp_model_with_diff_params(
ctx, net, group_to_use, group_gloo
net, group_to_use, group_gloo
)
@require_backend_is_available(DistTestCases.backend_feature["gpu"])
@ -8459,16 +8412,10 @@ class DistributedTest:
group_gloo = dist.new_group(
timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
)
# Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test
# determinism.
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
group_to_use = dist.new_group(
backend=dist.get_backend(), timeout=timedelta(seconds=10)
)
torch.cuda.set_device(self.rank)
ctx, _expected_err = self._determine_expected_error_verify_model_across_rank(
group_to_use, diff_num_params=True
)
# Creates network with diff # of param across ranks, reducer should
# recognize this and throw appropriate error.
@ -8477,7 +8424,6 @@ class DistributedTest:
)
self._run_test_ddp_model_with_diff_params(
ctx,
net,
group_to_use,
group_gloo,