Include rank of default PG in C++ log messages (#110623)

I tested by adding some warning logs in C++, run a distributed program and show that they now had `[rank0]:` in the messages. There is no existing test infra for C++ logging so I couldn't easily add a unit test.

The implementation strategy is to setup a global variable in C++, and then poke it when we initialize a process group. This was the simplest thing I could think of that would work.

This PR only works for non-glog logging. Probably need to come up with some other strategy for glog, e.g., a custom prefix, but need to make sure this doesn't conflict with fbcode. I can't easily test this from OSS, will leave as follow up work.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110623
Approved by: https://github.com/voznesenskym, https://github.com/wanchaol, https://github.com/fduwjj
This commit is contained in:
Edward Z. Yang
2023-10-09 15:19:41 -04:00
committed by PyTorch MergeBot
parent 0341deb1c7
commit de3ae93e9b
4 changed files with 31 additions and 0 deletions

View File

@ -139,6 +139,16 @@ void SetPyTorchDDPUsageLogger(
*GetDDPUsageLogger() = std::move(logger);
}
static int64_t GLOBAL_RANK = -1;
int64_t GetGlobalRank() {
return GLOBAL_RANK;
}
void SetGlobalRank(int64_t rank) {
GLOBAL_RANK = rank;
}
void LogAPIUsage(const std::string& event) try {
if (auto logger = GetAPIUsageLogger())
(*logger)(event);
@ -352,6 +362,9 @@ MessageLogger::MessageLogger(const char* file, int line, int severity)
std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::high_resolution_clock::now().time_since_epoch());
*/
if (GLOBAL_RANK != -1) {
stream_ << "[rank" << GLOBAL_RANK << "]:";
}
stream_ << "["
<< CAFFE2_SEVERITY_PREFIX[std::min(4, GLOG_FATAL - severity_)]
//<< (timeinfo->tm_mon + 1) * 100 + timeinfo->tm_mday

View File

@ -367,6 +367,9 @@ C10_API bool LogAPIUsageFakeReturn(const std::string& context);
// Initializes the c10 logger.
C10_API void initLogging();
// Sets the rank, which will be included in log messages
C10_API void SetGlobalRank(int64_t rank);
} // namespace c10
#endif // C10_UTIL_LOGGING_H_

View File

@ -2621,6 +2621,19 @@ Example::
module.attr("_DEFAULT_PG_TIMEOUT") = py::cast(kProcessGroupDefaultTimeout);
module.attr("_DEFAULT_NO_TIMEOUT") = py::cast(kNoTimeout);
module.def(
"_set_global_rank",
[](int64_t rank) { c10::SetGlobalRank(rank); },
py::arg("rank"),
R"(
Arguments:
rank(int): The rank of the default process group
Informs the C++ runtime what the default process group (a strictly Python
notion) is. This mostly ensures that C++ log messages are prefixed with
rank information. This is not meant to be called manually; it is
called by _update_default_pg.
)");
module.def(
"_create_work_from_future",
[](std::shared_ptr<jit::PythonFutureWrapper> future) {

View File

@ -965,6 +965,8 @@ def _get_default_store():
def _update_default_pg(pg):
_world.default_pg = pg
rank = pg.rank() if pg is not None and pg != GroupMember.NON_GROUP_MEMBER else -1
torch._C._distributed_c10d._set_global_rank(rank)
def get_backend_config(group: Optional[ProcessGroup] = None) -> str:
if group is None: