Files
pytorch/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp
Will Feng 4ee514144b [c10d][Partial-Graph Overlap] Support calling .wait_tensor() on output tensor of eager async_op=True collective if under allow_inflight_collective_as_graph_input_ctx() context manager (#137763)
This PR aims to support the following use case:
```python
def all_reduce_eager(x):
    y = x * x
    req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
    assert isinstance(req, torch.distributed.Work)
    return y

@torch.compile(fullgraph=True)
def all_reduce_wait_compiled(y):
    torch.ops.c10d_functional.wait_tensor(y)
    return y * y

x = torch.ones(1280, 1280, device="cuda") + self.rank
with allow_inflight_collective_as_graph_input_ctx():
    y = all_reduce_eager(x)
    z = all_reduce_wait_compiled(y)
```
where the collective is issued in eager (with `async_op=True`) but waited in compiled region.

This is important for internal use cases such as TorchRec, where we issue collectives in eager for SparseArch all_to_all but want to wait for them in compiled region at beginning of OverArch, so that the all_to_all can be overlapped with the DenseArch compute that runs in parallel.

----

**Update**: Did two items to prevent regression to existing use cases:

1. Added memory-stressed test case to test_c10d_nccl.py `test_unwaited` to cover existing user's "not calling work.wait() for non-functional collective" use case
2. Gated all new `register_work()` / `unregister_work()` calls with `c10d::allow_inflight_collective_as_graph_input()` check, which is a new context manager that requires explicit user enablement (i.e. not on by default, so should not affect existing users).

The risk of this new version of PR causing regression should be very low.

------

Test commands:
- `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_eager_async_allreduce_inductor_wait`
- `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives`
- `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives_no_overload`
- `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_wait_tensor`
- `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_unwaited`
- `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_wait_tensor`
- `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_unwaited`
- `pytest -rA test/distributed/_tensor/test_tensor_ops.py::DistTensorOpsTest::test_equal`
- `pytest -rA test/distributed/_tensor/test_random_ops.py::DistTensorRandomOpTest::test_manual_seed`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_asymmetric_compilation`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_scalar`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_speculation_divergence`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_tensor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_dim_mismatch`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_graph_break_empty_graph_still_collective`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_missing_source`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_scalar_missing_source`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_type_mismatch`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_activation_checkpointing`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_activation_checkpointing`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_inductor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager_static_graph`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor_static_graph`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_fsdp_activation_checkpointing`
- `pytest -rA test/distributed/_tensor/test_experimental_ops.py::DistOtherOpsTest::test_bernoulli`
- `pytest -rA test/distributed/_tensor/test_dtensor_compile.py::TestDTensorCompileE2E::test_tp_compile_fullgraph_is_seq_parallel_True`
- `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_allreduce_inductor_cudagraph_trees`
- `python benchmarks/dynamo/torchbench.py --ci --accuracy --timing --explain --inductor --device cuda --inference --bfloat16 --total-partitions 2 --partition-id 1 --output inference_torchbench.csv --only moco`

------

Differential Revision: [D65023311](https://our.internmc.facebook.com/intern/diff/D65023311)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137763
Approved by: https://github.com/yifuwang
2024-10-29 03:31:19 +00:00

1676 lines
54 KiB
C++

#ifdef USE_C10D_UCC
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
#include <c10/util/env.h>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupUCC.hpp>
#include <torch/csrc/distributed/c10d/UCCTracing.hpp>
#include <torch/csrc/distributed/c10d/UCCUtils.hpp>
#include <list>
#include <memory>
#include <unordered_map>
#include <unordered_set>
namespace c10d {
namespace {
const std::map<c10::DeviceType, ucc_memory_type_t> ucc_mtype_map = {
{c10::kCPU, UCC_MEMORY_TYPE_HOST},
{c10::kCUDA, UCC_MEMORY_TYPE_CUDA},
};
ucc_memory_type_t to_ucc_memType(c10::DeviceType _c10_type) {
if (ucc_mtype_map.find(_c10_type) != ucc_mtype_map.end())
return ucc_mtype_map.at(_c10_type);
else
return UCC_MEMORY_TYPE_UNKNOWN;
}
const std::map<at::ScalarType, ucc_datatype_t> ucc_dtype_map = {
{at::kByte, UCC_DT_UINT8},
{at::kChar, UCC_DT_INT8},
{at::kHalf, UCC_DT_FLOAT16},
{at::kBFloat16, UCC_DT_BFLOAT16},
{at::kDouble, UCC_DT_FLOAT64},
{at::kFloat, UCC_DT_FLOAT32},
{at::kInt, UCC_DT_INT32},
{at::kLong, UCC_DT_INT64},
{at::kBool, UCC_DT_UINT8},
};
ucc_datatype_t to_ucc_dType(at::Tensor _tensor) {
if (_tensor.scalar_type() == at::kBool && _tensor.element_size() != 1) {
TORCH_CHECK(
false, "Size of Boolean type larger than 1 is not supported in UCC");
}
try {
return ucc_dtype_map.at(_tensor.scalar_type());
} catch (const std::out_of_range&) {
TORCH_CHECK(false, "Not supported data type for UCC");
}
}
const std::map<ReduceOp, ucc_reduction_op_t> ucc_op_map = {
{ReduceOp::SUM, UCC_OP_SUM},
{ReduceOp::PRODUCT, UCC_OP_PROD},
{ReduceOp::MIN, UCC_OP_MIN},
{ReduceOp::MAX, UCC_OP_MAX},
{ReduceOp::BAND, UCC_OP_BAND},
{ReduceOp::BOR, UCC_OP_BOR},
{ReduceOp::BXOR, UCC_OP_BXOR},
{ReduceOp::AVG, UCC_OP_AVG},
};
ucc_reduction_op_t to_ucc_reduceOp(
const ReduceOp _op,
const at::ScalarType _dt) {
if (_dt == at::kBool) {
if (_op == ReduceOp::SUM) {
// bitwise or
return UCC_OP_MAX;
} else if (_op == ReduceOp::PRODUCT) {
// bitwise and
return UCC_OP_MIN;
} else if (_op == ReduceOp::AVG) {
TORCH_CHECK(false, "Cannot use ReduceOp.AVG with boolean inputs");
}
}
try {
return ucc_op_map.at(_op);
} catch (const std::out_of_range&) {
TORCH_CHECK(false, "Not supported ReduceOp for UCC");
}
}
struct torch_ucc_config_t {
c10::once_flag flag;
std::array<bool, 32> blocking_wait;
bool enable_comms_logger;
bool use_future;
// Sharing UCC communicator among multiple PGs to save resource.
bool shared_comm;
// Using allgatherv to achieve allgather, without flattening the list of
// (potentially non-contiguous) tensors.
bool use_allgatherv;
bool enable_health_check;
} torch_ucc_config;
std::unordered_map<std::string, std::string> torch_ucc_envs_map = {
// TORCH_UCC_BLOCKING_WAIT allowed syntax:
// - TORCH_UCC_BLOCKING_WAIT=none --> blocking wait completely disabled
// - TORCH_UCC_BLOCKING_WAIT=all --> blocking wait completely enabled
// - TORCH_UCC_BLOCKING_WAIT=allreduce,send,recv --> blocking wait enabled
// on selected operations
// Supported operations:
// [allgather,allgather_base,allreduce,alltoall,broadcast,
// gather,reduce,reduce_scatter,scatter,send,recv]
{"TORCH_UCC_BLOCKING_WAIT", "none"},
{"TORCH_UCC_USE_FUTURE", "1"},
{"TORCH_UCC_PROFILING_ENABLE", "0"},
{"TORCH_UCC_SHARED_COMM", "1"},
{"TORCH_UCC_USE_ALLGATHERV", "0"},
{"TORCH_UCC_ENABLE_HEALTH_CHECK", "0"},
{"TORCH_UCC_ENABLE_COMMS_LOGGER", "0"},
};
std::vector<OpType> parse_blocking_wait(std::string op_list_string) {
const static std::unordered_map<std::string, OpType> str2op = {
{"allgather", OpType::ALLGATHER},
{"allgather_base", OpType::_ALLGATHER_BASE},
{"allreduce", OpType::ALLREDUCE},
{"alltoall_base", OpType::ALLTOALL_BASE},
{"broadcast", OpType::BROADCAST},
{"gather", OpType::GATHER},
{"reduce", OpType::REDUCE},
{"reduce_scatter", OpType::REDUCE_SCATTER},
{"scatter", OpType::SCATTER},
{"send", OpType::SEND},
{"recv", OpType::RECV},
};
auto op_list = parse_list(op_list_string);
if (op_list == std::vector<std::string>{"none"}) {
return {};
}
std::vector<OpType> result;
if (op_list == std::vector<std::string>{"all"}) {
for (auto entry : str2op) {
result.push_back(entry.second);
}
} else {
for (auto op_string : op_list) {
result.push_back(str2op.at(op_string));
}
}
return result;
}
} // namespace
void read_config() {
// default configuration
torch_ucc_config.blocking_wait.fill(false);
torch_ucc_config.use_future = true;
torch_ucc_config.shared_comm = false;
torch_ucc_config.use_allgatherv = false;
torch_ucc_config.enable_health_check = false;
torch_ucc_config.enable_comms_logger = false;
// read all torch_ucc env. variables and update the map
for (auto& [env_name, value] : torch_ucc_envs_map) {
auto env = c10::utils::get_env(env_name.c_str());
if (env.has_value()) {
value = std::move(env.value());
}
}
auto blocking_wait_str = torch_ucc_envs_map.at("TORCH_UCC_BLOCKING_WAIT");
for (auto op : parse_blocking_wait(blocking_wait_str)) {
torch_ucc_config.blocking_wait[(std::uint8_t)op] = true;
}
// barrier is always blocking
torch_ucc_config.blocking_wait[(std::uint8_t)OpType::BARRIER] = true;
torch_ucc_config.use_future =
std::stoi(torch_ucc_envs_map.at("TORCH_UCC_USE_FUTURE"));
torch_ucc_config.shared_comm =
std::stoi(torch_ucc_envs_map.at("TORCH_UCC_SHARED_COMM"));
torch_ucc_config.use_allgatherv =
std::stoi(torch_ucc_envs_map.at("TORCH_UCC_USE_ALLGATHERV"));
torch_ucc_config.enable_health_check =
std::stoi(torch_ucc_envs_map.at("TORCH_UCC_ENABLE_HEALTH_CHECK"));
torch_ucc_config.enable_comms_logger =
std::stoi(torch_ucc_envs_map.at("TORCH_UCC_ENABLE_COMMS_LOGGER"));
}
void check_device(c10::Device dev1, c10::Device dev2) {
if (dev1.is_cuda() && dev2.is_cuda() && dev1 != dev2) {
throw std::invalid_argument("ProcessGroupUCC multidevice is not supported");
}
}
void check_tensor(const std::vector<at::Tensor>& tensors) {
if (tensors.size() != 1) {
throw std::invalid_argument(
"ProcessGroupUCC takes 1 tensor. Got " +
std::to_string(tensors.size()) + ". ");
}
if (!tensors[0].is_contiguous()) {
throw std::invalid_argument(
"ProcessGroupUCC input tensor has to be contiguous");
}
if (tensors[0].is_sparse()) {
throw std::invalid_argument("ProcessGroupUCC input tensor has to be dense");
}
// TODO: check cuda case
}
ProcessGroupUCC::WorkUCC::~WorkUCC() {
#ifdef USE_CUDA
if (fence && ep) {
std::lock_guard<std::mutex> lock(ep->event_pool_mutex);
ep->event_pool.push(std::move(fence));
}
#endif
}
void ProcessGroupUCC::WorkUCC::setException() {
if (exception() || !entry_) {
return;
}
exception_ = entry_->eptr_;
}
void ProcessGroupUCC::WorkUCC::setAndThrowException() {
setException();
if (exception()) {
std::rethrow_exception(exception());
}
}
bool ProcessGroupUCC::WorkUCC::isCompleted() {
if (!entry_) {
return true;
}
setException();
// status_ <= 0 to avoid listing all possible status codes. The main thread
// needs to be unblocked when UCC (in progress thread) returns success (== 0)
// or any error code (< 0).
return exception() || entry_->status_ <= 0;
}
bool ProcessGroupUCC::WorkUCC::isSuccess() const {
if (!entry_) {
return true;
}
return !exception() && entry_->status_ == 0;
}
bool ProcessGroupUCC::WorkUCC::wait(std::chrono::milliseconds /* unused */) {
if (torch_ucc_config.enable_comms_logger && logger_) {
logger_->trace_generator->recordComms("wait", (uintptr_t)this, rank_);
}
#ifdef USE_CUDA
if (fence && !torch_ucc_config.blocking_wait[(int)opType_]) {
// block user stream
setAndThrowException();
fence->block(at::cuda::getCurrentCUDAStream());
return true;
}
#endif
// wait for complete. For blocking case, the main thread will be blocked in
// this loop until the progress thread changes the status of this request.
// If timeout occurs, UCC will return UCC_ERR_TIMEOUT as the status. The
// main thread will throw out the exception then. There is no "abort"
// function in UCC currently.
while (!isCompleted())
;
setAndThrowException();
// manually call profiling end callbacks if they are set,
// since progress thread does not own WorkUCC
if (Work::recordFunctionEndCallback_) {
Work::recordFunctionEndCallback_();
Work::recordFunctionEndCallback_ = nullptr;
}
if (c10d::allow_inflight_collective_as_graph_input()) {
c10d::unregister_work(
c10::intrusive_ptr<
ProcessGroupUCC::WorkUCC>::unsafe_reclaim_from_nonowning(this));
}
return true;
}
c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupUCC::WorkUCC::getFuture() {
return future_;
}
int ProcessGroupUCC::WorkUCC::sourceRank() const {
if (opType_ != OpType::RECV && opType_ != OpType::RECVANYSOURCE) {
// Throw an error
return Work::sourceRank();
}
return sourceRank_;
}
std::vector<at::Tensor> ProcessGroupUCC::WorkUCC::result() {
return *outputs_;
}
void ProcessGroupUCC::ProgressEntry::finalize(std::exception_ptr eptr) {
ucc_status_t status = UCC_OK;
if (request_ != nullptr) {
status = request_->status;
comm_->free_request(request_);
}
if (eptr) {
eptr_ = eptr;
} else {
status_ = status;
}
if (future_) {
if (eptr) {
future_->setError(eptr);
} else {
future_->markCompleted(
c10::IValue(data ? data->dst : std::vector<at::Tensor>()));
}
}
}
Comm::Comm(
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger_,
std::shared_ptr<torch_ucc_oob_coll_info_t> oob_,
c10::Device dev,
bool is_health_check)
: logger(logger_),
oob(oob_),
ucc_comm(oob, logger),
finalize_phase(
is_health_check ? TORCH_UCC_HEALTH_CHECK : TORCH_UCC_FINALIZE),
cuda_device_index(TORCH_UCC_DEVICE_NOT_SET) {
if (dev.is_cuda()) {
cuda_device_index = dev.index();
}
stop_progress_loop = false;
collective_inprogress = false;
progress_thread = std::thread(&Comm::progress_loop, this);
#ifdef _GNU_SOURCE
pthread_setname_np(progress_thread.native_handle(), "ucc-progress");
#endif
}
Comm::~Comm() {
std::unique_lock<std::mutex> lock(mutex);
queue_consume_cv.wait(
lock, [&] { return progress_queue.empty() && !collective_inprogress; });
stop_progress_loop = true;
lock.unlock();
queue_produce_cv.notify_all();
progress_thread.join();
}
std::shared_ptr<Comm> Comm::get_comm(
uint32_t& id,
c10::Device dev,
std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger,
bool is_health_check) {
static std::mutex m;
static std::weak_ptr<Comm> comm;
static uint32_t comm_id;
std::lock_guard<std::mutex> lock(m);
id = comm_id;
std::string group_id = "group_id";
if (is_health_check) {
group_id = c10::str(dev.type()) + "/" + group_id;
}
std::vector<uint8_t> remote_comm_id;
oob->store->deleteKey(group_id + std::to_string(0));
if (oob->rank != 0) {
std::vector<uint8_t> val = std::vector<uint8_t>(
reinterpret_cast<uint8_t*>(&id),
reinterpret_cast<uint8_t*>(&id) + sizeof(id));
oob->store->set(group_id + std::to_string(oob->rank), val);
} else {
for (int i = 1; i < oob->size; i++) {
remote_comm_id = oob->store->get(group_id + std::to_string(i));
oob->store->deleteKey(group_id + std::to_string(i));
// Find the highest id.
id = std::max(id, *(reinterpret_cast<uint32_t*>(remote_comm_id.data())));
}
std::vector<uint8_t> val = std::vector<uint8_t>(
reinterpret_cast<uint8_t*>(&id),
reinterpret_cast<uint8_t*>(&id) + sizeof(id));
oob->store->set(group_id + std::to_string(oob->rank), val);
}
remote_comm_id = oob->store->get(group_id + std::to_string(0));
oob->comm_id = *(reinterpret_cast<uint32_t*>(remote_comm_id.data()));
// Prepare comm_id (static variable) to the next id.
comm_id = oob->comm_id + 1;
if (torch_ucc_config.shared_comm) {
std::shared_ptr<Comm> shared_comm = comm.lock();
if (!shared_comm) {
shared_comm = std::make_shared<Comm>(logger, oob, dev, is_health_check);
comm = shared_comm;
} else {
if (dev.is_cuda() && !is_health_check) {
if ((shared_comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) &&
(shared_comm->cuda_device_index != dev.index())) {
TORCH_UCC_LOG_ERROR(
is_health_check ? TORCH_UCC_HEALTH_CHECK : TORCH_UCC_INIT,
"ucc communicator was initialized with different cuda device,"
"multi device is not supported");
throw std::invalid_argument(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
}
shared_comm->cuda_device_index = dev.index();
}
}
return shared_comm;
} else {
return std::make_shared<Comm>(logger, oob, dev, is_health_check);
}
}
void Comm::ucc_create_team(
ucc_team_h& team,
std::shared_ptr<torch_ucc_oob_coll_info_t> oob) {
ucc_status_t st;
ucc_team_params_t team_params;
team_params.mask = UCC_TEAM_PARAM_FIELD_EP | UCC_TEAM_PARAM_FIELD_EP_RANGE |
UCC_TEAM_PARAM_FIELD_OOB;
team_params.oob.allgather = oob_allgather;
team_params.oob.req_test = oob_allgather_test;
team_params.oob.req_free = oob_allgather_free;
team_params.oob.coll_info = oob.get();
team_params.oob.n_oob_eps = oob->size;
team_params.oob.oob_ep = oob->rank;
team_params.ep = oob->rank;
team_params.ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG;
TORCH_UCC_CHECK(
ucc_team_create_post(&ucc_comm.context, 1, &team_params, &team),
"failed to post team create");
do {
st = ucc_team_create_test(team);
ucc_context_progress(ucc_comm.context);
} while (st == UCC_INPROGRESS);
TORCH_UCC_CHECK(st, "failed to create UCC team");
}
void Comm::ucc_destroy_team(ucc_team_h& team) {
std::unique_lock<std::mutex> lock(mutex);
queue_consume_cv.wait(
lock, [&] { return progress_queue.empty() && !collective_inprogress; });
ucc_status_t status;
while (UCC_INPROGRESS == (status = ucc_team_destroy(team))) {
if (UCC_OK != status) {
TORCH_UCC_LOG_ERROR(
finalize_phase,
c10::str("ucc team destroy error: ", ucc_status_string(status)));
break;
}
}
lock.unlock();
}
void Comm::enqueue_collective(
std::unique_ptr<ProcessGroupUCC::WorkData> data,
c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
ucc_coll_args_t& coll,
ucc_team_h team) {
ucc_coll_req_h request;
TORCH_UCC_CHECK(
ucc_collective_init(&coll, &request, team), "failed to init collective");
TORCH_UCC_CHECK_REQUEST(
request, ucc_collective_post(request), "failed to post collective");
auto entry =
std::make_shared<ProcessGroupUCC::ProgressEntry>(&ucc_comm, request);
entry->data = std::move(data);
entry->future_ = work->getFuture();
work->entry_ = entry;
std::unique_lock<std::mutex> lock(mutex);
progress_queue.push_back(entry);
lock.unlock();
queue_produce_cv.notify_one();
}
#ifdef USE_CUDA
void Comm::enqueue_cuda_collective(
std::unique_ptr<ProcessGroupUCC::WorkData> data,
c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
ucc_coll_args_t& coll,
ucc_team_h team,
ucc_ee_h ee) {
ucc_coll_req_h request;
TORCH_UCC_CHECK(
ucc_collective_init(&coll, &request, team),
"failed to init cuda collective");
ucc_ev_t comp_ev, *post_ev;
comp_ev.ev_type = UCC_EVENT_COMPUTE_COMPLETE;
comp_ev.ev_context = nullptr;
comp_ev.ev_context_size = 0;
comp_ev.req = request;
TORCH_UCC_CHECK_REQUEST(
request,
ucc_collective_triggered_post(ee, &comp_ev),
"failed to post triggered collective");
ucc_status_t st = ucc_ee_get_event(ee, &post_ev);
TORCH_CHECK(st == UCC_OK && post_ev->ev_type == UCC_EVENT_COLLECTIVE_POST);
ucc_ee_ack_event(ee, post_ev);
auto entry =
std::make_shared<ProcessGroupUCC::ProgressEntry>(&ucc_comm, request);
entry->data = std::move(data);
work->entry_ = entry;
std::unique_lock<std::mutex> lock(mutex);
progress_queue.push_back(entry);
lock.unlock();
queue_produce_cv.notify_one();
}
#endif
void Comm::progress_loop() {
std::unique_lock<std::mutex> lock(mutex);
#ifdef USE_CUDA
bool device_set = false;
#endif
while (!stop_progress_loop) {
if (progress_queue.empty()) {
queue_produce_cv.wait(lock);
continue;
}
collective_inprogress = true;
auto work = progress_queue.front();
progress_queue.pop_front();
lock.unlock();
#ifdef USE_CUDA
if ((!device_set) && (cuda_device_index != TORCH_UCC_DEVICE_NOT_SET)) {
c10::cuda::set_device(cuda_device_index);
CUcontext pctx = nullptr;
at::globalContext().getNVRTC().cuCtxGetCurrent(&pctx);
if (C10_UNLIKELY(!pctx)) {
at::globalContext().getNVRTC().cuDevicePrimaryCtxRetain(
&pctx, cuda_device_index);
at::globalContext().getNVRTC().cuCtxSetCurrent(pctx);
}
device_set = true;
}
#endif
std::exception_ptr eptr;
try {
while (work->request_->status > 0) {
ucc_comm.progress();
}
if (work->request_->status < 0) {
eptr = std::make_exception_ptr(
std::runtime_error(ucc_status_string(work->request_->status)));
std::string err_log = c10::str(
"Failed to progress communication", // TODO: report exact op type or
// id?
ucc_status_string(work->request_->status));
TORCH_UCC_LOG_ERROR(TORCH_UCC_COLL_PROGRESS, err_log);
}
} catch (...) {
eptr = std::current_exception();
}
work->finalize(eptr);
work = nullptr;
collective_inprogress = false;
queue_consume_cv.notify_one();
lock.lock();
}
}
ProcessGroupUCC::ProcessGroupUCC(
const c10::intrusive_ptr<Store>& store,
int rank,
int size,
std::chrono::duration<float> timeout)
: Backend(rank, size), timeout_(timeout) {
c10::call_once(torch_ucc_config.flag, read_config);
oob = std::make_shared<torch_ucc_oob_coll_info_t>();
oob->rank = rank;
oob->size = size;
oob->store = store;
comm = nullptr;
cuda_ee = nullptr;
static uint32_t id = 0;
uint32_t pg_id = id++;
logger = c10::make_intrusive<ProcessGroupUCCLogger>(
c10::str("[Rank ", rank_, "]", "[ProcessGroupUCC-", pg_id, "]"),
TORCH_UCC_INIT);
TORCH_UCC_LOG_INFO(
TORCH_UCC_INIT,
c10::str(
"Created ProcessGroupUCC with ",
size,
" ranks, with timeout ",
timeout_.count(),
" secs"));
std::string envs = "";
for (auto& torch_ucc_env : torch_ucc_envs_map) {
envs += ("\n\t" + torch_ucc_env.first + "=" + torch_ucc_env.second);
}
TORCH_UCC_LOG_INFO(
TORCH_UCC_INIT,
c10::str(
"Successfully read and set ProcessGroupUCC env. variables as followings",
envs));
if (torch_ucc_config.enable_health_check) {
// Perform health check by initializing dummy communicators and destroying
// them. This will help indicate any UCC/UCX-related issues prior to the
// first collective. Run it in a separate thread and wait on CV to handle
// timeouts so that if there are hangs, the main thread can still run
// correctly.
runHealthCheck();
}
if (torch_ucc_config.enable_comms_logger) {
logger->initCommsTracer();
}
}
ProcessGroupUCC::~ProcessGroupUCC() {
if (torch_ucc_config.enable_comms_logger) {
logger->flushComms(this->getRank(), this->getSize());
}
if (comm) {
logger->setPhase(TORCH_UCC_FINALIZE);
comm->ucc_destroy_team(team);
TORCH_UCC_LOG_INFO(
TORCH_UCC_FINALIZE, "Successfully destroyed UCC library");
try {
if (cuda_ee) {
ucc_ee_destroy(cuda_ee);
ucc_ee_destroy(cuda_ee_p2p[0]);
ucc_ee_destroy(cuda_ee_p2p[1]);
}
} catch (std::exception& ex) {
TORCH_UCC_LOG_INFO(
TORCH_UCC_FINALIZE,
c10::str(
"(~ProcessGroupUCC) Caught error in Store Operation .. ",
"[",
ex.what(),
"]"));
}
comm = nullptr;
}
}
#ifdef USE_CUDA
// Return CUDA device with ordinal given by input rank.
c10::Device getCUDADeviceForRank(int rank) {
TORCH_CHECK(rank >= 0, "Invalid rank ", rank);
auto numGPUs = at::cuda::getNumGPUs();
auto deviceIdx = static_cast<c10::DeviceIndex>(rank % numGPUs);
return c10::Device(c10::DeviceType::CUDA, deviceIdx);
}
#endif
void ProcessGroupUCC::runHealthCheck() {
// Run health check in a separate thread and wait on CV to handle timeouts.
// This design allows us to handle hangs.
// When size_ is 1, there is no need to do any communication at all.
if (size_ == 1)
return;
struct HealthCheckData {
std::mutex healthCheckMutex;
std::condition_variable healthCheckCv;
bool uccHealthCheckSuccess = false;
std::exception_ptr healthCheckException;
} healthCheckData;
auto t = std::thread([&healthCheckData, this]() {
std::list<c10::Device> devices{c10::kCPU};
#ifdef USE_CUDA
c10::cuda::OptionalCUDAGuard gpuGuard;
if (at::cuda::is_available()) {
devices.emplace_front(getCUDADeviceForRank(rank_));
}
#endif
for (auto device : devices) {
bool is_last_device = (device == devices.back());
try {
auto oob = std::make_shared<torch_ucc_oob_coll_info_t>();
oob->rank = this->oob->rank;
oob->size = this->oob->size;
oob->store = this->oob->store;
ucc_team_h team = nullptr;
uint32_t comm_id;
#ifdef USE_CUDA
if (device.is_cuda()) {
gpuGuard.set_index(device.index());
}
#endif
auto comm = Comm::get_comm(comm_id, device, oob, logger, true);
comm->ucc_create_team(team, oob);
comm->ucc_destroy_team(team);
TORCH_UCC_LOG_INFO(
TORCH_UCC_HEALTH_CHECK,
c10::str(
"UCC library health check succeed for device ",
c10::DeviceTypeName(device.type())));
// Mark ucc health check as complete.
if (is_last_device) {
std::lock_guard<std::mutex> lk(healthCheckData.healthCheckMutex);
healthCheckData.uccHealthCheckSuccess = true;
}
comm = nullptr;
oob = nullptr;
// Notify main thread the health check is complete.
if (is_last_device) {
healthCheckData.healthCheckCv.notify_one();
}
} catch (const std::exception&) {
// Populate exception ptr.
healthCheckData.healthCheckException = std::current_exception();
// Unblock waiting main thread which will report exception.
healthCheckData.healthCheckCv.notify_one();
} // Unknown exceptions will just cause the program to terminate.
}
});
// We don't need to join the thread, just need to verify health check via the
// CV. Hence we detach the thread here.
t.detach(); // NOLINT
TORCH_UCC_LOG_INFO(
TORCH_UCC_HEALTH_CHECK,
c10::str(
"will wait up to ",
timeout_.count(),
" msec for UCC health check to complete."));
std::unique_lock<std::mutex> lock(healthCheckData.healthCheckMutex);
healthCheckData.healthCheckCv.wait_for(lock, timeout_, [&healthCheckData]() {
return healthCheckData.uccHealthCheckSuccess;
});
if (healthCheckData.healthCheckException) {
std::rethrow_exception(healthCheckData.healthCheckException);
}
// If there is no exception, the likely culprit is a timeout/hang
TORCH_CHECK(
healthCheckData.uccHealthCheckSuccess,
"ProcessGroupUCC: Health check failure: Failed to initialize UCC on rank ",
rank_);
}
void ProcessGroupUCC::set_timeout(ucc_coll_args_t& args) {
args.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
args.flags |= UCC_COLL_ARGS_FLAG_TIMEOUT;
args.timeout = timeout_.count();
}
#ifdef USE_CUDA
std::unique_ptr<at::cuda::CUDAEvent> ProcessGroupUCC::getPooledEvent() {
std::unique_ptr<at::cuda::CUDAEvent> ev;
std::lock_guard<std::mutex> lock(ep.event_pool_mutex);
if (ep.event_pool.empty()) {
ev = std::make_unique<at::cuda::CUDAEvent>();
} else {
ev = std::move(ep.event_pool.front());
ep.event_pool.pop();
}
return ev;
}
#endif
template <typename PreProcess, typename PostProcess>
c10::intrusive_ptr<Work> ProcessGroupUCC::collective_post(
OpType opType,
PreProcess preproc,
PostProcess postproc,
ucc_coll_args_t& coll,
std::unique_ptr<ProcessGroupUCC::WorkData> data,
c10::Device dev,
std::vector<at::Tensor>& inputTensors,
std::vector<at::Tensor>& outputTensors,
const char* prof_title) {
seq_++;
set_timeout(coll);
auto work = c10::make_intrusive<ProcessGroupUCC::WorkUCC>(
opType, seq_, prof_title, inputTensors, logger);
if (opType == OpType::RECV) {
work->sourceRank_ = coll.root;
}
RECORD_COMMS_TRACE(
logger->trace_generator,
work,
opType,
this->getRank(),
this->getSize(),
inputTensors,
outputTensors);
// Store references to outputs to be used by result
work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputTensors);
switch (dev.type()) {
case c10::DeviceType::CPU: {
if (torch_ucc_config.use_future) {
work->future_ = c10::make_intrusive<at::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()));
}
preproc();
comm->enqueue_collective(std::move(data), work, coll, team);
postproc();
return work;
}
#ifdef USE_CUDA
case c10::DeviceType::CUDA: {
auto cuda_ev = getPooledEvent();
at::cuda::CUDAStream* op_stream;
ucc_ee_h* op_ee;
if (opType == OpType::SEND) {
op_stream = stream_p2p[0].get();
op_ee = &cuda_ee_p2p[0];
} else if (opType == OpType::RECV) {
op_stream = stream_p2p[1].get();
op_ee = &cuda_ee_p2p[1];
} else {
op_stream = stream.get();
op_ee = &cuda_ee;
}
cuda_ev->record(at::cuda::getCurrentCUDAStream(dev.index()));
cuda_ev->block(*op_stream);
at::cuda::CUDAStreamGuard guard(*op_stream);
preproc();
comm->enqueue_cuda_collective(std::move(data), work, coll, team, *op_ee);
postproc();
cuda_ev->record(*op_stream);
work->fence = std::move(cuda_ev);
work->ep = &ep;
if (torch_ucc_config.use_future) {
c10::cuda::CUDAMultiStreamGuard streamGuard(*op_stream);
std::vector<c10::Device> devList{dev};
work->future_ = c10::make_intrusive<at::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()), devList);
// Add a callback that runs profiling end callbacks
if (work->recordFunctionEndCallback_) {
work->future_->addCallback([work](at::ivalue::Future& /* unused */) {
work->recordFunctionEndCallback_();
});
}
work->future_->markCompleted(c10::IValue(outputTensors));
}
return work;
}
#endif // #ifdef USE_CUDA
default: {
TORCH_UCC_LOG_ERROR(
TORCH_UCC_COLL_POST, c10::str("unsupported device type ", dev.str()));
throw std::invalid_argument(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
}
}
}
c10::intrusive_ptr<Work> ProcessGroupUCC::allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& /* unused */) {
auto& tensor = inputTensors[0];
check_device(tensor.device(), outputTensors[0][0].device());
initComm(tensor.device());
if (tensor.device().is_cpu() || torch_ucc_config.use_allgatherv) {
AllgathervWorkData* data = new AllgathervWorkData(size_);
for (int i = 0; i < size_; i++) {
data->recv_lengths[i] = tensor.element_size() * tensor.numel();
data->recv_offsets[i] = (uint64_t)outputTensors[0][i].data_ptr();
}
ucc_coll_args_t coll;
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll.flags =
UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
coll.coll_type = UCC_COLL_TYPE_ALLGATHERV;
coll.src.info.buffer = tensor.data_ptr();
coll.src.info.count = tensor.element_size() * tensor.numel();
coll.src.info.datatype = UCC_DT_UINT8;
coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
coll.dst.info_v.buffer = nullptr;
coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data();
coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data();
coll.dst.info_v.datatype = UCC_DT_UINT8;
coll.dst.info_v.mem_type =
to_ucc_memType(outputTensors[0][0].device().type());
SAVE_TENSORS(inputTensors, data->src);
SAVE_TENSORS(outputTensors[0], data->dst);
return collective_post(
OpType::ALLGATHER,
[]() {},
[]() {},
coll,
std::unique_ptr<WorkData>(data),
tensor.device(),
inputTensors,
outputTensors[0],
"ucc:all_gather");
} else {
WorkData* data = new WorkData();
std::vector<at::Tensor> flat_output(outputTensors.size());
for (size_t i = 0; i < outputTensors.size(); i++) {
TORCH_CHECK(
outputTensors[i].size() == outputTensors.size() * size_,
"Tensor output list is not valid for the number of participants");
flat_output[i] = c10d::newLikeFlat(outputTensors, i);
}
SAVE_TENSORS(flat_output, data->flat);
ucc_coll_args_t coll;
coll.mask = 0;
coll.flags = 0;
coll.coll_type = UCC_COLL_TYPE_ALLGATHER;
coll.src.info.buffer = tensor.data_ptr();
coll.src.info.count = tensor.numel();
coll.src.info.datatype = to_ucc_dType(tensor);
coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
coll.dst.info.buffer = flat_output[0].data_ptr();
coll.dst.info.count = flat_output[0].numel();
coll.dst.info.datatype = to_ucc_dType(flat_output[0]);
coll.dst.info.mem_type =
to_ucc_memType(outputTensors[0][0].device().type());
auto copy_from_flat = [&] {
bool asyncCopy = false;
#ifdef USE_CUDA
bool isCuda = outputTensors[0][0].device().is_cuda();
;
#endif
for (size_t i = 0; i < outputTensors.size(); i++) {
auto inumel = inputTensors[i].numel();
for (size_t j = 0; j < outputTensors[i].size(); j++) {
TORCH_CHECK(
(outputTensors[i][j].numel() == inumel),
"Tensor operand counts must be same");
#ifdef USE_CUDA
if (isCuda) {
c10::cuda::CUDACachingAllocator::recordStream(
outputTensors[i][j].storage().data_ptr(), (*stream));
asyncCopy = true;
}
#endif
outputTensors[i][j].copy_(flat_output[i][j], asyncCopy);
}
}
};
return collective_post(
OpType::ALLGATHER,
[]() {},
copy_from_flat,
coll,
std::unique_ptr<WorkData>(data),
tensor.device(),
inputTensors,
outputTensors[0],
"ucc:all_gather");
}
}
c10::intrusive_ptr<Work> ProcessGroupUCC::_allgather_base(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
const AllgatherOptions& opts) {
check_tensor({outputTensor});
check_tensor({inputTensor});
initComm(outputTensor.device());
WorkData* data = new WorkData();
ucc_coll_args_t coll;
coll.mask = 0;
coll.flags = 0;
coll.coll_type = UCC_COLL_TYPE_ALLGATHER;
coll.src.info.buffer = inputTensor.data_ptr();
coll.src.info.count = inputTensor.numel();
coll.src.info.datatype = ucc_dtype_map.at(inputTensor.scalar_type());
coll.src.info.mem_type = to_ucc_memType(inputTensor.device().type());
coll.dst.info.buffer = outputTensor.data_ptr();
coll.dst.info.count = outputTensor.numel();
coll.dst.info.datatype = ucc_dtype_map.at(outputTensor.scalar_type());
coll.dst.info.mem_type = to_ucc_memType(outputTensor.device().type());
std::vector<at::Tensor> inputTensors = {inputTensor};
std::vector<at::Tensor> outputTensors = {outputTensor};
SAVE_TENSORS(inputTensors, data->src);
SAVE_TENSORS(outputTensors, data->dst);
return collective_post(
OpType::_ALLGATHER_BASE,
[]() {},
[]() {},
coll,
std::unique_ptr<WorkData>(data),
outputTensor.device(),
inputTensors,
outputTensors,
"ucc:allgather_base");
}
c10::intrusive_ptr<Work> ProcessGroupUCC::allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts) {
check_tensor(tensors);
auto& tensor = tensors[0];
initComm(tensor.device());
WorkData* data = new WorkData();
ucc_coll_args_t coll;
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
coll.coll_type = UCC_COLL_TYPE_ALLREDUCE;
coll.op = to_ucc_reduceOp(opts.reduceOp, tensor.scalar_type());
coll.src.info.buffer = nullptr;
coll.src.info.count = tensor.numel();
coll.src.info.datatype = to_ucc_dType(tensor);
coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
coll.dst.info.buffer = tensor.data_ptr();
coll.dst.info.count = tensor.numel();
coll.dst.info.datatype = to_ucc_dType(tensor);
coll.dst.info.mem_type = to_ucc_memType(tensor.device().type());
SAVE_TENSORS(tensors, data->dst);
return collective_post(
OpType::ALLREDUCE,
[]() {},
[]() {},
coll,
std::unique_ptr<WorkData>(data),
tensor.device(),
tensors,
tensors,
"ucc:all_reduce");
}
c10::intrusive_ptr<Work> ProcessGroupUCC::allreduce_coalesced(
std::vector<at::Tensor>& /* unused */,
const AllreduceCoalescedOptions& /* unused */) {
throw std::invalid_argument(
"ProcessGroupUCC does not support allreduce_coalesced");
}
c10::intrusive_ptr<Work> ProcessGroupUCC::alltoall(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllToAllOptions& /* unused */) {
auto device = outputTensors[0].device();
for (const auto r : c10::irange(outputTensors.size())) {
TORCH_CHECK(
device == outputTensors[r].device() &&
device == inputTensors[r].device(),
"Tensors must be on the same device")
}
initComm(device);
ucc_coll_args_t coll;
AlltoallWorkData* data;
data = new AlltoallWorkData(size_);
/* to avoid flatten the tensors, we use alltoallv to achieve Alltoall as
follow.
1. store addresses of each tensor directly in displacements, keep buffer
to nullptr, i.e., 0
2. convert datatype to UINT8, which is always 1 bytes, to avoid wrong size
calculation in UCC layer
3. post Alltoallv
*/
for (const auto i : c10::irange(size_)) {
data->send_lengths[i] =
(uint64_t)(inputTensors[i].element_size() * inputTensors[i].numel());
data->send_offsets[i] = (uint64_t)inputTensors[i].data_ptr();
data->recv_lengths[i] =
(uint64_t)(outputTensors[i].element_size() * outputTensors[i].numel());
data->recv_offsets[i] = (uint64_t)outputTensors[i].data_ptr();
}
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll.flags =
UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
coll.coll_type = UCC_COLL_TYPE_ALLTOALLV;
coll.src.info_v.buffer = 0;
coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data();
coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data();
coll.src.info_v.datatype = UCC_DT_UINT8;
coll.src.info_v.mem_type = to_ucc_memType(inputTensors[0].device().type());
coll.dst.info_v.buffer = 0;
coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data();
coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data();
coll.dst.info_v.datatype = UCC_DT_UINT8;
coll.dst.info_v.mem_type = to_ucc_memType(outputTensors[0].device().type());
SAVE_TENSORS(inputTensors, data->src);
SAVE_TENSORS(outputTensors, data->dst);
return collective_post(
OpType::ALLTOALL,
[]() {},
[]() {},
coll,
std::unique_ptr<WorkData>(data),
device,
inputTensors,
outputTensors,
"ucc:alltoall");
}
c10::intrusive_ptr<Work> ProcessGroupUCC::alltoall_base(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
std::vector<int64_t>& outputSplitSizes,
std::vector<int64_t>& inputSplitSizes,
const AllToAllOptions& /* unused */) {
check_device(inputTensor.device(), outputTensor.device());
initComm(inputTensor.device());
ucc_coll_args_t coll;
AlltoallWorkData* data;
if ((outputSplitSizes.size() == 0) && (inputSplitSizes.size() == 0)) {
data = new AlltoallWorkData(0);
TORCH_CHECK(
(outputTensor.size(0) % size_ == 0) &&
(inputTensor.size(0) % size_ == 0),
"Tensor's dim 0 does not divide equally across group size");
coll.mask = 0;
coll.flags = 0;
coll.coll_type = UCC_COLL_TYPE_ALLTOALL;
coll.src.info.buffer = inputTensor.data_ptr();
coll.src.info.count = inputTensor.element_size() * inputTensor.numel();
coll.src.info.datatype = UCC_DT_UINT8;
coll.src.info.mem_type = to_ucc_memType(inputTensor.device().type());
coll.dst.info.buffer = outputTensor.data_ptr();
coll.dst.info.count = outputTensor.element_size() * outputTensor.numel();
coll.dst.info.datatype = UCC_DT_UINT8;
coll.dst.info.mem_type = to_ucc_memType(outputTensor.device().type());
coll.flags = 0;
} else {
data = new AlltoallWorkData(size_);
c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_);
c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_);
computeLengthsAndOffsets(
outputSplitSizes,
outputTensor,
&data->recv_lengths,
&data->recv_offsets);
computeLengthsAndOffsets(
inputSplitSizes, inputTensor, &data->send_lengths, &data->send_offsets);
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll.coll_type = UCC_COLL_TYPE_ALLTOALLV;
coll.src.info_v.buffer = inputTensor.data_ptr();
coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data();
coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data();
coll.src.info_v.datatype = to_ucc_dType(inputTensor);
coll.src.info_v.mem_type = to_ucc_memType(inputTensor.device().type());
coll.dst.info_v.buffer = outputTensor.data_ptr();
coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data();
coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data();
coll.dst.info_v.datatype = to_ucc_dType(outputTensor);
coll.dst.info_v.mem_type = to_ucc_memType(outputTensor.device().type());
coll.flags = UCC_COLL_ARGS_FLAG_CONTIG_SRC_BUFFER |
UCC_COLL_ARGS_FLAG_CONTIG_DST_BUFFER | UCC_COLL_ARGS_FLAG_COUNT_64BIT |
UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
if (torch_ucc_config.enable_comms_logger) {
logger->trace_generator->recordOptionalInfo(
outputSplitSizes, inputSplitSizes);
}
}
std::vector<at::Tensor> inputTensors = {inputTensor};
std::vector<at::Tensor> outputTensors = {outputTensor};
SAVE_TENSORS(inputTensors, data->src);
SAVE_TENSORS(outputTensors, data->dst);
return collective_post(
OpType::ALLTOALL_BASE,
[]() {},
[]() {},
coll,
std::unique_ptr<WorkData>(data),
inputTensor.device(),
inputTensors,
outputTensors,
"ucc:alltoall");
}
c10::intrusive_ptr<Work> ProcessGroupUCC::barrier(const BarrierOptions& opts) {
c10::Device device = c10::Device(c10::DeviceType::CPU);
#ifdef USE_CUDA
auto numGPUs = c10::cuda::device_count();
if (!opts.device_ids.empty()) {
device = c10::Device(c10::DeviceType::CUDA, opts.device_ids.front());
} else if (comm && comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) {
device = c10::Device(c10::DeviceType::CUDA, comm->cuda_device_index);
} else if (numGPUs > 0) {
int8_t deviceIdx = static_cast<int8_t>(c10::cuda::current_device());
// if current device is 0, likely the device is not set, use the best guess
if (0 == (int)deviceIdx) {
deviceIdx = static_cast<int8_t>(this->getRank() % numGPUs);
}
TORCH_UCC_LOG_INFO(
TORCH_UCC_COLL_POST,
c10::str(
"post barrier before specifying any GPU while there are ",
numGPUs,
" GPUs available. ",
"Not clear if GPU barrier is required, using GPU ",
(int)deviceIdx,
" to perform barrier. ",
"Specify device_ids option in barrier() to force ",
"use of a particular device"));
device = c10::Device(c10::DeviceType::CUDA, deviceIdx);
}
#endif
initComm(device);
ucc_coll_args_t coll;
coll.mask = 0;
coll.flags = 0;
coll.coll_type = UCC_COLL_TYPE_BARRIER;
auto dummy_tensor = std::vector<at::Tensor>();
return collective_post(
OpType::BARRIER,
[]() {},
[]() {},
coll,
nullptr,
device,
dummy_tensor,
dummy_tensor,
"ucc:barrier");
}
c10::intrusive_ptr<Work> ProcessGroupUCC::broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts) {
check_tensor(tensors);
auto& tensor = tensors[0];
initComm(tensor.device());
WorkData* data = new WorkData();
ucc_coll_args_t coll;
coll.mask = 0;
coll.flags = 0;
coll.coll_type = UCC_COLL_TYPE_BCAST;
coll.src.info.buffer = tensor.data_ptr();
coll.src.info.count = tensor.numel();
coll.src.info.datatype = to_ucc_dType(tensor);
coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
coll.root = opts.rootRank;
SAVE_TENSORS(tensors, data->dst);
if (torch_ucc_config.enable_comms_logger) {
logger->trace_generator->recordOptionalInfo(opts.rootRank);
}
return collective_post(
OpType::BROADCAST,
[]() {},
[]() {},
coll,
std::unique_ptr<WorkData>(data),
tensor.device(),
tensors,
tensors,
"ucc:broadcast");
}
c10::intrusive_ptr<Work> ProcessGroupUCC::gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const GatherOptions& opts) {
std::vector<at::Tensor> outputs;
auto& input = inputTensors[0];
initComm(input.device());
AllgathervWorkData* data = new AllgathervWorkData(size_);
ucc_coll_args_t coll;
coll.root = opts.rootRank;
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll.flags =
UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
coll.coll_type = UCC_COLL_TYPE_GATHERV;
/* for non-root ranks, only src is valid */
coll.src.info.buffer = input.data_ptr();
coll.src.info.count = (uint64_t)(input.element_size() * input.numel());
coll.src.info.datatype = UCC_DT_UINT8;
coll.src.info.mem_type = to_ucc_memType(input.device().type());
if (getRank() == opts.rootRank) {
if (outputTensors.size() != 1) {
TORCH_UCC_LOG_ERROR(
TORCH_UCC_COLL_POST,
c10::str(
"gather requires a single-element output list containing a list with ",
getSize(),
" tensors."));
} else if (outputTensors[0].size() != static_cast<size_t>(getSize())) {
TORCH_UCC_LOG_ERROR(
TORCH_UCC_COLL_POST,
c10::str(
"Incorrect output list size ",
outputTensors[0].size(),
". Output list size should be ",
getSize(),
", same as size of the process group."));
}
outputs = outputTensors[0];
for (int i = 0; i < size_; i++) {
data->recv_lengths[i] =
(uint64_t)(outputs[i].element_size() * outputs[i].numel());
data->recv_offsets[i] = (uint64_t)outputs[i].data_ptr();
}
/* use gatherv and store non-contiguous addresses in displacements to avoid
* flatten outputTensors */
coll.dst.info_v.buffer = nullptr;
coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data();
coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data();
coll.dst.info_v.datatype = UCC_DT_UINT8;
coll.dst.info_v.mem_type = to_ucc_memType(outputs[0].device().type());
SAVE_TENSORS(outputs, data->dst);
} else {
// for non-root ranks, outputTensors should be an empty list
if (outputTensors.size() != 0) {
TORCH_UCC_LOG_ERROR(
TORCH_UCC_COLL_POST, "requires empty output on non-root");
}
outputs = {};
// append a empty tensor to the list to be used by future mark
outputs.emplace_back();
}
SAVE_TENSORS(inputTensors, data->src);
return collective_post(
OpType::GATHER,
[]() {},
[]() {},
coll,
std::unique_ptr<WorkData>(data),
input.device(),
inputTensors,
outputs,
"ucc:gather");
}
c10::intrusive_ptr<Work> ProcessGroupUCC::reduce(
std::vector<at::Tensor>& tensors,
const ReduceOptions& opts) {
check_tensor(tensors);
auto& tensor = tensors[0];
initComm(tensor.device());
WorkData* data = new WorkData();
ucc_coll_args_t coll;
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
coll.coll_type = UCC_COLL_TYPE_REDUCE;
coll.op = ucc_op_map.at(opts.reduceOp);
coll.root = opts.rootRank;
coll.src.info.buffer = tensor.data_ptr();
coll.src.info.count = tensor.numel();
coll.src.info.datatype = ucc_dtype_map.at(tensor.scalar_type());
coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
coll.dst.info.buffer = tensor.data_ptr();
coll.dst.info.count = tensor.numel();
coll.dst.info.datatype = ucc_dtype_map.at(tensor.scalar_type());
coll.dst.info.mem_type = to_ucc_memType(tensor.device().type());
SAVE_TENSORS(tensors, data->dst);
return collective_post(
OpType::REDUCE,
[]() {},
[]() {},
coll,
std::unique_ptr<WorkData>(data),
tensor.device(),
tensors,
tensors,
"ucc:reduce");
}
c10::intrusive_ptr<Work> ProcessGroupUCC::reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts) {
TORCH_CHECK(
(outputTensors.size() == inputTensors.size()),
"Tensor input/output list for reduce_scatter must have same size");
check_tensor(outputTensors);
check_device(inputTensors[0][0].device(), outputTensors[0].device());
initComm(inputTensors[0][0].device());
auto data = std::make_unique<WorkData>();
std::vector<at::Tensor> flat_input(inputTensors.size());
for (size_t i = 0; i < inputTensors.size(); i++) {
TORCH_CHECK(
inputTensors[i].size() == inputTensors.size() * size_,
"Tensor input list is not valid for the number of participants");
flat_input[i] = c10d::newLikeFlat(inputTensors, i);
}
SAVE_TENSORS(flat_input, data->flat);
check_tensor(flat_input);
ucc_coll_args_t coll;
coll.mask = 0;
coll.flags = 0;
coll.coll_type = UCC_COLL_TYPE_REDUCE_SCATTER;
coll.op = to_ucc_reduceOp(opts.reduceOp, flat_input[0].scalar_type());
coll.src.info.buffer = flat_input[0].data_ptr();
coll.src.info.count = flat_input[0].numel();
coll.src.info.datatype = to_ucc_dType(flat_input[0]);
coll.src.info.mem_type = to_ucc_memType(flat_input[0].device().type());
coll.dst.info.buffer = outputTensors[0].data_ptr();
coll.dst.info.count = outputTensors[0].numel();
coll.dst.info.datatype = to_ucc_dType(outputTensors[0]);
coll.dst.info.mem_type = to_ucc_memType(outputTensors[0].device().type());
SAVE_TENSORS(inputTensors[0], data->src);
SAVE_TENSORS(outputTensors, data->dst);
auto copy_to_flat = [&] {
bool asyncCopy = false;
auto isize = inputTensors.size();
#ifdef USE_CUDA
bool isCuda = inputTensors[0][0].device().is_cuda();
#endif
for (size_t i = 0; i < isize; i++) {
auto onumel = outputTensors[i].numel();
for (size_t j = 0; j < inputTensors[i].size(); j++) {
TORCH_CHECK(
(inputTensors[i][j].numel() == onumel),
"Tensor operand counts must be same");
#ifdef USE_CUDA
if (isCuda) {
c10::cuda::CUDACachingAllocator::recordStream(
inputTensors[i][j].storage().data_ptr(), (*stream));
asyncCopy = true;
}
#endif
flat_input[i][j].copy_(inputTensors[i][j], asyncCopy);
}
}
};
return collective_post(
OpType::REDUCE_SCATTER,
copy_to_flat,
[]() {},
coll,
std::move(data),
inputTensors[0][0].device(),
inputTensors[0],
outputTensors,
"ucc:reduce_scatter");
}
c10::intrusive_ptr<Work> ProcessGroupUCC::scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ScatterOptions& opts) {
auto& tensor = outputTensors[0];
initComm(tensor.device());
ScattervWorkData* data = new ScattervWorkData(size_);
ucc_coll_args_t coll;
coll.root = opts.rootRank;
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll.flags =
UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
coll.coll_type = UCC_COLL_TYPE_SCATTERV;
if (getRank() == opts.rootRank) {
/* src is only valid at non-root rank */
if (inputTensors.size() != 1) {
TORCH_UCC_LOG_ERROR(
TORCH_UCC_COLL_POST,
c10::str(
"gather requires a single-element output list containing a list with ",
getSize(),
" tensors."));
} else if (inputTensors[0].size() != static_cast<size_t>(getSize())) {
TORCH_UCC_LOG_ERROR(
TORCH_UCC_COLL_POST,
c10::str(
"Incorrect output list size ",
inputTensors[0].size(),
". Output list size should be ",
getSize(),
", same as size of the process group."));
}
for (int i = 0; i < size_; i++) {
data->send_lengths[i] = (uint64_t)tensor.element_size() * tensor.numel();
data->send_offsets[i] = (uint64_t)inputTensors[0][i].data_ptr();
}
/* use scatter and store non-contiguous addresses in displacements to avoid
* flatten inputTensors */
coll.src.info_v.buffer = nullptr;
coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data();
coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data();
coll.src.info_v.datatype = UCC_DT_UINT8;
coll.src.info_v.mem_type =
to_ucc_memType(inputTensors[0][0].device().type());
SAVE_TENSORS(inputTensors[0], data->src);
} else {
// for non-root ranks, inputTensors should be an empty list
if (inputTensors.size() != 0) {
TORCH_UCC_LOG_ERROR(
TORCH_UCC_COLL_POST, "requires empty output on non-root");
}
}
coll.dst.info.buffer = tensor.data_ptr();
coll.dst.info.count = (uint64_t)tensor.element_size() * tensor.numel();
coll.dst.info.datatype = UCC_DT_UINT8;
coll.dst.info.mem_type = to_ucc_memType(tensor.device().type());
SAVE_TENSORS(outputTensors, data->dst);
return collective_post(
OpType::SCATTER,
[]() {},
[]() {},
coll,
std::unique_ptr<WorkData>(data),
tensor.device(),
(getRank() == opts.rootRank) ? inputTensors[0] : outputTensors,
outputTensors,
"ucc:scatter");
}
c10::intrusive_ptr<Work> ProcessGroupUCC::send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag) {
check_tensor(tensors);
auto& tensor = tensors[0];
initComm(tensor.device());
WorkData* data = new WorkData();
ucc_coll_args_t coll;
coll.tag = tag;
coll.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET | UCC_COLL_ARGS_FIELD_TAG;
coll.flags = 0;
coll.coll_type = UCC_COLL_TYPE_BCAST;
coll.src.info.buffer = tensor.data_ptr();
coll.src.info.count = tensor.numel();
coll.src.info.datatype = to_ucc_dType(tensor);
coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
coll.root = getRank();
coll.active_set.size = 2;
coll.active_set.start = getRank();
coll.active_set.stride = dstRank - getRank();
SAVE_TENSORS(tensors, data->dst);
return collective_post(
OpType::SEND,
[]() {},
[]() {},
coll,
std::unique_ptr<WorkData>(data),
tensor.device(),
tensors,
tensors,
"ucc:send");
}
c10::intrusive_ptr<Work> ProcessGroupUCC::recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag) {
check_tensor(tensors);
auto& tensor = tensors[0];
initComm(tensor.device());
WorkData* data = new WorkData();
ucc_coll_args_t coll;
coll.tag = tag;
coll.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET | UCC_COLL_ARGS_FIELD_TAG;
coll.flags = 0;
coll.coll_type = UCC_COLL_TYPE_BCAST;
coll.src.info.buffer = tensor.data_ptr();
coll.src.info.count = tensor.numel();
coll.src.info.datatype = to_ucc_dType(tensor);
coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
coll.root = srcRank;
coll.active_set.size = 2;
coll.active_set.start = srcRank;
coll.active_set.stride = getRank() - srcRank;
SAVE_TENSORS(tensors, data->dst);
return collective_post(
OpType::RECV,
[]() {},
[]() {},
coll,
std::unique_ptr<WorkData>(data),
tensor.device(),
tensors,
tensors,
"ucc:recv");
}
void ProcessGroupUCC::setSequenceNumberForGroup() {}
uint64_t ProcessGroupUCC::getSequenceNumberForGroup() {
return seq_;
}
c10::intrusive_ptr<Backend> ProcessGroupUCC::createProcessGroupUCC(
const c10::intrusive_ptr<::c10d::Store>& store,
int rank,
int size,
const std::chrono::duration<float>& timeout) {
return c10::make_intrusive<ProcessGroupUCC>(store, rank, size, timeout);
}
void ProcessGroupUCC::initComm(c10::Device dev) {
if (!comm) {
#ifdef USE_CUDA
if (dev.is_cuda()) {
c10::cuda::set_device(dev.index());
}
#endif
comm = Comm::get_comm(comm_id, dev, oob, logger);
TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCX library");
comm->ucc_create_team(team, oob);
TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCC library");
logger->setPhase(TORCH_UCC_READY);
} else {
if (dev.is_cuda()) {
if ((comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) &&
(comm->cuda_device_index != dev.index())) {
TORCH_UCC_LOG_ERROR(
TORCH_UCC_INIT,
"ucc communicator was initialized with different cuda device,"
"multi device is not supported");
throw std::invalid_argument(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
}
comm->cuda_device_index = dev.index();
}
}
#ifdef USE_CUDA
// Create UCC execution engine.
if (!cuda_ee && dev.is_cuda()) {
stream = std::make_unique<at::cuda::CUDAStream>(
at::cuda::getStreamFromPool(true, dev.index()));
ucc_ee_params_t params;
params.ee_type = UCC_EE_CUDA_STREAM;
params.ee_context = (void*)stream->stream();
params.ee_context_size = sizeof(cudaStream_t);
TORCH_UCC_CHECK(
ucc_ee_create(team, &params, &cuda_ee),
"failed to create UCC execution engine");
for (int i = 0; i < 2; i++) {
stream_p2p[i] = std::make_unique<at::cuda::CUDAStream>(
at::cuda::getStreamFromPool(true, dev.index()));
ucc_ee_params_t params;
params.ee_type = UCC_EE_CUDA_STREAM;
params.ee_context = (void*)stream_p2p[i]->stream();
params.ee_context_size = sizeof(cudaStream_t);
TORCH_UCC_CHECK(
ucc_ee_create(team, &params, &cuda_ee_p2p[i]),
"failed to create UCC P2P execution engine");
}
}
#endif
}
} // namespace c10d
#endif // USE_C10D_UCC