From afa1eda901b81fa6ce1afe651d3c2da53fa92440 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 17 Mar 2025 22:43:15 +0000 Subject: [PATCH] Revert "[PGNCCL] Launch kernel on current stream & remove `record_stream` entirely (#148590)" This reverts commit ef6296e7f20d744a0cfed81cab573d60204e7626. Reverted https://github.com/pytorch/pytorch/pull/148590 on behalf of https://github.com/izaitsevfb due to reverted internally, see D71292427 ([comment](https://github.com/pytorch/pytorch/pull/148590#issuecomment-2731114626)) --- test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp | 3 - .../check_forward_backward_compatibility.py | 3 - torch/_C/_distributed_c10d.pyi | 8 +- torch/csrc/distributed/c10d/Ops.cpp | 130 +++---- torch/csrc/distributed/c10d/ProcessGroup.hpp | 43 +-- .../distributed/c10d/ProcessGroupNCCL.cpp | 341 ++++++++++-------- .../distributed/c10d/ProcessGroupNCCL.hpp | 26 +- torch/csrc/distributed/c10d/Types.hpp | 5 - torch/csrc/distributed/c10d/init.cpp | 18 +- torch/distributed/distributed_c10d.py | 104 ++---- .../_internal/distributed/distributed_test.py | 92 ++++- 11 files changed, 362 insertions(+), 411 deletions(-) diff --git a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp index 533c50a43fe8..a2fa2b467c52 100644 --- a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp @@ -363,9 +363,6 @@ class TestDebugInfoWriter : public c10d::DebugInfoWriter { }; TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) { - // Note (kwen2501) 03/07/2025 - // TODO: re-enable - GTEST_SKIP() << "Skipping test as the trace write seems unstable."; int heartBeatIntervalInSec = 2; std::string timeInterval = std::to_string(heartBeatIntervalInSec); ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "0", 1) == 0); diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index bfd255c50111..03b065a3691a 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -126,9 +126,6 @@ ALLOW_LIST = [ ("aten::reduce_scatter_tensor", datetime.date(9999, 1, 30)), ("aten::all_gather_into_tensor", datetime.date(9999, 1, 30)), ("aten::all_reduce", datetime.date(9999, 1, 30)), - # These ops are defined in torch/csrc/distributed/c10d/Ops.cpp - # TODO: add back restriction when c10d ops can be exported - ("c10d::.*", datetime.date(9999, 1, 1)), ] ALLOW_LIST_COMPILED = [ diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index f94c8b4fcedd..e4b5a116fdbd 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -2,7 +2,7 @@ # mypy: disable-error-code="type-arg" from datetime import timedelta from enum import Enum -from typing import Any, Optional, overload +from typing import Any, overload import torch from torch import Tensor @@ -139,8 +139,6 @@ class BroadcastOptions: class AllreduceOptions: reduceOp: ReduceOp timeout: timedelta - asyncOp: bool - sparseIndices: Optional[Tensor] class AllreduceCoalescedOptions(AllreduceOptions): ... @@ -149,7 +147,6 @@ class ReduceOptions: rootRank: int rootTensor: int timeout: timedelta - asyncOp: bool class AllgatherOptions: timeout: timedelta @@ -158,7 +155,6 @@ class AllgatherOptions: class GatherOptions: rootRank: int timeout: timedelta - asyncOp: bool class ScatterOptions: rootRank: int @@ -174,11 +170,9 @@ class BarrierOptions: device_ids: list[int] device: torch.device timeout: timedelta - asyncOp: bool class AllToAllOptions: timeout: timedelta - asyncOp: bool class Store: def set(self, key: str, value: str): ... diff --git a/torch/csrc/distributed/c10d/Ops.cpp b/torch/csrc/distributed/c10d/Ops.cpp index 0480f1b9191d..6251bfa1817d 100644 --- a/torch/csrc/distributed/c10d/Ops.cpp +++ b/torch/csrc/distributed/c10d/Ops.cpp @@ -17,37 +17,37 @@ TORCH_LIBRARY(c10d, m) { .def("wait", [](const c10::intrusive_ptr& self) { self->wait(); }); m.class_("ReduceOp").def(torch::init<>()); m.def( - "broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); + "broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, bool asyncOp, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); m.def( - "allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); + "allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); m.def( - "allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"); + "allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> __torch__.torch.classes.c10d.Work"); m.def( - "allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, int timeout=-1) -> (Tensor[][], __torch__.torch.classes.c10d.Work)"); + "allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int timeout) -> (Tensor[][], __torch__.torch.classes.c10d.Work)"); m.def( - "_allgather_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, int timeout=-1) -> (Tensor, __torch__.torch.classes.c10d.Work)"); + "_allgather_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, bool asyncOp, int timeout) -> (Tensor, __torch__.torch.classes.c10d.Work)"); m.def( - "allgather_coalesced_(Tensor[][] output_lists, Tensor[] input_list, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True) -> __torch__.torch.classes.c10d.Work"); + "allgather_coalesced_(Tensor[][] output_lists, Tensor[] input_list, __torch__.torch.classes.c10d.ProcessGroup process_group) -> __torch__.torch.classes.c10d.Work"); m.def( - "allgather_into_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True) -> __torch__.torch.classes.c10d.Work"); + "allgather_into_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group) -> __torch__.torch.classes.c10d.Work"); m.def( - "reduce_scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); + "reduce_scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); m.def( - "_reduce_scatter_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> (Tensor, __torch__.torch.classes.c10d.Work)"); + "_reduce_scatter_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool asyncOp, int timeout) -> (Tensor, __torch__.torch.classes.c10d.Work)"); m.def( - "reduce_scatter_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"); + "reduce_scatter_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> __torch__.torch.classes.c10d.Work"); m.def( - "reduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int root_rank, int root_tensor, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"); + "reduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int root_rank, int root_tensor, int timeout) -> __torch__.torch.classes.c10d.Work"); m.def( - "gather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"); + "gather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int timeout) -> __torch__.torch.classes.c10d.Work"); m.def( - "scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); + "scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool asyncOp, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); m.def( - "alltoall_(Tensor[] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); + "alltoall_(Tensor[] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); m.def( - "alltoall_base_(Tensor output, Tensor input, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] output_split_sizes, int[] input_split_sizes, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"); + "alltoall_base_(Tensor output, Tensor input, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] output_split_sizes, int[] input_split_sizes, int timeout) -> __torch__.torch.classes.c10d.Work"); m.def( - "barrier(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"); + "barrier(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, int timeout) -> __torch__.torch.classes.c10d.Work"); m.def( "monitored_barrier_(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, int timeout, bool wait_all_ranks) -> ()"); m.def( @@ -118,7 +118,6 @@ IMPL_RECV_ANY_SOURCE(PrivateUse1) const c10::intrusive_ptr& reduce_op, \ int64_t root_rank, \ int64_t root_tensor, \ - bool asyncOp, \ int64_t timeout) { \ auto tensor_vec = tensors.vec(); \ return process_group->getBackend(c10::DeviceType::DEV) \ @@ -128,8 +127,7 @@ IMPL_RECV_ANY_SOURCE(PrivateUse1) *reduce_op.get(), \ root_rank, \ root_tensor, \ - std::chrono::milliseconds(timeout), \ - asyncOp}); \ + std::chrono::milliseconds(timeout)}); \ } IMPL_REDUCE(CPU) @@ -171,13 +169,12 @@ IMPL_BROADCAST(PrivateUse1) const c10::intrusive_ptr& process_group, \ const c10::intrusive_ptr& reduce_op, \ const std::optional& sparse_indices, \ - bool asyncOp, \ int64_t timeout) { \ auto tensor_vec = tensors.vec(); \ auto work = process_group->getBackend(c10::DeviceType::DEV) -> allreduce( \ tensor_vec, \ AllreduceOptions{ \ - *reduce_op.get(), std::chrono::milliseconds(timeout), asyncOp}); \ + *reduce_op.get(), std::chrono::milliseconds(timeout)}); \ return std::tuple, c10::intrusive_ptr>( \ std::move(tensor_vec), work); \ } @@ -191,13 +188,11 @@ IMPL_ALLREDUCE(PrivateUse1) at::TensorList tensors, \ const c10::intrusive_ptr& process_group, \ const c10::intrusive_ptr& reduce_op, \ - bool asyncOp, \ int64_t timeout) { \ auto tensor_vec = tensors.vec(); \ AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; \ opts.reduceOp = *reduce_op.get(); \ opts.timeout = std::chrono::milliseconds(timeout); \ - opts.asyncOp = asyncOp; \ return process_group->getBackend(c10::DeviceType::DEV) \ ->allreduce_coalesced(tensor_vec, opts); \ } @@ -214,13 +209,12 @@ IMPL_ALLREDUCE_COALESCED(PrivateUse1) const std::vector>& output_tensors, \ at::TensorList input_tensors, \ const c10::intrusive_ptr& process_group, \ - bool asyncOp, \ int64_t timeout) { \ auto input_tensors_vec = input_tensors.vec(); \ auto work = process_group->getBackend(c10::DeviceType::DEV) -> allgather( \ const_cast>&>(output_tensors), \ input_tensors_vec, \ - AllgatherOptions{std::chrono::milliseconds(timeout), asyncOp}); \ + AllgatherOptions{std::chrono::milliseconds(timeout)}); \ return std:: \ tuple>, c10::intrusive_ptr>( \ output_tensors, work); \ @@ -255,16 +249,12 @@ IMPL__ALLGATHER_BASE(PrivateUse1) c10::intrusive_ptr allgather_coalesced_##DEV( \ const std::vector>& output_lists, \ const at::TensorList& input_list, \ - const c10::intrusive_ptr& process_group, \ - bool asyncOp) { \ + const c10::intrusive_ptr& process_group) { \ auto input_list_vec = input_list.vec(); \ - auto opts = AllgatherOptions{}; \ - opts.asyncOp = asyncOp; \ return process_group->getBackend(c10::DeviceType::DEV) \ ->allgather_coalesced( \ const_cast>&>(output_lists), \ - input_list_vec, \ - opts); \ + input_list_vec); \ } IMPL_ALLGATHER_COALESCED(CPU) @@ -275,14 +265,11 @@ IMPL_ALLGATHER_COALESCED(PrivateUse1) c10::intrusive_ptr allgather_into_tensor_coalesced_##DEV( \ at::TensorList outputs, \ at::TensorList inputs, \ - const c10::intrusive_ptr& process_group, \ - bool asyncOp) { \ + const c10::intrusive_ptr& process_group) { \ auto output_vec = outputs.vec(); \ auto input_vec = inputs.vec(); \ - auto opts = AllgatherOptions{}; \ - opts.asyncOp = asyncOp; \ return process_group->getBackend(c10::DeviceType::DEV) \ - ->allgather_into_tensor_coalesced(output_vec, input_vec, opts); \ + ->allgather_into_tensor_coalesced(output_vec, input_vec); \ } IMPL_ALLGATHER_INTO_TENSOR_COALESCED(CPU) @@ -296,7 +283,6 @@ IMPL_ALLGATHER_INTO_TENSOR_COALESCED(PrivateUse1) const std::vector>& input_tensors, \ const c10::intrusive_ptr& process_group, \ const c10::intrusive_ptr& reduce_op, \ - bool asyncOp, \ int64_t timeout) { \ auto output_tensors_vec = output_tensors.vec(); \ auto work = \ @@ -304,9 +290,7 @@ IMPL_ALLGATHER_INTO_TENSOR_COALESCED(PrivateUse1) output_tensors_vec, \ const_cast>&>(input_tensors), \ ReduceScatterOptions{ \ - *reduce_op.get(), \ - std::chrono::milliseconds(timeout), \ - asyncOp}); \ + *reduce_op.get(), std::chrono::milliseconds(timeout)}); \ return std::tuple, c10::intrusive_ptr>( \ output_tensors_vec, work); \ } @@ -345,7 +329,6 @@ IMPL__REDUCE_SCATTER_BASE(PrivateUse1) at::TensorList inputs, \ const c10::intrusive_ptr& process_group, \ const c10::intrusive_ptr& reduce_op, \ - bool asyncOp, \ int64_t timeout) { \ auto output_vec = outputs.vec(); \ auto input_vec = inputs.vec(); \ @@ -354,9 +337,7 @@ IMPL__REDUCE_SCATTER_BASE(PrivateUse1) output_vec, \ input_vec, \ ReduceScatterOptions{ \ - *reduce_op.get(), \ - std::chrono::milliseconds(timeout), \ - asyncOp}); \ + *reduce_op.get(), std::chrono::milliseconds(timeout)}); \ } IMPL_REDUCE_SCATTER_TENSOR_COALESCED(CPU) @@ -369,15 +350,13 @@ IMPL_REDUCE_SCATTER_TENSOR_COALESCED(PrivateUse1) const at::TensorList& input_tensors, \ const c10::intrusive_ptr& process_group, \ int64_t root_rank, \ - bool asyncOp, \ int64_t timeout) { \ auto input_tensors_vec = input_tensors.vec(); \ return process_group->getBackend(c10::DeviceType::DEV) \ ->gather( \ const_cast>&>(output_tensors), \ input_tensors_vec, \ - GatherOptions{ \ - root_rank, std::chrono::milliseconds(timeout), asyncOp}); \ + GatherOptions{root_rank, std::chrono::milliseconds(timeout)}); \ } IMPL_GATHER(CPU) @@ -412,14 +391,13 @@ IMPL_SCATTER(PrivateUse1) const at::TensorList& output_tensors, \ const at::TensorList& input_tensors, \ const c10::intrusive_ptr& process_group, \ - bool asyncOp, \ int64_t timeout) { \ auto output_tensors_vec = output_tensors.vec(); \ auto input_tensors_vec = input_tensors.vec(); \ auto work = process_group->getBackend(c10::DeviceType::DEV) -> alltoall( \ output_tensors_vec, \ input_tensors_vec, \ - AllToAllOptions{std::chrono::milliseconds(timeout), asyncOp}); \ + AllToAllOptions{std::chrono::milliseconds(timeout)}); \ return std::tuple, c10::intrusive_ptr>( \ std::move(output_tensors_vec), work); \ } @@ -428,22 +406,21 @@ IMPL_ALLTOALL(CPU) IMPL_ALLTOALL(CUDA) IMPL_ALLTOALL(PrivateUse1) -#define IMPL_ALLTOALL_BASE(DEV) \ - c10::intrusive_ptr alltoall_base_##DEV( \ - at::Tensor& output, \ - at::Tensor& input, \ - const c10::intrusive_ptr& process_group, \ - std::vector output_split_sizes, \ - std::vector input_split_sizes, \ - bool asyncOp, \ - int64_t timeout) { \ - return process_group->getBackend(c10::DeviceType::DEV) \ - ->alltoall_base( \ - output, \ - input, \ - output_split_sizes, \ - input_split_sizes, \ - AllToAllOptions{std::chrono::milliseconds(timeout), asyncOp}); \ +#define IMPL_ALLTOALL_BASE(DEV) \ + c10::intrusive_ptr alltoall_base_##DEV( \ + at::Tensor& output, \ + at::Tensor& input, \ + const c10::intrusive_ptr& process_group, \ + std::vector output_split_sizes, \ + std::vector input_split_sizes, \ + int64_t timeout) { \ + return process_group->getBackend(c10::DeviceType::DEV) \ + ->alltoall_base( \ + output, \ + input, \ + output_split_sizes, \ + input_split_sizes, \ + AllToAllOptions{std::chrono::milliseconds(timeout)}); \ } IMPL_ALLTOALL_BASE(CPU) @@ -451,18 +428,15 @@ IMPL_ALLTOALL_BASE(CUDA) IMPL_ALLTOALL_BASE(PrivateUse1) // NOLINTBEGIN(performance-unnecessary-value-param) -#define IMPL_BARRIER(DEV) \ - c10::intrusive_ptr barrier##DEV( \ - at::Tensor /* unused */, \ - const c10::intrusive_ptr& process_group, \ - const std::vector& device_ids, \ - bool asyncOp, \ - int64_t timeout) { \ - auto opts = BarrierOptions{}; \ - opts.device_ids = device_ids; \ - opts.timeout = std::chrono::milliseconds(timeout); \ - opts.asyncOp = asyncOp; \ - return process_group->getBackend(c10::DeviceType::DEV)->barrier(opts); \ +#define IMPL_BARRIER(DEV) \ + c10::intrusive_ptr barrier##DEV( \ + at::Tensor /* unused */, \ + const c10::intrusive_ptr& process_group, \ + const std::vector& device_ids, \ + int64_t timeout) { \ + return process_group->getBackend(c10::DeviceType::DEV) \ + ->barrier( \ + BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}); \ } IMPL_BARRIER(CPU) @@ -490,7 +464,6 @@ allreduce_sparse_cuda_( const c10::intrusive_ptr& process_group, const c10::intrusive_ptr& reduce_op, const std::optional& sparse_indices, - bool asyncOp, int64_t timeout) { auto tensor_vec = tensors.vec(); auto work = process_group->getBackend(c10::DeviceType::CUDA) @@ -499,7 +472,6 @@ allreduce_sparse_cuda_( AllreduceOptions{ *reduce_op, std::chrono::milliseconds(timeout), - asyncOp, sparse_indices}); return std::tuple, c10::intrusive_ptr>( diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index 4ce67c9f5798..b3f3d9bdd72d 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -224,7 +224,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ReduceOp>&, const std::optional& sparse_indices, - bool, int64_t)>(); auto work = std::get<1>(op.call( @@ -232,7 +231,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive(opts.reduceOp), opts.sparseIndices, - opts.asyncOp, opts.timeout.count())); if (c10d::allow_inflight_collective_as_graph_input()) { @@ -252,14 +250,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { at::TensorList, const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ReduceOp>&, - bool, int64_t)>(); auto work = op.call( tensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive(opts.reduceOp), - opts.asyncOp, opts.timeout.count()); if (c10d::allow_inflight_collective_as_graph_input()) { @@ -281,7 +277,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const c10::intrusive_ptr<::c10d::ReduceOp>&, int64_t, int64_t, - bool, int64_t)>(); auto work = op.call( tensors, @@ -289,7 +284,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { c10::make_intrusive(opts.reduceOp), opts.rootRank, opts.rootTensor, - opts.asyncOp, opts.timeout.count()); if (c10d::allow_inflight_collective_as_graph_input()) { @@ -312,14 +306,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const std::vector>&, at::TensorList, const c10::intrusive_ptr<::c10d::ProcessGroup>&, - bool, int64_t)>(); auto work = std::get<1>(op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), - opts.asyncOp, opts.timeout.count())); if (c10d::allow_inflight_collective_as_graph_input()) { @@ -371,19 +363,18 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) { - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("c10d::allgather_coalesced_", "") - .typed( - const std::vector>&, - const at::TensorList&, - const c10::intrusive_ptr<::c10d::ProcessGroup>&, - bool)>(); + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::allgather_coalesced_", "") + .typed( + const std::vector>&, + const at::TensorList&, + const c10::intrusive_ptr<::c10d::ProcessGroup>&)>(); auto work = op.call( outputTensorLists, inputTensors, - c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), - opts.asyncOp); + c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this)); if (c10d::allow_inflight_collective_as_graph_input()) { for (const auto& tensor_list : outputTensorLists) { @@ -408,14 +399,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { .typed( const at::TensorList, const at::TensorList, - const c10::intrusive_ptr<::c10d::ProcessGroup>&, - bool)>(); + const c10::intrusive_ptr<::c10d::ProcessGroup>&)>(); auto work = op.call( outputTensors, inputTensors, - c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), - opts.asyncOp); + c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this)); if (c10d::allow_inflight_collective_as_graph_input()) { for (const auto& tensor : outputTensors) { @@ -436,14 +425,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const at::TensorList&, const c10::intrusive_ptr<::c10d::ProcessGroup>&, int64_t, - bool, int64_t)>(); auto work = op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.rootRank, - opts.asyncOp, opts.timeout.count()); if (c10d::allow_inflight_collective_as_graph_input()) { @@ -500,14 +487,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const std::vector>&, const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ReduceOp>&, - bool, int64_t)>(); auto work = std::get<1>(op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp), - opts.asyncOp, opts.timeout.count())); if (c10d::allow_inflight_collective_as_graph_input()) { @@ -561,7 +546,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const at::TensorList, const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ReduceOp>&, - bool, int64_t)>(); auto work = op.call( @@ -569,7 +553,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp), - opts.asyncOp, opts.timeout.count()); if (c10d::allow_inflight_collective_as_graph_input()) { @@ -594,7 +577,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const c10::intrusive_ptr<::c10d::ProcessGroup>&, std::vector, std::vector, - bool, int64_t)>(); auto work = op.call( outputBuffer, @@ -602,7 +584,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), outputSplitSizes, inputSplitSizes, - opts.asyncOp, opts.timeout.count()); if (c10d::allow_inflight_collective_as_graph_input()) { @@ -623,13 +604,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const at::TensorList&, const at::TensorList&, const c10::intrusive_ptr<::c10d::ProcessGroup>&, - bool, int64_t)>(); auto work = std::get<1>(op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), - opts.asyncOp, opts.timeout.count())); if (c10d::allow_inflight_collective_as_graph_input()) { @@ -799,14 +778,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { at::Tensor, const c10::intrusive_ptr<::c10d::ProcessGroup>&, const std::vector&, - bool, int64_t)>(); auto work = op.call( tensor, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.device_ids, - opts.asyncOp, opts.timeout.count()); if (c10d::allow_inflight_collective_as_graph_input()) { c10d::register_work(tensor, work); diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index b8175ec28379..23433e58003a 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -496,8 +496,6 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL( } futureWorkResult_ = c10::make_intrusive(c10::AnyEnumType::get()); - // other functions expect an initialized ptr - stashed_for_allocator_safety_ = std::make_shared>(); } ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w) @@ -519,11 +517,6 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w) numelIn_(w.numelIn_), numelOut_(w.numelOut_), store_(w.store_), - // Note: the `work` returned to user and the `work` enqueued to watchdog - // share the pointer to the tensor stash. At least one of them should - // clean the tensor stash, the earlier the better, i.e. user calling - // `work.wait` than watchdog detecting work completion. - stashed_for_allocator_safety_(w.stashed_for_allocator_safety_), futureWorkResult_(w.futureWorkResult_), timingEnabled_(w.timingEnabled_), trace_id_(w.trace_id_), @@ -717,25 +710,14 @@ void ProcessGroupNCCL::WorkNCCL::synchronize() { } } -void ProcessGroupNCCL::WorkNCCL::stashTensors( - std::vector& tensors) { - std::lock_guard lock(stashMutex_); - stashed_for_allocator_safety_->insert( - stashed_for_allocator_safety_->end(), tensors.begin(), tensors.end()); -} - -void ProcessGroupNCCL::WorkNCCL::unstashTensors() { - std::lock_guard lock(stashMutex_); - stashed_for_allocator_safety_->clear(); -} - void ProcessGroupNCCL::WorkNCCL::synchronizeStream() { auto currentStream = at::cuda::getCurrentCUDAStream(device_.index()); // Block the current stream on the NCCL stream ncclEndEvent_->block(currentStream); - // Unstage the stashed tensors so that CachingAllocator can recycle them - // THIS MUST HAPPEN AFTER THE BLOCKING CALL ABOVE - unstashTensors(); + + if (avoidRecordStreams_) { + stashed_for_allocator_safety_->clear(); + } } // Same as calling synchronize() when blockingWait_ is false @@ -951,10 +933,7 @@ ProcessGroupNCCL::ProcessGroupNCCL( enableTiming_.store( getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug_); #endif // ENABLE_NCCL_ERROR_CHECKING - if (getCvarBool(TORCH_NCCL_AVOID_RECORD_STREAMS, false)) { - TORCH_WARN_ONCE( - "TORCH_NCCL_AVOID_RECORD_STREAMS is the default now, this environment variable is thus deprecated."); - } + avoidRecordStreams_ = getCvarBool(TORCH_NCCL_AVOID_RECORD_STREAMS, false); #ifdef NCCL_HAS_COMM_REGISTER useTensorRegisterAllocatorHook_ = getCvarBool(TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK, false); @@ -2344,12 +2323,6 @@ void ProcessGroupNCCL::watchdogHandler() { // Clean up completed work if (work.isCompleted()) { - // In case user didn't call `work.wait()` with async collectives, - // watchdog would unstage the stashed tensors when detecting completion - // of the collective, to prevent ProcessGroupNCCL from holding reference - // to those tensors forever. - work.unstashTensors(); - // Work status logging for desync debug desyncDebugger_.logWorkEnd(work); @@ -3084,7 +3057,6 @@ c10::intrusive_ptr ProcessGroupNCCL::initWork( enableTiming_.load(), cudaEventCacheEnabled_.load(), dist_debug_level_); - if (record) { bool isP2P = isP2POp(opType); // Ideally record every work that we enqueue, rather than every work we @@ -3242,6 +3214,7 @@ c10::intrusive_ptr ProcessGroupNCCL::endCoalescing(OpType optype) { enqueue); work->ncclComm_ = comm; work->blockingWait_ = blockingWait_; + work->avoidRecordStreams_ = avoidRecordStreams_; work->store_ = store_; assignTimeoutToWork(work, options_); @@ -3260,16 +3233,19 @@ c10::intrusive_ptr ProcessGroupNCCL::endCoalescing(OpType optype) { // TODO(eqy): is this still necessary if avoidRecordStreams_ is set? work->ncclEndEvent_->record(ncclStream); + if (avoidRecordStreams_) { + // other functions expect an initialized ptr if avoidRecordStreams_ is set + work->stashed_for_allocator_safety_ = + std::make_shared>(); + } + if (enqueue) { workEnqueue(work); } - // Reset coalescing state coalescing_state_ = 0; coalescedComm_ = nullptr; - // If in async mode, return work; otherwise, kernel is enqueued on current - // stream, no need to return work - return coalescedAsync_ ? work : nullptr; + return work; } c10::intrusive_ptr ProcessGroupNCCL::endCoalescing() { @@ -3285,10 +3261,11 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( PreProcess pre, PostProcess post, OpType opType, - bool asyncOp, const char* profilingTitle, + bool avoidRecordStreams, bool nanCheck) { // Environment setting by the user may add onto collective call's option + avoidRecordStreams |= avoidRecordStreams_; nanCheck &= enableNanCheck_; auto device = getDevice(inputs[0]); @@ -3329,17 +3306,13 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( } else { TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG); } - coalescedAsync_ = asyncOp; } - // in asyncOp=false [default] mode, we use currentStream as ncclStream - // otherwise, we use separate ncclStream and let it sync on currentStream - auto ncclStream = asyncOp ? ncclStreams_.at(key) - : at::cuda::getCurrentCUDAStream(device.index()); - if (asyncOp) { - // First let NCCL streams wait for input tensors allocation streams - syncStream(device, ncclEvents_[key], ncclStream); - } + // Used many times below, so we stash the unordered_map lookup + auto ncclStream = ncclStreams_.at(key); + + // First let NCCL streams wait for input tensors allocation streams + syncStream(device, ncclEvents_[key], ncclStream); bool enqueue = !coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None; @@ -3349,12 +3322,9 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( // Store references to outputs to be used by WorkNCCL::result and operator<<. work->outputs_ = std::make_shared>(outputs); - // If we are performing sync operations, i.e. equeuing kernel onto "current" - // stream, we don't need to do anything for tensor lifetime management. - // Otherwise, we need to stage the tensors will `work.wait()`. - if (asyncOp) { - work->stashTensors(inputs); - work->stashTensors(outputs); + if (avoidRecordStreams) { + work->stashed_for_allocator_safety_ = + std::make_shared>(inputs); } if (nanCheck) { @@ -3380,6 +3350,21 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( // operations where `inputs' and `outputs' are not the same. // // See [Sync Streams]. + if (!avoidRecordStreams) { + for (const auto& input : inputs) { + if (!input.is_sparse()) { + c10::cuda::CUDACachingAllocator::recordStream( + input.storage().data_ptr(), ncclStream); + } else { + // for sparse input case record streams on both index and value + // tensors + c10::cuda::CUDACachingAllocator::recordStream( + input.values().storage().data_ptr(), ncclStream); + c10::cuda::CUDACachingAllocator::recordStream( + input.indices().storage().data_ptr(), ncclStream); + } + } + } // Not all collectives have the same signature, e.g, all-reduce take in a Tensor // as the input and output while all-to-all take in a vector of Tensors as input @@ -3431,6 +3416,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( // Set appropriate work parameters. work->blockingWait_ = blockingWait_; + work->avoidRecordStreams_ = avoidRecordStreams; work->store_ = store_; assignTimeoutToWork(work, options_); // Record size info for debug. We only record the size on the first device as @@ -3448,7 +3434,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( workEnqueue(work); } - return asyncOp ? work : nullptr; + return work; } template @@ -3457,8 +3443,11 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( std::vector& outputs, Fn fn, OpType opType, - bool asyncOp, - const char* profilingTitle) { + const char* profilingTitle, + bool avoidRecordStreams) { + // Environment setting by the user may add onto collective call's option + avoidRecordStreams |= avoidRecordStreams_; + // Currently, the API permits one scenario where inputs.size() and // outputs.size() are > 0. // 1. If the call was a _coalesced call, all inputs must be on the same @@ -3504,17 +3493,13 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( } else { TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG); } - coalescedAsync_ = asyncOp; } - // in asyncOp=false [default] mode, we use currentStream as ncclStream - // otherwise, we use separate ncclStream and let it sync on currentStream - auto ncclStream = asyncOp ? ncclStreams_.at(key) - : at::cuda::getCurrentCUDAStream(device.index()); - if (asyncOp) { - // First let NCCL streams wait for input tensors allocation streams - syncStream(device, ncclEvents_[key], ncclStream); - } + // Used many times below, so we stash the unordered_map lookup + auto ncclStream = ncclStreams_.at(key); + + // First let NCCL streams wait for input tensors allocation streams + syncStream(device, ncclEvents_[key], ncclStream); auto work = initWork( device, @@ -3529,12 +3514,9 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( // Store references to outputs to be used by WorkNCCL::result and operator<<. work->outputs_ = std::make_shared>(outputs); - // If we are performing sync operations, i.e. equeuing kernel onto "current" - // stream, we don't need to do anything for tensor lifetime management. - // Otherwise, we need to stage the tensors will `work.wait()`. - if (asyncOp) { - work->stashTensors(inputs); - work->stashTensors(outputs); + if (avoidRecordStreams) { + work->stashed_for_allocator_safety_ = + std::make_shared>(inputs); } // Start event should only be recorded before the ncclGroupStart() (which @@ -3560,6 +3542,27 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( { torch::cuda::nccl::AutoNcclGroup nccl_group_guard(comm, useNonblocking()); for (const auto i : c10::irange(inputs.size())) { + // Both `inputs' and `outputs' are created on a worker stream and used in + // different ncclStreams. Hence, both must record the ncclStream to + // prevent being freed before the collective finishes. + // + // We only record `inputs' here, and leave recording `outputs' to `fn' for + // operations where `inputs' and `outputs' are not the same. + // + // See [Sync Streams]. + if (!avoidRecordStreams) { + if (!inputs[i].is_sparse()) { + c10::cuda::CUDACachingAllocator::recordStream( + inputs[i].storage().data_ptr(), ncclStream); + } else { + // for sparse input case record streams on both index and value + // tensors + c10::cuda::CUDACachingAllocator::recordStream( + inputs[i].values().storage().data_ptr(), ncclStream); + c10::cuda::CUDACachingAllocator::recordStream( + inputs[i].indices().storage().data_ptr(), ncclStream); + } + } #ifndef NCCL_HAS_COMM_NONBLOCKING C10D_NCCL_CHECK( fn(inputs[i], outputs[i], comm, ncclStream), @@ -3600,6 +3603,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( // Set appropriate work parameters. work->blockingWait_ = blockingWait_; + work->avoidRecordStreams_ = avoidRecordStreams; work->store_ = store_; assignTimeoutToWork(work, options_); // Record size info for debug. We only record the size on the first device as @@ -3630,7 +3634,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( // it, since interactions with it by usercode won't behave normally - they // won't observe work completion, for instance. Will this lead to silent // problems during capture? - return asyncOp ? work : nullptr; + return work; } template @@ -3648,8 +3652,13 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( // to wait() on the returned handle, so ProcessGroupNCCL can't know // when it's safe to release the input back to the allocator, // and the present call has no way to know it's not an isend. - // Therefore, we warn and fall back to the typical recordStream logic. - // TODO( kwen2501 ): revisit this when we have a better solution. + // Therefore, we warn and fall back to the typical recordStream logic: + if (avoidRecordStreams_) { + TORCH_WARN_ONCE( + "TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point " + "collectives."); + } + auto device = getDevice(tensor); at::cuda::OptionalCUDAGuard gpuGuard(device); @@ -3704,8 +3713,6 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( } else { TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG); } - // For now, P2P ops are always put on internal stream - coalescedAsync_ = true; } // Used many times below, so we stash the unordered_map lookup @@ -3877,8 +3884,8 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( PreProcess pre, PostProcess post, OpType opType, - bool asyncOp, const char* profilingTitle, + bool avoidRecordStreams, bool nanCheck) { auto inputs = std::vector{input}; auto outputs = std::vector{output}; @@ -3889,8 +3896,8 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( pre, post, opType, - asyncOp, profilingTitle, + avoidRecordStreams, nanCheck); } @@ -3900,8 +3907,8 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( at::Tensor& output, Fn fn, OpType opType, - bool asyncOp, const char* profilingTitle, + bool avoidRecordStreams, bool nanCheck) { auto inputs = std::vector{input}; auto outputs = std::vector{output}; @@ -3914,8 +3921,8 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( [](at::cuda::CUDAStream&, c10::intrusive_ptr& work) {}, opType, - asyncOp, profilingTitle, + avoidRecordStreams, nanCheck); } @@ -3967,8 +3974,6 @@ c10::intrusive_ptr ProcessGroupNCCL::allreduce_sparse( auto recvIndices = indices[0] * colSize; // prevent output and recvIndices from being freed - // TODO: not changing the lifetime management of outputs this time, - // revisit later c10::cuda::CUDACachingAllocator::recordStream( output.storage().data_ptr(), stream); c10::cuda::CUDACachingAllocator::recordStream( @@ -4000,7 +4005,6 @@ c10::intrusive_ptr ProcessGroupNCCL::allreduce_sparse( } }, OpType::_ALLREDUCE_SPARSE, - opts.asyncOp, "nccl:all_reduce_sparse"); return work; #else @@ -4035,7 +4039,6 @@ c10::intrusive_ptr ProcessGroupNCCL::allreduce_impl( stream.stream()); }, OpType::ALLREDUCE, - opts.asyncOp, profilingTitle); } @@ -4136,7 +4139,6 @@ c10::intrusive_ptr ProcessGroupNCCL::allreduce_coalesced( stream.stream()); }, OpType::COALESCED, - opts.asyncOp, "nccl:allreduce_coalesced"); } @@ -4168,10 +4170,12 @@ c10::intrusive_ptr ProcessGroupNCCL::broadcast( globalRankStride_, // globalRankStride_ this->getSize()); // worldSize + // avoidRecordStreams_ note: collective() will stash tensors. + bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); + const auto root = opts.rootRank + opts.rootTensor; bool nanCheck = (root == rank_); - // avoidRecordStreams_ note: collective() will stash tensors. return collective( tensor, tensor, @@ -4188,8 +4192,8 @@ c10::intrusive_ptr ProcessGroupNCCL::broadcast( stream.stream()); }, OpType::BROADCAST, - opts.asyncOp, "nccl:broadcast", + avoidRecordStreams, nanCheck); } @@ -4228,8 +4232,8 @@ c10::intrusive_ptr ProcessGroupNCCL::_broadcast_oop( stream.stream()); }, OpType::BROADCAST, - opts.asyncOp, "nccl:_broadcast_oop", + /*avoidRecordStreams=*/false, nanCheck); } @@ -4288,7 +4292,6 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce( stream.stream()); }, OpType::REDUCE, - opts.asyncOp, "nccl:reduce"); } @@ -4330,7 +4333,6 @@ c10::intrusive_ptr ProcessGroupNCCL::_reduce_oop( stream.stream()); }, OpType::REDUCE, - opts.asyncOp, "nccl:_reduce_oop"); } @@ -4374,7 +4376,10 @@ c10::intrusive_ptr ProcessGroupNCCL::allgather( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { - // See [We actually don't need to stash anything here]. + if (!avoidRecordStreams_) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } return ncclAllGather( input.data_ptr(), output.data_ptr(), @@ -4390,27 +4395,27 @@ c10::intrusive_ptr ProcessGroupNCCL::allgather( // - inputTensors is stashed onto work->stashed_for_allocator_safety_ // in collective(). // - outputFlattened is stashed onto work->outputs_ in collective(). + // - User-facing outputTensors should be held by the user until after + // waiting on work_, or the call makes no sense. + // So all participating tensors are accounted for, and won't be + // released back to their allocation streams until after work_ is + // waited on. }, [&](at::cuda::CUDAStream& ncclStream, c10::intrusive_ptr& work) { - // User-facing outputTensors should be held by the user until after - // waiting on work_, or the call makes no sense. We do a stashing here - // in case user doesn't hold the outputTensors in downstream code, - // which can cause an early recyle by the CachingAllocator, which can - // lead to segfault or data corruption. - if (opts.asyncOp) { - work->stashTensors(outputTensors_); - } // Copy the flattened output tensors to the outputs. at::cuda::CUDAStreamGuard guard(ncclStream); for (const auto j : c10::irange(outputTensors_.size())) { - // See [We actually don't need to stash anything here]. + // See [Sync Streams]. + if (!avoidRecordStreams_) { + c10::cuda::CUDACachingAllocator::recordStream( + outputTensors_[j].storage().data_ptr(), ncclStream); + } outputTensors_[j].copy_( outputFlattened[static_cast(j)], true); } }, OpType::ALLGATHER, - opts.asyncOp, "nccl:all_gather"); } else { const auto num_reduces = outputTensors_.size(); @@ -4418,8 +4423,7 @@ c10::intrusive_ptr ProcessGroupNCCL::allgather( for (const int64_t i : c10::irange(static_cast(num_reduces))) { auto& output = outputTensors_[i]; auto& input = (i == rank_) ? inputTensor : output; - auto broadcastOpts = - BroadcastOptions{i, int64_t(0), opts.timeout, opts.asyncOp}; + auto broadcastOpts = BroadcastOptions{i, int64_t(0), opts.timeout}; _broadcast_oop(output, input, broadcastOpts); } auto work = endCoalescing(OpType::ALLGATHER); @@ -4475,7 +4479,6 @@ c10::intrusive_ptr ProcessGroupNCCL::allgather_into_tensor_coalesced( stream.stream()); }, OpType::COALESCED, - opts.asyncOp, "nccl:all_gather_into_tensor_coalesced"); } @@ -4521,6 +4524,10 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { + if (!avoidRecordStreams_) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } const auto ncclDataType = getNcclDataType(input.scalar_type()); const auto ncclReduceOp = getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); @@ -4535,18 +4542,27 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( }, [&](at::cuda::CUDAStream& ncclStream, c10::intrusive_ptr& work) { - // We only need to stash inputTensors. - // - inputFlattened is stashed onto - // work->stashed_for_allocator_safety_ in collective(). - // - User-facing outputTensors is stashed onto work->outputs_ in - // collective(), and should also be held by the user until after - // waiting on work_. - if (opts.asyncOp) { - work->stashTensors(inputTensors_); + if (avoidRecordStreams_) { + // We only need to stash inputTensors. + // - inputFlattened is stashed onto + // work->stashed_for_allocator_safety_ + // in collective(). + // - User-facing outputTensors is stashed onto work->outputs_ in + // collective(), + // and should also be held by the user until after waiting on + // work_. + auto& v = work->stashed_for_allocator_safety_; + v->insert(v->end(), inputTensors_.begin(), inputTensors_.end()); } + // Copy the input tensors to the flattened inputs. at::cuda::CUDAStreamGuard guard(ncclStream); for (const auto j : c10::irange(inputTensors_.size())) { + // See [Sync Streams]. + if (!avoidRecordStreams_) { + c10::cuda::CUDACachingAllocator::recordStream( + inputTensors_[j].storage().data_ptr(), ncclStream); + } inputFlattened[static_cast(j)].copy_( inputTensors_[j], true); } @@ -4554,7 +4570,6 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( [&](at::cuda::CUDAStream&, c10::intrusive_ptr& work) {}, OpType::REDUCE_SCATTER, - opts.asyncOp, "nccl:reduce_scatter"); } else { const auto num_reduces = inputTensors_.size(); @@ -4566,8 +4581,7 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( opts.reduceOp, static_cast(i), static_cast(0), - opts.timeout, - opts.asyncOp}; + opts.timeout}; _reduce_oop(output, input, reduceOpts); } auto work = endCoalescing(OpType::REDUCE_SCATTER); @@ -4621,6 +4635,7 @@ c10::intrusive_ptr ProcessGroupNCCL::_reduce_scatter_base( // stream so that the caching allocator can reuse memory pool for this stream // in a clever way. This setting is added for libraries like FSDP which uses // `reduce_scatter_tensor`. + bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); return collective( inputTensor, @@ -4629,6 +4644,10 @@ c10::intrusive_ptr ProcessGroupNCCL::_reduce_scatter_base( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { + if (!avoidRecordStreams) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } auto ncclDataType = getNcclDataType(input.scalar_type()); auto ncclReduceOp = getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); @@ -4642,8 +4661,8 @@ c10::intrusive_ptr ProcessGroupNCCL::_reduce_scatter_base( stream.stream()); }, OpType::_REDUCE_SCATTER_BASE, - opts.asyncOp, - "nccl:_reduce_scatter_base"); + "nccl:_reduce_scatter_base", + avoidRecordStreams); } c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter_tensor_coalesced( @@ -4680,6 +4699,10 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter_tensor_coalesced( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { + if (!avoidRecordStreams_) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } auto ncclDataType = getNcclDataType(input.scalar_type()); auto ncclReduceOp = getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); @@ -4693,7 +4716,6 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter_tensor_coalesced( stream.stream()); }, OpType::COALESCED, - opts.asyncOp, "nccl:reduce_scatter_tensor_coalesced"); } @@ -4772,28 +4794,13 @@ c10::intrusive_ptr ProcessGroupNCCL::barrier(const BarrierOptions& opts) { at::zeros({1}, at::TensorOptions().device(barDevice).dtype(at::kFloat)); // All reduce to achieve the barrier - AllreduceOptions arOpts = AllreduceOptions(); - arOpts.asyncOp = opts.asyncOp; - auto work = allreduce_impl(barrierTensor, "nccl:all_reduce_barrier", arOpts); + auto work = allreduce_impl(barrierTensor, "nccl:all_reduce_barrier"); - if (opts.asyncOp) { - // Work will take over barrierTensors - auto ncclWork = dynamic_cast(work.get()); - // If user specified async, the work should not be nullptr - TORCH_CHECK(ncclWork); - // Put a marker here so that `work.wait()` issue by users does - // barrier-specific thing: CPU sync - ncclWork->isBarrierOp_ = true; - return work; - } - - // Otherwise, we are in sync mode, we directly wait here. - // (It is a CPU wait for barrier) - auto currentStream = at::cuda::getCurrentCUDAStream(barDevIdx); - // CUDAStream wrapper will correctly use a DeviceGuard here - currentStream.synchronize(); - // No work to return - return nullptr; + // Work will take over barrierTensors + auto ncclWork = dynamic_cast(work.get()); + TORCH_CHECK(ncclWork); + ncclWork->isBarrierOp_ = true; + return work; } c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( @@ -4801,7 +4808,7 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( at::Tensor& inputTensor, std::vector& outputSplitSizes, std::vector& inputSplitSizes, - const AllToAllOptions& opts) { + const AllToAllOptions& /* unused */) { check_gpu_single_tensor(outputTensor); check_gpu_single_tensor(inputTensor); if (outputSplitSizes.empty() && inputSplitSizes.empty()) { @@ -4832,12 +4839,16 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { + // See [Sync Streams]. + if (!avoidRecordStreams_) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } torch::cuda::nccl::all2all_single_equal_split( input, output, this->getSize(), comm, stream); return ncclSuccess; }, OpType::ALLTOALL_BASE, - opts.asyncOp, "nccl:all_to_all"); } else { c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); @@ -4879,6 +4890,10 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( c10d::computeLengthsAndOffsets( outputSplitSizes, output, &recv_lengths, &recv_offsets); // See [Sync Streams]. + if (!avoidRecordStreams_) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } torch::cuda::nccl::all2all_single_unequal_split( input.data_ptr(), send_lengths.data(), @@ -4893,7 +4908,6 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( return ncclSuccess; }, OpType::ALLTOALL_BASE, - opts.asyncOp, "nccl:all_to_all"); } } @@ -4901,7 +4915,7 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( c10::intrusive_ptr ProcessGroupNCCL::alltoall( std::vector& outputTensors, std::vector& inputTensors, - const AllToAllOptions& opts) { + const AllToAllOptions& /* unused */) { std::vector inSplitSizes; std::vector outSplitSizes; int64_t total_numel = 0; @@ -4948,11 +4962,18 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall( return ncclSuccess; }, [&](at::cuda::CUDAStream&, - c10::intrusive_ptr& work) {}, + c10::intrusive_ptr& work) { + if (avoidRecordStreams_) { + // inputTensor0 and outputTensor0 are stashed redundantly by + // collective(), but that's ok. + auto& v = work->stashed_for_allocator_safety_; + v->insert(v->end(), inputTensors.begin(), inputTensors.end()); + v->insert(v->end(), outputTensors.begin(), outputTensors.end()); + } + }, [](at::cuda::CUDAStream&, c10::intrusive_ptr& work) {}, OpType::ALLTOALL, - opts.asyncOp, "nccl:all_to_all"); } @@ -5150,6 +5171,14 @@ c10::intrusive_ptr ProcessGroupNCCL::gather( ncclComm_t comm, at::cuda::CUDAStream& stream) { const auto root = opts.rootRank; + if (getRank() == root) { + if (!avoidRecordStreams_) { + for (auto const& output : outputs) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } + } + } torch::cuda::nccl::gather( inputTensor, outputs, comm, stream, static_cast(root)); return ncclSuccess; @@ -5159,7 +5188,6 @@ c10::intrusive_ptr ProcessGroupNCCL::gather( [](at::cuda::CUDAStream&, c10::intrusive_ptr& work) {}, OpType::GATHER, - opts.asyncOp, "nccl:gather"); } @@ -5228,6 +5256,8 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( // avoidRecordStreams_ note: collective() will stash outputTensors and // inputs, which == inputTensors[0] on the root rank where it matters. + bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); + const auto root = opts.rootRank; bool nanCheck = (rank_ == root); @@ -5239,6 +5269,14 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( at::Tensor& /* unused */, ncclComm_t comm, at::cuda::CUDAStream& stream) { + if (getRank() == root) { + if (!avoidRecordStreams) { + for (auto const& input : inputs) { + c10::cuda::CUDACachingAllocator::recordStream( + input.storage().data_ptr(), stream); + } + } + } torch::cuda::nccl::scatter( inputs, outputTensor, comm, stream, static_cast(root)); return ncclSuccess; @@ -5248,8 +5286,8 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( [](at::cuda::CUDAStream&, c10::intrusive_ptr& work) {}, OpType::SCATTER, - opts.asyncOp, "nccl:scatter", + avoidRecordStreams, nanCheck); } @@ -5305,6 +5343,7 @@ c10::intrusive_ptr ProcessGroupNCCL::_allgather_base( // stream so that the caching allocator can reuse memory pool for this stream // in a clever way. This setting is added for libraries like FSDP which uses // `all_gather_into_tensor`. + bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); return collective( input_tensor, @@ -5313,6 +5352,10 @@ c10::intrusive_ptr ProcessGroupNCCL::_allgather_base( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { + if (!avoidRecordStreams) { + c10::cuda::CUDACachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } return ncclAllGather( input.data_ptr(), output.data_ptr(), @@ -5322,8 +5365,8 @@ c10::intrusive_ptr ProcessGroupNCCL::_allgather_base( stream.stream()); }, OpType::_ALLGATHER_BASE, - opts.asyncOp, - "nccl:_all_gather_base"); + "nccl:_all_gather_base", + avoidRecordStreams); } // Create a memory allocator for NCCL. This allocator is used to allocate memory diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 09b2137475b4..af57c3f294d2 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -382,6 +382,9 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Clone of blockingWait_ from ProcessGroupNCCL. bool blockingWait_{false}; + // Clone of avoidRecordStreams_ from ProcessGroupNCCL. + bool avoidRecordStreams_{false}; + // Clone of opTimeout_ from ProcessGroupNCCL. std::chrono::milliseconds opTimeout_{}; @@ -428,13 +431,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { // exception_ptr. bool finishedGPUExecutionInternal() const; - // Stash tensors so that CachingAllocator cannot recycle them prematurely. - // Used in case of async ops. - void stashTensors(std::vector& tensors); - - // Unstage the stashed tensors so that CachingAllocator can recycle them - void unstashTensors(); - // Reference to the store so that we can write aborted communicators // to the store. c10::intrusive_ptr store_; @@ -454,9 +450,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { // For in-place collectives, some refs stashed here may alias outputs_, // but that doesn't do any harm. std::shared_ptr> stashed_for_allocator_safety_; - // Need a mutex to protect stashed_for_allocator_safety_ because it can be - // accessed from both main thread and watchdog thread. - std::mutex stashMutex_; // The future returned by getFuture. c10::intrusive_ptr future_; @@ -885,8 +878,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { at::Tensor& output, Fn fn, OpType opType, - bool asyncOp, const char* profilingTitle = nullptr, + bool avoidRecordStreams = false, bool nanCheck = true); template @@ -897,8 +890,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { PreProcess pre, PostProcess post, OpType opType, - bool asyncOp, const char* profilingTitle = nullptr, + bool avoidRecordStreams = false, bool nanCheck = true); template @@ -909,8 +902,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { PreProcess pre, PostProcess post, OpType opType, - bool asyncOp, const char* profilingTitle = nullptr, + bool avoidRecordStreams = false, bool nanCheck = true); template @@ -919,8 +912,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { std::vector& output, Fn fn, OpType opType, - bool asyncOp, - const char* profilingTitle = nullptr); + const char* profilingTitle = nullptr, + bool avoidRecordStreams = false); // Helper that encapsulates work shared across point-to-point communication // primitives. It is the same structure as the helper used for collective @@ -1229,9 +1222,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Stores communicators for all collectives run inside a coalescing block std::shared_ptr coalescedComm_ = nullptr; - // Whether the coalesced calls are sync or async. - bool coalescedAsync_; - // Whether or not wait() and synchronize() are blocking operations that wait // for the operation to complete. bool blockingWait_ = false; diff --git a/torch/csrc/distributed/c10d/Types.hpp b/torch/csrc/distributed/c10d/Types.hpp index 8fec5dd0e9e2..5d15708c953e 100644 --- a/torch/csrc/distributed/c10d/Types.hpp +++ b/torch/csrc/distributed/c10d/Types.hpp @@ -122,7 +122,6 @@ struct BroadcastOptions { struct AllreduceOptions { ReduceOp reduceOp = ReduceOp::SUM; std::chrono::milliseconds timeout = kUnsetTimeout; - bool asyncOp = true; std::optional sparseIndices = std::nullopt; }; @@ -133,7 +132,6 @@ struct ReduceOptions { int64_t rootRank = 0; int64_t rootTensor = 0; std::chrono::milliseconds timeout = kUnsetTimeout; - bool asyncOp = true; }; struct AllgatherOptions { @@ -144,7 +142,6 @@ struct AllgatherOptions { struct GatherOptions { int64_t rootRank = 0; std::chrono::milliseconds timeout = kUnsetTimeout; - bool asyncOp = true; }; struct ScatterOptions { @@ -161,14 +158,12 @@ struct ReduceScatterOptions { struct AllToAllOptions { std::chrono::milliseconds timeout = kUnsetTimeout; - bool asyncOp = true; }; struct BarrierOptions { std::vector device_ids; std::chrono::milliseconds timeout = kUnsetTimeout; std::optional device; - bool asyncOp = true; }; struct DistributedBackendOptions { diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 4033fcc0ad68..50683f9e29a4 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -999,23 +999,20 @@ This class does not support ``__members__`` property.)"); py::class_<::c10d::AllreduceOptions>(module, "AllreduceOptions") .def(py::init<>()) .def_readwrite("reduceOp", &::c10d::AllreduceOptions::reduceOp) - .def_readwrite("timeout", &::c10d::AllreduceOptions::timeout) - .def_readwrite("asyncOp", &::c10d::AllreduceOptions::asyncOp); + .def_readwrite("timeout", &::c10d::AllreduceOptions::timeout); py::class_<::c10d::AllreduceCoalescedOptions>( module, "AllreduceCoalescedOptions") .def(py::init<>()) .def_readwrite("reduceOp", &::c10d::AllreduceCoalescedOptions::reduceOp) - .def_readwrite("timeout", &::c10d::AllreduceCoalescedOptions::timeout) - .def_readwrite("asyncOp", &::c10d::AllreduceCoalescedOptions::asyncOp); + .def_readwrite("timeout", &::c10d::AllreduceCoalescedOptions::timeout); py::class_<::c10d::ReduceOptions>(module, "ReduceOptions") .def(py::init<>()) .def_readwrite("reduceOp", &::c10d::ReduceOptions::reduceOp) .def_readwrite("rootRank", &::c10d::ReduceOptions::rootRank) .def_readwrite("rootTensor", &::c10d::ReduceOptions::rootTensor) - .def_readwrite("timeout", &::c10d::ReduceOptions::timeout) - .def_readwrite("asyncOp", &::c10d::ReduceOptions::asyncOp); + .def_readwrite("timeout", &::c10d::ReduceOptions::timeout); py::class_<::c10d::AllgatherOptions>(module, "AllgatherOptions") .def(py::init<>()) @@ -1025,8 +1022,7 @@ This class does not support ``__members__`` property.)"); py::class_<::c10d::GatherOptions>(module, "GatherOptions") .def(py::init<>()) .def_readwrite("rootRank", &::c10d::GatherOptions::rootRank) - .def_readwrite("timeout", &::c10d::GatherOptions::timeout) - .def_readwrite("asyncOp", &::c10d::GatherOptions::asyncOp); + .def_readwrite("timeout", &::c10d::GatherOptions::timeout); py::class_<::c10d::ScatterOptions>(module, "ScatterOptions") .def(py::init<>()) @@ -1044,13 +1040,11 @@ This class does not support ``__members__`` property.)"); .def(py::init<>()) .def_readwrite("device_ids", &::c10d::BarrierOptions::device_ids) .def_readwrite("timeout", &::c10d::BarrierOptions::timeout) - .def_readwrite("device", &::c10d::BarrierOptions::device) - .def_readwrite("asyncOp", &::c10d::BarrierOptions::asyncOp); + .def_readwrite("device", &::c10d::BarrierOptions::device); py::class_<::c10d::AllToAllOptions>(module, "AllToAllOptions") .def(py::init<>()) - .def_readwrite("timeout", &::c10d::AllToAllOptions::timeout) - .def_readwrite("asyncOp", &::c10d::AllToAllOptions::asyncOp); + .def_readwrite("timeout", &::c10d::AllToAllOptions::timeout); py::class_<::c10d::DistributedBackendOptions>( module, "_DistributedBackendOptions") diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 3477250ef02b..2eed6d6aa021 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -2500,7 +2500,7 @@ class _CoalescingManager: def __init__(self) -> None: self.works: list[Work] = [] - def append(self, work: Optional[Work] = None): + def append(self, work: Work): if work: self.works.append(work) @@ -2513,7 +2513,7 @@ class _CoalescingManager: def _coalescing_manager( group: Optional[ProcessGroup] = None, device: Optional[torch.device] = None, - async_ops: bool = False, + async_ops: Optional[bool] = False, ): """ Context manager used to coalesce collectives or P2P operations when possible. @@ -2552,7 +2552,6 @@ def _coalescing_manager( group._start_coalescing(device) cm = _CoalescingManager() yield cm - work = None op_list = _world.pg_coalesce_state.pop(group) if op_list: # Collectives supporting "Fast Path" coalescing are captured. @@ -2566,7 +2565,6 @@ def _coalescing_manager( tensors = [op.tensor for op in op_list] all_reduce_opts = AllreduceCoalescedOptions() all_reduce_opts.reduceOp = not_none(op_list[0].redop) - all_reduce_opts.asyncOp = async_ops work = group.allreduce_coalesced(tensors, all_reduce_opts) elif op0 == all_gather_into_tensor: inputs = [] @@ -2574,8 +2572,6 @@ def _coalescing_manager( for op in op_list: inputs.append(op.tensor) outputs.append(not_none(op.dst_tensor)) - all_gather_opts = AllgatherOptions() - all_gather_opts.asyncOp = async_ops work = group.allgather_into_tensor_coalesced(outputs, inputs) elif op0 == reduce_scatter_tensor: inputs = [] @@ -2585,7 +2581,6 @@ def _coalescing_manager( outputs.append(not_none(op.dst_tensor)) reduce_opts = ReduceScatterOptions() reduce_opts.reduceOp = not_none(op_list[0].redop) - reduce_opts.asyncOp = async_ops work = group.reduce_scatter_tensor_coalesced(outputs, inputs, reduce_opts) else: raise AssertionError( @@ -2598,12 +2593,9 @@ def _coalescing_manager( work = group._end_coalescing(device) if async_ops: - cm.append(work) - elif ( - work is not None - ): # Backward compatible with backends that don't sync at CPP level - work.wait() - # Otherwise, the backend has sync'ed at CPP level + cm.append(work) # type: ignore[possibly-undefined] + else: + work.wait() # type: ignore[possibly-undefined] def batch_isend_irecv(p2p_op_list: list[P2POp]) -> list[Work]: @@ -2728,11 +2720,8 @@ def broadcast( work = group.broadcast([tensor], opts) if async_op: return work - elif ( - work is not None - ): # Backward compatible with backends that don't sync at CPP level + else: work.wait() - # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -2812,7 +2801,6 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): opts = AllreduceOptions() opts.reduceOp = op - opts.asyncOp = async_op if group is None: group = _get_default_group() @@ -2829,11 +2817,8 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): if async_op: return work - elif ( - work is not None - ): # Backward compatible with backends that don't sync at CPP level + else: work.wait() - # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -2892,17 +2877,13 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False): opts = AllreduceCoalescedOptions() opts.reduceOp = op - opts.asyncOp = async_op group = group or _get_default_group() work = group.allreduce_coalesced(tensors, opts) if async_op: return work.get_future() - elif ( - work is not None - ): # Backward compatible with backends that don't sync at CPP level + else: work.wait() - # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -2947,15 +2928,11 @@ def reduce( opts = ReduceOptions() opts.reduceOp = op opts.rootRank = group_dst - opts.asyncOp = async_op work = group.reduce([tensor], opts) if async_op: return work - elif ( - work is not None - ): # Backward compatible with backends that don't sync at CPP level + else: work.wait() - # Otherwise, the backend has sync'ed at CPP level def _object_to_tensor(obj, device, group): @@ -3754,17 +3731,12 @@ def all_gather(tensor_list, tensor, group=None, async_op=False): tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) group = group or _get_default_group() - opts = AllgatherOptions() - opts.asyncOp = async_op - work = group.allgather([tensor_list], [tensor], opts) + work = group.allgather([tensor_list], [tensor]) if async_op: return work - elif ( - work is not None - ): # Backward compatible with backends that don't sync at CPP level + else: work.wait() - # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -3871,11 +3843,8 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal if async_op: return work - elif ( - work is not None - ): # Backward compatible with backends that don't sync at CPP level + else: work.wait() - # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -3985,17 +3954,12 @@ def all_gather_coalesced( ] group = group or _get_default_group() - opts = AllgatherOptions() - opts.asyncOp = async_op - work = group.allgather_coalesced(output_tensor_lists, input_tensor_list, opts) + work = group.allgather_coalesced(output_tensor_lists, input_tensor_list) if async_op: return work.get_future() - elif ( - work is not None - ): # Backward compatible with backends that don't sync at CPP level + else: work.wait() - # Otherwise, the backend has sync'ed at CPP level def _validate_output_list_for_rank(my_rank, dst, gather_list): @@ -4082,16 +4046,12 @@ def gather( opts = GatherOptions() opts.rootRank = group_dst - opts.asyncOp = async_op work = group.gather(output_tensors, input_tensors, opts) if async_op: return work - elif ( - work is not None - ): # Backward compatible with backends that don't sync at CPP level + else: work.wait() - # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -4193,11 +4153,8 @@ def scatter( if async_op: return work - elif ( - work is not None - ): # Backward compatible with backends that don't sync at CPP level + else: work.wait() - # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -4229,18 +4186,14 @@ def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=Fal opts = ReduceScatterOptions() opts.reduceOp = op - opts.asyncOp = async_op group = group or _get_default_group() work = group.reduce_scatter([output], [input_list], opts) if async_op: return work - elif ( - work is not None - ): # Backward compatible with backends that don't sync at CPP level + else: work.wait() - # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -4340,11 +4293,8 @@ def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=F if async_op: return work - elif ( - work is not None - ): # Backward compatible with backends that don't sync at CPP level + else: work.wait() - # Otherwise, the backend has sync'ed at CPP level @deprecated( @@ -4497,7 +4447,6 @@ def all_to_all_single( return opts = AllToAllOptions() - opts.asyncOp = async_op _check_single_tensor(output, "output") _check_single_tensor(input, "input") _ensure_all_tensors_same_dtype(output, input) @@ -4517,11 +4466,8 @@ def all_to_all_single( if async_op: return work - elif ( - work is not None - ): # Backward compatible with backends that don't sync at CPP level + else: work.wait() - # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -4622,7 +4568,6 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False return opts = AllToAllOptions() - opts.asyncOp = async_op _check_tensor_list(output_tensor_list, "output_tensor_list") _check_tensor_list(input_tensor_list, "input_tensor_list") _ensure_all_tensors_same_dtype(output_tensor_list, input_tensor_list) @@ -4639,11 +4584,8 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False if async_op: return work - elif ( - work is not None - ): # Backward compatible with backends that don't sync at CPP level + else: work.wait() - # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -4674,7 +4616,6 @@ def barrier( opts = BarrierOptions() opts.device = torch.device(_get_object_coll_device(group)) - opts.asyncOp = async_op if device_ids is not None: if isinstance(device_ids, list): opts.device_ids = device_ids @@ -4688,11 +4629,8 @@ def barrier( if async_op: return work - elif ( - work is not None - ): # Backward compatible with backends that don't sync at CPP level + else: work.wait() - # Otherwise, the backend has sync'ed at CPP level def monitored_barrier( diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index db9f9e70dee1..3f4a24a1ffb1 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -96,7 +96,7 @@ try: import torchvision HAS_TORCHVISION = True -except Exception: # Covering both ImportError and RuntimeError +except ImportError: HAS_TORCHVISION = False if sys.platform == "win32": @@ -8310,14 +8310,50 @@ class DistributedTest: def test_compute_bucket_assignment_by_size_sparse_error_with_logger(self): self._test_compute_bucket_assignment_by_size(use_logger=True) + def _determine_expected_error_verify_model_across_rank( + self, group_to_use, diff_num_params=False + ): + # When running with NCCL backend, we don't expect an error on rank 0, + # rather, it will be taken down by TORCH_NCCL_ASYNC_ERROR_HANDLING. When + # running with Gloo or with debug mode wrapper, we expect the error + # to be caught inline. + # All ranks report same error when there is a # of parameter + # mismatch since we use allgather in the impl. + if diff_num_params: + expected_err = "DDP expects same model across all ranks" + ctx = self.assertRaisesRegex(RuntimeError, expected_err) + return ctx, expected_err + + is_detail_dbg_mode = dist.get_debug_level() == dist.DebugLevel.DETAIL + if self.rank == 0: + if ( + dist.get_backend(group_to_use) == dist.Backend.NCCL + and not is_detail_dbg_mode + ): + expected_err = "caught collective operation timeout" + ctx = self.assertRaisesRegex(RuntimeError, expected_err) + else: + expected_err = None + ctx = self.assertRaises(RuntimeError) + else: + expected_err = "appears not to match" + ctx = self.assertRaisesRegex(RuntimeError, expected_err) + return ctx, expected_err + def _test_verify_model_across_rank(self, use_logger): group_gloo = dist.new_group( timeout=timedelta(seconds=60), backend=dist.Backend.GLOO ) + # Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test + # determinism. + os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" group_to_use = dist.new_group( backend=dist.get_backend(), timeout=timedelta(seconds=5) ) torch.cuda.set_device(self.rank) + ctx, expected_err = self._determine_expected_error_verify_model_across_rank( + group_to_use + ) # Create a valid model. The constructor initializes the logger that we use later. net = EmbeddingNetDifferentParams(0) @@ -8335,8 +8371,7 @@ class DistributedTest: net.module.lin = nn.Linear(100 if self.rank == 0 else 10, 1) # if we pass a logger we can verify that it was logged - caught = 0 - try: + with ctx: if use_logger: _verify_param_shape_across_processes( net.process_group, list(net.parameters()), net.logger @@ -8345,13 +8380,18 @@ class DistributedTest: _verify_param_shape_across_processes( net.process_group, list(net.parameters()) ) - except Exception: - caught = 1 + # Should only be run by rank 0, and blocking_wait catches and + # reports exception. + dist.barrier(group_to_use) - # As long as there is one rank catching the exception - t = torch.Tensor([caught]) - dist.all_reduce(t, group=group_gloo) - self.assertGreater(t, 0) + # We don't check when self.rank != 0 because the logger doesn't log + # the error "Caught collective operation" as that is not thrown in the reducer. + if use_logger and self.rank != 0: + verify_ddp_error_logged(net, expected_err) + + # Perform gloo-based barrier to ensure one rank doesn't exit test + # early which causes failure with Barrier.sync. + dist.barrier(group_gloo) @require_backend_is_available(DistTestCases.backend_feature["gpu"]) @skip_but_pass_in_sandcastle_if( @@ -8369,19 +8409,20 @@ class DistributedTest: def test_verify_model_across_rank_without_logger(self): self._test_verify_model_across_rank(use_logger=False) - def _run_test_ddp_model_with_diff_params(self, net, ddp_group, group_gloo): - caught = 0 - try: + def _run_test_ddp_model_with_diff_params(self, ctx, net, ddp_group, group_gloo): + with ctx: net = torch.nn.parallel.DistributedDataParallel( net.to(self.rank), device_ids=[self.rank], process_group=ddp_group ) - except Exception: - caught = 1 + # Should only be run by rank 0, and blocking_wait catches and + # reports exception. + dist.barrier(ddp_group) - # As long as there is one rank catching the exception - t = torch.Tensor([caught]) - dist.all_reduce(t, group=group_gloo) - self.assertGreater(t, 0) + # can't use verify_ddp_error_logged here because net was never properly constructed + + # Perform gloo-based barrier to ensure one rank doesn't exit test + # early which causes failure with Barrier.sync. + dist.barrier(group_gloo) @require_backend_is_available(DistTestCases.backend_feature["gpu"]) @skip_but_pass_in_sandcastle_if( @@ -8392,15 +8433,21 @@ class DistributedTest: group_gloo = dist.new_group( timeout=timedelta(seconds=60), backend=dist.Backend.GLOO ) + # Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test + # determinism. + os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" group_to_use = dist.new_group( backend=dist.get_backend(), timeout=timedelta(seconds=10) ) torch.cuda.set_device(self.rank) + ctx, _expected_err = self._determine_expected_error_verify_model_across_rank( + group_to_use + ) # Creates network with different sized embedding table on different # ranks. This should throw an error during DDP init. net = EmbeddingNetDifferentParams(self.rank) self._run_test_ddp_model_with_diff_params( - net, group_to_use, group_gloo + ctx, net, group_to_use, group_gloo ) @require_backend_is_available(DistTestCases.backend_feature["gpu"]) @@ -8412,10 +8459,16 @@ class DistributedTest: group_gloo = dist.new_group( timeout=timedelta(seconds=60), backend=dist.Backend.GLOO ) + # Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test + # determinism. + os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" group_to_use = dist.new_group( backend=dist.get_backend(), timeout=timedelta(seconds=10) ) torch.cuda.set_device(self.rank) + ctx, _expected_err = self._determine_expected_error_verify_model_across_rank( + group_to_use, diff_num_params=True + ) # Creates network with diff # of param across ranks, reducer should # recognize this and throw appropriate error. @@ -8424,6 +8477,7 @@ class DistributedTest: ) self._run_test_ddp_model_with_diff_params( + ctx, net, group_to_use, group_gloo,