mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[PGNCCL] Launch kernel on current stream & remove record_stream
entirely (#148590)"
This reverts commit ef6296e7f20d744a0cfed81cab573d60204e7626. Reverted https://github.com/pytorch/pytorch/pull/148590 on behalf of https://github.com/izaitsevfb due to reverted internally, see D71292427 ([comment](https://github.com/pytorch/pytorch/pull/148590#issuecomment-2731114626))
This commit is contained in:
@ -363,9 +363,6 @@ class TestDebugInfoWriter : public c10d::DebugInfoWriter {
|
||||
};
|
||||
|
||||
TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
|
||||
// Note (kwen2501) 03/07/2025
|
||||
// TODO: re-enable
|
||||
GTEST_SKIP() << "Skipping test as the trace write seems unstable.";
|
||||
int heartBeatIntervalInSec = 2;
|
||||
std::string timeInterval = std::to_string(heartBeatIntervalInSec);
|
||||
ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "0", 1) == 0);
|
||||
|
@ -126,9 +126,6 @@ ALLOW_LIST = [
|
||||
("aten::reduce_scatter_tensor", datetime.date(9999, 1, 30)),
|
||||
("aten::all_gather_into_tensor", datetime.date(9999, 1, 30)),
|
||||
("aten::all_reduce", datetime.date(9999, 1, 30)),
|
||||
# These ops are defined in torch/csrc/distributed/c10d/Ops.cpp
|
||||
# TODO: add back restriction when c10d ops can be exported
|
||||
("c10d::.*", datetime.date(9999, 1, 1)),
|
||||
]
|
||||
|
||||
ALLOW_LIST_COMPILED = [
|
||||
|
@ -2,7 +2,7 @@
|
||||
# mypy: disable-error-code="type-arg"
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, overload
|
||||
from typing import Any, overload
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -139,8 +139,6 @@ class BroadcastOptions:
|
||||
class AllreduceOptions:
|
||||
reduceOp: ReduceOp
|
||||
timeout: timedelta
|
||||
asyncOp: bool
|
||||
sparseIndices: Optional[Tensor]
|
||||
|
||||
class AllreduceCoalescedOptions(AllreduceOptions): ...
|
||||
|
||||
@ -149,7 +147,6 @@ class ReduceOptions:
|
||||
rootRank: int
|
||||
rootTensor: int
|
||||
timeout: timedelta
|
||||
asyncOp: bool
|
||||
|
||||
class AllgatherOptions:
|
||||
timeout: timedelta
|
||||
@ -158,7 +155,6 @@ class AllgatherOptions:
|
||||
class GatherOptions:
|
||||
rootRank: int
|
||||
timeout: timedelta
|
||||
asyncOp: bool
|
||||
|
||||
class ScatterOptions:
|
||||
rootRank: int
|
||||
@ -174,11 +170,9 @@ class BarrierOptions:
|
||||
device_ids: list[int]
|
||||
device: torch.device
|
||||
timeout: timedelta
|
||||
asyncOp: bool
|
||||
|
||||
class AllToAllOptions:
|
||||
timeout: timedelta
|
||||
asyncOp: bool
|
||||
|
||||
class Store:
|
||||
def set(self, key: str, value: str): ...
|
||||
|
@ -17,37 +17,37 @@ TORCH_LIBRARY(c10d, m) {
|
||||
.def("wait", [](const c10::intrusive_ptr<Work>& self) { self->wait(); });
|
||||
m.class_<ReduceOp>("ReduceOp").def(torch::init<>());
|
||||
m.def(
|
||||
"broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
|
||||
"broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, bool asyncOp, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
|
||||
m.def(
|
||||
"allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
|
||||
"allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
|
||||
m.def(
|
||||
"allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
|
||||
"allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> __torch__.torch.classes.c10d.Work");
|
||||
m.def(
|
||||
"allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, int timeout=-1) -> (Tensor[][], __torch__.torch.classes.c10d.Work)");
|
||||
"allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int timeout) -> (Tensor[][], __torch__.torch.classes.c10d.Work)");
|
||||
m.def(
|
||||
"_allgather_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, int timeout=-1) -> (Tensor, __torch__.torch.classes.c10d.Work)");
|
||||
"_allgather_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, bool asyncOp, int timeout) -> (Tensor, __torch__.torch.classes.c10d.Work)");
|
||||
m.def(
|
||||
"allgather_coalesced_(Tensor[][] output_lists, Tensor[] input_list, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True) -> __torch__.torch.classes.c10d.Work");
|
||||
"allgather_coalesced_(Tensor[][] output_lists, Tensor[] input_list, __torch__.torch.classes.c10d.ProcessGroup process_group) -> __torch__.torch.classes.c10d.Work");
|
||||
m.def(
|
||||
"allgather_into_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True) -> __torch__.torch.classes.c10d.Work");
|
||||
"allgather_into_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group) -> __torch__.torch.classes.c10d.Work");
|
||||
m.def(
|
||||
"reduce_scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
|
||||
"reduce_scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
|
||||
m.def(
|
||||
"_reduce_scatter_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> (Tensor, __torch__.torch.classes.c10d.Work)");
|
||||
"_reduce_scatter_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool asyncOp, int timeout) -> (Tensor, __torch__.torch.classes.c10d.Work)");
|
||||
m.def(
|
||||
"reduce_scatter_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
|
||||
"reduce_scatter_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> __torch__.torch.classes.c10d.Work");
|
||||
m.def(
|
||||
"reduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int root_rank, int root_tensor, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
|
||||
"reduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int root_rank, int root_tensor, int timeout) -> __torch__.torch.classes.c10d.Work");
|
||||
m.def(
|
||||
"gather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
|
||||
"gather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int timeout) -> __torch__.torch.classes.c10d.Work");
|
||||
m.def(
|
||||
"scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
|
||||
"scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool asyncOp, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
|
||||
m.def(
|
||||
"alltoall_(Tensor[] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
|
||||
"alltoall_(Tensor[] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
|
||||
m.def(
|
||||
"alltoall_base_(Tensor output, Tensor input, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] output_split_sizes, int[] input_split_sizes, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
|
||||
"alltoall_base_(Tensor output, Tensor input, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] output_split_sizes, int[] input_split_sizes, int timeout) -> __torch__.torch.classes.c10d.Work");
|
||||
m.def(
|
||||
"barrier(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
|
||||
"barrier(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, int timeout) -> __torch__.torch.classes.c10d.Work");
|
||||
m.def(
|
||||
"monitored_barrier_(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, int timeout, bool wait_all_ranks) -> ()");
|
||||
m.def(
|
||||
@ -118,7 +118,6 @@ IMPL_RECV_ANY_SOURCE(PrivateUse1)
|
||||
const c10::intrusive_ptr<ReduceOp>& reduce_op, \
|
||||
int64_t root_rank, \
|
||||
int64_t root_tensor, \
|
||||
bool asyncOp, \
|
||||
int64_t timeout) { \
|
||||
auto tensor_vec = tensors.vec(); \
|
||||
return process_group->getBackend(c10::DeviceType::DEV) \
|
||||
@ -128,8 +127,7 @@ IMPL_RECV_ANY_SOURCE(PrivateUse1)
|
||||
*reduce_op.get(), \
|
||||
root_rank, \
|
||||
root_tensor, \
|
||||
std::chrono::milliseconds(timeout), \
|
||||
asyncOp}); \
|
||||
std::chrono::milliseconds(timeout)}); \
|
||||
}
|
||||
|
||||
IMPL_REDUCE(CPU)
|
||||
@ -171,13 +169,12 @@ IMPL_BROADCAST(PrivateUse1)
|
||||
const c10::intrusive_ptr<ProcessGroup>& process_group, \
|
||||
const c10::intrusive_ptr<ReduceOp>& reduce_op, \
|
||||
const std::optional<at::Tensor>& sparse_indices, \
|
||||
bool asyncOp, \
|
||||
int64_t timeout) { \
|
||||
auto tensor_vec = tensors.vec(); \
|
||||
auto work = process_group->getBackend(c10::DeviceType::DEV) -> allreduce( \
|
||||
tensor_vec, \
|
||||
AllreduceOptions{ \
|
||||
*reduce_op.get(), std::chrono::milliseconds(timeout), asyncOp}); \
|
||||
*reduce_op.get(), std::chrono::milliseconds(timeout)}); \
|
||||
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( \
|
||||
std::move(tensor_vec), work); \
|
||||
}
|
||||
@ -191,13 +188,11 @@ IMPL_ALLREDUCE(PrivateUse1)
|
||||
at::TensorList tensors, \
|
||||
const c10::intrusive_ptr<ProcessGroup>& process_group, \
|
||||
const c10::intrusive_ptr<ReduceOp>& reduce_op, \
|
||||
bool asyncOp, \
|
||||
int64_t timeout) { \
|
||||
auto tensor_vec = tensors.vec(); \
|
||||
AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; \
|
||||
opts.reduceOp = *reduce_op.get(); \
|
||||
opts.timeout = std::chrono::milliseconds(timeout); \
|
||||
opts.asyncOp = asyncOp; \
|
||||
return process_group->getBackend(c10::DeviceType::DEV) \
|
||||
->allreduce_coalesced(tensor_vec, opts); \
|
||||
}
|
||||
@ -214,13 +209,12 @@ IMPL_ALLREDUCE_COALESCED(PrivateUse1)
|
||||
const std::vector<std::vector<at::Tensor>>& output_tensors, \
|
||||
at::TensorList input_tensors, \
|
||||
const c10::intrusive_ptr<ProcessGroup>& process_group, \
|
||||
bool asyncOp, \
|
||||
int64_t timeout) { \
|
||||
auto input_tensors_vec = input_tensors.vec(); \
|
||||
auto work = process_group->getBackend(c10::DeviceType::DEV) -> allgather( \
|
||||
const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors), \
|
||||
input_tensors_vec, \
|
||||
AllgatherOptions{std::chrono::milliseconds(timeout), asyncOp}); \
|
||||
AllgatherOptions{std::chrono::milliseconds(timeout)}); \
|
||||
return std:: \
|
||||
tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>>( \
|
||||
output_tensors, work); \
|
||||
@ -255,16 +249,12 @@ IMPL__ALLGATHER_BASE(PrivateUse1)
|
||||
c10::intrusive_ptr<Work> allgather_coalesced_##DEV( \
|
||||
const std::vector<std::vector<at::Tensor>>& output_lists, \
|
||||
const at::TensorList& input_list, \
|
||||
const c10::intrusive_ptr<ProcessGroup>& process_group, \
|
||||
bool asyncOp) { \
|
||||
const c10::intrusive_ptr<ProcessGroup>& process_group) { \
|
||||
auto input_list_vec = input_list.vec(); \
|
||||
auto opts = AllgatherOptions{}; \
|
||||
opts.asyncOp = asyncOp; \
|
||||
return process_group->getBackend(c10::DeviceType::DEV) \
|
||||
->allgather_coalesced( \
|
||||
const_cast<std::vector<std::vector<at::Tensor>>&>(output_lists), \
|
||||
input_list_vec, \
|
||||
opts); \
|
||||
input_list_vec); \
|
||||
}
|
||||
|
||||
IMPL_ALLGATHER_COALESCED(CPU)
|
||||
@ -275,14 +265,11 @@ IMPL_ALLGATHER_COALESCED(PrivateUse1)
|
||||
c10::intrusive_ptr<c10d::Work> allgather_into_tensor_coalesced_##DEV( \
|
||||
at::TensorList outputs, \
|
||||
at::TensorList inputs, \
|
||||
const c10::intrusive_ptr<ProcessGroup>& process_group, \
|
||||
bool asyncOp) { \
|
||||
const c10::intrusive_ptr<ProcessGroup>& process_group) { \
|
||||
auto output_vec = outputs.vec(); \
|
||||
auto input_vec = inputs.vec(); \
|
||||
auto opts = AllgatherOptions{}; \
|
||||
opts.asyncOp = asyncOp; \
|
||||
return process_group->getBackend(c10::DeviceType::DEV) \
|
||||
->allgather_into_tensor_coalesced(output_vec, input_vec, opts); \
|
||||
->allgather_into_tensor_coalesced(output_vec, input_vec); \
|
||||
}
|
||||
|
||||
IMPL_ALLGATHER_INTO_TENSOR_COALESCED(CPU)
|
||||
@ -296,7 +283,6 @@ IMPL_ALLGATHER_INTO_TENSOR_COALESCED(PrivateUse1)
|
||||
const std::vector<std::vector<at::Tensor>>& input_tensors, \
|
||||
const c10::intrusive_ptr<ProcessGroup>& process_group, \
|
||||
const c10::intrusive_ptr<ReduceOp>& reduce_op, \
|
||||
bool asyncOp, \
|
||||
int64_t timeout) { \
|
||||
auto output_tensors_vec = output_tensors.vec(); \
|
||||
auto work = \
|
||||
@ -304,9 +290,7 @@ IMPL_ALLGATHER_INTO_TENSOR_COALESCED(PrivateUse1)
|
||||
output_tensors_vec, \
|
||||
const_cast<std::vector<std::vector<at::Tensor>>&>(input_tensors), \
|
||||
ReduceScatterOptions{ \
|
||||
*reduce_op.get(), \
|
||||
std::chrono::milliseconds(timeout), \
|
||||
asyncOp}); \
|
||||
*reduce_op.get(), std::chrono::milliseconds(timeout)}); \
|
||||
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( \
|
||||
output_tensors_vec, work); \
|
||||
}
|
||||
@ -345,7 +329,6 @@ IMPL__REDUCE_SCATTER_BASE(PrivateUse1)
|
||||
at::TensorList inputs, \
|
||||
const c10::intrusive_ptr<ProcessGroup>& process_group, \
|
||||
const c10::intrusive_ptr<ReduceOp>& reduce_op, \
|
||||
bool asyncOp, \
|
||||
int64_t timeout) { \
|
||||
auto output_vec = outputs.vec(); \
|
||||
auto input_vec = inputs.vec(); \
|
||||
@ -354,9 +337,7 @@ IMPL__REDUCE_SCATTER_BASE(PrivateUse1)
|
||||
output_vec, \
|
||||
input_vec, \
|
||||
ReduceScatterOptions{ \
|
||||
*reduce_op.get(), \
|
||||
std::chrono::milliseconds(timeout), \
|
||||
asyncOp}); \
|
||||
*reduce_op.get(), std::chrono::milliseconds(timeout)}); \
|
||||
}
|
||||
|
||||
IMPL_REDUCE_SCATTER_TENSOR_COALESCED(CPU)
|
||||
@ -369,15 +350,13 @@ IMPL_REDUCE_SCATTER_TENSOR_COALESCED(PrivateUse1)
|
||||
const at::TensorList& input_tensors, \
|
||||
const c10::intrusive_ptr<ProcessGroup>& process_group, \
|
||||
int64_t root_rank, \
|
||||
bool asyncOp, \
|
||||
int64_t timeout) { \
|
||||
auto input_tensors_vec = input_tensors.vec(); \
|
||||
return process_group->getBackend(c10::DeviceType::DEV) \
|
||||
->gather( \
|
||||
const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors), \
|
||||
input_tensors_vec, \
|
||||
GatherOptions{ \
|
||||
root_rank, std::chrono::milliseconds(timeout), asyncOp}); \
|
||||
GatherOptions{root_rank, std::chrono::milliseconds(timeout)}); \
|
||||
}
|
||||
|
||||
IMPL_GATHER(CPU)
|
||||
@ -412,14 +391,13 @@ IMPL_SCATTER(PrivateUse1)
|
||||
const at::TensorList& output_tensors, \
|
||||
const at::TensorList& input_tensors, \
|
||||
const c10::intrusive_ptr<ProcessGroup>& process_group, \
|
||||
bool asyncOp, \
|
||||
int64_t timeout) { \
|
||||
auto output_tensors_vec = output_tensors.vec(); \
|
||||
auto input_tensors_vec = input_tensors.vec(); \
|
||||
auto work = process_group->getBackend(c10::DeviceType::DEV) -> alltoall( \
|
||||
output_tensors_vec, \
|
||||
input_tensors_vec, \
|
||||
AllToAllOptions{std::chrono::milliseconds(timeout), asyncOp}); \
|
||||
AllToAllOptions{std::chrono::milliseconds(timeout)}); \
|
||||
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( \
|
||||
std::move(output_tensors_vec), work); \
|
||||
}
|
||||
@ -428,22 +406,21 @@ IMPL_ALLTOALL(CPU)
|
||||
IMPL_ALLTOALL(CUDA)
|
||||
IMPL_ALLTOALL(PrivateUse1)
|
||||
|
||||
#define IMPL_ALLTOALL_BASE(DEV) \
|
||||
c10::intrusive_ptr<Work> alltoall_base_##DEV( \
|
||||
at::Tensor& output, \
|
||||
at::Tensor& input, \
|
||||
const c10::intrusive_ptr<ProcessGroup>& process_group, \
|
||||
std::vector<int64_t> output_split_sizes, \
|
||||
std::vector<int64_t> input_split_sizes, \
|
||||
bool asyncOp, \
|
||||
int64_t timeout) { \
|
||||
return process_group->getBackend(c10::DeviceType::DEV) \
|
||||
->alltoall_base( \
|
||||
output, \
|
||||
input, \
|
||||
output_split_sizes, \
|
||||
input_split_sizes, \
|
||||
AllToAllOptions{std::chrono::milliseconds(timeout), asyncOp}); \
|
||||
#define IMPL_ALLTOALL_BASE(DEV) \
|
||||
c10::intrusive_ptr<Work> alltoall_base_##DEV( \
|
||||
at::Tensor& output, \
|
||||
at::Tensor& input, \
|
||||
const c10::intrusive_ptr<ProcessGroup>& process_group, \
|
||||
std::vector<int64_t> output_split_sizes, \
|
||||
std::vector<int64_t> input_split_sizes, \
|
||||
int64_t timeout) { \
|
||||
return process_group->getBackend(c10::DeviceType::DEV) \
|
||||
->alltoall_base( \
|
||||
output, \
|
||||
input, \
|
||||
output_split_sizes, \
|
||||
input_split_sizes, \
|
||||
AllToAllOptions{std::chrono::milliseconds(timeout)}); \
|
||||
}
|
||||
|
||||
IMPL_ALLTOALL_BASE(CPU)
|
||||
@ -451,18 +428,15 @@ IMPL_ALLTOALL_BASE(CUDA)
|
||||
IMPL_ALLTOALL_BASE(PrivateUse1)
|
||||
|
||||
// NOLINTBEGIN(performance-unnecessary-value-param)
|
||||
#define IMPL_BARRIER(DEV) \
|
||||
c10::intrusive_ptr<Work> barrier##DEV( \
|
||||
at::Tensor /* unused */, \
|
||||
const c10::intrusive_ptr<ProcessGroup>& process_group, \
|
||||
const std::vector<int64_t>& device_ids, \
|
||||
bool asyncOp, \
|
||||
int64_t timeout) { \
|
||||
auto opts = BarrierOptions{}; \
|
||||
opts.device_ids = device_ids; \
|
||||
opts.timeout = std::chrono::milliseconds(timeout); \
|
||||
opts.asyncOp = asyncOp; \
|
||||
return process_group->getBackend(c10::DeviceType::DEV)->barrier(opts); \
|
||||
#define IMPL_BARRIER(DEV) \
|
||||
c10::intrusive_ptr<Work> barrier##DEV( \
|
||||
at::Tensor /* unused */, \
|
||||
const c10::intrusive_ptr<ProcessGroup>& process_group, \
|
||||
const std::vector<int64_t>& device_ids, \
|
||||
int64_t timeout) { \
|
||||
return process_group->getBackend(c10::DeviceType::DEV) \
|
||||
->barrier( \
|
||||
BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}); \
|
||||
}
|
||||
|
||||
IMPL_BARRIER(CPU)
|
||||
@ -490,7 +464,6 @@ allreduce_sparse_cuda_(
|
||||
const c10::intrusive_ptr<ProcessGroup>& process_group,
|
||||
const c10::intrusive_ptr<ReduceOp>& reduce_op,
|
||||
const std::optional<at::Tensor>& sparse_indices,
|
||||
bool asyncOp,
|
||||
int64_t timeout) {
|
||||
auto tensor_vec = tensors.vec();
|
||||
auto work = process_group->getBackend(c10::DeviceType::CUDA)
|
||||
@ -499,7 +472,6 @@ allreduce_sparse_cuda_(
|
||||
AllreduceOptions{
|
||||
*reduce_op,
|
||||
std::chrono::milliseconds(timeout),
|
||||
asyncOp,
|
||||
sparse_indices});
|
||||
|
||||
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
|
||||
|
@ -224,7 +224,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
const c10::intrusive_ptr<::c10d::ReduceOp>&,
|
||||
const std::optional<at::Tensor>& sparse_indices,
|
||||
bool,
|
||||
int64_t)>();
|
||||
|
||||
auto work = std::get<1>(op.call(
|
||||
@ -232,7 +231,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
c10::make_intrusive<ReduceOp>(opts.reduceOp),
|
||||
opts.sparseIndices,
|
||||
opts.asyncOp,
|
||||
opts.timeout.count()));
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
@ -252,14 +250,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
at::TensorList,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
const c10::intrusive_ptr<::c10d::ReduceOp>&,
|
||||
bool,
|
||||
int64_t)>();
|
||||
|
||||
auto work = op.call(
|
||||
tensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
c10::make_intrusive<ReduceOp>(opts.reduceOp),
|
||||
opts.asyncOp,
|
||||
opts.timeout.count());
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
@ -281,7 +277,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const c10::intrusive_ptr<::c10d::ReduceOp>&,
|
||||
int64_t,
|
||||
int64_t,
|
||||
bool,
|
||||
int64_t)>();
|
||||
auto work = op.call(
|
||||
tensors,
|
||||
@ -289,7 +284,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
c10::make_intrusive<ReduceOp>(opts.reduceOp),
|
||||
opts.rootRank,
|
||||
opts.rootTensor,
|
||||
opts.asyncOp,
|
||||
opts.timeout.count());
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
@ -312,14 +306,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const std::vector<std::vector<at::Tensor>>&,
|
||||
at::TensorList,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
bool,
|
||||
int64_t)>();
|
||||
|
||||
auto work = std::get<1>(op.call(
|
||||
outputTensors,
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
opts.asyncOp,
|
||||
opts.timeout.count()));
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
@ -371,19 +363,18 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
std::vector<std::vector<at::Tensor>>& outputTensorLists,
|
||||
std::vector<at::Tensor>& inputTensors,
|
||||
const AllgatherOptions& opts = AllgatherOptions()) {
|
||||
static auto op = c10::Dispatcher::singleton()
|
||||
.findSchemaOrThrow("c10d::allgather_coalesced_", "")
|
||||
.typed<c10::intrusive_ptr<Work>(
|
||||
const std::vector<std::vector<at::Tensor>>&,
|
||||
const at::TensorList&,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
bool)>();
|
||||
static auto op =
|
||||
c10::Dispatcher::singleton()
|
||||
.findSchemaOrThrow("c10d::allgather_coalesced_", "")
|
||||
.typed<c10::intrusive_ptr<Work>(
|
||||
const std::vector<std::vector<at::Tensor>>&,
|
||||
const at::TensorList&,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
|
||||
|
||||
auto work = op.call(
|
||||
outputTensorLists,
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
opts.asyncOp);
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
for (const auto& tensor_list : outputTensorLists) {
|
||||
@ -408,14 +399,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
.typed<c10::intrusive_ptr<Work>(
|
||||
const at::TensorList,
|
||||
const at::TensorList,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
bool)>();
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
|
||||
|
||||
auto work = op.call(
|
||||
outputTensors,
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
opts.asyncOp);
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
for (const auto& tensor : outputTensors) {
|
||||
@ -436,14 +425,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const at::TensorList&,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
int64_t,
|
||||
bool,
|
||||
int64_t)>();
|
||||
auto work = op.call(
|
||||
outputTensors,
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
opts.rootRank,
|
||||
opts.asyncOp,
|
||||
opts.timeout.count());
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
@ -500,14 +487,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const std::vector<std::vector<at::Tensor>>&,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
const c10::intrusive_ptr<::c10d::ReduceOp>&,
|
||||
bool,
|
||||
int64_t)>();
|
||||
auto work = std::get<1>(op.call(
|
||||
outputTensors,
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
|
||||
opts.asyncOp,
|
||||
opts.timeout.count()));
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
@ -561,7 +546,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const at::TensorList,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
const c10::intrusive_ptr<::c10d::ReduceOp>&,
|
||||
bool,
|
||||
int64_t)>();
|
||||
|
||||
auto work = op.call(
|
||||
@ -569,7 +553,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
|
||||
opts.asyncOp,
|
||||
opts.timeout.count());
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
@ -594,7 +577,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
std::vector<int64_t>,
|
||||
std::vector<int64_t>,
|
||||
bool,
|
||||
int64_t)>();
|
||||
auto work = op.call(
|
||||
outputBuffer,
|
||||
@ -602,7 +584,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
outputSplitSizes,
|
||||
inputSplitSizes,
|
||||
opts.asyncOp,
|
||||
opts.timeout.count());
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
@ -623,13 +604,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const at::TensorList&,
|
||||
const at::TensorList&,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
bool,
|
||||
int64_t)>();
|
||||
auto work = std::get<1>(op.call(
|
||||
outputTensors,
|
||||
inputTensors,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
opts.asyncOp,
|
||||
opts.timeout.count()));
|
||||
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
@ -799,14 +778,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
at::Tensor,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
|
||||
const std::vector<int64_t>&,
|
||||
bool,
|
||||
int64_t)>();
|
||||
|
||||
auto work = op.call(
|
||||
tensor,
|
||||
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
|
||||
opts.device_ids,
|
||||
opts.asyncOp,
|
||||
opts.timeout.count());
|
||||
if (c10d::allow_inflight_collective_as_graph_input()) {
|
||||
c10d::register_work(tensor, work);
|
||||
|
@ -496,8 +496,6 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(
|
||||
}
|
||||
futureWorkResult_ =
|
||||
c10::make_intrusive<at::ivalue::Future>(c10::AnyEnumType::get());
|
||||
// other functions expect an initialized ptr
|
||||
stashed_for_allocator_safety_ = std::make_shared<std::vector<at::Tensor>>();
|
||||
}
|
||||
|
||||
ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w)
|
||||
@ -519,11 +517,6 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w)
|
||||
numelIn_(w.numelIn_),
|
||||
numelOut_(w.numelOut_),
|
||||
store_(w.store_),
|
||||
// Note: the `work` returned to user and the `work` enqueued to watchdog
|
||||
// share the pointer to the tensor stash. At least one of them should
|
||||
// clean the tensor stash, the earlier the better, i.e. user calling
|
||||
// `work.wait` than watchdog detecting work completion.
|
||||
stashed_for_allocator_safety_(w.stashed_for_allocator_safety_),
|
||||
futureWorkResult_(w.futureWorkResult_),
|
||||
timingEnabled_(w.timingEnabled_),
|
||||
trace_id_(w.trace_id_),
|
||||
@ -717,25 +710,14 @@ void ProcessGroupNCCL::WorkNCCL::synchronize() {
|
||||
}
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::WorkNCCL::stashTensors(
|
||||
std::vector<at::Tensor>& tensors) {
|
||||
std::lock_guard<std::mutex> lock(stashMutex_);
|
||||
stashed_for_allocator_safety_->insert(
|
||||
stashed_for_allocator_safety_->end(), tensors.begin(), tensors.end());
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::WorkNCCL::unstashTensors() {
|
||||
std::lock_guard<std::mutex> lock(stashMutex_);
|
||||
stashed_for_allocator_safety_->clear();
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::WorkNCCL::synchronizeStream() {
|
||||
auto currentStream = at::cuda::getCurrentCUDAStream(device_.index());
|
||||
// Block the current stream on the NCCL stream
|
||||
ncclEndEvent_->block(currentStream);
|
||||
// Unstage the stashed tensors so that CachingAllocator can recycle them
|
||||
// THIS MUST HAPPEN AFTER THE BLOCKING CALL ABOVE
|
||||
unstashTensors();
|
||||
|
||||
if (avoidRecordStreams_) {
|
||||
stashed_for_allocator_safety_->clear();
|
||||
}
|
||||
}
|
||||
|
||||
// Same as calling synchronize() when blockingWait_ is false
|
||||
@ -951,10 +933,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||
enableTiming_.store(
|
||||
getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug_);
|
||||
#endif // ENABLE_NCCL_ERROR_CHECKING
|
||||
if (getCvarBool(TORCH_NCCL_AVOID_RECORD_STREAMS, false)) {
|
||||
TORCH_WARN_ONCE(
|
||||
"TORCH_NCCL_AVOID_RECORD_STREAMS is the default now, this environment variable is thus deprecated.");
|
||||
}
|
||||
avoidRecordStreams_ = getCvarBool(TORCH_NCCL_AVOID_RECORD_STREAMS, false);
|
||||
#ifdef NCCL_HAS_COMM_REGISTER
|
||||
useTensorRegisterAllocatorHook_ =
|
||||
getCvarBool(TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK, false);
|
||||
@ -2344,12 +2323,6 @@ void ProcessGroupNCCL::watchdogHandler() {
|
||||
|
||||
// Clean up completed work
|
||||
if (work.isCompleted()) {
|
||||
// In case user didn't call `work.wait()` with async collectives,
|
||||
// watchdog would unstage the stashed tensors when detecting completion
|
||||
// of the collective, to prevent ProcessGroupNCCL from holding reference
|
||||
// to those tensors forever.
|
||||
work.unstashTensors();
|
||||
|
||||
// Work status logging for desync debug
|
||||
desyncDebugger_.logWorkEnd(work);
|
||||
|
||||
@ -3084,7 +3057,6 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
|
||||
enableTiming_.load(),
|
||||
cudaEventCacheEnabled_.load(),
|
||||
dist_debug_level_);
|
||||
|
||||
if (record) {
|
||||
bool isP2P = isP2POp(opType);
|
||||
// Ideally record every work that we enqueue, rather than every work we
|
||||
@ -3242,6 +3214,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::endCoalescing(OpType optype) {
|
||||
enqueue);
|
||||
work->ncclComm_ = comm;
|
||||
work->blockingWait_ = blockingWait_;
|
||||
work->avoidRecordStreams_ = avoidRecordStreams_;
|
||||
work->store_ = store_;
|
||||
assignTimeoutToWork(work, options_);
|
||||
|
||||
@ -3260,16 +3233,19 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::endCoalescing(OpType optype) {
|
||||
// TODO(eqy): is this still necessary if avoidRecordStreams_ is set?
|
||||
work->ncclEndEvent_->record(ncclStream);
|
||||
|
||||
if (avoidRecordStreams_) {
|
||||
// other functions expect an initialized ptr if avoidRecordStreams_ is set
|
||||
work->stashed_for_allocator_safety_ =
|
||||
std::make_shared<std::vector<at::Tensor>>();
|
||||
}
|
||||
|
||||
if (enqueue) {
|
||||
workEnqueue(work);
|
||||
}
|
||||
|
||||
// Reset coalescing state
|
||||
coalescing_state_ = 0;
|
||||
coalescedComm_ = nullptr;
|
||||
// If in async mode, return work; otherwise, kernel is enqueued on current
|
||||
// stream, no need to return work
|
||||
return coalescedAsync_ ? work : nullptr;
|
||||
return work;
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Work> ProcessGroupNCCL::endCoalescing() {
|
||||
@ -3285,10 +3261,11 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
|
||||
PreProcess pre,
|
||||
PostProcess post,
|
||||
OpType opType,
|
||||
bool asyncOp,
|
||||
const char* profilingTitle,
|
||||
bool avoidRecordStreams,
|
||||
bool nanCheck) {
|
||||
// Environment setting by the user may add onto collective call's option
|
||||
avoidRecordStreams |= avoidRecordStreams_;
|
||||
nanCheck &= enableNanCheck_;
|
||||
|
||||
auto device = getDevice(inputs[0]);
|
||||
@ -3329,17 +3306,13 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
|
||||
} else {
|
||||
TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG);
|
||||
}
|
||||
coalescedAsync_ = asyncOp;
|
||||
}
|
||||
|
||||
// in asyncOp=false [default] mode, we use currentStream as ncclStream
|
||||
// otherwise, we use separate ncclStream and let it sync on currentStream
|
||||
auto ncclStream = asyncOp ? ncclStreams_.at(key)
|
||||
: at::cuda::getCurrentCUDAStream(device.index());
|
||||
if (asyncOp) {
|
||||
// First let NCCL streams wait for input tensors allocation streams
|
||||
syncStream(device, ncclEvents_[key], ncclStream);
|
||||
}
|
||||
// Used many times below, so we stash the unordered_map lookup
|
||||
auto ncclStream = ncclStreams_.at(key);
|
||||
|
||||
// First let NCCL streams wait for input tensors allocation streams
|
||||
syncStream(device, ncclEvents_[key], ncclStream);
|
||||
|
||||
bool enqueue =
|
||||
!coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None;
|
||||
@ -3349,12 +3322,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
|
||||
// Store references to outputs to be used by WorkNCCL::result and operator<<.
|
||||
work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);
|
||||
|
||||
// If we are performing sync operations, i.e. equeuing kernel onto "current"
|
||||
// stream, we don't need to do anything for tensor lifetime management.
|
||||
// Otherwise, we need to stage the tensors will `work.wait()`.
|
||||
if (asyncOp) {
|
||||
work->stashTensors(inputs);
|
||||
work->stashTensors(outputs);
|
||||
if (avoidRecordStreams) {
|
||||
work->stashed_for_allocator_safety_ =
|
||||
std::make_shared<std::vector<at::Tensor>>(inputs);
|
||||
}
|
||||
|
||||
if (nanCheck) {
|
||||
@ -3380,6 +3350,21 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
|
||||
// operations where `inputs' and `outputs' are not the same.
|
||||
//
|
||||
// See [Sync Streams].
|
||||
if (!avoidRecordStreams) {
|
||||
for (const auto& input : inputs) {
|
||||
if (!input.is_sparse()) {
|
||||
c10::cuda::CUDACachingAllocator::recordStream(
|
||||
input.storage().data_ptr(), ncclStream);
|
||||
} else {
|
||||
// for sparse input case record streams on both index and value
|
||||
// tensors
|
||||
c10::cuda::CUDACachingAllocator::recordStream(
|
||||
input.values().storage().data_ptr(), ncclStream);
|
||||
c10::cuda::CUDACachingAllocator::recordStream(
|
||||
input.indices().storage().data_ptr(), ncclStream);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Not all collectives have the same signature, e.g, all-reduce take in a Tensor
|
||||
// as the input and output while all-to-all take in a vector of Tensors as input
|
||||
@ -3431,6 +3416,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
|
||||
|
||||
// Set appropriate work parameters.
|
||||
work->blockingWait_ = blockingWait_;
|
||||
work->avoidRecordStreams_ = avoidRecordStreams;
|
||||
work->store_ = store_;
|
||||
assignTimeoutToWork(work, options_);
|
||||
// Record size info for debug. We only record the size on the first device as
|
||||
@ -3448,7 +3434,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
|
||||
workEnqueue(work);
|
||||
}
|
||||
|
||||
return asyncOp ? work : nullptr;
|
||||
return work;
|
||||
}
|
||||
|
||||
template <typename Fn>
|
||||
@ -3457,8 +3443,11 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
|
||||
std::vector<at::Tensor>& outputs,
|
||||
Fn fn,
|
||||
OpType opType,
|
||||
bool asyncOp,
|
||||
const char* profilingTitle) {
|
||||
const char* profilingTitle,
|
||||
bool avoidRecordStreams) {
|
||||
// Environment setting by the user may add onto collective call's option
|
||||
avoidRecordStreams |= avoidRecordStreams_;
|
||||
|
||||
// Currently, the API permits one scenario where inputs.size() and
|
||||
// outputs.size() are > 0.
|
||||
// 1. If the call was a _coalesced call, all inputs must be on the same
|
||||
@ -3504,17 +3493,13 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
|
||||
} else {
|
||||
TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG);
|
||||
}
|
||||
coalescedAsync_ = asyncOp;
|
||||
}
|
||||
|
||||
// in asyncOp=false [default] mode, we use currentStream as ncclStream
|
||||
// otherwise, we use separate ncclStream and let it sync on currentStream
|
||||
auto ncclStream = asyncOp ? ncclStreams_.at(key)
|
||||
: at::cuda::getCurrentCUDAStream(device.index());
|
||||
if (asyncOp) {
|
||||
// First let NCCL streams wait for input tensors allocation streams
|
||||
syncStream(device, ncclEvents_[key], ncclStream);
|
||||
}
|
||||
// Used many times below, so we stash the unordered_map lookup
|
||||
auto ncclStream = ncclStreams_.at(key);
|
||||
|
||||
// First let NCCL streams wait for input tensors allocation streams
|
||||
syncStream(device, ncclEvents_[key], ncclStream);
|
||||
|
||||
auto work = initWork(
|
||||
device,
|
||||
@ -3529,12 +3514,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
|
||||
// Store references to outputs to be used by WorkNCCL::result and operator<<.
|
||||
work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);
|
||||
|
||||
// If we are performing sync operations, i.e. equeuing kernel onto "current"
|
||||
// stream, we don't need to do anything for tensor lifetime management.
|
||||
// Otherwise, we need to stage the tensors will `work.wait()`.
|
||||
if (asyncOp) {
|
||||
work->stashTensors(inputs);
|
||||
work->stashTensors(outputs);
|
||||
if (avoidRecordStreams) {
|
||||
work->stashed_for_allocator_safety_ =
|
||||
std::make_shared<std::vector<at::Tensor>>(inputs);
|
||||
}
|
||||
|
||||
// Start event should only be recorded before the ncclGroupStart() (which
|
||||
@ -3560,6 +3542,27 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
|
||||
{
|
||||
torch::cuda::nccl::AutoNcclGroup nccl_group_guard(comm, useNonblocking());
|
||||
for (const auto i : c10::irange(inputs.size())) {
|
||||
// Both `inputs' and `outputs' are created on a worker stream and used in
|
||||
// different ncclStreams. Hence, both must record the ncclStream to
|
||||
// prevent being freed before the collective finishes.
|
||||
//
|
||||
// We only record `inputs' here, and leave recording `outputs' to `fn' for
|
||||
// operations where `inputs' and `outputs' are not the same.
|
||||
//
|
||||
// See [Sync Streams].
|
||||
if (!avoidRecordStreams) {
|
||||
if (!inputs[i].is_sparse()) {
|
||||
c10::cuda::CUDACachingAllocator::recordStream(
|
||||
inputs[i].storage().data_ptr(), ncclStream);
|
||||
} else {
|
||||
// for sparse input case record streams on both index and value
|
||||
// tensors
|
||||
c10::cuda::CUDACachingAllocator::recordStream(
|
||||
inputs[i].values().storage().data_ptr(), ncclStream);
|
||||
c10::cuda::CUDACachingAllocator::recordStream(
|
||||
inputs[i].indices().storage().data_ptr(), ncclStream);
|
||||
}
|
||||
}
|
||||
#ifndef NCCL_HAS_COMM_NONBLOCKING
|
||||
C10D_NCCL_CHECK(
|
||||
fn(inputs[i], outputs[i], comm, ncclStream),
|
||||
@ -3600,6 +3603,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
|
||||
|
||||
// Set appropriate work parameters.
|
||||
work->blockingWait_ = blockingWait_;
|
||||
work->avoidRecordStreams_ = avoidRecordStreams;
|
||||
work->store_ = store_;
|
||||
assignTimeoutToWork(work, options_);
|
||||
// Record size info for debug. We only record the size on the first device as
|
||||
@ -3630,7 +3634,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
|
||||
// it, since interactions with it by usercode won't behave normally - they
|
||||
// won't observe work completion, for instance. Will this lead to silent
|
||||
// problems during capture?
|
||||
return asyncOp ? work : nullptr;
|
||||
return work;
|
||||
}
|
||||
|
||||
template <typename Fn, typename PreProcess, typename PostProcess>
|
||||
@ -3648,8 +3652,13 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
|
||||
// to wait() on the returned handle, so ProcessGroupNCCL can't know
|
||||
// when it's safe to release the input back to the allocator,
|
||||
// and the present call has no way to know it's not an isend.
|
||||
// Therefore, we warn and fall back to the typical recordStream logic.
|
||||
// TODO( kwen2501 ): revisit this when we have a better solution.
|
||||
// Therefore, we warn and fall back to the typical recordStream logic:
|
||||
if (avoidRecordStreams_) {
|
||||
TORCH_WARN_ONCE(
|
||||
"TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point "
|
||||
"collectives.");
|
||||
}
|
||||
|
||||
auto device = getDevice(tensor);
|
||||
at::cuda::OptionalCUDAGuard gpuGuard(device);
|
||||
|
||||
@ -3704,8 +3713,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
|
||||
} else {
|
||||
TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG);
|
||||
}
|
||||
// For now, P2P ops are always put on internal stream
|
||||
coalescedAsync_ = true;
|
||||
}
|
||||
|
||||
// Used many times below, so we stash the unordered_map lookup
|
||||
@ -3877,8 +3884,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
|
||||
PreProcess pre,
|
||||
PostProcess post,
|
||||
OpType opType,
|
||||
bool asyncOp,
|
||||
const char* profilingTitle,
|
||||
bool avoidRecordStreams,
|
||||
bool nanCheck) {
|
||||
auto inputs = std::vector<at::Tensor>{input};
|
||||
auto outputs = std::vector<at::Tensor>{output};
|
||||
@ -3889,8 +3896,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
|
||||
pre,
|
||||
post,
|
||||
opType,
|
||||
asyncOp,
|
||||
profilingTitle,
|
||||
avoidRecordStreams,
|
||||
nanCheck);
|
||||
}
|
||||
|
||||
@ -3900,8 +3907,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
|
||||
at::Tensor& output,
|
||||
Fn fn,
|
||||
OpType opType,
|
||||
bool asyncOp,
|
||||
const char* profilingTitle,
|
||||
bool avoidRecordStreams,
|
||||
bool nanCheck) {
|
||||
auto inputs = std::vector<at::Tensor>{input};
|
||||
auto outputs = std::vector<at::Tensor>{output};
|
||||
@ -3914,8 +3921,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
|
||||
[](at::cuda::CUDAStream&,
|
||||
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
|
||||
opType,
|
||||
asyncOp,
|
||||
profilingTitle,
|
||||
avoidRecordStreams,
|
||||
nanCheck);
|
||||
}
|
||||
|
||||
@ -3967,8 +3974,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_sparse(
|
||||
auto recvIndices = indices[0] * colSize;
|
||||
|
||||
// prevent output and recvIndices from being freed
|
||||
// TODO: not changing the lifetime management of outputs this time,
|
||||
// revisit later
|
||||
c10::cuda::CUDACachingAllocator::recordStream(
|
||||
output.storage().data_ptr(), stream);
|
||||
c10::cuda::CUDACachingAllocator::recordStream(
|
||||
@ -4000,7 +4005,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_sparse(
|
||||
}
|
||||
},
|
||||
OpType::_ALLREDUCE_SPARSE,
|
||||
opts.asyncOp,
|
||||
"nccl:all_reduce_sparse");
|
||||
return work;
|
||||
#else
|
||||
@ -4035,7 +4039,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_impl(
|
||||
stream.stream());
|
||||
},
|
||||
OpType::ALLREDUCE,
|
||||
opts.asyncOp,
|
||||
profilingTitle);
|
||||
}
|
||||
|
||||
@ -4136,7 +4139,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_coalesced(
|
||||
stream.stream());
|
||||
},
|
||||
OpType::COALESCED,
|
||||
opts.asyncOp,
|
||||
"nccl:allreduce_coalesced");
|
||||
}
|
||||
|
||||
@ -4168,10 +4170,12 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::broadcast(
|
||||
globalRankStride_, // globalRankStride_
|
||||
this->getSize()); // worldSize
|
||||
|
||||
// avoidRecordStreams_ note: collective() will stash tensors.
|
||||
bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp);
|
||||
|
||||
const auto root = opts.rootRank + opts.rootTensor;
|
||||
bool nanCheck = (root == rank_);
|
||||
|
||||
// avoidRecordStreams_ note: collective() will stash tensors.
|
||||
return collective(
|
||||
tensor,
|
||||
tensor,
|
||||
@ -4188,8 +4192,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::broadcast(
|
||||
stream.stream());
|
||||
},
|
||||
OpType::BROADCAST,
|
||||
opts.asyncOp,
|
||||
"nccl:broadcast",
|
||||
avoidRecordStreams,
|
||||
nanCheck);
|
||||
}
|
||||
|
||||
@ -4228,8 +4232,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_broadcast_oop(
|
||||
stream.stream());
|
||||
},
|
||||
OpType::BROADCAST,
|
||||
opts.asyncOp,
|
||||
"nccl:_broadcast_oop",
|
||||
/*avoidRecordStreams=*/false,
|
||||
nanCheck);
|
||||
}
|
||||
|
||||
@ -4288,7 +4292,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce(
|
||||
stream.stream());
|
||||
},
|
||||
OpType::REDUCE,
|
||||
opts.asyncOp,
|
||||
"nccl:reduce");
|
||||
}
|
||||
|
||||
@ -4330,7 +4333,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_oop(
|
||||
stream.stream());
|
||||
},
|
||||
OpType::REDUCE,
|
||||
opts.asyncOp,
|
||||
"nccl:_reduce_oop");
|
||||
}
|
||||
|
||||
@ -4374,7 +4376,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
|
||||
at::Tensor& output,
|
||||
ncclComm_t comm,
|
||||
at::cuda::CUDAStream& stream) {
|
||||
// See [We actually don't need to stash anything here].
|
||||
if (!avoidRecordStreams_) {
|
||||
c10::cuda::CUDACachingAllocator::recordStream(
|
||||
output.storage().data_ptr(), stream);
|
||||
}
|
||||
return ncclAllGather(
|
||||
input.data_ptr(),
|
||||
output.data_ptr(),
|
||||
@ -4390,27 +4395,27 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
|
||||
// - inputTensors is stashed onto work->stashed_for_allocator_safety_
|
||||
// in collective().
|
||||
// - outputFlattened is stashed onto work->outputs_ in collective().
|
||||
// - User-facing outputTensors should be held by the user until after
|
||||
// waiting on work_, or the call makes no sense.
|
||||
// So all participating tensors are accounted for, and won't be
|
||||
// released back to their allocation streams until after work_ is
|
||||
// waited on.
|
||||
},
|
||||
[&](at::cuda::CUDAStream& ncclStream,
|
||||
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
|
||||
// User-facing outputTensors should be held by the user until after
|
||||
// waiting on work_, or the call makes no sense. We do a stashing here
|
||||
// in case user doesn't hold the outputTensors in downstream code,
|
||||
// which can cause an early recyle by the CachingAllocator, which can
|
||||
// lead to segfault or data corruption.
|
||||
if (opts.asyncOp) {
|
||||
work->stashTensors(outputTensors_);
|
||||
}
|
||||
// Copy the flattened output tensors to the outputs.
|
||||
at::cuda::CUDAStreamGuard guard(ncclStream);
|
||||
for (const auto j : c10::irange(outputTensors_.size())) {
|
||||
// See [We actually don't need to stash anything here].
|
||||
// See [Sync Streams].
|
||||
if (!avoidRecordStreams_) {
|
||||
c10::cuda::CUDACachingAllocator::recordStream(
|
||||
outputTensors_[j].storage().data_ptr(), ncclStream);
|
||||
}
|
||||
outputTensors_[j].copy_(
|
||||
outputFlattened[static_cast<int64_t>(j)], true);
|
||||
}
|
||||
},
|
||||
OpType::ALLGATHER,
|
||||
opts.asyncOp,
|
||||
"nccl:all_gather");
|
||||
} else {
|
||||
const auto num_reduces = outputTensors_.size();
|
||||
@ -4418,8 +4423,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
|
||||
for (const int64_t i : c10::irange(static_cast<int64_t>(num_reduces))) {
|
||||
auto& output = outputTensors_[i];
|
||||
auto& input = (i == rank_) ? inputTensor : output;
|
||||
auto broadcastOpts =
|
||||
BroadcastOptions{i, int64_t(0), opts.timeout, opts.asyncOp};
|
||||
auto broadcastOpts = BroadcastOptions{i, int64_t(0), opts.timeout};
|
||||
_broadcast_oop(output, input, broadcastOpts);
|
||||
}
|
||||
auto work = endCoalescing(OpType::ALLGATHER);
|
||||
@ -4475,7 +4479,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather_into_tensor_coalesced(
|
||||
stream.stream());
|
||||
},
|
||||
OpType::COALESCED,
|
||||
opts.asyncOp,
|
||||
"nccl:all_gather_into_tensor_coalesced");
|
||||
}
|
||||
|
||||
@ -4521,6 +4524,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
|
||||
at::Tensor& output,
|
||||
ncclComm_t comm,
|
||||
at::cuda::CUDAStream& stream) {
|
||||
if (!avoidRecordStreams_) {
|
||||
c10::cuda::CUDACachingAllocator::recordStream(
|
||||
output.storage().data_ptr(), stream);
|
||||
}
|
||||
const auto ncclDataType = getNcclDataType(input.scalar_type());
|
||||
const auto ncclReduceOp =
|
||||
getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm);
|
||||
@ -4535,18 +4542,27 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
|
||||
},
|
||||
[&](at::cuda::CUDAStream& ncclStream,
|
||||
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
|
||||
// We only need to stash inputTensors.
|
||||
// - inputFlattened is stashed onto
|
||||
// work->stashed_for_allocator_safety_ in collective().
|
||||
// - User-facing outputTensors is stashed onto work->outputs_ in
|
||||
// collective(), and should also be held by the user until after
|
||||
// waiting on work_.
|
||||
if (opts.asyncOp) {
|
||||
work->stashTensors(inputTensors_);
|
||||
if (avoidRecordStreams_) {
|
||||
// We only need to stash inputTensors.
|
||||
// - inputFlattened is stashed onto
|
||||
// work->stashed_for_allocator_safety_
|
||||
// in collective().
|
||||
// - User-facing outputTensors is stashed onto work->outputs_ in
|
||||
// collective(),
|
||||
// and should also be held by the user until after waiting on
|
||||
// work_.
|
||||
auto& v = work->stashed_for_allocator_safety_;
|
||||
v->insert(v->end(), inputTensors_.begin(), inputTensors_.end());
|
||||
}
|
||||
|
||||
// Copy the input tensors to the flattened inputs.
|
||||
at::cuda::CUDAStreamGuard guard(ncclStream);
|
||||
for (const auto j : c10::irange(inputTensors_.size())) {
|
||||
// See [Sync Streams].
|
||||
if (!avoidRecordStreams_) {
|
||||
c10::cuda::CUDACachingAllocator::recordStream(
|
||||
inputTensors_[j].storage().data_ptr(), ncclStream);
|
||||
}
|
||||
inputFlattened[static_cast<int64_t>(j)].copy_(
|
||||
inputTensors_[j], true);
|
||||
}
|
||||
@ -4554,7 +4570,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
|
||||
[&](at::cuda::CUDAStream&,
|
||||
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
|
||||
OpType::REDUCE_SCATTER,
|
||||
opts.asyncOp,
|
||||
"nccl:reduce_scatter");
|
||||
} else {
|
||||
const auto num_reduces = inputTensors_.size();
|
||||
@ -4566,8 +4581,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
|
||||
opts.reduceOp,
|
||||
static_cast<int64_t>(i),
|
||||
static_cast<int64_t>(0),
|
||||
opts.timeout,
|
||||
opts.asyncOp};
|
||||
opts.timeout};
|
||||
_reduce_oop(output, input, reduceOpts);
|
||||
}
|
||||
auto work = endCoalescing(OpType::REDUCE_SCATTER);
|
||||
@ -4621,6 +4635,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_scatter_base(
|
||||
// stream so that the caching allocator can reuse memory pool for this stream
|
||||
// in a clever way. This setting is added for libraries like FSDP which uses
|
||||
// `reduce_scatter_tensor`.
|
||||
bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp);
|
||||
|
||||
return collective(
|
||||
inputTensor,
|
||||
@ -4629,6 +4644,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_scatter_base(
|
||||
at::Tensor& output,
|
||||
ncclComm_t comm,
|
||||
at::cuda::CUDAStream& stream) {
|
||||
if (!avoidRecordStreams) {
|
||||
c10::cuda::CUDACachingAllocator::recordStream(
|
||||
output.storage().data_ptr(), stream);
|
||||
}
|
||||
auto ncclDataType = getNcclDataType(input.scalar_type());
|
||||
auto ncclReduceOp =
|
||||
getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm);
|
||||
@ -4642,8 +4661,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_scatter_base(
|
||||
stream.stream());
|
||||
},
|
||||
OpType::_REDUCE_SCATTER_BASE,
|
||||
opts.asyncOp,
|
||||
"nccl:_reduce_scatter_base");
|
||||
"nccl:_reduce_scatter_base",
|
||||
avoidRecordStreams);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter_tensor_coalesced(
|
||||
@ -4680,6 +4699,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter_tensor_coalesced(
|
||||
at::Tensor& output,
|
||||
ncclComm_t comm,
|
||||
at::cuda::CUDAStream& stream) {
|
||||
if (!avoidRecordStreams_) {
|
||||
c10::cuda::CUDACachingAllocator::recordStream(
|
||||
output.storage().data_ptr(), stream);
|
||||
}
|
||||
auto ncclDataType = getNcclDataType(input.scalar_type());
|
||||
auto ncclReduceOp =
|
||||
getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm);
|
||||
@ -4693,7 +4716,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter_tensor_coalesced(
|
||||
stream.stream());
|
||||
},
|
||||
OpType::COALESCED,
|
||||
opts.asyncOp,
|
||||
"nccl:reduce_scatter_tensor_coalesced");
|
||||
}
|
||||
|
||||
@ -4772,28 +4794,13 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::barrier(const BarrierOptions& opts) {
|
||||
at::zeros({1}, at::TensorOptions().device(barDevice).dtype(at::kFloat));
|
||||
|
||||
// All reduce to achieve the barrier
|
||||
AllreduceOptions arOpts = AllreduceOptions();
|
||||
arOpts.asyncOp = opts.asyncOp;
|
||||
auto work = allreduce_impl(barrierTensor, "nccl:all_reduce_barrier", arOpts);
|
||||
auto work = allreduce_impl(barrierTensor, "nccl:all_reduce_barrier");
|
||||
|
||||
if (opts.asyncOp) {
|
||||
// Work will take over barrierTensors
|
||||
auto ncclWork = dynamic_cast<ProcessGroupNCCL::WorkNCCL*>(work.get());
|
||||
// If user specified async, the work should not be nullptr
|
||||
TORCH_CHECK(ncclWork);
|
||||
// Put a marker here so that `work.wait()` issue by users does
|
||||
// barrier-specific thing: CPU sync
|
||||
ncclWork->isBarrierOp_ = true;
|
||||
return work;
|
||||
}
|
||||
|
||||
// Otherwise, we are in sync mode, we directly wait here.
|
||||
// (It is a CPU wait for barrier)
|
||||
auto currentStream = at::cuda::getCurrentCUDAStream(barDevIdx);
|
||||
// CUDAStream wrapper will correctly use a DeviceGuard here
|
||||
currentStream.synchronize();
|
||||
// No work to return
|
||||
return nullptr;
|
||||
// Work will take over barrierTensors
|
||||
auto ncclWork = dynamic_cast<ProcessGroupNCCL::WorkNCCL*>(work.get());
|
||||
TORCH_CHECK(ncclWork);
|
||||
ncclWork->isBarrierOp_ = true;
|
||||
return work;
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
|
||||
@ -4801,7 +4808,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
|
||||
at::Tensor& inputTensor,
|
||||
std::vector<int64_t>& outputSplitSizes,
|
||||
std::vector<int64_t>& inputSplitSizes,
|
||||
const AllToAllOptions& opts) {
|
||||
const AllToAllOptions& /* unused */) {
|
||||
check_gpu_single_tensor(outputTensor);
|
||||
check_gpu_single_tensor(inputTensor);
|
||||
if (outputSplitSizes.empty() && inputSplitSizes.empty()) {
|
||||
@ -4832,12 +4839,16 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
|
||||
at::Tensor& output,
|
||||
ncclComm_t comm,
|
||||
at::cuda::CUDAStream& stream) {
|
||||
// See [Sync Streams].
|
||||
if (!avoidRecordStreams_) {
|
||||
c10::cuda::CUDACachingAllocator::recordStream(
|
||||
output.storage().data_ptr(), stream);
|
||||
}
|
||||
torch::cuda::nccl::all2all_single_equal_split(
|
||||
input, output, this->getSize(), comm, stream);
|
||||
return ncclSuccess;
|
||||
},
|
||||
OpType::ALLTOALL_BASE,
|
||||
opts.asyncOp,
|
||||
"nccl:all_to_all");
|
||||
} else {
|
||||
c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_);
|
||||
@ -4879,6 +4890,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
|
||||
c10d::computeLengthsAndOffsets(
|
||||
outputSplitSizes, output, &recv_lengths, &recv_offsets);
|
||||
// See [Sync Streams].
|
||||
if (!avoidRecordStreams_) {
|
||||
c10::cuda::CUDACachingAllocator::recordStream(
|
||||
output.storage().data_ptr(), stream);
|
||||
}
|
||||
torch::cuda::nccl::all2all_single_unequal_split(
|
||||
input.data_ptr(),
|
||||
send_lengths.data(),
|
||||
@ -4893,7 +4908,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
|
||||
return ncclSuccess;
|
||||
},
|
||||
OpType::ALLTOALL_BASE,
|
||||
opts.asyncOp,
|
||||
"nccl:all_to_all");
|
||||
}
|
||||
}
|
||||
@ -4901,7 +4915,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
|
||||
c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall(
|
||||
std::vector<at::Tensor>& outputTensors,
|
||||
std::vector<at::Tensor>& inputTensors,
|
||||
const AllToAllOptions& opts) {
|
||||
const AllToAllOptions& /* unused */) {
|
||||
std::vector<int64_t> inSplitSizes;
|
||||
std::vector<int64_t> outSplitSizes;
|
||||
int64_t total_numel = 0;
|
||||
@ -4948,11 +4962,18 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall(
|
||||
return ncclSuccess;
|
||||
},
|
||||
[&](at::cuda::CUDAStream&,
|
||||
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
|
||||
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
|
||||
if (avoidRecordStreams_) {
|
||||
// inputTensor0 and outputTensor0 are stashed redundantly by
|
||||
// collective(), but that's ok.
|
||||
auto& v = work->stashed_for_allocator_safety_;
|
||||
v->insert(v->end(), inputTensors.begin(), inputTensors.end());
|
||||
v->insert(v->end(), outputTensors.begin(), outputTensors.end());
|
||||
}
|
||||
},
|
||||
[](at::cuda::CUDAStream&,
|
||||
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
|
||||
OpType::ALLTOALL,
|
||||
opts.asyncOp,
|
||||
"nccl:all_to_all");
|
||||
}
|
||||
|
||||
@ -5150,6 +5171,14 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::gather(
|
||||
ncclComm_t comm,
|
||||
at::cuda::CUDAStream& stream) {
|
||||
const auto root = opts.rootRank;
|
||||
if (getRank() == root) {
|
||||
if (!avoidRecordStreams_) {
|
||||
for (auto const& output : outputs) {
|
||||
c10::cuda::CUDACachingAllocator::recordStream(
|
||||
output.storage().data_ptr(), stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
torch::cuda::nccl::gather(
|
||||
inputTensor, outputs, comm, stream, static_cast<int32_t>(root));
|
||||
return ncclSuccess;
|
||||
@ -5159,7 +5188,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::gather(
|
||||
[](at::cuda::CUDAStream&,
|
||||
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
|
||||
OpType::GATHER,
|
||||
opts.asyncOp,
|
||||
"nccl:gather");
|
||||
}
|
||||
|
||||
@ -5228,6 +5256,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::scatter(
|
||||
|
||||
// avoidRecordStreams_ note: collective() will stash outputTensors and
|
||||
// inputs, which == inputTensors[0] on the root rank where it matters.
|
||||
bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp);
|
||||
|
||||
const auto root = opts.rootRank;
|
||||
bool nanCheck = (rank_ == root);
|
||||
|
||||
@ -5239,6 +5269,14 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::scatter(
|
||||
at::Tensor& /* unused */,
|
||||
ncclComm_t comm,
|
||||
at::cuda::CUDAStream& stream) {
|
||||
if (getRank() == root) {
|
||||
if (!avoidRecordStreams) {
|
||||
for (auto const& input : inputs) {
|
||||
c10::cuda::CUDACachingAllocator::recordStream(
|
||||
input.storage().data_ptr(), stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
torch::cuda::nccl::scatter(
|
||||
inputs, outputTensor, comm, stream, static_cast<int32_t>(root));
|
||||
return ncclSuccess;
|
||||
@ -5248,8 +5286,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::scatter(
|
||||
[](at::cuda::CUDAStream&,
|
||||
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
|
||||
OpType::SCATTER,
|
||||
opts.asyncOp,
|
||||
"nccl:scatter",
|
||||
avoidRecordStreams,
|
||||
nanCheck);
|
||||
}
|
||||
|
||||
@ -5305,6 +5343,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_allgather_base(
|
||||
// stream so that the caching allocator can reuse memory pool for this stream
|
||||
// in a clever way. This setting is added for libraries like FSDP which uses
|
||||
// `all_gather_into_tensor`.
|
||||
bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp);
|
||||
|
||||
return collective(
|
||||
input_tensor,
|
||||
@ -5313,6 +5352,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_allgather_base(
|
||||
at::Tensor& output,
|
||||
ncclComm_t comm,
|
||||
at::cuda::CUDAStream& stream) {
|
||||
if (!avoidRecordStreams) {
|
||||
c10::cuda::CUDACachingAllocator::recordStream(
|
||||
output.storage().data_ptr(), stream);
|
||||
}
|
||||
return ncclAllGather(
|
||||
input.data_ptr(),
|
||||
output.data_ptr(),
|
||||
@ -5322,8 +5365,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_allgather_base(
|
||||
stream.stream());
|
||||
},
|
||||
OpType::_ALLGATHER_BASE,
|
||||
opts.asyncOp,
|
||||
"nccl:_all_gather_base");
|
||||
"nccl:_all_gather_base",
|
||||
avoidRecordStreams);
|
||||
}
|
||||
|
||||
// Create a memory allocator for NCCL. This allocator is used to allocate memory
|
||||
|
@ -382,6 +382,9 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// Clone of blockingWait_ from ProcessGroupNCCL.
|
||||
bool blockingWait_{false};
|
||||
|
||||
// Clone of avoidRecordStreams_ from ProcessGroupNCCL.
|
||||
bool avoidRecordStreams_{false};
|
||||
|
||||
// Clone of opTimeout_ from ProcessGroupNCCL.
|
||||
std::chrono::milliseconds opTimeout_{};
|
||||
|
||||
@ -428,13 +431,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// exception_ptr.
|
||||
bool finishedGPUExecutionInternal() const;
|
||||
|
||||
// Stash tensors so that CachingAllocator cannot recycle them prematurely.
|
||||
// Used in case of async ops.
|
||||
void stashTensors(std::vector<at::Tensor>& tensors);
|
||||
|
||||
// Unstage the stashed tensors so that CachingAllocator can recycle them
|
||||
void unstashTensors();
|
||||
|
||||
// Reference to the store so that we can write aborted communicators
|
||||
// to the store.
|
||||
c10::intrusive_ptr<Store> store_;
|
||||
@ -454,9 +450,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// For in-place collectives, some refs stashed here may alias outputs_,
|
||||
// but that doesn't do any harm.
|
||||
std::shared_ptr<std::vector<at::Tensor>> stashed_for_allocator_safety_;
|
||||
// Need a mutex to protect stashed_for_allocator_safety_ because it can be
|
||||
// accessed from both main thread and watchdog thread.
|
||||
std::mutex stashMutex_;
|
||||
|
||||
// The future returned by getFuture.
|
||||
c10::intrusive_ptr<at::ivalue::Future> future_;
|
||||
@ -885,8 +878,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
at::Tensor& output,
|
||||
Fn fn,
|
||||
OpType opType,
|
||||
bool asyncOp,
|
||||
const char* profilingTitle = nullptr,
|
||||
bool avoidRecordStreams = false,
|
||||
bool nanCheck = true);
|
||||
|
||||
template <typename Fn, typename PreProcess, typename PostProcess>
|
||||
@ -897,8 +890,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
PreProcess pre,
|
||||
PostProcess post,
|
||||
OpType opType,
|
||||
bool asyncOp,
|
||||
const char* profilingTitle = nullptr,
|
||||
bool avoidRecordStreams = false,
|
||||
bool nanCheck = true);
|
||||
|
||||
template <typename Fn, typename PreProcess, typename PostProcess>
|
||||
@ -909,8 +902,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
PreProcess pre,
|
||||
PostProcess post,
|
||||
OpType opType,
|
||||
bool asyncOp,
|
||||
const char* profilingTitle = nullptr,
|
||||
bool avoidRecordStreams = false,
|
||||
bool nanCheck = true);
|
||||
|
||||
template <typename Fn>
|
||||
@ -919,8 +912,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
std::vector<at::Tensor>& output,
|
||||
Fn fn,
|
||||
OpType opType,
|
||||
bool asyncOp,
|
||||
const char* profilingTitle = nullptr);
|
||||
const char* profilingTitle = nullptr,
|
||||
bool avoidRecordStreams = false);
|
||||
|
||||
// Helper that encapsulates work shared across point-to-point communication
|
||||
// primitives. It is the same structure as the helper used for collective
|
||||
@ -1229,9 +1222,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// Stores communicators for all collectives run inside a coalescing block
|
||||
std::shared_ptr<NCCLComm> coalescedComm_ = nullptr;
|
||||
|
||||
// Whether the coalesced calls are sync or async.
|
||||
bool coalescedAsync_;
|
||||
|
||||
// Whether or not wait() and synchronize() are blocking operations that wait
|
||||
// for the operation to complete.
|
||||
bool blockingWait_ = false;
|
||||
|
@ -122,7 +122,6 @@ struct BroadcastOptions {
|
||||
struct AllreduceOptions {
|
||||
ReduceOp reduceOp = ReduceOp::SUM;
|
||||
std::chrono::milliseconds timeout = kUnsetTimeout;
|
||||
bool asyncOp = true;
|
||||
std::optional<at::Tensor> sparseIndices = std::nullopt;
|
||||
};
|
||||
|
||||
@ -133,7 +132,6 @@ struct ReduceOptions {
|
||||
int64_t rootRank = 0;
|
||||
int64_t rootTensor = 0;
|
||||
std::chrono::milliseconds timeout = kUnsetTimeout;
|
||||
bool asyncOp = true;
|
||||
};
|
||||
|
||||
struct AllgatherOptions {
|
||||
@ -144,7 +142,6 @@ struct AllgatherOptions {
|
||||
struct GatherOptions {
|
||||
int64_t rootRank = 0;
|
||||
std::chrono::milliseconds timeout = kUnsetTimeout;
|
||||
bool asyncOp = true;
|
||||
};
|
||||
|
||||
struct ScatterOptions {
|
||||
@ -161,14 +158,12 @@ struct ReduceScatterOptions {
|
||||
|
||||
struct AllToAllOptions {
|
||||
std::chrono::milliseconds timeout = kUnsetTimeout;
|
||||
bool asyncOp = true;
|
||||
};
|
||||
|
||||
struct BarrierOptions {
|
||||
std::vector<int64_t> device_ids;
|
||||
std::chrono::milliseconds timeout = kUnsetTimeout;
|
||||
std::optional<at::Device> device;
|
||||
bool asyncOp = true;
|
||||
};
|
||||
|
||||
struct DistributedBackendOptions {
|
||||
|
@ -999,23 +999,20 @@ This class does not support ``__members__`` property.)");
|
||||
py::class_<::c10d::AllreduceOptions>(module, "AllreduceOptions")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("reduceOp", &::c10d::AllreduceOptions::reduceOp)
|
||||
.def_readwrite("timeout", &::c10d::AllreduceOptions::timeout)
|
||||
.def_readwrite("asyncOp", &::c10d::AllreduceOptions::asyncOp);
|
||||
.def_readwrite("timeout", &::c10d::AllreduceOptions::timeout);
|
||||
|
||||
py::class_<::c10d::AllreduceCoalescedOptions>(
|
||||
module, "AllreduceCoalescedOptions")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("reduceOp", &::c10d::AllreduceCoalescedOptions::reduceOp)
|
||||
.def_readwrite("timeout", &::c10d::AllreduceCoalescedOptions::timeout)
|
||||
.def_readwrite("asyncOp", &::c10d::AllreduceCoalescedOptions::asyncOp);
|
||||
.def_readwrite("timeout", &::c10d::AllreduceCoalescedOptions::timeout);
|
||||
|
||||
py::class_<::c10d::ReduceOptions>(module, "ReduceOptions")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("reduceOp", &::c10d::ReduceOptions::reduceOp)
|
||||
.def_readwrite("rootRank", &::c10d::ReduceOptions::rootRank)
|
||||
.def_readwrite("rootTensor", &::c10d::ReduceOptions::rootTensor)
|
||||
.def_readwrite("timeout", &::c10d::ReduceOptions::timeout)
|
||||
.def_readwrite("asyncOp", &::c10d::ReduceOptions::asyncOp);
|
||||
.def_readwrite("timeout", &::c10d::ReduceOptions::timeout);
|
||||
|
||||
py::class_<::c10d::AllgatherOptions>(module, "AllgatherOptions")
|
||||
.def(py::init<>())
|
||||
@ -1025,8 +1022,7 @@ This class does not support ``__members__`` property.)");
|
||||
py::class_<::c10d::GatherOptions>(module, "GatherOptions")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("rootRank", &::c10d::GatherOptions::rootRank)
|
||||
.def_readwrite("timeout", &::c10d::GatherOptions::timeout)
|
||||
.def_readwrite("asyncOp", &::c10d::GatherOptions::asyncOp);
|
||||
.def_readwrite("timeout", &::c10d::GatherOptions::timeout);
|
||||
|
||||
py::class_<::c10d::ScatterOptions>(module, "ScatterOptions")
|
||||
.def(py::init<>())
|
||||
@ -1044,13 +1040,11 @@ This class does not support ``__members__`` property.)");
|
||||
.def(py::init<>())
|
||||
.def_readwrite("device_ids", &::c10d::BarrierOptions::device_ids)
|
||||
.def_readwrite("timeout", &::c10d::BarrierOptions::timeout)
|
||||
.def_readwrite("device", &::c10d::BarrierOptions::device)
|
||||
.def_readwrite("asyncOp", &::c10d::BarrierOptions::asyncOp);
|
||||
.def_readwrite("device", &::c10d::BarrierOptions::device);
|
||||
|
||||
py::class_<::c10d::AllToAllOptions>(module, "AllToAllOptions")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("timeout", &::c10d::AllToAllOptions::timeout)
|
||||
.def_readwrite("asyncOp", &::c10d::AllToAllOptions::asyncOp);
|
||||
.def_readwrite("timeout", &::c10d::AllToAllOptions::timeout);
|
||||
|
||||
py::class_<::c10d::DistributedBackendOptions>(
|
||||
module, "_DistributedBackendOptions")
|
||||
|
@ -2500,7 +2500,7 @@ class _CoalescingManager:
|
||||
def __init__(self) -> None:
|
||||
self.works: list[Work] = []
|
||||
|
||||
def append(self, work: Optional[Work] = None):
|
||||
def append(self, work: Work):
|
||||
if work:
|
||||
self.works.append(work)
|
||||
|
||||
@ -2513,7 +2513,7 @@ class _CoalescingManager:
|
||||
def _coalescing_manager(
|
||||
group: Optional[ProcessGroup] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
async_ops: bool = False,
|
||||
async_ops: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
Context manager used to coalesce collectives or P2P operations when possible.
|
||||
@ -2552,7 +2552,6 @@ def _coalescing_manager(
|
||||
group._start_coalescing(device)
|
||||
cm = _CoalescingManager()
|
||||
yield cm
|
||||
work = None
|
||||
op_list = _world.pg_coalesce_state.pop(group)
|
||||
if op_list:
|
||||
# Collectives supporting "Fast Path" coalescing are captured.
|
||||
@ -2566,7 +2565,6 @@ def _coalescing_manager(
|
||||
tensors = [op.tensor for op in op_list]
|
||||
all_reduce_opts = AllreduceCoalescedOptions()
|
||||
all_reduce_opts.reduceOp = not_none(op_list[0].redop)
|
||||
all_reduce_opts.asyncOp = async_ops
|
||||
work = group.allreduce_coalesced(tensors, all_reduce_opts)
|
||||
elif op0 == all_gather_into_tensor:
|
||||
inputs = []
|
||||
@ -2574,8 +2572,6 @@ def _coalescing_manager(
|
||||
for op in op_list:
|
||||
inputs.append(op.tensor)
|
||||
outputs.append(not_none(op.dst_tensor))
|
||||
all_gather_opts = AllgatherOptions()
|
||||
all_gather_opts.asyncOp = async_ops
|
||||
work = group.allgather_into_tensor_coalesced(outputs, inputs)
|
||||
elif op0 == reduce_scatter_tensor:
|
||||
inputs = []
|
||||
@ -2585,7 +2581,6 @@ def _coalescing_manager(
|
||||
outputs.append(not_none(op.dst_tensor))
|
||||
reduce_opts = ReduceScatterOptions()
|
||||
reduce_opts.reduceOp = not_none(op_list[0].redop)
|
||||
reduce_opts.asyncOp = async_ops
|
||||
work = group.reduce_scatter_tensor_coalesced(outputs, inputs, reduce_opts)
|
||||
else:
|
||||
raise AssertionError(
|
||||
@ -2598,12 +2593,9 @@ def _coalescing_manager(
|
||||
work = group._end_coalescing(device)
|
||||
|
||||
if async_ops:
|
||||
cm.append(work)
|
||||
elif (
|
||||
work is not None
|
||||
): # Backward compatible with backends that don't sync at CPP level
|
||||
work.wait()
|
||||
# Otherwise, the backend has sync'ed at CPP level
|
||||
cm.append(work) # type: ignore[possibly-undefined]
|
||||
else:
|
||||
work.wait() # type: ignore[possibly-undefined]
|
||||
|
||||
|
||||
def batch_isend_irecv(p2p_op_list: list[P2POp]) -> list[Work]:
|
||||
@ -2728,11 +2720,8 @@ def broadcast(
|
||||
work = group.broadcast([tensor], opts)
|
||||
if async_op:
|
||||
return work
|
||||
elif (
|
||||
work is not None
|
||||
): # Backward compatible with backends that don't sync at CPP level
|
||||
else:
|
||||
work.wait()
|
||||
# Otherwise, the backend has sync'ed at CPP level
|
||||
|
||||
|
||||
@_exception_logger
|
||||
@ -2812,7 +2801,6 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
|
||||
|
||||
opts = AllreduceOptions()
|
||||
opts.reduceOp = op
|
||||
opts.asyncOp = async_op
|
||||
if group is None:
|
||||
group = _get_default_group()
|
||||
|
||||
@ -2829,11 +2817,8 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
|
||||
|
||||
if async_op:
|
||||
return work
|
||||
elif (
|
||||
work is not None
|
||||
): # Backward compatible with backends that don't sync at CPP level
|
||||
else:
|
||||
work.wait()
|
||||
# Otherwise, the backend has sync'ed at CPP level
|
||||
|
||||
|
||||
@_exception_logger
|
||||
@ -2892,17 +2877,13 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False):
|
||||
|
||||
opts = AllreduceCoalescedOptions()
|
||||
opts.reduceOp = op
|
||||
opts.asyncOp = async_op
|
||||
group = group or _get_default_group()
|
||||
work = group.allreduce_coalesced(tensors, opts)
|
||||
|
||||
if async_op:
|
||||
return work.get_future()
|
||||
elif (
|
||||
work is not None
|
||||
): # Backward compatible with backends that don't sync at CPP level
|
||||
else:
|
||||
work.wait()
|
||||
# Otherwise, the backend has sync'ed at CPP level
|
||||
|
||||
|
||||
@_exception_logger
|
||||
@ -2947,15 +2928,11 @@ def reduce(
|
||||
opts = ReduceOptions()
|
||||
opts.reduceOp = op
|
||||
opts.rootRank = group_dst
|
||||
opts.asyncOp = async_op
|
||||
work = group.reduce([tensor], opts)
|
||||
if async_op:
|
||||
return work
|
||||
elif (
|
||||
work is not None
|
||||
): # Backward compatible with backends that don't sync at CPP level
|
||||
else:
|
||||
work.wait()
|
||||
# Otherwise, the backend has sync'ed at CPP level
|
||||
|
||||
|
||||
def _object_to_tensor(obj, device, group):
|
||||
@ -3754,17 +3731,12 @@ def all_gather(tensor_list, tensor, group=None, async_op=False):
|
||||
tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor)
|
||||
|
||||
group = group or _get_default_group()
|
||||
opts = AllgatherOptions()
|
||||
opts.asyncOp = async_op
|
||||
work = group.allgather([tensor_list], [tensor], opts)
|
||||
work = group.allgather([tensor_list], [tensor])
|
||||
|
||||
if async_op:
|
||||
return work
|
||||
elif (
|
||||
work is not None
|
||||
): # Backward compatible with backends that don't sync at CPP level
|
||||
else:
|
||||
work.wait()
|
||||
# Otherwise, the backend has sync'ed at CPP level
|
||||
|
||||
|
||||
@_exception_logger
|
||||
@ -3871,11 +3843,8 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal
|
||||
|
||||
if async_op:
|
||||
return work
|
||||
elif (
|
||||
work is not None
|
||||
): # Backward compatible with backends that don't sync at CPP level
|
||||
else:
|
||||
work.wait()
|
||||
# Otherwise, the backend has sync'ed at CPP level
|
||||
|
||||
|
||||
@_exception_logger
|
||||
@ -3985,17 +3954,12 @@ def all_gather_coalesced(
|
||||
]
|
||||
|
||||
group = group or _get_default_group()
|
||||
opts = AllgatherOptions()
|
||||
opts.asyncOp = async_op
|
||||
work = group.allgather_coalesced(output_tensor_lists, input_tensor_list, opts)
|
||||
work = group.allgather_coalesced(output_tensor_lists, input_tensor_list)
|
||||
|
||||
if async_op:
|
||||
return work.get_future()
|
||||
elif (
|
||||
work is not None
|
||||
): # Backward compatible with backends that don't sync at CPP level
|
||||
else:
|
||||
work.wait()
|
||||
# Otherwise, the backend has sync'ed at CPP level
|
||||
|
||||
|
||||
def _validate_output_list_for_rank(my_rank, dst, gather_list):
|
||||
@ -4082,16 +4046,12 @@ def gather(
|
||||
|
||||
opts = GatherOptions()
|
||||
opts.rootRank = group_dst
|
||||
opts.asyncOp = async_op
|
||||
work = group.gather(output_tensors, input_tensors, opts)
|
||||
|
||||
if async_op:
|
||||
return work
|
||||
elif (
|
||||
work is not None
|
||||
): # Backward compatible with backends that don't sync at CPP level
|
||||
else:
|
||||
work.wait()
|
||||
# Otherwise, the backend has sync'ed at CPP level
|
||||
|
||||
|
||||
@_exception_logger
|
||||
@ -4193,11 +4153,8 @@ def scatter(
|
||||
|
||||
if async_op:
|
||||
return work
|
||||
elif (
|
||||
work is not None
|
||||
): # Backward compatible with backends that don't sync at CPP level
|
||||
else:
|
||||
work.wait()
|
||||
# Otherwise, the backend has sync'ed at CPP level
|
||||
|
||||
|
||||
@_exception_logger
|
||||
@ -4229,18 +4186,14 @@ def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=Fal
|
||||
|
||||
opts = ReduceScatterOptions()
|
||||
opts.reduceOp = op
|
||||
opts.asyncOp = async_op
|
||||
|
||||
group = group or _get_default_group()
|
||||
work = group.reduce_scatter([output], [input_list], opts)
|
||||
|
||||
if async_op:
|
||||
return work
|
||||
elif (
|
||||
work is not None
|
||||
): # Backward compatible with backends that don't sync at CPP level
|
||||
else:
|
||||
work.wait()
|
||||
# Otherwise, the backend has sync'ed at CPP level
|
||||
|
||||
|
||||
@_exception_logger
|
||||
@ -4340,11 +4293,8 @@ def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=F
|
||||
|
||||
if async_op:
|
||||
return work
|
||||
elif (
|
||||
work is not None
|
||||
): # Backward compatible with backends that don't sync at CPP level
|
||||
else:
|
||||
work.wait()
|
||||
# Otherwise, the backend has sync'ed at CPP level
|
||||
|
||||
|
||||
@deprecated(
|
||||
@ -4497,7 +4447,6 @@ def all_to_all_single(
|
||||
return
|
||||
|
||||
opts = AllToAllOptions()
|
||||
opts.asyncOp = async_op
|
||||
_check_single_tensor(output, "output")
|
||||
_check_single_tensor(input, "input")
|
||||
_ensure_all_tensors_same_dtype(output, input)
|
||||
@ -4517,11 +4466,8 @@ def all_to_all_single(
|
||||
|
||||
if async_op:
|
||||
return work
|
||||
elif (
|
||||
work is not None
|
||||
): # Backward compatible with backends that don't sync at CPP level
|
||||
else:
|
||||
work.wait()
|
||||
# Otherwise, the backend has sync'ed at CPP level
|
||||
|
||||
|
||||
@_exception_logger
|
||||
@ -4622,7 +4568,6 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False
|
||||
return
|
||||
|
||||
opts = AllToAllOptions()
|
||||
opts.asyncOp = async_op
|
||||
_check_tensor_list(output_tensor_list, "output_tensor_list")
|
||||
_check_tensor_list(input_tensor_list, "input_tensor_list")
|
||||
_ensure_all_tensors_same_dtype(output_tensor_list, input_tensor_list)
|
||||
@ -4639,11 +4584,8 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False
|
||||
|
||||
if async_op:
|
||||
return work
|
||||
elif (
|
||||
work is not None
|
||||
): # Backward compatible with backends that don't sync at CPP level
|
||||
else:
|
||||
work.wait()
|
||||
# Otherwise, the backend has sync'ed at CPP level
|
||||
|
||||
|
||||
@_exception_logger
|
||||
@ -4674,7 +4616,6 @@ def barrier(
|
||||
|
||||
opts = BarrierOptions()
|
||||
opts.device = torch.device(_get_object_coll_device(group))
|
||||
opts.asyncOp = async_op
|
||||
if device_ids is not None:
|
||||
if isinstance(device_ids, list):
|
||||
opts.device_ids = device_ids
|
||||
@ -4688,11 +4629,8 @@ def barrier(
|
||||
|
||||
if async_op:
|
||||
return work
|
||||
elif (
|
||||
work is not None
|
||||
): # Backward compatible with backends that don't sync at CPP level
|
||||
else:
|
||||
work.wait()
|
||||
# Otherwise, the backend has sync'ed at CPP level
|
||||
|
||||
|
||||
def monitored_barrier(
|
||||
|
@ -96,7 +96,7 @@ try:
|
||||
import torchvision
|
||||
|
||||
HAS_TORCHVISION = True
|
||||
except Exception: # Covering both ImportError and RuntimeError
|
||||
except ImportError:
|
||||
HAS_TORCHVISION = False
|
||||
|
||||
if sys.platform == "win32":
|
||||
@ -8310,14 +8310,50 @@ class DistributedTest:
|
||||
def test_compute_bucket_assignment_by_size_sparse_error_with_logger(self):
|
||||
self._test_compute_bucket_assignment_by_size(use_logger=True)
|
||||
|
||||
def _determine_expected_error_verify_model_across_rank(
|
||||
self, group_to_use, diff_num_params=False
|
||||
):
|
||||
# When running with NCCL backend, we don't expect an error on rank 0,
|
||||
# rather, it will be taken down by TORCH_NCCL_ASYNC_ERROR_HANDLING. When
|
||||
# running with Gloo or with debug mode wrapper, we expect the error
|
||||
# to be caught inline.
|
||||
# All ranks report same error when there is a # of parameter
|
||||
# mismatch since we use allgather in the impl.
|
||||
if diff_num_params:
|
||||
expected_err = "DDP expects same model across all ranks"
|
||||
ctx = self.assertRaisesRegex(RuntimeError, expected_err)
|
||||
return ctx, expected_err
|
||||
|
||||
is_detail_dbg_mode = dist.get_debug_level() == dist.DebugLevel.DETAIL
|
||||
if self.rank == 0:
|
||||
if (
|
||||
dist.get_backend(group_to_use) == dist.Backend.NCCL
|
||||
and not is_detail_dbg_mode
|
||||
):
|
||||
expected_err = "caught collective operation timeout"
|
||||
ctx = self.assertRaisesRegex(RuntimeError, expected_err)
|
||||
else:
|
||||
expected_err = None
|
||||
ctx = self.assertRaises(RuntimeError)
|
||||
else:
|
||||
expected_err = "appears not to match"
|
||||
ctx = self.assertRaisesRegex(RuntimeError, expected_err)
|
||||
return ctx, expected_err
|
||||
|
||||
def _test_verify_model_across_rank(self, use_logger):
|
||||
group_gloo = dist.new_group(
|
||||
timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
|
||||
)
|
||||
# Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test
|
||||
# determinism.
|
||||
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
|
||||
group_to_use = dist.new_group(
|
||||
backend=dist.get_backend(), timeout=timedelta(seconds=5)
|
||||
)
|
||||
torch.cuda.set_device(self.rank)
|
||||
ctx, expected_err = self._determine_expected_error_verify_model_across_rank(
|
||||
group_to_use
|
||||
)
|
||||
|
||||
# Create a valid model. The constructor initializes the logger that we use later.
|
||||
net = EmbeddingNetDifferentParams(0)
|
||||
@ -8335,8 +8371,7 @@ class DistributedTest:
|
||||
net.module.lin = nn.Linear(100 if self.rank == 0 else 10, 1)
|
||||
|
||||
# if we pass a logger we can verify that it was logged
|
||||
caught = 0
|
||||
try:
|
||||
with ctx:
|
||||
if use_logger:
|
||||
_verify_param_shape_across_processes(
|
||||
net.process_group, list(net.parameters()), net.logger
|
||||
@ -8345,13 +8380,18 @@ class DistributedTest:
|
||||
_verify_param_shape_across_processes(
|
||||
net.process_group, list(net.parameters())
|
||||
)
|
||||
except Exception:
|
||||
caught = 1
|
||||
# Should only be run by rank 0, and blocking_wait catches and
|
||||
# reports exception.
|
||||
dist.barrier(group_to_use)
|
||||
|
||||
# As long as there is one rank catching the exception
|
||||
t = torch.Tensor([caught])
|
||||
dist.all_reduce(t, group=group_gloo)
|
||||
self.assertGreater(t, 0)
|
||||
# We don't check when self.rank != 0 because the logger doesn't log
|
||||
# the error "Caught collective operation" as that is not thrown in the reducer.
|
||||
if use_logger and self.rank != 0:
|
||||
verify_ddp_error_logged(net, expected_err)
|
||||
|
||||
# Perform gloo-based barrier to ensure one rank doesn't exit test
|
||||
# early which causes failure with Barrier.sync.
|
||||
dist.barrier(group_gloo)
|
||||
|
||||
@require_backend_is_available(DistTestCases.backend_feature["gpu"])
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
@ -8369,19 +8409,20 @@ class DistributedTest:
|
||||
def test_verify_model_across_rank_without_logger(self):
|
||||
self._test_verify_model_across_rank(use_logger=False)
|
||||
|
||||
def _run_test_ddp_model_with_diff_params(self, net, ddp_group, group_gloo):
|
||||
caught = 0
|
||||
try:
|
||||
def _run_test_ddp_model_with_diff_params(self, ctx, net, ddp_group, group_gloo):
|
||||
with ctx:
|
||||
net = torch.nn.parallel.DistributedDataParallel(
|
||||
net.to(self.rank), device_ids=[self.rank], process_group=ddp_group
|
||||
)
|
||||
except Exception:
|
||||
caught = 1
|
||||
# Should only be run by rank 0, and blocking_wait catches and
|
||||
# reports exception.
|
||||
dist.barrier(ddp_group)
|
||||
|
||||
# As long as there is one rank catching the exception
|
||||
t = torch.Tensor([caught])
|
||||
dist.all_reduce(t, group=group_gloo)
|
||||
self.assertGreater(t, 0)
|
||||
# can't use verify_ddp_error_logged here because net was never properly constructed
|
||||
|
||||
# Perform gloo-based barrier to ensure one rank doesn't exit test
|
||||
# early which causes failure with Barrier.sync.
|
||||
dist.barrier(group_gloo)
|
||||
|
||||
@require_backend_is_available(DistTestCases.backend_feature["gpu"])
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
@ -8392,15 +8433,21 @@ class DistributedTest:
|
||||
group_gloo = dist.new_group(
|
||||
timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
|
||||
)
|
||||
# Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test
|
||||
# determinism.
|
||||
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
|
||||
group_to_use = dist.new_group(
|
||||
backend=dist.get_backend(), timeout=timedelta(seconds=10)
|
||||
)
|
||||
torch.cuda.set_device(self.rank)
|
||||
ctx, _expected_err = self._determine_expected_error_verify_model_across_rank(
|
||||
group_to_use
|
||||
)
|
||||
# Creates network with different sized embedding table on different
|
||||
# ranks. This should throw an error during DDP init.
|
||||
net = EmbeddingNetDifferentParams(self.rank)
|
||||
self._run_test_ddp_model_with_diff_params(
|
||||
net, group_to_use, group_gloo
|
||||
ctx, net, group_to_use, group_gloo
|
||||
)
|
||||
|
||||
@require_backend_is_available(DistTestCases.backend_feature["gpu"])
|
||||
@ -8412,10 +8459,16 @@ class DistributedTest:
|
||||
group_gloo = dist.new_group(
|
||||
timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
|
||||
)
|
||||
# Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test
|
||||
# determinism.
|
||||
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
|
||||
group_to_use = dist.new_group(
|
||||
backend=dist.get_backend(), timeout=timedelta(seconds=10)
|
||||
)
|
||||
torch.cuda.set_device(self.rank)
|
||||
ctx, _expected_err = self._determine_expected_error_verify_model_across_rank(
|
||||
group_to_use, diff_num_params=True
|
||||
)
|
||||
|
||||
# Creates network with diff # of param across ranks, reducer should
|
||||
# recognize this and throw appropriate error.
|
||||
@ -8424,6 +8477,7 @@ class DistributedTest:
|
||||
)
|
||||
|
||||
self._run_test_ddp_model_with_diff_params(
|
||||
ctx,
|
||||
net,
|
||||
group_to_use,
|
||||
group_gloo,
|
||||
|
Reference in New Issue
Block a user