#include #ifdef USE_C10D_MPI #include #include #include #include #if defined(OPEN_MPI) && OPEN_MPI #include // Needed for CUDA-aware check #endif namespace c10d { #define MPI_CHECK(cmd) \ do { \ int mpiStatus = cmd; \ if (mpiStatus != MPI_SUCCESS) { \ std::string err = "MPI error in: " + std::string(__FILE__) + ":" + \ std::to_string(__LINE__) + \ ", with error code: " + std::to_string(mpiStatus); \ TORCH_CHECK(false, err); \ } \ } while (0) namespace { // Op mapping std::map mpiOp = { {ReduceOp::MIN, MPI_MIN}, {ReduceOp::MAX, MPI_MAX}, {ReduceOp::SUM, MPI_SUM}, {ReduceOp::PRODUCT, MPI_PROD}, }; // Type mapping std::map mpiDatatype = { {at::kByte, MPI_UNSIGNED_CHAR}, {at::kChar, MPI_CHAR}, {at::kDouble, MPI_DOUBLE}, {at::kFloat, MPI_FLOAT}, {at::kInt, MPI_INT}, {at::kLong, MPI_LONG}, {at::kShort, MPI_SHORT}, }; // Checking CUDA-aware MPI support, currently we only support CUDA aware // MPI ops through Open MPI bool cudaAwareMpiCheck() { // Run time check #if defined(MPIX_CUDA_AWARE_SUPPORT) if (MPIX_Query_cuda_support() == 1) { return true; } else { return false; } #else // !defined(MPIX_CUDA_AWARE_SUPPORT) return false; #endif // MPIX_CUDA_AWARE_SUPPORT } // Checking the input tensor's validity void checkSingleTensorHelper(const at::Tensor& tensor) { if (!tensor.is_contiguous()) { TORCH_CHECK(false, "input tensor has to be contiguous"); } if (tensor.is_sparse()) { TORCH_CHECK(false, "input tensor has to be dense"); } if (tensor.is_cuda() && !cudaAwareMpiCheck()) { TORCH_CHECK( false, "CUDA tensor detected and the MPI used doesn't " "have CUDA-aware MPI support"); } } void checkSingleTensor(const std::vector& tensors) { if (tensors.size() != 1) { TORCH_CHECK( false, "MPI process group does not support multi-GPU collectives"); } checkSingleTensorHelper(tensors[0]); } void checkSameSizeAndType( const at::Tensor& t_in, const std::vector& tensors) { for (const auto& tensor : tensors) { if ((tensor.numel() != t_in.numel()) || (tensor.scalar_type() != t_in.scalar_type())) { TORCH_CHECK(false, "Tensors are not equal in size or data type"); } checkSingleTensorHelper(tensor); } } } // namespace std::vector ProcessGroupMPI::WorkMPI::result() { return outputTensors_; } c10::intrusive_ptr ProcessGroupMPI::WorkMPI::getFuture() { return future_; } void ProcessGroupMPI::WorkMPI::finishWorkMPIError( const std::exception_ptr& eptr) { future_->setError(eptr); finish(eptr); } void ProcessGroupMPI::WorkMPI::finishWorkMPI() { future_->markCompleted(at::IValue(outputTensors_)); finish(); } ProcessGroupMPI::AsyncWork::AsyncWork( MPI_Request request, std::vector outputTensors, const char* profilingTitle, const std::optional>& inputTensors) : Work(-1, OpType::UNKNOWN, profilingTitle, inputTensors), outputTensors_(std::move(outputTensors)), request_(request) { memset(&status_, 0, sizeof(status_)); } ProcessGroupMPI::AsyncWork::~AsyncWork() { if (request_ != MPI_REQUEST_NULL) { std::cerr << "Attempted destruction of AsyncWork before work has completed, " << "terminating the program." << '\n'; std::terminate(); } } bool ProcessGroupMPI::AsyncWork::isCompleted() { if (request_ == MPI_REQUEST_NULL) { return true; } std::unique_lock globalLock(pgGlobalMutex_); int flag = 0; MPI_CHECK(MPI_Test(&request_, &flag, &status_)); if (request_ != MPI_REQUEST_NULL) { return false; } // request_ == MPI_REQUEST_NULL; the work has completed // Populate exception if request was not successful if (status_.MPI_ERROR != MPI_SUCCESS) { populateException(); } return true; } bool ProcessGroupMPI::AsyncWork::isSuccess() const { if (request_ != MPI_REQUEST_NULL) { TORCH_CHECK( false, "Invalid call to AsyncWork::isSuccess before work has completed"); } return status_.MPI_ERROR == MPI_SUCCESS; } int ProcessGroupMPI::AsyncWork::sourceRank() const { return status_.MPI_SOURCE; } bool ProcessGroupMPI::AsyncWork::wait(std::chrono::milliseconds /* unused */) { if (request_ == MPI_REQUEST_NULL) { // AsyncWork needs to manually call profiling end callbacks if they are set, // since it does not call ProcessGroup::finish(). if (Work::recordFunctionEndCallback_) { Work::recordFunctionEndCallback_(); Work::recordFunctionEndCallback_ = nullptr; } return true; } std::unique_lock globalLock(pgGlobalMutex_); MPI_CHECK(MPI_Wait(&request_, &status_)); auto ok = (status_.MPI_ERROR == MPI_SUCCESS); // AsyncWork needs to manually call profiling end callbacks if they are set, // since it does not call ProcessGroup::finish(). if (Work::recordFunctionEndCallback_) { Work::recordFunctionEndCallback_(); Work::recordFunctionEndCallback_ = nullptr; } if (!ok) { populateException(); std::rethrow_exception(exception_); } // Always return true, because abort API is not implemented. return true; } void ProcessGroupMPI::AsyncWork::abort(){ TORCH_CHECK(false, "ProcessGroupMPI::AsyncWork::abort not implemented.")} std::vector ProcessGroupMPI::AsyncWork::result() { return outputTensors_; } void ProcessGroupMPI::AsyncWork::populateException() { std::array buf{}; int len = buf.size(); MPI_CHECK(MPI_Error_string(status_.MPI_ERROR, buf.data(), &len)); exception_ = std::make_exception_ptr(std::runtime_error(std::string(buf.data(), len))); } // Static global states int ProcessGroupMPI::mpiThreadSupport_ = 0; std::mutex ProcessGroupMPI::pgGlobalMutex_; // We only want to initialize once c10::once_flag ProcessGroupMPI::onceFlagInitMPI; void ProcessGroupMPI::mpiExit() { std::unique_lock globalLock(pgGlobalMutex_); MPI_CHECK(MPI_Finalize()); } void ProcessGroupMPI::initMPIOnce() { // Initialize MPI environment c10::call_once(onceFlagInitMPI, []() { int mpi_was_initialized = 0; MPI_CHECK(MPI_Initialized(&mpi_was_initialized)); if (mpi_was_initialized == 0) { MPI_CHECK(MPI_Init_thread( nullptr, nullptr, MPI_THREAD_SERIALIZED, &mpiThreadSupport_)); if (mpiThreadSupport_ < MPI_THREAD_SERIALIZED) { TORCH_CHECK( false, "Used MPI implementation doesn't have the " "minimum level of threading support: " "MPI_THREAD_SERIALIZED. This is required by " "c10d package"); } if (std::atexit(ProcessGroupMPI::mpiExit)) { TORCH_CHECK(false, "Fail to register the MPI exit handler"); } } else { TORCH_WARN_ONCE("MPI was previously initialized."); } }); } c10::intrusive_ptr ProcessGroupMPI::createProcessGroupMPI( std::vector ranks) { // Once initialization initMPIOnce(); MPI_Comm groupComm = MPI_COMM_WORLD; int rank = -1; int size = -1; { std::lock_guard globalLock(pgGlobalMutex_); // If no ranks are specified, assume we're creating the root group if (!ranks.empty()) { MPI_Group worldGroup{}; MPI_Group ranksGroup{}; MPI_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup)); MPI_CHECK( MPI_Group_incl(worldGroup, ranks.size(), ranks.data(), &ranksGroup)); // `MPI_Comm_create` can be flaky in certain cases. // See: https://github.com/pytorch/pytorch/issues/53899 constexpr int kMaxNumRetries = 3; bool groupComm_updated = false; MPI_Barrier(MPI_COMM_WORLD); for (const auto i : c10::irange(kMaxNumRetries)) { (void)i; if (MPI_Comm_create(MPI_COMM_WORLD, ranksGroup, &groupComm)) { groupComm_updated = true; break; } } MPI_CHECK(groupComm_updated); MPI_CHECK(MPI_Group_free(&worldGroup)); MPI_CHECK(MPI_Group_free(&ranksGroup)); } // Fetch rank and world size for this group (MPI_COMM_WORLD or new) if (groupComm != MPI_COMM_NULL) { MPI_CHECK(MPI_Comm_rank(groupComm, &rank)); MPI_CHECK(MPI_Comm_size(groupComm, &size)); if (rank < 0 || size < 0) { TORCH_CHECK(false, "Failed to get the world_size / rank"); } } } // If this process is not part of the group, we don't construct a // process group instance. This is in line with the semantics of the // other process group types. if (groupComm == MPI_COMM_NULL) { return c10::intrusive_ptr(); } return c10::make_intrusive(rank, size, groupComm); } ProcessGroupMPI::ProcessGroupMPI(int rank, int size, MPI_Comm pgComm) : Backend(rank, size), stop_(false), pgComm_(pgComm) { if (pgComm_ == MPI_COMM_NULL) { TORCH_CHECK(false, "pgComm_ must not be MPI_COMM_NULL"); } // Start the worker thread accepting MPI calls workerThread_ = std::thread(&ProcessGroupMPI::runLoop, this); init(); } ProcessGroupMPI::~ProcessGroupMPI() { destroy(); } void ProcessGroupMPI::destroy() { std::unique_lock lock(pgMutex_); queueConsumeCV_.wait(lock, [&] { return queue_.empty(); }); // Queue is empty, signal stop stop_ = true; // Release lock to allow threads to terminate lock.unlock(); queueProduceCV_.notify_all(); // Join the single worker thread workerThread_.join(); } void ProcessGroupMPI::abort() { destroy(); MPI_Abort(pgComm_, EXIT_FAILURE); } void ProcessGroupMPI::runLoop() { std::unique_lock lock(pgMutex_); while (!stop_) { if (queue_.empty()) { queueProduceCV_.wait(lock); continue; } auto workTuple = std::move(queue_.front()); queue_.pop_front(); auto& workEntry = std::get<0>(workTuple); auto& work = std::get<1>(workTuple); lock.unlock(); queueConsumeCV_.notify_one(); try { workEntry->run(workEntry); work->finishWorkMPI(); } catch (...) { work->finishWorkMPIError(std::current_exception()); } lock.lock(); } } c10::intrusive_ptr ProcessGroupMPI::enqueue( std::unique_ptr entry, const char* profilingTitle, const std::optional>& inputTensors) { auto work = c10::make_intrusive(entry->dst, profilingTitle, inputTensors); std::unique_lock lock(pgMutex_); queue_.emplace_back(std::move(entry), work); lock.unlock(); queueProduceCV_.notify_one(); return work; } c10::intrusive_ptr ProcessGroupMPI::broadcast( std::vector& tensors, const BroadcastOptions& opts) { checkSingleTensor(tensors); std::function&)> runFunc = [opts, this](std::unique_ptr& entry) { auto data = (entry->src)[0]; c10::DeviceGuard guard(data.device()); std::unique_lock globalLock(pgGlobalMutex_); MPI_CHECK(MPI_Bcast( data.data_ptr(), data.numel(), mpiDatatype.at(data.scalar_type()), opts.rootRank, pgComm_)); }; auto entry = std::make_unique(&tensors, &tensors, std::move(runFunc)); return enqueue( std::move(entry), "mpi:broadcast", std::optional>(tensors)); } c10::intrusive_ptr ProcessGroupMPI::allreduce( std::vector& tensors, const AllreduceOptions& opts) { checkSingleTensor(tensors); std::function&)> runFunc = [opts, this](std::unique_ptr& entry) { auto data = (entry->src)[0]; c10::DeviceGuard guard(data.device()); std::unique_lock globalLock(pgGlobalMutex_); MPI_CHECK(MPI_Allreduce( MPI_IN_PLACE, data.data_ptr(), data.numel(), mpiDatatype.at(data.scalar_type()), mpiOp.at(opts.reduceOp), pgComm_)); }; auto entry = std::make_unique(&tensors, &tensors, std::move(runFunc)); return enqueue( std::move(entry), "mpi:all_reduce", std::optional>(tensors)); } c10::intrusive_ptr ProcessGroupMPI::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { TORCH_CHECK(false, "allreduce_coalesced is currently not supported with MPI"); } c10::intrusive_ptr ProcessGroupMPI::reduce( std::vector& tensors, const ReduceOptions& opts) { checkSingleTensor(tensors); std::function&)> runFunc = [opts, this](std::unique_ptr& entry) { auto data = (entry->src)[0]; auto dataPtr = (entry->src)[0].data_ptr(); void* sendbuf = (rank_ == opts.rootRank) ? MPI_IN_PLACE : dataPtr; void* recvbuf = (rank_ == opts.rootRank) ? dataPtr : nullptr; c10::DeviceGuard guard(data.device()); std::unique_lock globalLock(pgGlobalMutex_); MPI_CHECK(MPI_Reduce( sendbuf, recvbuf, data.numel(), mpiDatatype.at(data.scalar_type()), mpiOp.at(opts.reduceOp), opts.rootRank, pgComm_)); }; auto entry = std::make_unique(&tensors, &tensors, std::move(runFunc)); return enqueue( std::move(entry), "mpi:reduce", std::optional>(tensors)); } c10::intrusive_ptr ProcessGroupMPI::allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts) { checkSingleTensor(inputTensors); if (outputTensors.size() != 1) { TORCH_CHECK( false, "MPI process group only supports a single " "tensor op"); } if (static_cast(size_) != outputTensors[0].size()) { TORCH_CHECK( false, "All gather: number of output tensors should equal " "to the world size"); } checkSameSizeAndType(inputTensors[0], outputTensors[0]); std::function&)> runFunc = [this](std::unique_ptr& entry) { auto data = (entry->src)[0]; std::vector outputDataVec = entry->dst; auto flatOutputTensor = newLikeFlat(outputDataVec); c10::DeviceGuard guard(data.device()); std::unique_lock globalLock(pgGlobalMutex_); MPI_CHECK(MPI_Allgather( data.data_ptr(), data.numel(), mpiDatatype.at(data.scalar_type()), flatOutputTensor.data_ptr(), data.numel(), mpiDatatype.at(data.scalar_type()), pgComm_)); for (const auto i : c10::irange(outputDataVec.size())) { outputDataVec[i].copy_(flatOutputTensor[static_cast(i)]); } }; auto entry = std::make_unique( &inputTensors, &outputTensors[0], std::move(runFunc)); return enqueue( std::move(entry), "mpi:all_gather", std::optional>(inputTensors)); } c10::intrusive_ptr ProcessGroupMPI::allgather_coalesced( std::vector>& /* unused */, std::vector& /* unused */, const AllgatherOptions& /* unused */) { TORCH_CHECK(false, "ProcessGroupMPI does not support allgather_coalesced"); } c10::intrusive_ptr ProcessGroupMPI::gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts) { checkSingleTensor(inputTensors); if (rank_ != opts.rootRank) { if (!outputTensors.empty()) { TORCH_CHECK( false, "Gather: number of output tensors should be 0 " "for non-root"); } } else { if (outputTensors.size() != 1) { TORCH_CHECK(false, "Gather: multi-GPU collective is not supported"); } if (static_cast(size_) != outputTensors[0].size()) { TORCH_CHECK( false, "Gather: number of output tensors should equal " "to the world size"); } checkSameSizeAndType(inputTensors[0], outputTensors[0]); } std::function&)> runFunc = [opts, this](std::unique_ptr& entry) { auto data = (entry->src)[0]; void* recvbuf = nullptr; at::Tensor flatOutputTensor; std::vector dstdata = entry->dst; if (rank_ == opts.rootRank) { flatOutputTensor = newLikeFlat(dstdata); recvbuf = flatOutputTensor.data_ptr(); } c10::DeviceGuard guard(data.device()); std::unique_lock globalLock(pgGlobalMutex_); MPI_CHECK(MPI_Gather( data.data_ptr(), data.numel(), mpiDatatype.at(data.scalar_type()), recvbuf, data.numel(), mpiDatatype.at(data.scalar_type()), opts.rootRank, pgComm_)); if (rank_ == opts.rootRank) { const std::vector& outputDataVec = entry->dst; // copy the flattened output tensors to the outputs for (const auto i : c10::irange(outputDataVec.size())) { outputDataVec.at(i).copy_( flatOutputTensor[static_cast(i)]); } } }; if (rank_ == opts.rootRank) { auto entry = std::make_unique( &inputTensors, &outputTensors[0], std::move(runFunc)); return enqueue( std::move(entry), "mpi:gather", std::optional>(inputTensors)); } else { auto entry = std::make_unique(&inputTensors, nullptr, std::move(runFunc)); return enqueue( std::move(entry), "mpi:gather", std::optional>(inputTensors)); } } c10::intrusive_ptr ProcessGroupMPI::scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts) { checkSingleTensor(outputTensors); if (rank_ != opts.rootRank) { if (!inputTensors.empty()) { TORCH_CHECK( false, "Scatter: number of input tensors should be 0 " "for non-root"); } } else { if (inputTensors.size() != 1) { TORCH_CHECK(false, "Scatter: multi-GPU collective is not supported"); } if (static_cast(size_) != inputTensors[0].size()) { TORCH_CHECK( false, "Scatter: number of input tensors should equal " "to the world size"); } checkSameSizeAndType(outputTensors[0], inputTensors[0]); } std::function&)> runFunc = [opts, this](std::unique_ptr& entry) { auto data = (entry->dst)[0]; void* sendbuf = nullptr; at::Tensor flatInputTensor; if (rank_ == opts.rootRank) { std::vector& inputDataVec = entry->src; flatInputTensor = newLikeFlat(inputDataVec); sendbuf = flatInputTensor.data_ptr(); // copy the input tensors to the flatten large send buffer for (const auto i : c10::irange(inputDataVec.size())) { flatInputTensor[static_cast(i)].copy_(inputDataVec.at(i)); } } c10::DeviceGuard guard(data.device()); std::unique_lock globalLock(pgGlobalMutex_); MPI_CHECK(MPI_Scatter( sendbuf, data.numel(), mpiDatatype.at(data.scalar_type()), data.data_ptr(), data.numel(), mpiDatatype.at(data.scalar_type()), opts.rootRank, pgComm_)); }; if (rank_ == opts.rootRank) { auto entry = std::make_unique( &inputTensors[0], &outputTensors, std::move(runFunc)); return enqueue( std::move(entry), "mpi:scatter", !inputTensors.empty() ? std::optional>(inputTensors[0]) : std::nullopt); } else { auto entry = std::make_unique( nullptr, &outputTensors, std::move(runFunc)); return enqueue( std::move(entry), "mpi:scatter", !inputTensors.empty() ? std::optional>(inputTensors[0]) : std::nullopt); } } c10::intrusive_ptr ProcessGroupMPI::reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts) { TORCH_CHECK(false, "ProcessGroupMPI does not support reduce_scatter"); } c10::intrusive_ptr ProcessGroupMPI::alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, std::vector& inputSplitSizes, const AllToAllOptions& opts) { checkSingleTensorHelper(inputTensor); checkSingleTensorHelper(outputTensor); if (outputSplitSizes.empty() && inputSplitSizes.empty()) { // We can use alltoall TORCH_CHECK( outputTensor.numel() == inputTensor.numel() && outputTensor.type() == inputTensor.type(), "Tensors are not equal in size or data type"); TORCH_CHECK( outputTensor.size(0) % size_ == 0, "Tensor's dim 0 does not divide equally across group size"); std::function&)> runFunc = [this](std::unique_ptr& entry) { auto srcdata = (entry->src)[0]; auto dstdata = (entry->dst)[0]; c10::DeviceGuard guard(srcdata.device()); std::unique_lock globalLock(pgGlobalMutex_); MPI_CHECK(MPI_Alltoall( srcdata.data_ptr(), srcdata.numel() / size_, mpiDatatype.at(srcdata.scalar_type()), dstdata.data_ptr(), dstdata.numel() / size_, mpiDatatype.at(dstdata.scalar_type()), pgComm_)); }; std::vector inputTensors = {inputTensor}; std::vector outputTensors = {outputTensor}; auto entry = std::make_unique( &inputTensors, &outputTensors, std::move(runFunc)); return enqueue( std::move(entry), "mpi:all_to_all", std::optional>(inputTensors)); } else { // Need alltoallv c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); std::function&)> runFunc = [this, inputSplitSizes, outputSplitSizes]( std::unique_ptr& entry) { auto srcdata = (entry->src)[0]; auto dstdata = (entry->dst)[0]; std::vector send_lengths(size_); std::vector recv_lengths(size_); std::vector send_offsets(size_); std::vector recv_offsets(size_); c10d::computeLengthsAndOffsets( inputSplitSizes, srcdata, &send_lengths, &send_offsets); c10d::computeLengthsAndOffsets( outputSplitSizes, dstdata, &recv_lengths, &recv_offsets); c10::DeviceGuard guard(srcdata.device()); std::unique_lock globalLock(pgGlobalMutex_); MPI_CHECK(MPI_Alltoallv( srcdata.data_ptr(), send_lengths.data(), send_offsets.data(), mpiDatatype.at(srcdata.scalar_type()), dstdata.data_ptr(), recv_lengths.data(), recv_offsets.data(), mpiDatatype.at(dstdata.scalar_type()), pgComm_)); }; std::vector inputTensors = {inputTensor}; std::vector outputTensors = {outputTensor}; auto entry = std::make_unique( &inputTensors, &outputTensors, std::move(runFunc)); return enqueue( std::move(entry), "mpi:all_to_all", std::optional>(inputTensors)); } } c10::intrusive_ptr ProcessGroupMPI::alltoall( std::vector& outputTensors, std::vector& inputTensors, const AllToAllOptions& opts) { TORCH_CHECK( inputTensors.size() == static_cast(size_), "Number of input tensors are not equal to group size"); TORCH_CHECK( outputTensors.size() == static_cast(size_), "Number of output tensors are not equal to group size"); std::function&)> runFunc = [this](std::unique_ptr& entry) { std::vector send_lengths(size_); std::vector recv_lengths(size_); std::vector send_offsets(size_); std::vector recv_offsets(size_); auto srcdata = entry->src; auto dstdata = entry->dst; auto src_len = c10d::computeLengthsAndOffsets( srcdata, &send_lengths, &send_offsets); auto dst_len = c10d::computeLengthsAndOffsets( dstdata, &recv_lengths, &recv_offsets); std::vector send_lengthsL( send_lengths.begin(), send_lengths.end()); std::vector recv_lengthsL( recv_lengths.begin(), recv_lengths.end()); at::Tensor srcFlatData = at::empty({static_cast(src_len)}, srcdata[0].options()); at::Tensor dstFlatData = at::empty({static_cast(dst_len)}, dstdata[0].options()); auto srcFlatDataSplits = srcFlatData.split_with_sizes(c10::IntArrayRef(send_lengthsL), 0); for (const auto i : c10::irange(size_)) { srcFlatDataSplits[i].copy_(srcdata[i].view({-1})); } c10::DeviceGuard guard1(srcdata[0].device()); std::unique_lock globalLock(pgGlobalMutex_); MPI_CHECK(MPI_Alltoallv( srcFlatData.data_ptr(), send_lengths.data(), send_offsets.data(), mpiDatatype.at(srcdata[0].scalar_type()), dstFlatData.data_ptr(), recv_lengths.data(), recv_offsets.data(), mpiDatatype.at(dstdata[0].scalar_type()), pgComm_)); auto dstFlatDataSplits = dstFlatData.split_with_sizes(c10::IntArrayRef(recv_lengthsL), 0); for (const auto i : c10::irange(size_)) { dstdata[i].view({-1}).copy_(dstFlatDataSplits[i]); } }; auto entry = std::make_unique( &inputTensors, &outputTensors, std::move(runFunc)); return enqueue( std::move(entry), "mpi:all_to_all", std::optional>(inputTensors)); } c10::intrusive_ptr ProcessGroupMPI::send( std::vector& tensors, int dstRank, int tag) { checkSingleTensor(tensors); auto& tensor = tensors[0]; MPI_Request request = MPI_REQUEST_NULL; { c10::DeviceGuard guard(tensor.device()); std::unique_lock globalLock(pgGlobalMutex_); MPI_CHECK(MPI_Isend( tensor.data_ptr(), tensor.numel(), mpiDatatype.at(tensor.scalar_type()), dstRank, tag, pgComm_, &request)); } return c10::make_intrusive( request, std::vector(), "mpi:send", std::optional>(tensors)); } c10::intrusive_ptr ProcessGroupMPI::recv( std::vector& tensors, int srcRank, int tag) { checkSingleTensor(tensors); auto& tensor = tensors[0]; MPI_Request request = MPI_REQUEST_NULL; { c10::DeviceGuard guard(tensor.device()); std::unique_lock globalLock(pgGlobalMutex_); MPI_CHECK(MPI_Irecv( tensor.data_ptr(), tensor.numel(), mpiDatatype.at(tensor.scalar_type()), srcRank, tag, pgComm_, &request)); } return c10::make_intrusive( request, tensors, "mpi:recv", std::optional>(tensors)); } c10::intrusive_ptr ProcessGroupMPI::recvAnysource( std::vector& tensors, int tag) { checkSingleTensor(tensors); auto& tensor = tensors[0]; MPI_Request request = MPI_REQUEST_NULL; { c10::DeviceGuard guard(tensor.device()); std::unique_lock globalLock(pgGlobalMutex_); MPI_CHECK(MPI_Irecv( tensor.data_ptr(), tensor.numel(), mpiDatatype.at(tensor.scalar_type()), MPI_ANY_SOURCE, tag, pgComm_, &request)); } return c10::make_intrusive( request, tensors, "mpi:recvAnySource", std::optional>(tensors)); } c10::intrusive_ptr ProcessGroupMPI::barrier(const BarrierOptions& opts) { std::function&)> runFunc = [this](std::unique_ptr& entry) { std::unique_lock globalLock(pgGlobalMutex_); MPI_CHECK(MPI_Barrier(pgComm_)); }; auto entry = std::make_unique(nullptr, nullptr, std::move(runFunc)); return enqueue(std::move(entry), "mpi:barrier", std::nullopt); } c10::intrusive_ptr ProcessGroupMPI::_allgather_base( at::Tensor& /*unused */, at::Tensor& /*unused */, const AllgatherOptions& /*unused */) { TORCH_CHECK(false, "no support for _allgather_base in MPI process group"); } } // namespace c10d #endif // USE_C10D_MPI