mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PTD][c10d] Include PG status into flight recorder (#131268)
We are considering consolidating data source for logging and flight recorder so that we don't build multiple paths for debugging information. Before we do any merging, we want to first ensure that the PG status is also included in flight recorder. Also, we can leverage this information to validate our FR dump as well. Because the dump is not synced so we might potentially see some variants in the dump. Pull Request resolved: https://github.com/pytorch/pytorch/pull/131268 Approved by: https://github.com/shuqiangzhang
This commit is contained in:
@ -185,8 +185,9 @@ DEFINE_CONSTANT(version_key, "version");
|
||||
// (minor when adding fields, major when changing existing fields)
|
||||
// Also update both JSON and Pickle dumps to make use of the newly defined
|
||||
// field(s).
|
||||
DEFINE_CONSTANT(version_val, "2.2");
|
||||
DEFINE_CONSTANT(version_val, "2.3");
|
||||
DEFINE_CONSTANT(pg_config_key, "pg_config");
|
||||
DEFINE_CONSTANT(pg_status_key, "pg_status");
|
||||
DEFINE_CONSTANT(record_id_key, "record_id");
|
||||
DEFINE_CONSTANT(pg_id_key, "pg_id");
|
||||
DEFINE_CONSTANT(pg_name_key, "process_group");
|
||||
@ -644,6 +645,7 @@ struct NCCLTraceBuffer {
|
||||
size_t max_entries_ = 0;
|
||||
size_t next_ = 0;
|
||||
size_t id_ = 0;
|
||||
std::map<size_t, std::shared_ptr<ProcessGroupStatus>> all_pg_status_ = {};
|
||||
std::map<std::tuple<std::string, std::string>, std::vector<uint64_t>>
|
||||
pg_name_to_ranks_ = {};
|
||||
|
||||
@ -659,10 +661,15 @@ struct NCCLTraceBuffer {
|
||||
Event* start,
|
||||
Event* end,
|
||||
std::chrono::milliseconds timeout_ms,
|
||||
std::shared_ptr<ProcessGroupStatus> pg_status,
|
||||
bool isP2P) {
|
||||
if (!enabled_) {
|
||||
return std::nullopt;
|
||||
}
|
||||
if (all_pg_status_.find(pg_id) == all_pg_status_.end()) {
|
||||
// Current pg_status is not in FR.
|
||||
all_pg_status_[pg_id] = pg_status;
|
||||
}
|
||||
auto traceback =
|
||||
torch::CapturedTraceback::gather(true, true, capture_cpp_stack_);
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
@ -1014,6 +1021,35 @@ struct NCCLTraceBuffer {
|
||||
return result;
|
||||
}
|
||||
|
||||
// dump pg_status
|
||||
const c10::Dict<c10::IValue, c10::IValue> getPgStatus() {
|
||||
auto all_pg_status = new_dict();
|
||||
for (const auto& [pg_id, status] : all_pg_status_) {
|
||||
auto pg_status = new_dict();
|
||||
pg_status.insert("last_enqueued_collective", status->lastEnqueuedSeq);
|
||||
pg_status.insert("last_started_collective", status->lastStartedSeq);
|
||||
pg_status.insert("last_completed_collective", status->lastCompletedSeq);
|
||||
all_pg_status.insert(std::to_string(pg_id), pg_status);
|
||||
}
|
||||
return all_pg_status;
|
||||
}
|
||||
|
||||
const std::map<std::string, std::map<std::string, std::string>>
|
||||
getPgStatusJson() {
|
||||
std::map<std::string, std::map<std::string, std::string>> result;
|
||||
for (const auto& [pg_id, status] : all_pg_status_) {
|
||||
auto pg_status = std::map<std::string, std::string>();
|
||||
pg_status["last_enqueued_collective"] =
|
||||
std::to_string(status->lastEnqueuedSeq);
|
||||
pg_status["last_started_collective"] =
|
||||
std::to_string(status->lastStartedSeq);
|
||||
pg_status["last_completed_collective"] =
|
||||
std::to_string(status->lastCompletedSeq);
|
||||
result[std::to_string(pg_id)] = pg_status;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string dump_json(
|
||||
const std::optional<std::unordered_map<
|
||||
std::string,
|
||||
@ -1023,6 +1059,7 @@ struct NCCLTraceBuffer {
|
||||
json result;
|
||||
result[version_key_str] = version_val_str;
|
||||
result[pg_config_key_str] = getPgConfigJson();
|
||||
result[pg_status_key_str] = getPgStatusJson();
|
||||
|
||||
// collective trace
|
||||
if (includeCollectives) {
|
||||
@ -1051,6 +1088,7 @@ struct NCCLTraceBuffer {
|
||||
// common values
|
||||
result.insert(version_key, version_val);
|
||||
result.insert(pg_config_key, getPgConfig());
|
||||
result.insert(pg_status_key, getPgStatus());
|
||||
|
||||
// collective trace
|
||||
if (includeCollectives) {
|
||||
|
Reference in New Issue
Block a user