mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/139861 Approved by: https://github.com/ezyang
188 lines
6.2 KiB
C++
188 lines
6.2 KiB
C++
#pragma once
|
|
|
|
#ifdef USE_C10D_UCC
|
|
|
|
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
|
#include <torch/csrc/distributed/c10d/Store.hpp>
|
|
#include <ucc/api/ucc.h>
|
|
|
|
namespace c10d {
|
|
|
|
// Macro to generate the error message on a non-successful UCC return value.
|
|
#define TORCH_UCC_GET_ERROR_MSG(_err, _error_msg, _result) \
|
|
do { \
|
|
_err = c10::str( \
|
|
"[", \
|
|
std::string(__FILE__), \
|
|
":", \
|
|
std::to_string(__LINE__), \
|
|
"] ", \
|
|
logger->getLogPrefix(), \
|
|
_error_msg, \
|
|
", error code ", \
|
|
_result, \
|
|
": ", \
|
|
ucc_status_string(_result), \
|
|
", system error code ", \
|
|
errno); \
|
|
} while (0)
|
|
|
|
// Macro to throw on a non-successful UCC return value.
|
|
#define TORCH_UCC_CHECK(_cmd, _error_msg) \
|
|
do { \
|
|
ucc_status_t result = _cmd; \
|
|
if (result != UCC_OK) { \
|
|
std::string err; \
|
|
TORCH_UCC_GET_ERROR_MSG(err, _error_msg, result); \
|
|
TORCH_CHECK(false, err); \
|
|
} \
|
|
} while (0)
|
|
|
|
// Macro and throw on a non-successful UCC return value and free its request.
|
|
#define TORCH_UCC_CHECK_REQUEST(_request, _cmd, _error_msg) \
|
|
do { \
|
|
ucc_status_t result = _cmd; \
|
|
if (result != UCC_OK) { \
|
|
std::string err; \
|
|
TORCH_UCC_GET_ERROR_MSG(err, _error_msg, result); \
|
|
if (_request != nullptr) { \
|
|
ucc_collective_finalize(_request); \
|
|
} \
|
|
TORCH_CHECK(false, err); \
|
|
} \
|
|
} while (0)
|
|
|
|
// Macros to print logs with unified format
|
|
#define TORCH_UCC_LOG_ERROR(_phase, _msg) \
|
|
LOG(ERROR) << logger->getLogPrefix(_phase) << "[ERROR] " << _msg;
|
|
#define TORCH_UCC_LOG_INFO(_phase, _msg) \
|
|
LOG(INFO) << logger->getLogPrefix(_phase) << "[INFO] " << _msg;
|
|
#define TORCH_UCC_LOG_DEBUG(_phase, _msg) \
|
|
VLOG(1) << logger->getLogPrefix(_phase) << "[DEBUG] " << _msg;
|
|
|
|
enum torch_ucc_phase_t {
|
|
TORCH_UCC_UNKNOWN = -1,
|
|
TORCH_UCC_INIT,
|
|
TORCH_UCC_HEALTH_CHECK,
|
|
TORCH_UCC_READY,
|
|
TORCH_UCC_COLL_POST,
|
|
TORCH_UCC_COLL_PROGRESS,
|
|
TORCH_UCC_FINALIZE,
|
|
};
|
|
|
|
const std::map<torch_ucc_phase_t, std::string> ucc_phase_map = {
|
|
{TORCH_UCC_UNKNOWN, "UNKNOWN"},
|
|
{TORCH_UCC_INIT, "INIT"},
|
|
{TORCH_UCC_HEALTH_CHECK, "HEALTH_CHECK"},
|
|
{TORCH_UCC_READY, "READY"},
|
|
{TORCH_UCC_COLL_POST, "COLL_POST"},
|
|
{TORCH_UCC_COLL_PROGRESS, "COLL_PROGRESS"},
|
|
{TORCH_UCC_FINALIZE, "FINALIZE"},
|
|
};
|
|
|
|
class CommTraceLogger;
|
|
|
|
class TORCH_API ProcessGroupUCCLogger : public torch::CustomClassHolder {
|
|
public:
|
|
ProcessGroupUCCLogger();
|
|
ProcessGroupUCCLogger(std::string log_prefix, torch_ucc_phase_t phase);
|
|
|
|
std::string getLogPrefix(torch_ucc_phase_t phase = TORCH_UCC_UNKNOWN);
|
|
void setLogPrefix(std::string log_prefix);
|
|
inline void setPhase(torch_ucc_phase_t phase) {
|
|
local_phase = phase;
|
|
}
|
|
|
|
void initCommsTracer();
|
|
void flushComms(int rank, int world_size);
|
|
std::shared_ptr<CommTraceLogger> trace_generator = nullptr;
|
|
|
|
protected:
|
|
std::string log_prefix;
|
|
torch_ucc_phase_t local_phase = TORCH_UCC_UNKNOWN;
|
|
bool initialized_CommTraceLogger = false;
|
|
};
|
|
|
|
struct torch_ucc_oob_coll_info_t {
|
|
c10::intrusive_ptr<Store> store;
|
|
uint32_t comm_id;
|
|
int rank;
|
|
int size;
|
|
void* rbuf;
|
|
size_t msglen;
|
|
std::string getKey(std::string key) {
|
|
return std::to_string(comm_id) + key;
|
|
}
|
|
};
|
|
|
|
class CommBase {
|
|
public:
|
|
CommBase(const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger_)
|
|
: logger(logger_) {}
|
|
virtual void progress() = 0;
|
|
virtual void free_request(ucc_coll_req_h request) = 0;
|
|
virtual ~CommBase() {}
|
|
c10::intrusive_ptr<ProcessGroupUCCLogger> logger;
|
|
};
|
|
class CommUCC : public CommBase {
|
|
public:
|
|
ucc_lib_h lib{nullptr};
|
|
ucc_context_h context{nullptr};
|
|
|
|
public:
|
|
void progress() override;
|
|
CommUCC(
|
|
std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
|
|
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger);
|
|
void free_request(ucc_coll_req_h request) override;
|
|
~CommUCC();
|
|
};
|
|
|
|
ucc_status_t oob_allgather(
|
|
void* sbuf,
|
|
void* rbuf,
|
|
size_t msglen,
|
|
void* coll_info,
|
|
void** req);
|
|
|
|
ucc_status_t oob_allgather_test(void* req);
|
|
|
|
ucc_status_t oob_allgather_free(void* req);
|
|
|
|
// trim: remove spaces before and after the string view
|
|
// implementation borrowed from https://stackoverflow.com/a/17976541
|
|
inline std::string_view trim(std::string_view s) {
|
|
auto wsfront = std::find_if_not(
|
|
s.begin(), s.end(), [](int c) { return std::isspace(c); });
|
|
auto wsback = std::find_if_not(s.rbegin(), s.rend(), [](int c) {
|
|
return std::isspace(c);
|
|
}).base();
|
|
return (
|
|
wsback <= wsfront ? "" : s.substr(wsfront - s.begin(), wsback - wsfront));
|
|
}
|
|
|
|
inline std::string tolower(std::string_view s) {
|
|
std::string result;
|
|
result.reserve(s.size());
|
|
for (auto c : s) {
|
|
result.push_back(std::tolower(c));
|
|
}
|
|
return result;
|
|
}
|
|
|
|
inline std::vector<std::string> parse_list(std::string list) {
|
|
std::vector<std::string> result;
|
|
list = tolower(trim(list));
|
|
while (!list.empty()) {
|
|
const auto end_pos = list.find_first_of(',');
|
|
const auto token = trim(list.substr(0, end_pos));
|
|
result.push_back(std::string(token));
|
|
list = (end_pos != std::string_view::npos) ? list.substr(end_pos + 1) : "";
|
|
}
|
|
return result;
|
|
}
|
|
|
|
} // namespace c10d
|
|
|
|
#endif // USE_C10D_UCC
|