Revert "[PGNCCL] Launch kernel on current stream & remove record_stream entirely (#148590)"

This reverts commit ef6296e7f20d744a0cfed81cab573d60204e7626.

Reverted https://github.com/pytorch/pytorch/pull/148590 on behalf of https://github.com/izaitsevfb due to reverted internally, see D71292427 ([comment](https://github.com/pytorch/pytorch/pull/148590#issuecomment-2731114626))
This commit is contained in:
PyTorch MergeBot
2025-03-17 22:43:15 +00:00
parent a16ada41b9
commit afa1eda901
11 changed files with 362 additions and 411 deletions

View File

@ -363,9 +363,6 @@ class TestDebugInfoWriter : public c10d::DebugInfoWriter {
};
TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
// Note (kwen2501) 03/07/2025
// TODO: re-enable
GTEST_SKIP() << "Skipping test as the trace write seems unstable.";
int heartBeatIntervalInSec = 2;
std::string timeInterval = std::to_string(heartBeatIntervalInSec);
ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "0", 1) == 0);

View File

@ -126,9 +126,6 @@ ALLOW_LIST = [
("aten::reduce_scatter_tensor", datetime.date(9999, 1, 30)),
("aten::all_gather_into_tensor", datetime.date(9999, 1, 30)),
("aten::all_reduce", datetime.date(9999, 1, 30)),
# These ops are defined in torch/csrc/distributed/c10d/Ops.cpp
# TODO: add back restriction when c10d ops can be exported
("c10d::.*", datetime.date(9999, 1, 1)),
]
ALLOW_LIST_COMPILED = [

View File

@ -2,7 +2,7 @@
# mypy: disable-error-code="type-arg"
from datetime import timedelta
from enum import Enum
from typing import Any, Optional, overload
from typing import Any, overload
import torch
from torch import Tensor
@ -139,8 +139,6 @@ class BroadcastOptions:
class AllreduceOptions:
reduceOp: ReduceOp
timeout: timedelta
asyncOp: bool
sparseIndices: Optional[Tensor]
class AllreduceCoalescedOptions(AllreduceOptions): ...
@ -149,7 +147,6 @@ class ReduceOptions:
rootRank: int
rootTensor: int
timeout: timedelta
asyncOp: bool
class AllgatherOptions:
timeout: timedelta
@ -158,7 +155,6 @@ class AllgatherOptions:
class GatherOptions:
rootRank: int
timeout: timedelta
asyncOp: bool
class ScatterOptions:
rootRank: int
@ -174,11 +170,9 @@ class BarrierOptions:
device_ids: list[int]
device: torch.device
timeout: timedelta
asyncOp: bool
class AllToAllOptions:
timeout: timedelta
asyncOp: bool
class Store:
def set(self, key: str, value: str): ...

View File

@ -17,37 +17,37 @@ TORCH_LIBRARY(c10d, m) {
.def("wait", [](const c10::intrusive_ptr<Work>& self) { self->wait(); });
m.class_<ReduceOp>("ReduceOp").def(torch::init<>());
m.def(
"broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
"broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, bool asyncOp, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
m.def(
"allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
"allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
m.def(
"allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
"allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> __torch__.torch.classes.c10d.Work");
m.def(
"allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, int timeout=-1) -> (Tensor[][], __torch__.torch.classes.c10d.Work)");
"allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int timeout) -> (Tensor[][], __torch__.torch.classes.c10d.Work)");
m.def(
"_allgather_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, int timeout=-1) -> (Tensor, __torch__.torch.classes.c10d.Work)");
"_allgather_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, bool asyncOp, int timeout) -> (Tensor, __torch__.torch.classes.c10d.Work)");
m.def(
"allgather_coalesced_(Tensor[][] output_lists, Tensor[] input_list, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True) -> __torch__.torch.classes.c10d.Work");
"allgather_coalesced_(Tensor[][] output_lists, Tensor[] input_list, __torch__.torch.classes.c10d.ProcessGroup process_group) -> __torch__.torch.classes.c10d.Work");
m.def(
"allgather_into_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True) -> __torch__.torch.classes.c10d.Work");
"allgather_into_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group) -> __torch__.torch.classes.c10d.Work");
m.def(
"reduce_scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
"reduce_scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
m.def(
"_reduce_scatter_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> (Tensor, __torch__.torch.classes.c10d.Work)");
"_reduce_scatter_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool asyncOp, int timeout) -> (Tensor, __torch__.torch.classes.c10d.Work)");
m.def(
"reduce_scatter_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
"reduce_scatter_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> __torch__.torch.classes.c10d.Work");
m.def(
"reduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int root_rank, int root_tensor, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
"reduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int root_rank, int root_tensor, int timeout) -> __torch__.torch.classes.c10d.Work");
m.def(
"gather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
"gather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int timeout) -> __torch__.torch.classes.c10d.Work");
m.def(
"scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
"scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool asyncOp, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
m.def(
"alltoall_(Tensor[] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
"alltoall_(Tensor[] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
m.def(
"alltoall_base_(Tensor output, Tensor input, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] output_split_sizes, int[] input_split_sizes, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
"alltoall_base_(Tensor output, Tensor input, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] output_split_sizes, int[] input_split_sizes, int timeout) -> __torch__.torch.classes.c10d.Work");
m.def(
"barrier(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
"barrier(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, int timeout) -> __torch__.torch.classes.c10d.Work");
m.def(
"monitored_barrier_(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, int timeout, bool wait_all_ranks) -> ()");
m.def(
@ -118,7 +118,6 @@ IMPL_RECV_ANY_SOURCE(PrivateUse1)
const c10::intrusive_ptr<ReduceOp>& reduce_op, \
int64_t root_rank, \
int64_t root_tensor, \
bool asyncOp, \
int64_t timeout) { \
auto tensor_vec = tensors.vec(); \
return process_group->getBackend(c10::DeviceType::DEV) \
@ -128,8 +127,7 @@ IMPL_RECV_ANY_SOURCE(PrivateUse1)
*reduce_op.get(), \
root_rank, \
root_tensor, \
std::chrono::milliseconds(timeout), \
asyncOp}); \
std::chrono::milliseconds(timeout)}); \
}
IMPL_REDUCE(CPU)
@ -171,13 +169,12 @@ IMPL_BROADCAST(PrivateUse1)
const c10::intrusive_ptr<ProcessGroup>& process_group, \
const c10::intrusive_ptr<ReduceOp>& reduce_op, \
const std::optional<at::Tensor>& sparse_indices, \
bool asyncOp, \
int64_t timeout) { \
auto tensor_vec = tensors.vec(); \
auto work = process_group->getBackend(c10::DeviceType::DEV) -> allreduce( \
tensor_vec, \
AllreduceOptions{ \
*reduce_op.get(), std::chrono::milliseconds(timeout), asyncOp}); \
*reduce_op.get(), std::chrono::milliseconds(timeout)}); \
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( \
std::move(tensor_vec), work); \
}
@ -191,13 +188,11 @@ IMPL_ALLREDUCE(PrivateUse1)
at::TensorList tensors, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \
const c10::intrusive_ptr<ReduceOp>& reduce_op, \
bool asyncOp, \
int64_t timeout) { \
auto tensor_vec = tensors.vec(); \
AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; \
opts.reduceOp = *reduce_op.get(); \
opts.timeout = std::chrono::milliseconds(timeout); \
opts.asyncOp = asyncOp; \
return process_group->getBackend(c10::DeviceType::DEV) \
->allreduce_coalesced(tensor_vec, opts); \
}
@ -214,13 +209,12 @@ IMPL_ALLREDUCE_COALESCED(PrivateUse1)
const std::vector<std::vector<at::Tensor>>& output_tensors, \
at::TensorList input_tensors, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \
bool asyncOp, \
int64_t timeout) { \
auto input_tensors_vec = input_tensors.vec(); \
auto work = process_group->getBackend(c10::DeviceType::DEV) -> allgather( \
const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors), \
input_tensors_vec, \
AllgatherOptions{std::chrono::milliseconds(timeout), asyncOp}); \
AllgatherOptions{std::chrono::milliseconds(timeout)}); \
return std:: \
tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>>( \
output_tensors, work); \
@ -255,16 +249,12 @@ IMPL__ALLGATHER_BASE(PrivateUse1)
c10::intrusive_ptr<Work> allgather_coalesced_##DEV( \
const std::vector<std::vector<at::Tensor>>& output_lists, \
const at::TensorList& input_list, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \
bool asyncOp) { \
const c10::intrusive_ptr<ProcessGroup>& process_group) { \
auto input_list_vec = input_list.vec(); \
auto opts = AllgatherOptions{}; \
opts.asyncOp = asyncOp; \
return process_group->getBackend(c10::DeviceType::DEV) \
->allgather_coalesced( \
const_cast<std::vector<std::vector<at::Tensor>>&>(output_lists), \
input_list_vec, \
opts); \
input_list_vec); \
}
IMPL_ALLGATHER_COALESCED(CPU)
@ -275,14 +265,11 @@ IMPL_ALLGATHER_COALESCED(PrivateUse1)
c10::intrusive_ptr<c10d::Work> allgather_into_tensor_coalesced_##DEV( \
at::TensorList outputs, \
at::TensorList inputs, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \
bool asyncOp) { \
const c10::intrusive_ptr<ProcessGroup>& process_group) { \
auto output_vec = outputs.vec(); \
auto input_vec = inputs.vec(); \
auto opts = AllgatherOptions{}; \
opts.asyncOp = asyncOp; \
return process_group->getBackend(c10::DeviceType::DEV) \
->allgather_into_tensor_coalesced(output_vec, input_vec, opts); \
->allgather_into_tensor_coalesced(output_vec, input_vec); \
}
IMPL_ALLGATHER_INTO_TENSOR_COALESCED(CPU)
@ -296,7 +283,6 @@ IMPL_ALLGATHER_INTO_TENSOR_COALESCED(PrivateUse1)
const std::vector<std::vector<at::Tensor>>& input_tensors, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \
const c10::intrusive_ptr<ReduceOp>& reduce_op, \
bool asyncOp, \
int64_t timeout) { \
auto output_tensors_vec = output_tensors.vec(); \
auto work = \
@ -304,9 +290,7 @@ IMPL_ALLGATHER_INTO_TENSOR_COALESCED(PrivateUse1)
output_tensors_vec, \
const_cast<std::vector<std::vector<at::Tensor>>&>(input_tensors), \
ReduceScatterOptions{ \
*reduce_op.get(), \
std::chrono::milliseconds(timeout), \
asyncOp}); \
*reduce_op.get(), std::chrono::milliseconds(timeout)}); \
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( \
output_tensors_vec, work); \
}
@ -345,7 +329,6 @@ IMPL__REDUCE_SCATTER_BASE(PrivateUse1)
at::TensorList inputs, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \
const c10::intrusive_ptr<ReduceOp>& reduce_op, \
bool asyncOp, \
int64_t timeout) { \
auto output_vec = outputs.vec(); \
auto input_vec = inputs.vec(); \
@ -354,9 +337,7 @@ IMPL__REDUCE_SCATTER_BASE(PrivateUse1)
output_vec, \
input_vec, \
ReduceScatterOptions{ \
*reduce_op.get(), \
std::chrono::milliseconds(timeout), \
asyncOp}); \
*reduce_op.get(), std::chrono::milliseconds(timeout)}); \
}
IMPL_REDUCE_SCATTER_TENSOR_COALESCED(CPU)
@ -369,15 +350,13 @@ IMPL_REDUCE_SCATTER_TENSOR_COALESCED(PrivateUse1)
const at::TensorList& input_tensors, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \
int64_t root_rank, \
bool asyncOp, \
int64_t timeout) { \
auto input_tensors_vec = input_tensors.vec(); \
return process_group->getBackend(c10::DeviceType::DEV) \
->gather( \
const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors), \
input_tensors_vec, \
GatherOptions{ \
root_rank, std::chrono::milliseconds(timeout), asyncOp}); \
GatherOptions{root_rank, std::chrono::milliseconds(timeout)}); \
}
IMPL_GATHER(CPU)
@ -412,14 +391,13 @@ IMPL_SCATTER(PrivateUse1)
const at::TensorList& output_tensors, \
const at::TensorList& input_tensors, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \
bool asyncOp, \
int64_t timeout) { \
auto output_tensors_vec = output_tensors.vec(); \
auto input_tensors_vec = input_tensors.vec(); \
auto work = process_group->getBackend(c10::DeviceType::DEV) -> alltoall( \
output_tensors_vec, \
input_tensors_vec, \
AllToAllOptions{std::chrono::milliseconds(timeout), asyncOp}); \
AllToAllOptions{std::chrono::milliseconds(timeout)}); \
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( \
std::move(output_tensors_vec), work); \
}
@ -428,22 +406,21 @@ IMPL_ALLTOALL(CPU)
IMPL_ALLTOALL(CUDA)
IMPL_ALLTOALL(PrivateUse1)
#define IMPL_ALLTOALL_BASE(DEV) \
c10::intrusive_ptr<Work> alltoall_base_##DEV( \
at::Tensor& output, \
at::Tensor& input, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \
std::vector<int64_t> output_split_sizes, \
std::vector<int64_t> input_split_sizes, \
bool asyncOp, \
int64_t timeout) { \
return process_group->getBackend(c10::DeviceType::DEV) \
->alltoall_base( \
output, \
input, \
output_split_sizes, \
input_split_sizes, \
AllToAllOptions{std::chrono::milliseconds(timeout), asyncOp}); \
#define IMPL_ALLTOALL_BASE(DEV) \
c10::intrusive_ptr<Work> alltoall_base_##DEV( \
at::Tensor& output, \
at::Tensor& input, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \
std::vector<int64_t> output_split_sizes, \
std::vector<int64_t> input_split_sizes, \
int64_t timeout) { \
return process_group->getBackend(c10::DeviceType::DEV) \
->alltoall_base( \
output, \
input, \
output_split_sizes, \
input_split_sizes, \
AllToAllOptions{std::chrono::milliseconds(timeout)}); \
}
IMPL_ALLTOALL_BASE(CPU)
@ -451,18 +428,15 @@ IMPL_ALLTOALL_BASE(CUDA)
IMPL_ALLTOALL_BASE(PrivateUse1)
// NOLINTBEGIN(performance-unnecessary-value-param)
#define IMPL_BARRIER(DEV) \
c10::intrusive_ptr<Work> barrier##DEV( \
at::Tensor /* unused */, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \
const std::vector<int64_t>& device_ids, \
bool asyncOp, \
int64_t timeout) { \
auto opts = BarrierOptions{}; \
opts.device_ids = device_ids; \
opts.timeout = std::chrono::milliseconds(timeout); \
opts.asyncOp = asyncOp; \
return process_group->getBackend(c10::DeviceType::DEV)->barrier(opts); \
#define IMPL_BARRIER(DEV) \
c10::intrusive_ptr<Work> barrier##DEV( \
at::Tensor /* unused */, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \
const std::vector<int64_t>& device_ids, \
int64_t timeout) { \
return process_group->getBackend(c10::DeviceType::DEV) \
->barrier( \
BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}); \
}
IMPL_BARRIER(CPU)
@ -490,7 +464,6 @@ allreduce_sparse_cuda_(
const c10::intrusive_ptr<ProcessGroup>& process_group,
const c10::intrusive_ptr<ReduceOp>& reduce_op,
const std::optional<at::Tensor>& sparse_indices,
bool asyncOp,
int64_t timeout) {
auto tensor_vec = tensors.vec();
auto work = process_group->getBackend(c10::DeviceType::CUDA)
@ -499,7 +472,6 @@ allreduce_sparse_cuda_(
AllreduceOptions{
*reduce_op,
std::chrono::milliseconds(timeout),
asyncOp,
sparse_indices});
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(

View File

@ -224,7 +224,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
const std::optional<at::Tensor>& sparse_indices,
bool,
int64_t)>();
auto work = std::get<1>(op.call(
@ -232,7 +231,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.sparseIndices,
opts.asyncOp,
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -252,14 +250,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
bool,
int64_t)>();
auto work = op.call(
tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.asyncOp,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -281,7 +277,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ReduceOp>&,
int64_t,
int64_t,
bool,
int64_t)>();
auto work = op.call(
tensors,
@ -289,7 +284,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.rootRank,
opts.rootTensor,
opts.asyncOp,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -312,14 +306,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const std::vector<std::vector<at::Tensor>>&,
at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
bool,
int64_t)>();
auto work = std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.asyncOp,
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -371,19 +363,18 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
std::vector<std::vector<at::Tensor>>& outputTensorLists,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::allgather_coalesced_", "")
.typed<c10::intrusive_ptr<Work>(
const std::vector<std::vector<at::Tensor>>&,
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
bool)>();
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::allgather_coalesced_", "")
.typed<c10::intrusive_ptr<Work>(
const std::vector<std::vector<at::Tensor>>&,
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
auto work = op.call(
outputTensorLists,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.asyncOp);
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor_list : outputTensorLists) {
@ -408,14 +399,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
.typed<c10::intrusive_ptr<Work>(
const at::TensorList,
const at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
bool)>();
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
auto work = op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.asyncOp);
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : outputTensors) {
@ -436,14 +425,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t,
bool,
int64_t)>();
auto work = op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.rootRank,
opts.asyncOp,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -500,14 +487,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const std::vector<std::vector<at::Tensor>>&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
bool,
int64_t)>();
auto work = std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
opts.asyncOp,
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -561,7 +546,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&,
bool,
int64_t)>();
auto work = op.call(
@ -569,7 +553,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
opts.asyncOp,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -594,7 +577,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
std::vector<int64_t>,
std::vector<int64_t>,
bool,
int64_t)>();
auto work = op.call(
outputBuffer,
@ -602,7 +584,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
outputSplitSizes,
inputSplitSizes,
opts.asyncOp,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -623,13 +604,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList&,
const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
bool,
int64_t)>();
auto work = std::get<1>(op.call(
outputTensors,
inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.asyncOp,
opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) {
@ -799,14 +778,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
at::Tensor,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const std::vector<int64_t>&,
bool,
int64_t)>();
auto work = op.call(
tensor,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.device_ids,
opts.asyncOp,
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
c10d::register_work(tensor, work);

View File

@ -496,8 +496,6 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(
}
futureWorkResult_ =
c10::make_intrusive<at::ivalue::Future>(c10::AnyEnumType::get());
// other functions expect an initialized ptr
stashed_for_allocator_safety_ = std::make_shared<std::vector<at::Tensor>>();
}
ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w)
@ -519,11 +517,6 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w)
numelIn_(w.numelIn_),
numelOut_(w.numelOut_),
store_(w.store_),
// Note: the `work` returned to user and the `work` enqueued to watchdog
// share the pointer to the tensor stash. At least one of them should
// clean the tensor stash, the earlier the better, i.e. user calling
// `work.wait` than watchdog detecting work completion.
stashed_for_allocator_safety_(w.stashed_for_allocator_safety_),
futureWorkResult_(w.futureWorkResult_),
timingEnabled_(w.timingEnabled_),
trace_id_(w.trace_id_),
@ -717,25 +710,14 @@ void ProcessGroupNCCL::WorkNCCL::synchronize() {
}
}
void ProcessGroupNCCL::WorkNCCL::stashTensors(
std::vector<at::Tensor>& tensors) {
std::lock_guard<std::mutex> lock(stashMutex_);
stashed_for_allocator_safety_->insert(
stashed_for_allocator_safety_->end(), tensors.begin(), tensors.end());
}
void ProcessGroupNCCL::WorkNCCL::unstashTensors() {
std::lock_guard<std::mutex> lock(stashMutex_);
stashed_for_allocator_safety_->clear();
}
void ProcessGroupNCCL::WorkNCCL::synchronizeStream() {
auto currentStream = at::cuda::getCurrentCUDAStream(device_.index());
// Block the current stream on the NCCL stream
ncclEndEvent_->block(currentStream);
// Unstage the stashed tensors so that CachingAllocator can recycle them
// THIS MUST HAPPEN AFTER THE BLOCKING CALL ABOVE
unstashTensors();
if (avoidRecordStreams_) {
stashed_for_allocator_safety_->clear();
}
}
// Same as calling synchronize() when blockingWait_ is false
@ -951,10 +933,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
enableTiming_.store(
getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug_);
#endif // ENABLE_NCCL_ERROR_CHECKING
if (getCvarBool(TORCH_NCCL_AVOID_RECORD_STREAMS, false)) {
TORCH_WARN_ONCE(
"TORCH_NCCL_AVOID_RECORD_STREAMS is the default now, this environment variable is thus deprecated.");
}
avoidRecordStreams_ = getCvarBool(TORCH_NCCL_AVOID_RECORD_STREAMS, false);
#ifdef NCCL_HAS_COMM_REGISTER
useTensorRegisterAllocatorHook_ =
getCvarBool(TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK, false);
@ -2344,12 +2323,6 @@ void ProcessGroupNCCL::watchdogHandler() {
// Clean up completed work
if (work.isCompleted()) {
// In case user didn't call `work.wait()` with async collectives,
// watchdog would unstage the stashed tensors when detecting completion
// of the collective, to prevent ProcessGroupNCCL from holding reference
// to those tensors forever.
work.unstashTensors();
// Work status logging for desync debug
desyncDebugger_.logWorkEnd(work);
@ -3084,7 +3057,6 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
enableTiming_.load(),
cudaEventCacheEnabled_.load(),
dist_debug_level_);
if (record) {
bool isP2P = isP2POp(opType);
// Ideally record every work that we enqueue, rather than every work we
@ -3242,6 +3214,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::endCoalescing(OpType optype) {
enqueue);
work->ncclComm_ = comm;
work->blockingWait_ = blockingWait_;
work->avoidRecordStreams_ = avoidRecordStreams_;
work->store_ = store_;
assignTimeoutToWork(work, options_);
@ -3260,16 +3233,19 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::endCoalescing(OpType optype) {
// TODO(eqy): is this still necessary if avoidRecordStreams_ is set?
work->ncclEndEvent_->record(ncclStream);
if (avoidRecordStreams_) {
// other functions expect an initialized ptr if avoidRecordStreams_ is set
work->stashed_for_allocator_safety_ =
std::make_shared<std::vector<at::Tensor>>();
}
if (enqueue) {
workEnqueue(work);
}
// Reset coalescing state
coalescing_state_ = 0;
coalescedComm_ = nullptr;
// If in async mode, return work; otherwise, kernel is enqueued on current
// stream, no need to return work
return coalescedAsync_ ? work : nullptr;
return work;
}
c10::intrusive_ptr<Work> ProcessGroupNCCL::endCoalescing() {
@ -3285,10 +3261,11 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
PreProcess pre,
PostProcess post,
OpType opType,
bool asyncOp,
const char* profilingTitle,
bool avoidRecordStreams,
bool nanCheck) {
// Environment setting by the user may add onto collective call's option
avoidRecordStreams |= avoidRecordStreams_;
nanCheck &= enableNanCheck_;
auto device = getDevice(inputs[0]);
@ -3329,17 +3306,13 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
} else {
TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG);
}
coalescedAsync_ = asyncOp;
}
// in asyncOp=false [default] mode, we use currentStream as ncclStream
// otherwise, we use separate ncclStream and let it sync on currentStream
auto ncclStream = asyncOp ? ncclStreams_.at(key)
: at::cuda::getCurrentCUDAStream(device.index());
if (asyncOp) {
// First let NCCL streams wait for input tensors allocation streams
syncStream(device, ncclEvents_[key], ncclStream);
}
// Used many times below, so we stash the unordered_map lookup
auto ncclStream = ncclStreams_.at(key);
// First let NCCL streams wait for input tensors allocation streams
syncStream(device, ncclEvents_[key], ncclStream);
bool enqueue =
!coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None;
@ -3349,12 +3322,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
// Store references to outputs to be used by WorkNCCL::result and operator<<.
work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);
// If we are performing sync operations, i.e. equeuing kernel onto "current"
// stream, we don't need to do anything for tensor lifetime management.
// Otherwise, we need to stage the tensors will `work.wait()`.
if (asyncOp) {
work->stashTensors(inputs);
work->stashTensors(outputs);
if (avoidRecordStreams) {
work->stashed_for_allocator_safety_ =
std::make_shared<std::vector<at::Tensor>>(inputs);
}
if (nanCheck) {
@ -3380,6 +3350,21 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
// operations where `inputs' and `outputs' are not the same.
//
// See [Sync Streams].
if (!avoidRecordStreams) {
for (const auto& input : inputs) {
if (!input.is_sparse()) {
c10::cuda::CUDACachingAllocator::recordStream(
input.storage().data_ptr(), ncclStream);
} else {
// for sparse input case record streams on both index and value
// tensors
c10::cuda::CUDACachingAllocator::recordStream(
input.values().storage().data_ptr(), ncclStream);
c10::cuda::CUDACachingAllocator::recordStream(
input.indices().storage().data_ptr(), ncclStream);
}
}
}
// Not all collectives have the same signature, e.g, all-reduce take in a Tensor
// as the input and output while all-to-all take in a vector of Tensors as input
@ -3431,6 +3416,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
// Set appropriate work parameters.
work->blockingWait_ = blockingWait_;
work->avoidRecordStreams_ = avoidRecordStreams;
work->store_ = store_;
assignTimeoutToWork(work, options_);
// Record size info for debug. We only record the size on the first device as
@ -3448,7 +3434,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
workEnqueue(work);
}
return asyncOp ? work : nullptr;
return work;
}
template <typename Fn>
@ -3457,8 +3443,11 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
std::vector<at::Tensor>& outputs,
Fn fn,
OpType opType,
bool asyncOp,
const char* profilingTitle) {
const char* profilingTitle,
bool avoidRecordStreams) {
// Environment setting by the user may add onto collective call's option
avoidRecordStreams |= avoidRecordStreams_;
// Currently, the API permits one scenario where inputs.size() and
// outputs.size() are > 0.
// 1. If the call was a _coalesced call, all inputs must be on the same
@ -3504,17 +3493,13 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
} else {
TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG);
}
coalescedAsync_ = asyncOp;
}
// in asyncOp=false [default] mode, we use currentStream as ncclStream
// otherwise, we use separate ncclStream and let it sync on currentStream
auto ncclStream = asyncOp ? ncclStreams_.at(key)
: at::cuda::getCurrentCUDAStream(device.index());
if (asyncOp) {
// First let NCCL streams wait for input tensors allocation streams
syncStream(device, ncclEvents_[key], ncclStream);
}
// Used many times below, so we stash the unordered_map lookup
auto ncclStream = ncclStreams_.at(key);
// First let NCCL streams wait for input tensors allocation streams
syncStream(device, ncclEvents_[key], ncclStream);
auto work = initWork(
device,
@ -3529,12 +3514,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
// Store references to outputs to be used by WorkNCCL::result and operator<<.
work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);
// If we are performing sync operations, i.e. equeuing kernel onto "current"
// stream, we don't need to do anything for tensor lifetime management.
// Otherwise, we need to stage the tensors will `work.wait()`.
if (asyncOp) {
work->stashTensors(inputs);
work->stashTensors(outputs);
if (avoidRecordStreams) {
work->stashed_for_allocator_safety_ =
std::make_shared<std::vector<at::Tensor>>(inputs);
}
// Start event should only be recorded before the ncclGroupStart() (which
@ -3560,6 +3542,27 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
{
torch::cuda::nccl::AutoNcclGroup nccl_group_guard(comm, useNonblocking());
for (const auto i : c10::irange(inputs.size())) {
// Both `inputs' and `outputs' are created on a worker stream and used in
// different ncclStreams. Hence, both must record the ncclStream to
// prevent being freed before the collective finishes.
//
// We only record `inputs' here, and leave recording `outputs' to `fn' for
// operations where `inputs' and `outputs' are not the same.
//
// See [Sync Streams].
if (!avoidRecordStreams) {
if (!inputs[i].is_sparse()) {
c10::cuda::CUDACachingAllocator::recordStream(
inputs[i].storage().data_ptr(), ncclStream);
} else {
// for sparse input case record streams on both index and value
// tensors
c10::cuda::CUDACachingAllocator::recordStream(
inputs[i].values().storage().data_ptr(), ncclStream);
c10::cuda::CUDACachingAllocator::recordStream(
inputs[i].indices().storage().data_ptr(), ncclStream);
}
}
#ifndef NCCL_HAS_COMM_NONBLOCKING
C10D_NCCL_CHECK(
fn(inputs[i], outputs[i], comm, ncclStream),
@ -3600,6 +3603,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
// Set appropriate work parameters.
work->blockingWait_ = blockingWait_;
work->avoidRecordStreams_ = avoidRecordStreams;
work->store_ = store_;
assignTimeoutToWork(work, options_);
// Record size info for debug. We only record the size on the first device as
@ -3630,7 +3634,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
// it, since interactions with it by usercode won't behave normally - they
// won't observe work completion, for instance. Will this lead to silent
// problems during capture?
return asyncOp ? work : nullptr;
return work;
}
template <typename Fn, typename PreProcess, typename PostProcess>
@ -3648,8 +3652,13 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
// to wait() on the returned handle, so ProcessGroupNCCL can't know
// when it's safe to release the input back to the allocator,
// and the present call has no way to know it's not an isend.
// Therefore, we warn and fall back to the typical recordStream logic.
// TODO( kwen2501 ): revisit this when we have a better solution.
// Therefore, we warn and fall back to the typical recordStream logic:
if (avoidRecordStreams_) {
TORCH_WARN_ONCE(
"TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point "
"collectives.");
}
auto device = getDevice(tensor);
at::cuda::OptionalCUDAGuard gpuGuard(device);
@ -3704,8 +3713,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
} else {
TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG);
}
// For now, P2P ops are always put on internal stream
coalescedAsync_ = true;
}
// Used many times below, so we stash the unordered_map lookup
@ -3877,8 +3884,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
PreProcess pre,
PostProcess post,
OpType opType,
bool asyncOp,
const char* profilingTitle,
bool avoidRecordStreams,
bool nanCheck) {
auto inputs = std::vector<at::Tensor>{input};
auto outputs = std::vector<at::Tensor>{output};
@ -3889,8 +3896,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
pre,
post,
opType,
asyncOp,
profilingTitle,
avoidRecordStreams,
nanCheck);
}
@ -3900,8 +3907,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
at::Tensor& output,
Fn fn,
OpType opType,
bool asyncOp,
const char* profilingTitle,
bool avoidRecordStreams,
bool nanCheck) {
auto inputs = std::vector<at::Tensor>{input};
auto outputs = std::vector<at::Tensor>{output};
@ -3914,8 +3921,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
[](at::cuda::CUDAStream&,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
opType,
asyncOp,
profilingTitle,
avoidRecordStreams,
nanCheck);
}
@ -3967,8 +3974,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_sparse(
auto recvIndices = indices[0] * colSize;
// prevent output and recvIndices from being freed
// TODO: not changing the lifetime management of outputs this time,
// revisit later
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
c10::cuda::CUDACachingAllocator::recordStream(
@ -4000,7 +4005,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_sparse(
}
},
OpType::_ALLREDUCE_SPARSE,
opts.asyncOp,
"nccl:all_reduce_sparse");
return work;
#else
@ -4035,7 +4039,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_impl(
stream.stream());
},
OpType::ALLREDUCE,
opts.asyncOp,
profilingTitle);
}
@ -4136,7 +4139,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_coalesced(
stream.stream());
},
OpType::COALESCED,
opts.asyncOp,
"nccl:allreduce_coalesced");
}
@ -4168,10 +4170,12 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::broadcast(
globalRankStride_, // globalRankStride_
this->getSize()); // worldSize
// avoidRecordStreams_ note: collective() will stash tensors.
bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp);
const auto root = opts.rootRank + opts.rootTensor;
bool nanCheck = (root == rank_);
// avoidRecordStreams_ note: collective() will stash tensors.
return collective(
tensor,
tensor,
@ -4188,8 +4192,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::broadcast(
stream.stream());
},
OpType::BROADCAST,
opts.asyncOp,
"nccl:broadcast",
avoidRecordStreams,
nanCheck);
}
@ -4228,8 +4232,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_broadcast_oop(
stream.stream());
},
OpType::BROADCAST,
opts.asyncOp,
"nccl:_broadcast_oop",
/*avoidRecordStreams=*/false,
nanCheck);
}
@ -4288,7 +4292,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce(
stream.stream());
},
OpType::REDUCE,
opts.asyncOp,
"nccl:reduce");
}
@ -4330,7 +4333,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_oop(
stream.stream());
},
OpType::REDUCE,
opts.asyncOp,
"nccl:_reduce_oop");
}
@ -4374,7 +4376,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
at::Tensor& output,
ncclComm_t comm,
at::cuda::CUDAStream& stream) {
// See [We actually don't need to stash anything here].
if (!avoidRecordStreams_) {
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
}
return ncclAllGather(
input.data_ptr(),
output.data_ptr(),
@ -4390,27 +4395,27 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
// - inputTensors is stashed onto work->stashed_for_allocator_safety_
// in collective().
// - outputFlattened is stashed onto work->outputs_ in collective().
// - User-facing outputTensors should be held by the user until after
// waiting on work_, or the call makes no sense.
// So all participating tensors are accounted for, and won't be
// released back to their allocation streams until after work_ is
// waited on.
},
[&](at::cuda::CUDAStream& ncclStream,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
// User-facing outputTensors should be held by the user until after
// waiting on work_, or the call makes no sense. We do a stashing here
// in case user doesn't hold the outputTensors in downstream code,
// which can cause an early recyle by the CachingAllocator, which can
// lead to segfault or data corruption.
if (opts.asyncOp) {
work->stashTensors(outputTensors_);
}
// Copy the flattened output tensors to the outputs.
at::cuda::CUDAStreamGuard guard(ncclStream);
for (const auto j : c10::irange(outputTensors_.size())) {
// See [We actually don't need to stash anything here].
// See [Sync Streams].
if (!avoidRecordStreams_) {
c10::cuda::CUDACachingAllocator::recordStream(
outputTensors_[j].storage().data_ptr(), ncclStream);
}
outputTensors_[j].copy_(
outputFlattened[static_cast<int64_t>(j)], true);
}
},
OpType::ALLGATHER,
opts.asyncOp,
"nccl:all_gather");
} else {
const auto num_reduces = outputTensors_.size();
@ -4418,8 +4423,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
for (const int64_t i : c10::irange(static_cast<int64_t>(num_reduces))) {
auto& output = outputTensors_[i];
auto& input = (i == rank_) ? inputTensor : output;
auto broadcastOpts =
BroadcastOptions{i, int64_t(0), opts.timeout, opts.asyncOp};
auto broadcastOpts = BroadcastOptions{i, int64_t(0), opts.timeout};
_broadcast_oop(output, input, broadcastOpts);
}
auto work = endCoalescing(OpType::ALLGATHER);
@ -4475,7 +4479,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather_into_tensor_coalesced(
stream.stream());
},
OpType::COALESCED,
opts.asyncOp,
"nccl:all_gather_into_tensor_coalesced");
}
@ -4521,6 +4524,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
at::Tensor& output,
ncclComm_t comm,
at::cuda::CUDAStream& stream) {
if (!avoidRecordStreams_) {
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
}
const auto ncclDataType = getNcclDataType(input.scalar_type());
const auto ncclReduceOp =
getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm);
@ -4535,18 +4542,27 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
},
[&](at::cuda::CUDAStream& ncclStream,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
// We only need to stash inputTensors.
// - inputFlattened is stashed onto
// work->stashed_for_allocator_safety_ in collective().
// - User-facing outputTensors is stashed onto work->outputs_ in
// collective(), and should also be held by the user until after
// waiting on work_.
if (opts.asyncOp) {
work->stashTensors(inputTensors_);
if (avoidRecordStreams_) {
// We only need to stash inputTensors.
// - inputFlattened is stashed onto
// work->stashed_for_allocator_safety_
// in collective().
// - User-facing outputTensors is stashed onto work->outputs_ in
// collective(),
// and should also be held by the user until after waiting on
// work_.
auto& v = work->stashed_for_allocator_safety_;
v->insert(v->end(), inputTensors_.begin(), inputTensors_.end());
}
// Copy the input tensors to the flattened inputs.
at::cuda::CUDAStreamGuard guard(ncclStream);
for (const auto j : c10::irange(inputTensors_.size())) {
// See [Sync Streams].
if (!avoidRecordStreams_) {
c10::cuda::CUDACachingAllocator::recordStream(
inputTensors_[j].storage().data_ptr(), ncclStream);
}
inputFlattened[static_cast<int64_t>(j)].copy_(
inputTensors_[j], true);
}
@ -4554,7 +4570,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
[&](at::cuda::CUDAStream&,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
OpType::REDUCE_SCATTER,
opts.asyncOp,
"nccl:reduce_scatter");
} else {
const auto num_reduces = inputTensors_.size();
@ -4566,8 +4581,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
opts.reduceOp,
static_cast<int64_t>(i),
static_cast<int64_t>(0),
opts.timeout,
opts.asyncOp};
opts.timeout};
_reduce_oop(output, input, reduceOpts);
}
auto work = endCoalescing(OpType::REDUCE_SCATTER);
@ -4621,6 +4635,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_scatter_base(
// stream so that the caching allocator can reuse memory pool for this stream
// in a clever way. This setting is added for libraries like FSDP which uses
// `reduce_scatter_tensor`.
bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp);
return collective(
inputTensor,
@ -4629,6 +4644,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_scatter_base(
at::Tensor& output,
ncclComm_t comm,
at::cuda::CUDAStream& stream) {
if (!avoidRecordStreams) {
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
}
auto ncclDataType = getNcclDataType(input.scalar_type());
auto ncclReduceOp =
getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm);
@ -4642,8 +4661,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_scatter_base(
stream.stream());
},
OpType::_REDUCE_SCATTER_BASE,
opts.asyncOp,
"nccl:_reduce_scatter_base");
"nccl:_reduce_scatter_base",
avoidRecordStreams);
}
c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter_tensor_coalesced(
@ -4680,6 +4699,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter_tensor_coalesced(
at::Tensor& output,
ncclComm_t comm,
at::cuda::CUDAStream& stream) {
if (!avoidRecordStreams_) {
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
}
auto ncclDataType = getNcclDataType(input.scalar_type());
auto ncclReduceOp =
getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm);
@ -4693,7 +4716,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter_tensor_coalesced(
stream.stream());
},
OpType::COALESCED,
opts.asyncOp,
"nccl:reduce_scatter_tensor_coalesced");
}
@ -4772,28 +4794,13 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::barrier(const BarrierOptions& opts) {
at::zeros({1}, at::TensorOptions().device(barDevice).dtype(at::kFloat));
// All reduce to achieve the barrier
AllreduceOptions arOpts = AllreduceOptions();
arOpts.asyncOp = opts.asyncOp;
auto work = allreduce_impl(barrierTensor, "nccl:all_reduce_barrier", arOpts);
auto work = allreduce_impl(barrierTensor, "nccl:all_reduce_barrier");
if (opts.asyncOp) {
// Work will take over barrierTensors
auto ncclWork = dynamic_cast<ProcessGroupNCCL::WorkNCCL*>(work.get());
// If user specified async, the work should not be nullptr
TORCH_CHECK(ncclWork);
// Put a marker here so that `work.wait()` issue by users does
// barrier-specific thing: CPU sync
ncclWork->isBarrierOp_ = true;
return work;
}
// Otherwise, we are in sync mode, we directly wait here.
// (It is a CPU wait for barrier)
auto currentStream = at::cuda::getCurrentCUDAStream(barDevIdx);
// CUDAStream wrapper will correctly use a DeviceGuard here
currentStream.synchronize();
// No work to return
return nullptr;
// Work will take over barrierTensors
auto ncclWork = dynamic_cast<ProcessGroupNCCL::WorkNCCL*>(work.get());
TORCH_CHECK(ncclWork);
ncclWork->isBarrierOp_ = true;
return work;
}
c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
@ -4801,7 +4808,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
at::Tensor& inputTensor,
std::vector<int64_t>& outputSplitSizes,
std::vector<int64_t>& inputSplitSizes,
const AllToAllOptions& opts) {
const AllToAllOptions& /* unused */) {
check_gpu_single_tensor(outputTensor);
check_gpu_single_tensor(inputTensor);
if (outputSplitSizes.empty() && inputSplitSizes.empty()) {
@ -4832,12 +4839,16 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
at::Tensor& output,
ncclComm_t comm,
at::cuda::CUDAStream& stream) {
// See [Sync Streams].
if (!avoidRecordStreams_) {
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
}
torch::cuda::nccl::all2all_single_equal_split(
input, output, this->getSize(), comm, stream);
return ncclSuccess;
},
OpType::ALLTOALL_BASE,
opts.asyncOp,
"nccl:all_to_all");
} else {
c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_);
@ -4879,6 +4890,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
c10d::computeLengthsAndOffsets(
outputSplitSizes, output, &recv_lengths, &recv_offsets);
// See [Sync Streams].
if (!avoidRecordStreams_) {
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
}
torch::cuda::nccl::all2all_single_unequal_split(
input.data_ptr(),
send_lengths.data(),
@ -4893,7 +4908,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
return ncclSuccess;
},
OpType::ALLTOALL_BASE,
opts.asyncOp,
"nccl:all_to_all");
}
}
@ -4901,7 +4915,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllToAllOptions& opts) {
const AllToAllOptions& /* unused */) {
std::vector<int64_t> inSplitSizes;
std::vector<int64_t> outSplitSizes;
int64_t total_numel = 0;
@ -4948,11 +4962,18 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall(
return ncclSuccess;
},
[&](at::cuda::CUDAStream&,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
if (avoidRecordStreams_) {
// inputTensor0 and outputTensor0 are stashed redundantly by
// collective(), but that's ok.
auto& v = work->stashed_for_allocator_safety_;
v->insert(v->end(), inputTensors.begin(), inputTensors.end());
v->insert(v->end(), outputTensors.begin(), outputTensors.end());
}
},
[](at::cuda::CUDAStream&,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
OpType::ALLTOALL,
opts.asyncOp,
"nccl:all_to_all");
}
@ -5150,6 +5171,14 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::gather(
ncclComm_t comm,
at::cuda::CUDAStream& stream) {
const auto root = opts.rootRank;
if (getRank() == root) {
if (!avoidRecordStreams_) {
for (auto const& output : outputs) {
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
}
}
}
torch::cuda::nccl::gather(
inputTensor, outputs, comm, stream, static_cast<int32_t>(root));
return ncclSuccess;
@ -5159,7 +5188,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::gather(
[](at::cuda::CUDAStream&,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
OpType::GATHER,
opts.asyncOp,
"nccl:gather");
}
@ -5228,6 +5256,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::scatter(
// avoidRecordStreams_ note: collective() will stash outputTensors and
// inputs, which == inputTensors[0] on the root rank where it matters.
bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp);
const auto root = opts.rootRank;
bool nanCheck = (rank_ == root);
@ -5239,6 +5269,14 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::scatter(
at::Tensor& /* unused */,
ncclComm_t comm,
at::cuda::CUDAStream& stream) {
if (getRank() == root) {
if (!avoidRecordStreams) {
for (auto const& input : inputs) {
c10::cuda::CUDACachingAllocator::recordStream(
input.storage().data_ptr(), stream);
}
}
}
torch::cuda::nccl::scatter(
inputs, outputTensor, comm, stream, static_cast<int32_t>(root));
return ncclSuccess;
@ -5248,8 +5286,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::scatter(
[](at::cuda::CUDAStream&,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
OpType::SCATTER,
opts.asyncOp,
"nccl:scatter",
avoidRecordStreams,
nanCheck);
}
@ -5305,6 +5343,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_allgather_base(
// stream so that the caching allocator can reuse memory pool for this stream
// in a clever way. This setting is added for libraries like FSDP which uses
// `all_gather_into_tensor`.
bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp);
return collective(
input_tensor,
@ -5313,6 +5352,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_allgather_base(
at::Tensor& output,
ncclComm_t comm,
at::cuda::CUDAStream& stream) {
if (!avoidRecordStreams) {
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
}
return ncclAllGather(
input.data_ptr(),
output.data_ptr(),
@ -5322,8 +5365,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_allgather_base(
stream.stream());
},
OpType::_ALLGATHER_BASE,
opts.asyncOp,
"nccl:_all_gather_base");
"nccl:_all_gather_base",
avoidRecordStreams);
}
// Create a memory allocator for NCCL. This allocator is used to allocate memory

View File

@ -382,6 +382,9 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// Clone of blockingWait_ from ProcessGroupNCCL.
bool blockingWait_{false};
// Clone of avoidRecordStreams_ from ProcessGroupNCCL.
bool avoidRecordStreams_{false};
// Clone of opTimeout_ from ProcessGroupNCCL.
std::chrono::milliseconds opTimeout_{};
@ -428,13 +431,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// exception_ptr.
bool finishedGPUExecutionInternal() const;
// Stash tensors so that CachingAllocator cannot recycle them prematurely.
// Used in case of async ops.
void stashTensors(std::vector<at::Tensor>& tensors);
// Unstage the stashed tensors so that CachingAllocator can recycle them
void unstashTensors();
// Reference to the store so that we can write aborted communicators
// to the store.
c10::intrusive_ptr<Store> store_;
@ -454,9 +450,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// For in-place collectives, some refs stashed here may alias outputs_,
// but that doesn't do any harm.
std::shared_ptr<std::vector<at::Tensor>> stashed_for_allocator_safety_;
// Need a mutex to protect stashed_for_allocator_safety_ because it can be
// accessed from both main thread and watchdog thread.
std::mutex stashMutex_;
// The future returned by getFuture.
c10::intrusive_ptr<at::ivalue::Future> future_;
@ -885,8 +878,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
at::Tensor& output,
Fn fn,
OpType opType,
bool asyncOp,
const char* profilingTitle = nullptr,
bool avoidRecordStreams = false,
bool nanCheck = true);
template <typename Fn, typename PreProcess, typename PostProcess>
@ -897,8 +890,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
PreProcess pre,
PostProcess post,
OpType opType,
bool asyncOp,
const char* profilingTitle = nullptr,
bool avoidRecordStreams = false,
bool nanCheck = true);
template <typename Fn, typename PreProcess, typename PostProcess>
@ -909,8 +902,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
PreProcess pre,
PostProcess post,
OpType opType,
bool asyncOp,
const char* profilingTitle = nullptr,
bool avoidRecordStreams = false,
bool nanCheck = true);
template <typename Fn>
@ -919,8 +912,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
std::vector<at::Tensor>& output,
Fn fn,
OpType opType,
bool asyncOp,
const char* profilingTitle = nullptr);
const char* profilingTitle = nullptr,
bool avoidRecordStreams = false);
// Helper that encapsulates work shared across point-to-point communication
// primitives. It is the same structure as the helper used for collective
@ -1229,9 +1222,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// Stores communicators for all collectives run inside a coalescing block
std::shared_ptr<NCCLComm> coalescedComm_ = nullptr;
// Whether the coalesced calls are sync or async.
bool coalescedAsync_;
// Whether or not wait() and synchronize() are blocking operations that wait
// for the operation to complete.
bool blockingWait_ = false;

View File

@ -122,7 +122,6 @@ struct BroadcastOptions {
struct AllreduceOptions {
ReduceOp reduceOp = ReduceOp::SUM;
std::chrono::milliseconds timeout = kUnsetTimeout;
bool asyncOp = true;
std::optional<at::Tensor> sparseIndices = std::nullopt;
};
@ -133,7 +132,6 @@ struct ReduceOptions {
int64_t rootRank = 0;
int64_t rootTensor = 0;
std::chrono::milliseconds timeout = kUnsetTimeout;
bool asyncOp = true;
};
struct AllgatherOptions {
@ -144,7 +142,6 @@ struct AllgatherOptions {
struct GatherOptions {
int64_t rootRank = 0;
std::chrono::milliseconds timeout = kUnsetTimeout;
bool asyncOp = true;
};
struct ScatterOptions {
@ -161,14 +158,12 @@ struct ReduceScatterOptions {
struct AllToAllOptions {
std::chrono::milliseconds timeout = kUnsetTimeout;
bool asyncOp = true;
};
struct BarrierOptions {
std::vector<int64_t> device_ids;
std::chrono::milliseconds timeout = kUnsetTimeout;
std::optional<at::Device> device;
bool asyncOp = true;
};
struct DistributedBackendOptions {

View File

@ -999,23 +999,20 @@ This class does not support ``__members__`` property.)");
py::class_<::c10d::AllreduceOptions>(module, "AllreduceOptions")
.def(py::init<>())
.def_readwrite("reduceOp", &::c10d::AllreduceOptions::reduceOp)
.def_readwrite("timeout", &::c10d::AllreduceOptions::timeout)
.def_readwrite("asyncOp", &::c10d::AllreduceOptions::asyncOp);
.def_readwrite("timeout", &::c10d::AllreduceOptions::timeout);
py::class_<::c10d::AllreduceCoalescedOptions>(
module, "AllreduceCoalescedOptions")
.def(py::init<>())
.def_readwrite("reduceOp", &::c10d::AllreduceCoalescedOptions::reduceOp)
.def_readwrite("timeout", &::c10d::AllreduceCoalescedOptions::timeout)
.def_readwrite("asyncOp", &::c10d::AllreduceCoalescedOptions::asyncOp);
.def_readwrite("timeout", &::c10d::AllreduceCoalescedOptions::timeout);
py::class_<::c10d::ReduceOptions>(module, "ReduceOptions")
.def(py::init<>())
.def_readwrite("reduceOp", &::c10d::ReduceOptions::reduceOp)
.def_readwrite("rootRank", &::c10d::ReduceOptions::rootRank)
.def_readwrite("rootTensor", &::c10d::ReduceOptions::rootTensor)
.def_readwrite("timeout", &::c10d::ReduceOptions::timeout)
.def_readwrite("asyncOp", &::c10d::ReduceOptions::asyncOp);
.def_readwrite("timeout", &::c10d::ReduceOptions::timeout);
py::class_<::c10d::AllgatherOptions>(module, "AllgatherOptions")
.def(py::init<>())
@ -1025,8 +1022,7 @@ This class does not support ``__members__`` property.)");
py::class_<::c10d::GatherOptions>(module, "GatherOptions")
.def(py::init<>())
.def_readwrite("rootRank", &::c10d::GatherOptions::rootRank)
.def_readwrite("timeout", &::c10d::GatherOptions::timeout)
.def_readwrite("asyncOp", &::c10d::GatherOptions::asyncOp);
.def_readwrite("timeout", &::c10d::GatherOptions::timeout);
py::class_<::c10d::ScatterOptions>(module, "ScatterOptions")
.def(py::init<>())
@ -1044,13 +1040,11 @@ This class does not support ``__members__`` property.)");
.def(py::init<>())
.def_readwrite("device_ids", &::c10d::BarrierOptions::device_ids)
.def_readwrite("timeout", &::c10d::BarrierOptions::timeout)
.def_readwrite("device", &::c10d::BarrierOptions::device)
.def_readwrite("asyncOp", &::c10d::BarrierOptions::asyncOp);
.def_readwrite("device", &::c10d::BarrierOptions::device);
py::class_<::c10d::AllToAllOptions>(module, "AllToAllOptions")
.def(py::init<>())
.def_readwrite("timeout", &::c10d::AllToAllOptions::timeout)
.def_readwrite("asyncOp", &::c10d::AllToAllOptions::asyncOp);
.def_readwrite("timeout", &::c10d::AllToAllOptions::timeout);
py::class_<::c10d::DistributedBackendOptions>(
module, "_DistributedBackendOptions")

View File

@ -2500,7 +2500,7 @@ class _CoalescingManager:
def __init__(self) -> None:
self.works: list[Work] = []
def append(self, work: Optional[Work] = None):
def append(self, work: Work):
if work:
self.works.append(work)
@ -2513,7 +2513,7 @@ class _CoalescingManager:
def _coalescing_manager(
group: Optional[ProcessGroup] = None,
device: Optional[torch.device] = None,
async_ops: bool = False,
async_ops: Optional[bool] = False,
):
"""
Context manager used to coalesce collectives or P2P operations when possible.
@ -2552,7 +2552,6 @@ def _coalescing_manager(
group._start_coalescing(device)
cm = _CoalescingManager()
yield cm
work = None
op_list = _world.pg_coalesce_state.pop(group)
if op_list:
# Collectives supporting "Fast Path" coalescing are captured.
@ -2566,7 +2565,6 @@ def _coalescing_manager(
tensors = [op.tensor for op in op_list]
all_reduce_opts = AllreduceCoalescedOptions()
all_reduce_opts.reduceOp = not_none(op_list[0].redop)
all_reduce_opts.asyncOp = async_ops
work = group.allreduce_coalesced(tensors, all_reduce_opts)
elif op0 == all_gather_into_tensor:
inputs = []
@ -2574,8 +2572,6 @@ def _coalescing_manager(
for op in op_list:
inputs.append(op.tensor)
outputs.append(not_none(op.dst_tensor))
all_gather_opts = AllgatherOptions()
all_gather_opts.asyncOp = async_ops
work = group.allgather_into_tensor_coalesced(outputs, inputs)
elif op0 == reduce_scatter_tensor:
inputs = []
@ -2585,7 +2581,6 @@ def _coalescing_manager(
outputs.append(not_none(op.dst_tensor))
reduce_opts = ReduceScatterOptions()
reduce_opts.reduceOp = not_none(op_list[0].redop)
reduce_opts.asyncOp = async_ops
work = group.reduce_scatter_tensor_coalesced(outputs, inputs, reduce_opts)
else:
raise AssertionError(
@ -2598,12 +2593,9 @@ def _coalescing_manager(
work = group._end_coalescing(device)
if async_ops:
cm.append(work)
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait()
# Otherwise, the backend has sync'ed at CPP level
cm.append(work) # type: ignore[possibly-undefined]
else:
work.wait() # type: ignore[possibly-undefined]
def batch_isend_irecv(p2p_op_list: list[P2POp]) -> list[Work]:
@ -2728,11 +2720,8 @@ def broadcast(
work = group.broadcast([tensor], opts)
if async_op:
return work
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
else:
work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger
@ -2812,7 +2801,6 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
opts = AllreduceOptions()
opts.reduceOp = op
opts.asyncOp = async_op
if group is None:
group = _get_default_group()
@ -2829,11 +2817,8 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
if async_op:
return work
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
else:
work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger
@ -2892,17 +2877,13 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False):
opts = AllreduceCoalescedOptions()
opts.reduceOp = op
opts.asyncOp = async_op
group = group or _get_default_group()
work = group.allreduce_coalesced(tensors, opts)
if async_op:
return work.get_future()
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
else:
work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger
@ -2947,15 +2928,11 @@ def reduce(
opts = ReduceOptions()
opts.reduceOp = op
opts.rootRank = group_dst
opts.asyncOp = async_op
work = group.reduce([tensor], opts)
if async_op:
return work
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
else:
work.wait()
# Otherwise, the backend has sync'ed at CPP level
def _object_to_tensor(obj, device, group):
@ -3754,17 +3731,12 @@ def all_gather(tensor_list, tensor, group=None, async_op=False):
tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor)
group = group or _get_default_group()
opts = AllgatherOptions()
opts.asyncOp = async_op
work = group.allgather([tensor_list], [tensor], opts)
work = group.allgather([tensor_list], [tensor])
if async_op:
return work
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
else:
work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger
@ -3871,11 +3843,8 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal
if async_op:
return work
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
else:
work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger
@ -3985,17 +3954,12 @@ def all_gather_coalesced(
]
group = group or _get_default_group()
opts = AllgatherOptions()
opts.asyncOp = async_op
work = group.allgather_coalesced(output_tensor_lists, input_tensor_list, opts)
work = group.allgather_coalesced(output_tensor_lists, input_tensor_list)
if async_op:
return work.get_future()
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
else:
work.wait()
# Otherwise, the backend has sync'ed at CPP level
def _validate_output_list_for_rank(my_rank, dst, gather_list):
@ -4082,16 +4046,12 @@ def gather(
opts = GatherOptions()
opts.rootRank = group_dst
opts.asyncOp = async_op
work = group.gather(output_tensors, input_tensors, opts)
if async_op:
return work
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
else:
work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger
@ -4193,11 +4153,8 @@ def scatter(
if async_op:
return work
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
else:
work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger
@ -4229,18 +4186,14 @@ def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=Fal
opts = ReduceScatterOptions()
opts.reduceOp = op
opts.asyncOp = async_op
group = group or _get_default_group()
work = group.reduce_scatter([output], [input_list], opts)
if async_op:
return work
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
else:
work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger
@ -4340,11 +4293,8 @@ def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=F
if async_op:
return work
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
else:
work.wait()
# Otherwise, the backend has sync'ed at CPP level
@deprecated(
@ -4497,7 +4447,6 @@ def all_to_all_single(
return
opts = AllToAllOptions()
opts.asyncOp = async_op
_check_single_tensor(output, "output")
_check_single_tensor(input, "input")
_ensure_all_tensors_same_dtype(output, input)
@ -4517,11 +4466,8 @@ def all_to_all_single(
if async_op:
return work
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
else:
work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger
@ -4622,7 +4568,6 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False
return
opts = AllToAllOptions()
opts.asyncOp = async_op
_check_tensor_list(output_tensor_list, "output_tensor_list")
_check_tensor_list(input_tensor_list, "input_tensor_list")
_ensure_all_tensors_same_dtype(output_tensor_list, input_tensor_list)
@ -4639,11 +4584,8 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False
if async_op:
return work
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
else:
work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger
@ -4674,7 +4616,6 @@ def barrier(
opts = BarrierOptions()
opts.device = torch.device(_get_object_coll_device(group))
opts.asyncOp = async_op
if device_ids is not None:
if isinstance(device_ids, list):
opts.device_ids = device_ids
@ -4688,11 +4629,8 @@ def barrier(
if async_op:
return work
elif (
work is not None
): # Backward compatible with backends that don't sync at CPP level
else:
work.wait()
# Otherwise, the backend has sync'ed at CPP level
def monitored_barrier(

View File

@ -96,7 +96,7 @@ try:
import torchvision
HAS_TORCHVISION = True
except Exception: # Covering both ImportError and RuntimeError
except ImportError:
HAS_TORCHVISION = False
if sys.platform == "win32":
@ -8310,14 +8310,50 @@ class DistributedTest:
def test_compute_bucket_assignment_by_size_sparse_error_with_logger(self):
self._test_compute_bucket_assignment_by_size(use_logger=True)
def _determine_expected_error_verify_model_across_rank(
self, group_to_use, diff_num_params=False
):
# When running with NCCL backend, we don't expect an error on rank 0,
# rather, it will be taken down by TORCH_NCCL_ASYNC_ERROR_HANDLING. When
# running with Gloo or with debug mode wrapper, we expect the error
# to be caught inline.
# All ranks report same error when there is a # of parameter
# mismatch since we use allgather in the impl.
if diff_num_params:
expected_err = "DDP expects same model across all ranks"
ctx = self.assertRaisesRegex(RuntimeError, expected_err)
return ctx, expected_err
is_detail_dbg_mode = dist.get_debug_level() == dist.DebugLevel.DETAIL
if self.rank == 0:
if (
dist.get_backend(group_to_use) == dist.Backend.NCCL
and not is_detail_dbg_mode
):
expected_err = "caught collective operation timeout"
ctx = self.assertRaisesRegex(RuntimeError, expected_err)
else:
expected_err = None
ctx = self.assertRaises(RuntimeError)
else:
expected_err = "appears not to match"
ctx = self.assertRaisesRegex(RuntimeError, expected_err)
return ctx, expected_err
def _test_verify_model_across_rank(self, use_logger):
group_gloo = dist.new_group(
timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
)
# Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test
# determinism.
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
group_to_use = dist.new_group(
backend=dist.get_backend(), timeout=timedelta(seconds=5)
)
torch.cuda.set_device(self.rank)
ctx, expected_err = self._determine_expected_error_verify_model_across_rank(
group_to_use
)
# Create a valid model. The constructor initializes the logger that we use later.
net = EmbeddingNetDifferentParams(0)
@ -8335,8 +8371,7 @@ class DistributedTest:
net.module.lin = nn.Linear(100 if self.rank == 0 else 10, 1)
# if we pass a logger we can verify that it was logged
caught = 0
try:
with ctx:
if use_logger:
_verify_param_shape_across_processes(
net.process_group, list(net.parameters()), net.logger
@ -8345,13 +8380,18 @@ class DistributedTest:
_verify_param_shape_across_processes(
net.process_group, list(net.parameters())
)
except Exception:
caught = 1
# Should only be run by rank 0, and blocking_wait catches and
# reports exception.
dist.barrier(group_to_use)
# As long as there is one rank catching the exception
t = torch.Tensor([caught])
dist.all_reduce(t, group=group_gloo)
self.assertGreater(t, 0)
# We don't check when self.rank != 0 because the logger doesn't log
# the error "Caught collective operation" as that is not thrown in the reducer.
if use_logger and self.rank != 0:
verify_ddp_error_logged(net, expected_err)
# Perform gloo-based barrier to ensure one rank doesn't exit test
# early which causes failure with Barrier.sync.
dist.barrier(group_gloo)
@require_backend_is_available(DistTestCases.backend_feature["gpu"])
@skip_but_pass_in_sandcastle_if(
@ -8369,19 +8409,20 @@ class DistributedTest:
def test_verify_model_across_rank_without_logger(self):
self._test_verify_model_across_rank(use_logger=False)
def _run_test_ddp_model_with_diff_params(self, net, ddp_group, group_gloo):
caught = 0
try:
def _run_test_ddp_model_with_diff_params(self, ctx, net, ddp_group, group_gloo):
with ctx:
net = torch.nn.parallel.DistributedDataParallel(
net.to(self.rank), device_ids=[self.rank], process_group=ddp_group
)
except Exception:
caught = 1
# Should only be run by rank 0, and blocking_wait catches and
# reports exception.
dist.barrier(ddp_group)
# As long as there is one rank catching the exception
t = torch.Tensor([caught])
dist.all_reduce(t, group=group_gloo)
self.assertGreater(t, 0)
# can't use verify_ddp_error_logged here because net was never properly constructed
# Perform gloo-based barrier to ensure one rank doesn't exit test
# early which causes failure with Barrier.sync.
dist.barrier(group_gloo)
@require_backend_is_available(DistTestCases.backend_feature["gpu"])
@skip_but_pass_in_sandcastle_if(
@ -8392,15 +8433,21 @@ class DistributedTest:
group_gloo = dist.new_group(
timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
)
# Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test
# determinism.
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
group_to_use = dist.new_group(
backend=dist.get_backend(), timeout=timedelta(seconds=10)
)
torch.cuda.set_device(self.rank)
ctx, _expected_err = self._determine_expected_error_verify_model_across_rank(
group_to_use
)
# Creates network with different sized embedding table on different
# ranks. This should throw an error during DDP init.
net = EmbeddingNetDifferentParams(self.rank)
self._run_test_ddp_model_with_diff_params(
net, group_to_use, group_gloo
ctx, net, group_to_use, group_gloo
)
@require_backend_is_available(DistTestCases.backend_feature["gpu"])
@ -8412,10 +8459,16 @@ class DistributedTest:
group_gloo = dist.new_group(
timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
)
# Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test
# determinism.
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
group_to_use = dist.new_group(
backend=dist.get_backend(), timeout=timedelta(seconds=10)
)
torch.cuda.set_device(self.rank)
ctx, _expected_err = self._determine_expected_error_verify_model_across_rank(
group_to_use, diff_num_params=True
)
# Creates network with diff # of param across ranks, reducer should
# recognize this and throw appropriate error.
@ -8424,6 +8477,7 @@ class DistributedTest:
)
self._run_test_ddp_model_with_diff_params(
ctx,
net,
group_to_use,
group_gloo,