diff --git a/c10/util/Logging.cpp b/c10/util/Logging.cpp index 8dcf13ab9baa..68d288ada22c 100644 --- a/c10/util/Logging.cpp +++ b/c10/util/Logging.cpp @@ -139,6 +139,16 @@ void SetPyTorchDDPUsageLogger( *GetDDPUsageLogger() = std::move(logger); } +static int64_t GLOBAL_RANK = -1; + +int64_t GetGlobalRank() { + return GLOBAL_RANK; +} + +void SetGlobalRank(int64_t rank) { + GLOBAL_RANK = rank; +} + void LogAPIUsage(const std::string& event) try { if (auto logger = GetAPIUsageLogger()) (*logger)(event); @@ -352,6 +362,9 @@ MessageLogger::MessageLogger(const char* file, int line, int severity) std::chrono::duration_cast( std::chrono::high_resolution_clock::now().time_since_epoch()); */ + if (GLOBAL_RANK != -1) { + stream_ << "[rank" << GLOBAL_RANK << "]:"; + } stream_ << "[" << CAFFE2_SEVERITY_PREFIX[std::min(4, GLOG_FATAL - severity_)] //<< (timeinfo->tm_mon + 1) * 100 + timeinfo->tm_mday diff --git a/c10/util/Logging.h b/c10/util/Logging.h index 7d0ba5861874..9fa0465827e8 100644 --- a/c10/util/Logging.h +++ b/c10/util/Logging.h @@ -367,6 +367,9 @@ C10_API bool LogAPIUsageFakeReturn(const std::string& context); // Initializes the c10 logger. C10_API void initLogging(); +// Sets the rank, which will be included in log messages +C10_API void SetGlobalRank(int64_t rank); + } // namespace c10 #endif // C10_UTIL_LOGGING_H_ diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 1896f158d422..bc1d6a50f298 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -2621,6 +2621,19 @@ Example:: module.attr("_DEFAULT_PG_TIMEOUT") = py::cast(kProcessGroupDefaultTimeout); module.attr("_DEFAULT_NO_TIMEOUT") = py::cast(kNoTimeout); + module.def( + "_set_global_rank", + [](int64_t rank) { c10::SetGlobalRank(rank); }, + py::arg("rank"), + R"( + Arguments: + rank(int): The rank of the default process group + Informs the C++ runtime what the default process group (a strictly Python + notion) is. This mostly ensures that C++ log messages are prefixed with + rank information. This is not meant to be called manually; it is + called by _update_default_pg. + )"); + module.def( "_create_work_from_future", [](std::shared_ptr future) { diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 4ea68a5241fe..232fb42653d7 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -965,6 +965,8 @@ def _get_default_store(): def _update_default_pg(pg): _world.default_pg = pg + rank = pg.rank() if pg is not None and pg != GroupMember.NON_GROUP_MEMBER else -1 + torch._C._distributed_c10d._set_global_rank(rank) def get_backend_config(group: Optional[ProcessGroup] = None) -> str: if group is None: