Files
pytorch/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp
Will Feng 4ee514144b [c10d][Partial-Graph Overlap] Support calling .wait_tensor() on output tensor of eager async_op=True collective if under allow_inflight_collective_as_graph_input_ctx() context manager (#137763)
This PR aims to support the following use case:
```python
def all_reduce_eager(x):
    y = x * x
    req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
    assert isinstance(req, torch.distributed.Work)
    return y

@torch.compile(fullgraph=True)
def all_reduce_wait_compiled(y):
    torch.ops.c10d_functional.wait_tensor(y)
    return y * y

x = torch.ones(1280, 1280, device="cuda") + self.rank
with allow_inflight_collective_as_graph_input_ctx():
    y = all_reduce_eager(x)
    z = all_reduce_wait_compiled(y)
```
where the collective is issued in eager (with `async_op=True`) but waited in compiled region.

This is important for internal use cases such as TorchRec, where we issue collectives in eager for SparseArch all_to_all but want to wait for them in compiled region at beginning of OverArch, so that the all_to_all can be overlapped with the DenseArch compute that runs in parallel.

----

**Update**: Did two items to prevent regression to existing use cases:

1. Added memory-stressed test case to test_c10d_nccl.py `test_unwaited` to cover existing user's "not calling work.wait() for non-functional collective" use case
2. Gated all new `register_work()` / `unregister_work()` calls with `c10d::allow_inflight_collective_as_graph_input()` check, which is a new context manager that requires explicit user enablement (i.e. not on by default, so should not affect existing users).

The risk of this new version of PR causing regression should be very low.

------

Test commands:
- `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_eager_async_allreduce_inductor_wait`
- `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives`
- `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives_no_overload`
- `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_wait_tensor`
- `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_unwaited`
- `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_wait_tensor`
- `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_unwaited`
- `pytest -rA test/distributed/_tensor/test_tensor_ops.py::DistTensorOpsTest::test_equal`
- `pytest -rA test/distributed/_tensor/test_random_ops.py::DistTensorRandomOpTest::test_manual_seed`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_asymmetric_compilation`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_scalar`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_speculation_divergence`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_tensor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_dim_mismatch`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_graph_break_empty_graph_still_collective`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_missing_source`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_scalar_missing_source`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_type_mismatch`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_activation_checkpointing`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_activation_checkpointing`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_inductor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager_static_graph`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor_static_graph`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_fsdp_activation_checkpointing`
- `pytest -rA test/distributed/_tensor/test_experimental_ops.py::DistOtherOpsTest::test_bernoulli`
- `pytest -rA test/distributed/_tensor/test_dtensor_compile.py::TestDTensorCompileE2E::test_tp_compile_fullgraph_is_seq_parallel_True`
- `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_allreduce_inductor_cudagraph_trees`
- `python benchmarks/dynamo/torchbench.py --ci --accuracy --timing --explain --inductor --device cuda --inference --bfloat16 --total-partitions 2 --partition-id 1 --output inference_torchbench.csv --only moco`

------

Differential Revision: [D65023311](https://our.internmc.facebook.com/intern/diff/D65023311)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137763
Approved by: https://github.com/yifuwang
2024-10-29 03:31:19 +00:00

954 lines
30 KiB
C++

#include <torch/csrc/distributed/c10d/ProcessGroupMPI.hpp>
#ifdef USE_C10D_MPI
#include <iostream>
#include <map>
#include <c10/core/DeviceGuard.h>
#include <c10/util/irange.h>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#if defined(OPEN_MPI) && OPEN_MPI
#include <mpi-ext.h> // Needed for CUDA-aware check
#endif
namespace c10d {
#define MPI_CHECK(cmd) \
do { \
int mpiStatus = cmd; \
if (mpiStatus != MPI_SUCCESS) { \
std::string err = "MPI error in: " + std::string(__FILE__) + ":" + \
std::to_string(__LINE__) + \
", with error code: " + std::to_string(mpiStatus); \
TORCH_CHECK(false, err); \
} \
} while (0)
namespace {
// Op mapping
std::map<ReduceOp::RedOpType, MPI_Op> mpiOp = {
{ReduceOp::MIN, MPI_MIN},
{ReduceOp::MAX, MPI_MAX},
{ReduceOp::SUM, MPI_SUM},
{ReduceOp::PRODUCT, MPI_PROD},
};
// Type mapping
std::map<at::ScalarType, MPI_Datatype> mpiDatatype = {
{at::kByte, MPI_UNSIGNED_CHAR},
{at::kChar, MPI_CHAR},
{at::kDouble, MPI_DOUBLE},
{at::kFloat, MPI_FLOAT},
{at::kInt, MPI_INT},
{at::kLong, MPI_LONG},
{at::kShort, MPI_SHORT},
};
// Checking CUDA-aware MPI support, currently we only support CUDA aware
// MPI ops through Open MPI
bool cudaAwareMpiCheck() {
// Run time check
#if defined(MPIX_CUDA_AWARE_SUPPORT)
if (MPIX_Query_cuda_support() == 1) {
return true;
} else {
return false;
}
#else // !defined(MPIX_CUDA_AWARE_SUPPORT)
return false;
#endif // MPIX_CUDA_AWARE_SUPPORT
}
// Checking the input tensor's validity
void checkSingleTensorHelper(const at::Tensor& tensor) {
if (!tensor.is_contiguous()) {
TORCH_CHECK(false, "input tensor has to be contiguous");
}
if (tensor.is_sparse()) {
TORCH_CHECK(false, "input tensor has to be dense");
}
if (tensor.is_cuda() && !cudaAwareMpiCheck()) {
TORCH_CHECK(
false,
"CUDA tensor detected and the MPI used doesn't "
"have CUDA-aware MPI support");
}
}
void checkSingleTensor(const std::vector<at::Tensor>& tensors) {
if (tensors.size() != 1) {
TORCH_CHECK(
false, "MPI process group does not support multi-GPU collectives");
}
checkSingleTensorHelper(tensors[0]);
}
void checkSameSizeAndType(
const at::Tensor& t_in,
const std::vector<at::Tensor>& tensors) {
for (const auto& tensor : tensors) {
if ((tensor.numel() != t_in.numel()) ||
(tensor.scalar_type() != t_in.scalar_type())) {
TORCH_CHECK(false, "Tensors are not equal in size or data type");
}
checkSingleTensorHelper(tensor);
}
}
} // namespace
std::vector<at::Tensor> ProcessGroupMPI::WorkMPI::result() {
return outputTensors_;
}
c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupMPI::WorkMPI::getFuture() {
return future_;
}
void ProcessGroupMPI::WorkMPI::finishWorkMPIError(
const std::exception_ptr& eptr) {
future_->setError(eptr);
finish(eptr);
}
void ProcessGroupMPI::WorkMPI::finishWorkMPI() {
future_->markCompleted(at::IValue(outputTensors_));
finish();
}
ProcessGroupMPI::AsyncWork::AsyncWork(
MPI_Request request,
std::vector<at::Tensor> outputTensors,
const char* profilingTitle,
const std::optional<std::vector<at::Tensor>>& inputTensors)
: Work(-1, OpType::UNKNOWN, profilingTitle, inputTensors),
outputTensors_(std::move(outputTensors)),
request_(request) {
memset(&status_, 0, sizeof(status_));
}
ProcessGroupMPI::AsyncWork::~AsyncWork() {
if (request_ != MPI_REQUEST_NULL) {
std::cerr
<< "Attempted destruction of AsyncWork before work has completed, "
<< "terminating the program." << '\n';
std::terminate();
}
}
bool ProcessGroupMPI::AsyncWork::isCompleted() {
if (request_ == MPI_REQUEST_NULL) {
return true;
}
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
int flag = 0;
MPI_CHECK(MPI_Test(&request_, &flag, &status_));
if (request_ != MPI_REQUEST_NULL) {
return false;
}
// request_ == MPI_REQUEST_NULL; the work has completed
// Populate exception if request was not successful
if (status_.MPI_ERROR != MPI_SUCCESS) {
populateException();
}
return true;
}
bool ProcessGroupMPI::AsyncWork::isSuccess() const {
if (request_ != MPI_REQUEST_NULL) {
TORCH_CHECK(
false,
"Invalid call to AsyncWork::isSuccess before work has completed");
}
return status_.MPI_ERROR == MPI_SUCCESS;
}
int ProcessGroupMPI::AsyncWork::sourceRank() const {
return status_.MPI_SOURCE;
}
bool ProcessGroupMPI::AsyncWork::wait(std::chrono::milliseconds /* unused */) {
if (request_ == MPI_REQUEST_NULL) {
// AsyncWork needs to manually call profiling end callbacks if they are set,
// since it does not call ProcessGroup::finish().
if (Work::recordFunctionEndCallback_) {
Work::recordFunctionEndCallback_();
Work::recordFunctionEndCallback_ = nullptr;
}
return true;
}
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
MPI_CHECK(MPI_Wait(&request_, &status_));
auto ok = (status_.MPI_ERROR == MPI_SUCCESS);
// AsyncWork needs to manually call profiling end callbacks if they are set,
// since it does not call ProcessGroup::finish().
if (Work::recordFunctionEndCallback_) {
Work::recordFunctionEndCallback_();
Work::recordFunctionEndCallback_ = nullptr;
}
if (!ok) {
populateException();
std::rethrow_exception(exception_);
}
if (c10d::allow_inflight_collective_as_graph_input()) {
c10d::unregister_work(
c10::intrusive_ptr<
ProcessGroupMPI::AsyncWork>::unsafe_reclaim_from_nonowning(this));
}
// Always return true, because abort API is not implemented.
return true;
}
void ProcessGroupMPI::AsyncWork::abort(){
TORCH_CHECK(false, "ProcessGroupMPI::AsyncWork::abort not implemented.")}
std::vector<at::Tensor> ProcessGroupMPI::AsyncWork::result() {
return outputTensors_;
}
void ProcessGroupMPI::AsyncWork::populateException() {
std::array<char, MPI_MAX_ERROR_STRING> buf{};
int len = buf.size();
MPI_CHECK(MPI_Error_string(status_.MPI_ERROR, buf.data(), &len));
exception_ =
std::make_exception_ptr(std::runtime_error(std::string(buf.data(), len)));
}
// Static global states
int ProcessGroupMPI::mpiThreadSupport_ = 0;
std::mutex ProcessGroupMPI::pgGlobalMutex_;
// We only want to initialize once
c10::once_flag ProcessGroupMPI::onceFlagInitMPI;
void ProcessGroupMPI::mpiExit() {
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
MPI_CHECK(MPI_Finalize());
}
void ProcessGroupMPI::initMPIOnce() {
// Initialize MPI environment
c10::call_once(onceFlagInitMPI, []() {
int mpi_was_initialized = 0;
MPI_CHECK(MPI_Initialized(&mpi_was_initialized));
if (mpi_was_initialized == 0) {
MPI_CHECK(MPI_Init_thread(
nullptr, nullptr, MPI_THREAD_SERIALIZED, &mpiThreadSupport_));
if (mpiThreadSupport_ < MPI_THREAD_SERIALIZED) {
TORCH_CHECK(
false,
"Used MPI implementation doesn't have the "
"minimum level of threading support: "
"MPI_THREAD_SERIALIZED. This is required by "
"c10d package");
}
if (std::atexit(ProcessGroupMPI::mpiExit)) {
TORCH_CHECK(false, "Fail to register the MPI exit handler");
}
} else {
TORCH_WARN_ONCE("MPI was previously initialized.");
}
});
}
c10::intrusive_ptr<ProcessGroupMPI> ProcessGroupMPI::createProcessGroupMPI(
std::vector<int> ranks) {
// Once initialization
initMPIOnce();
MPI_Comm groupComm = MPI_COMM_WORLD;
int rank = -1;
int size = -1;
{
std::lock_guard<std::mutex> globalLock(pgGlobalMutex_);
// If no ranks are specified, assume we're creating the root group
if (!ranks.empty()) {
MPI_Group worldGroup{};
MPI_Group ranksGroup{};
MPI_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup));
MPI_CHECK(
MPI_Group_incl(worldGroup, ranks.size(), ranks.data(), &ranksGroup));
// `MPI_Comm_create` can be flaky in certain cases.
// See: https://github.com/pytorch/pytorch/issues/53899
constexpr int kMaxNumRetries = 3;
bool groupComm_updated = false;
MPI_Barrier(MPI_COMM_WORLD);
for (const auto i : c10::irange(kMaxNumRetries)) {
(void)i;
if (MPI_Comm_create(MPI_COMM_WORLD, ranksGroup, &groupComm)) {
groupComm_updated = true;
break;
}
}
MPI_CHECK(groupComm_updated);
MPI_CHECK(MPI_Group_free(&worldGroup));
MPI_CHECK(MPI_Group_free(&ranksGroup));
}
// Fetch rank and world size for this group (MPI_COMM_WORLD or new)
if (groupComm != MPI_COMM_NULL) {
MPI_CHECK(MPI_Comm_rank(groupComm, &rank));
MPI_CHECK(MPI_Comm_size(groupComm, &size));
if (rank < 0 || size < 0) {
TORCH_CHECK(false, "Failed to get the world_size / rank");
}
}
}
// If this process is not part of the group, we don't construct a
// process group instance. This is in line with the semantics of the
// other process group types.
if (groupComm == MPI_COMM_NULL) {
return c10::intrusive_ptr<ProcessGroupMPI>();
}
return c10::make_intrusive<ProcessGroupMPI>(rank, size, groupComm);
}
ProcessGroupMPI::ProcessGroupMPI(int rank, int size, MPI_Comm pgComm)
: Backend(rank, size), stop_(false), pgComm_(pgComm) {
if (pgComm_ == MPI_COMM_NULL) {
TORCH_CHECK(false, "pgComm_ must not be MPI_COMM_NULL");
}
// Start the worker thread accepting MPI calls
workerThread_ = std::thread(&ProcessGroupMPI::runLoop, this);
init();
}
ProcessGroupMPI::~ProcessGroupMPI() {
destroy();
}
void ProcessGroupMPI::destroy() {
std::unique_lock<std::mutex> lock(pgMutex_);
queueConsumeCV_.wait(lock, [&] { return queue_.empty(); });
// Queue is empty, signal stop
stop_ = true;
// Release lock to allow threads to terminate
lock.unlock();
queueProduceCV_.notify_all();
// Join the single worker thread
workerThread_.join();
}
void ProcessGroupMPI::abort() {
destroy();
MPI_Abort(pgComm_, EXIT_FAILURE);
}
void ProcessGroupMPI::runLoop() {
std::unique_lock<std::mutex> lock(pgMutex_);
while (!stop_) {
if (queue_.empty()) {
queueProduceCV_.wait(lock);
continue;
}
auto workTuple = std::move(queue_.front());
queue_.pop_front();
auto& workEntry = std::get<0>(workTuple);
auto& work = std::get<1>(workTuple);
lock.unlock();
queueConsumeCV_.notify_one();
try {
workEntry->run(workEntry);
work->finishWorkMPI();
} catch (...) {
work->finishWorkMPIError(std::current_exception());
}
lock.lock();
}
}
c10::intrusive_ptr<Work> ProcessGroupMPI::enqueue(
std::unique_ptr<WorkEntry> entry,
const char* profilingTitle,
const std::optional<std::vector<at::Tensor>>& inputTensors) {
auto work =
c10::make_intrusive<WorkMPI>(entry->dst, profilingTitle, inputTensors);
std::unique_lock<std::mutex> lock(pgMutex_);
queue_.emplace_back(std::move(entry), work);
lock.unlock();
queueProduceCV_.notify_one();
return work;
}
c10::intrusive_ptr<Work> ProcessGroupMPI::broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts) {
checkSingleTensor(tensors);
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
[opts, this](std::unique_ptr<WorkEntry>& entry) {
auto data = (entry->src)[0];
c10::DeviceGuard guard(data.device());
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
MPI_CHECK(MPI_Bcast(
data.data_ptr(),
data.numel(),
mpiDatatype.at(data.scalar_type()),
opts.rootRank,
pgComm_));
};
auto entry =
std::make_unique<WorkEntry>(&tensors, &tensors, std::move(runFunc));
return enqueue(
std::move(entry),
"mpi:broadcast",
std::optional<std::vector<at::Tensor>>(tensors));
}
c10::intrusive_ptr<Work> ProcessGroupMPI::allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts) {
checkSingleTensor(tensors);
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
[opts, this](std::unique_ptr<WorkEntry>& entry) {
auto data = (entry->src)[0];
c10::DeviceGuard guard(data.device());
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
MPI_CHECK(MPI_Allreduce(
MPI_IN_PLACE,
data.data_ptr(),
data.numel(),
mpiDatatype.at(data.scalar_type()),
mpiOp.at(opts.reduceOp),
pgComm_));
};
auto entry =
std::make_unique<WorkEntry>(&tensors, &tensors, std::move(runFunc));
return enqueue(
std::move(entry),
"mpi:all_reduce",
std::optional<std::vector<at::Tensor>>(tensors));
}
c10::intrusive_ptr<Work> ProcessGroupMPI::allreduce_coalesced(
std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts) {
TORCH_CHECK(false, "allreduce_coalesced is currently not supported with MPI");
}
c10::intrusive_ptr<Work> ProcessGroupMPI::reduce(
std::vector<at::Tensor>& tensors,
const ReduceOptions& opts) {
checkSingleTensor(tensors);
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
[opts, this](std::unique_ptr<WorkEntry>& entry) {
auto data = (entry->src)[0];
auto dataPtr = (entry->src)[0].data_ptr();
void* sendbuf = (rank_ == opts.rootRank) ? MPI_IN_PLACE : dataPtr;
void* recvbuf = (rank_ == opts.rootRank) ? dataPtr : nullptr;
c10::DeviceGuard guard(data.device());
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
MPI_CHECK(MPI_Reduce(
sendbuf,
recvbuf,
data.numel(),
mpiDatatype.at(data.scalar_type()),
mpiOp.at(opts.reduceOp),
opts.rootRank,
pgComm_));
};
auto entry =
std::make_unique<WorkEntry>(&tensors, &tensors, std::move(runFunc));
return enqueue(
std::move(entry),
"mpi:reduce",
std::optional<std::vector<at::Tensor>>(tensors));
}
c10::intrusive_ptr<Work> ProcessGroupMPI::allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts) {
checkSingleTensor(inputTensors);
if (outputTensors.size() != 1) {
TORCH_CHECK(
false,
"MPI process group only supports a single "
"tensor op");
}
if (static_cast<size_t>(size_) != outputTensors[0].size()) {
TORCH_CHECK(
false,
"All gather: number of output tensors should equal "
"to the world size");
}
checkSameSizeAndType(inputTensors[0], outputTensors[0]);
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
[this](std::unique_ptr<WorkEntry>& entry) {
auto data = (entry->src)[0];
std::vector<at::Tensor> outputDataVec = entry->dst;
auto flatOutputTensor = newLikeFlat(outputDataVec);
c10::DeviceGuard guard(data.device());
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
MPI_CHECK(MPI_Allgather(
data.data_ptr(),
data.numel(),
mpiDatatype.at(data.scalar_type()),
flatOutputTensor.data_ptr(),
data.numel(),
mpiDatatype.at(data.scalar_type()),
pgComm_));
for (const auto i : c10::irange(outputDataVec.size())) {
outputDataVec[i].copy_(flatOutputTensor[static_cast<int64_t>(i)]);
}
};
auto entry = std::make_unique<WorkEntry>(
&inputTensors, &outputTensors[0], std::move(runFunc));
return enqueue(
std::move(entry),
"mpi:all_gather",
std::optional<std::vector<at::Tensor>>(inputTensors));
}
c10::intrusive_ptr<Work> ProcessGroupMPI::allgather_coalesced(
std::vector<std::vector<at::Tensor>>& /* unused */,
std::vector<at::Tensor>& /* unused */,
const AllgatherOptions& /* unused */) {
TORCH_CHECK(false, "ProcessGroupMPI does not support allgather_coalesced");
}
c10::intrusive_ptr<Work> ProcessGroupMPI::gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const GatherOptions& opts) {
checkSingleTensor(inputTensors);
if (rank_ != opts.rootRank) {
if (!outputTensors.empty()) {
TORCH_CHECK(
false,
"Gather: number of output tensors should be 0 "
"for non-root");
}
} else {
if (outputTensors.size() != 1) {
TORCH_CHECK(false, "Gather: multi-GPU collective is not supported");
}
if (static_cast<size_t>(size_) != outputTensors[0].size()) {
TORCH_CHECK(
false,
"Gather: number of output tensors should equal "
"to the world size");
}
checkSameSizeAndType(inputTensors[0], outputTensors[0]);
}
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
[opts, this](std::unique_ptr<WorkEntry>& entry) {
auto data = (entry->src)[0];
void* recvbuf = nullptr;
at::Tensor flatOutputTensor;
std::vector<at::Tensor> dstdata = entry->dst;
if (rank_ == opts.rootRank) {
flatOutputTensor = newLikeFlat(dstdata);
recvbuf = flatOutputTensor.data_ptr();
}
c10::DeviceGuard guard(data.device());
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
MPI_CHECK(MPI_Gather(
data.data_ptr(),
data.numel(),
mpiDatatype.at(data.scalar_type()),
recvbuf,
data.numel(),
mpiDatatype.at(data.scalar_type()),
opts.rootRank,
pgComm_));
if (rank_ == opts.rootRank) {
const std::vector<at::Tensor>& outputDataVec = entry->dst;
// copy the flattened output tensors to the outputs
for (const auto i : c10::irange(outputDataVec.size())) {
outputDataVec.at(i).copy_(
flatOutputTensor[static_cast<int64_t>(i)]);
}
}
};
if (rank_ == opts.rootRank) {
auto entry = std::make_unique<WorkEntry>(
&inputTensors, &outputTensors[0], std::move(runFunc));
return enqueue(
std::move(entry),
"mpi:gather",
std::optional<std::vector<at::Tensor>>(inputTensors));
} else {
auto entry =
std::make_unique<WorkEntry>(&inputTensors, nullptr, std::move(runFunc));
return enqueue(
std::move(entry),
"mpi:gather",
std::optional<std::vector<at::Tensor>>(inputTensors));
}
}
c10::intrusive_ptr<Work> ProcessGroupMPI::scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ScatterOptions& opts) {
checkSingleTensor(outputTensors);
if (rank_ != opts.rootRank) {
if (!inputTensors.empty()) {
TORCH_CHECK(
false,
"Scatter: number of input tensors should be 0 "
"for non-root");
}
} else {
if (inputTensors.size() != 1) {
TORCH_CHECK(false, "Scatter: multi-GPU collective is not supported");
}
if (static_cast<size_t>(size_) != inputTensors[0].size()) {
TORCH_CHECK(
false,
"Scatter: number of input tensors should equal "
"to the world size");
}
checkSameSizeAndType(outputTensors[0], inputTensors[0]);
}
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
[opts, this](std::unique_ptr<WorkEntry>& entry) {
auto data = (entry->dst)[0];
void* sendbuf = nullptr;
at::Tensor flatInputTensor;
if (rank_ == opts.rootRank) {
std::vector<at::Tensor>& inputDataVec = entry->src;
flatInputTensor = newLikeFlat(inputDataVec);
sendbuf = flatInputTensor.data_ptr();
// copy the input tensors to the flatten large send buffer
for (const auto i : c10::irange(inputDataVec.size())) {
flatInputTensor[static_cast<int64_t>(i)].copy_(inputDataVec.at(i));
}
}
c10::DeviceGuard guard(data.device());
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
MPI_CHECK(MPI_Scatter(
sendbuf,
data.numel(),
mpiDatatype.at(data.scalar_type()),
data.data_ptr(),
data.numel(),
mpiDatatype.at(data.scalar_type()),
opts.rootRank,
pgComm_));
};
if (rank_ == opts.rootRank) {
auto entry = std::make_unique<WorkEntry>(
&inputTensors[0], &outputTensors, std::move(runFunc));
return enqueue(
std::move(entry),
"mpi:scatter",
!inputTensors.empty()
? std::optional<std::vector<at::Tensor>>(inputTensors[0])
: std::nullopt);
} else {
auto entry = std::make_unique<WorkEntry>(
nullptr, &outputTensors, std::move(runFunc));
return enqueue(
std::move(entry),
"mpi:scatter",
!inputTensors.empty()
? std::optional<std::vector<at::Tensor>>(inputTensors[0])
: std::nullopt);
}
}
c10::intrusive_ptr<Work> ProcessGroupMPI::reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts) {
TORCH_CHECK(false, "ProcessGroupMPI does not support reduce_scatter");
}
c10::intrusive_ptr<Work> ProcessGroupMPI::alltoall_base(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
std::vector<int64_t>& outputSplitSizes,
std::vector<int64_t>& inputSplitSizes,
const AllToAllOptions& opts) {
checkSingleTensorHelper(inputTensor);
checkSingleTensorHelper(outputTensor);
if (outputSplitSizes.empty() && inputSplitSizes.empty()) {
// We can use alltoall
TORCH_CHECK(
outputTensor.numel() == inputTensor.numel() &&
outputTensor.type() == inputTensor.type(),
"Tensors are not equal in size or data type");
TORCH_CHECK(
outputTensor.size(0) % size_ == 0,
"Tensor's dim 0 does not divide equally across group size");
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
[this](std::unique_ptr<WorkEntry>& entry) {
auto srcdata = (entry->src)[0];
auto dstdata = (entry->dst)[0];
c10::DeviceGuard guard(srcdata.device());
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
MPI_CHECK(MPI_Alltoall(
srcdata.data_ptr(),
srcdata.numel() / size_,
mpiDatatype.at(srcdata.scalar_type()),
dstdata.data_ptr(),
dstdata.numel() / size_,
mpiDatatype.at(dstdata.scalar_type()),
pgComm_));
};
std::vector<at::Tensor> inputTensors = {inputTensor};
std::vector<at::Tensor> outputTensors = {outputTensor};
auto entry = std::make_unique<WorkEntry>(
&inputTensors, &outputTensors, std::move(runFunc));
return enqueue(
std::move(entry),
"mpi:all_to_all",
std::optional<std::vector<at::Tensor>>(inputTensors));
} else {
// Need alltoallv
c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_);
c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_);
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
[this, inputSplitSizes, outputSplitSizes](
std::unique_ptr<WorkEntry>& entry) {
auto srcdata = (entry->src)[0];
auto dstdata = (entry->dst)[0];
std::vector<int> send_lengths(size_);
std::vector<int> recv_lengths(size_);
std::vector<int> send_offsets(size_);
std::vector<int> recv_offsets(size_);
c10d::computeLengthsAndOffsets(
inputSplitSizes, srcdata, &send_lengths, &send_offsets);
c10d::computeLengthsAndOffsets(
outputSplitSizes, dstdata, &recv_lengths, &recv_offsets);
c10::DeviceGuard guard(srcdata.device());
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
MPI_CHECK(MPI_Alltoallv(
srcdata.data_ptr(),
send_lengths.data(),
send_offsets.data(),
mpiDatatype.at(srcdata.scalar_type()),
dstdata.data_ptr(),
recv_lengths.data(),
recv_offsets.data(),
mpiDatatype.at(dstdata.scalar_type()),
pgComm_));
};
std::vector<at::Tensor> inputTensors = {inputTensor};
std::vector<at::Tensor> outputTensors = {outputTensor};
auto entry = std::make_unique<WorkEntry>(
&inputTensors, &outputTensors, std::move(runFunc));
return enqueue(
std::move(entry),
"mpi:all_to_all",
std::optional<std::vector<at::Tensor>>(inputTensors));
}
}
c10::intrusive_ptr<Work> ProcessGroupMPI::alltoall(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllToAllOptions& opts) {
TORCH_CHECK(
inputTensors.size() == static_cast<size_t>(size_),
"Number of input tensors are not equal to group size");
TORCH_CHECK(
outputTensors.size() == static_cast<size_t>(size_),
"Number of output tensors are not equal to group size");
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
[this](std::unique_ptr<WorkEntry>& entry) {
std::vector<int> send_lengths(size_);
std::vector<int> recv_lengths(size_);
std::vector<int> send_offsets(size_);
std::vector<int> recv_offsets(size_);
auto srcdata = entry->src;
auto dstdata = entry->dst;
auto src_len = c10d::computeLengthsAndOffsets(
srcdata, &send_lengths, &send_offsets);
auto dst_len = c10d::computeLengthsAndOffsets(
dstdata, &recv_lengths, &recv_offsets);
std::vector<int64_t> send_lengthsL(
send_lengths.begin(), send_lengths.end());
std::vector<int64_t> recv_lengthsL(
recv_lengths.begin(), recv_lengths.end());
at::Tensor srcFlatData =
at::empty({static_cast<int64_t>(src_len)}, srcdata[0].options());
at::Tensor dstFlatData =
at::empty({static_cast<int64_t>(dst_len)}, dstdata[0].options());
auto srcFlatDataSplits =
srcFlatData.split_with_sizes(c10::IntArrayRef(send_lengthsL), 0);
for (const auto i : c10::irange(size_)) {
srcFlatDataSplits[i].copy_(srcdata[i].view({-1}));
}
c10::DeviceGuard guard1(srcdata[0].device());
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
MPI_CHECK(MPI_Alltoallv(
srcFlatData.data_ptr(),
send_lengths.data(),
send_offsets.data(),
mpiDatatype.at(srcdata[0].scalar_type()),
dstFlatData.data_ptr(),
recv_lengths.data(),
recv_offsets.data(),
mpiDatatype.at(dstdata[0].scalar_type()),
pgComm_));
auto dstFlatDataSplits =
dstFlatData.split_with_sizes(c10::IntArrayRef(recv_lengthsL), 0);
for (const auto i : c10::irange(size_)) {
dstdata[i].view({-1}).copy_(dstFlatDataSplits[i]);
}
};
auto entry = std::make_unique<WorkEntry>(
&inputTensors, &outputTensors, std::move(runFunc));
return enqueue(
std::move(entry),
"mpi:all_to_all",
std::optional<std::vector<at::Tensor>>(inputTensors));
}
c10::intrusive_ptr<Work> ProcessGroupMPI::send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag) {
checkSingleTensor(tensors);
auto& tensor = tensors[0];
MPI_Request request = MPI_REQUEST_NULL;
{
c10::DeviceGuard guard(tensor.device());
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
MPI_CHECK(MPI_Isend(
tensor.data_ptr(),
tensor.numel(),
mpiDatatype.at(tensor.scalar_type()),
dstRank,
tag,
pgComm_,
&request));
}
return c10::make_intrusive<AsyncWork>(
request,
std::vector<at::Tensor>(),
"mpi:send",
std::optional<std::vector<at::Tensor>>(tensors));
}
c10::intrusive_ptr<Work> ProcessGroupMPI::recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag) {
checkSingleTensor(tensors);
auto& tensor = tensors[0];
MPI_Request request = MPI_REQUEST_NULL;
{
c10::DeviceGuard guard(tensor.device());
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
MPI_CHECK(MPI_Irecv(
tensor.data_ptr(),
tensor.numel(),
mpiDatatype.at(tensor.scalar_type()),
srcRank,
tag,
pgComm_,
&request));
}
return c10::make_intrusive<AsyncWork>(
request,
tensors,
"mpi:recv",
std::optional<std::vector<at::Tensor>>(tensors));
}
c10::intrusive_ptr<Work> ProcessGroupMPI::recvAnysource(
std::vector<at::Tensor>& tensors,
int tag) {
checkSingleTensor(tensors);
auto& tensor = tensors[0];
MPI_Request request = MPI_REQUEST_NULL;
{
c10::DeviceGuard guard(tensor.device());
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
MPI_CHECK(MPI_Irecv(
tensor.data_ptr(),
tensor.numel(),
mpiDatatype.at(tensor.scalar_type()),
MPI_ANY_SOURCE,
tag,
pgComm_,
&request));
}
return c10::make_intrusive<AsyncWork>(
request,
tensors,
"mpi:recvAnySource",
std::optional<std::vector<at::Tensor>>(tensors));
}
c10::intrusive_ptr<Work> ProcessGroupMPI::barrier(const BarrierOptions& opts) {
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
[this](std::unique_ptr<WorkEntry>& entry) {
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
MPI_CHECK(MPI_Barrier(pgComm_));
};
auto entry =
std::make_unique<WorkEntry>(nullptr, nullptr, std::move(runFunc));
return enqueue(std::move(entry), "mpi:barrier", std::nullopt);
}
c10::intrusive_ptr<Work> ProcessGroupMPI::_allgather_base(
at::Tensor& /*unused */,
at::Tensor& /*unused */,
const AllgatherOptions& /*unused */) {
TORCH_CHECK(false, "no support for _allgather_base in MPI process group");
}
} // namespace c10d
#endif // USE_C10D_MPI