[FR] Add stack_id and an optional print of stack_id to stack_trace mapping (#160119)

To better help users debug with FR, we want to add stack_id and print a map between stack_id and stack_trace (optional)

Screenshot:

<img width="1029" height="529" alt="image" src="https://github.com/user-attachments/assets/8404a1d3-cc33-4f5f-971b-29609ec316c1" />

<img width="1620" height="358" alt="image" src="https://github.com/user-attachments/assets/3dd29c8c-ff68-41a2-acfd-e770036cfeb1" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160119
Approved by: https://github.com/H-Huang, https://github.com/wconstab
This commit is contained in:
fduwjj
2025-08-08 16:39:15 -07:00
committed by PyTorch MergeBot
parent ff0d56d035
commit 1c2cba17ea
4 changed files with 43 additions and 1 deletions

View File

@ -24,6 +24,7 @@ from tools.flight_recorder.components.types import (
Traceback, Traceback,
) )
from tools.flight_recorder.components.utils import ( from tools.flight_recorder.components.utils import (
add_stack_id_in_entries,
align_trace_from_beginning, align_trace_from_beginning,
check_current_entry_match, check_current_entry_match,
check_no_missing_dump_files, check_no_missing_dump_files,
@ -391,6 +392,9 @@ def build_db(
# Ensure version is consistent across all ranks. # Ensure version is consistent across all ranks.
check_version(version_by_ranks, version) check_version(version_by_ranks, version)
entries = align_trace_from_beginning(entries) entries = align_trace_from_beginning(entries)
stack_id_trace_map: dict[str, int] = {}
if args.just_print_entries:
entries, stack_id_trace_map = add_stack_id_in_entries(entries)
# flattened database # flattened database
groups, _groups, memberships, _memberships, _pg_guids = build_groups_memberships( groups, _groups, memberships, _memberships, _pg_guids = build_groups_memberships(
@ -402,7 +406,9 @@ def build_db(
check_no_missing_dump_files(entries, memberships) check_no_missing_dump_files(entries, memberships)
if args.just_print_entries: if args.just_print_entries:
just_print_entries(entries, _groups, _memberships, _pg_guids, args) just_print_entries(
entries, _groups, _memberships, _pg_guids, args, stack_id_trace_map
)
sys.exit(0) sys.exit(0)
tracebacks, collectives, nccl_calls = build_collectives( tracebacks, collectives, nccl_calls = build_collectives(

View File

@ -67,6 +67,7 @@ class JobConfig:
) )
self.parser.add_argument("-j", "--just_print_entries", action="store_true") self.parser.add_argument("-j", "--just_print_entries", action="store_true")
self.parser.add_argument("-v", "--verbose", action="store_true") self.parser.add_argument("-v", "--verbose", action="store_true")
self.parser.add_argument("--print_stack_trace", action="store_true")
def parse_args( def parse_args(
self: "JobConfig", args: Optional[Sequence[str]] self: "JobConfig", args: Optional[Sequence[str]]

View File

@ -417,6 +417,7 @@ class Op:
else: else:
self.input_sizes, self.output_sizes = None, None self.input_sizes, self.output_sizes = None, None
self.collective_seq_id = event["collective_seq_id"] self.collective_seq_id = event["collective_seq_id"]
self.stack_id = event.get("stack_id", -1)
self.p2p_seq_id = event["p2p_seq_id"] self.p2p_seq_id = event["p2p_seq_id"]
self.input_dtypes = event["input_dtypes"] self.input_dtypes = event["input_dtypes"]
self.output_dtypes = event["output_dtypes"] self.output_dtypes = event["output_dtypes"]
@ -456,6 +457,7 @@ class Op:
f"pg_name={self.pg_name}", f"pg_name={self.pg_name}",
f"pg_description={self.pg_desc}", f"pg_description={self.pg_desc}",
f"pg_size={self.pg_size}", f"pg_size={self.pg_size}",
f"stack_id={self.stack_id}",
f"state={self.state}", f"state={self.state}",
) )
return f"{self.type}(%s)" % ", ".join(s for s in verbose_info if s) return f"{self.type}(%s)" % ", ".join(s for s in verbose_info if s)

View File

@ -616,6 +616,7 @@ def just_print_entries(
_memberships: dict[str, set[Any]], _memberships: dict[str, set[Any]],
_pg_guids: dict[tuple[str, int], str], _pg_guids: dict[tuple[str, int], str],
args: argparse.Namespace, args: argparse.Namespace,
stack_id_trace_map: dict[str, int],
) -> None: ) -> None:
rows = [] rows = []
ranks = sorted(all_entries.keys()) ranks = sorted(all_entries.keys())
@ -650,6 +651,17 @@ def just_print_entries(
logger.info(tabulate(rows, headers=headers)) logger.info(tabulate(rows, headers=headers))
if stack_id_trace_map and args.print_stack_trace:
headers = ["stack_id", "frame_stack"]
rows = []
for frame, stack_id in sorted(
stack_id_trace_map.items(), key=lambda item: item[1]
):
rows.append([str(stack_id), frame])
logger.info(tabulate(rows, headers=headers))
def check_no_missing_dump_files( def check_no_missing_dump_files(
entries: dict[int, Any], memberships: list[Membership] entries: dict[int, Any], memberships: list[Membership]
@ -677,6 +689,27 @@ def get_version_detail(version: str) -> tuple[int, int]:
return major, minor return major, minor
def add_stack_id_in_entries(
entries: dict[int, list[dict[str, Any]]],
) -> tuple[dict[int, list[dict[str, Any]]], dict[str, int]]:
stack_id = 0
stack_id_trace_map = {}
for rank in entries:
for dump in entries[rank]:
if dump.get("frames", []):
frames = str(dump["frames"])
if frames not in stack_id_trace_map:
stack_id_trace_map[frames] = stack_id
dump["stack_id"] = stack_id
stack_id += 1
else:
dump["stack_id"] = stack_id_trace_map[frames]
else:
dump["stack_id"] = -1
return entries, stack_id_trace_map
def align_trace_from_beginning( def align_trace_from_beginning(
entries: dict[int, list[dict[str, Any]]], entries: dict[int, list[dict[str, Any]]],
) -> dict[int, list[dict[str, Any]]]: ) -> dict[int, list[dict[str, Any]]]: