From 1c2cba17eab2b09d87142883da2bdbdbcf018613 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Fri, 8 Aug 2025 16:39:15 -0700 Subject: [PATCH] [FR] Add stack_id and an optional print of stack_id to stack_trace mapping (#160119) To better help users debug with FR, we want to add stack_id and print a map between stack_id and stack_trace (optional) Screenshot: image image Pull Request resolved: https://github.com/pytorch/pytorch/pull/160119 Approved by: https://github.com/H-Huang, https://github.com/wconstab --- tools/flight_recorder/components/builder.py | 8 ++++- .../components/config_manager.py | 1 + tools/flight_recorder/components/types.py | 2 ++ tools/flight_recorder/components/utils.py | 33 +++++++++++++++++++ 4 files changed, 43 insertions(+), 1 deletion(-) diff --git a/tools/flight_recorder/components/builder.py b/tools/flight_recorder/components/builder.py index 2a9cee36f7bc..4bc268022e28 100644 --- a/tools/flight_recorder/components/builder.py +++ b/tools/flight_recorder/components/builder.py @@ -24,6 +24,7 @@ from tools.flight_recorder.components.types import ( Traceback, ) from tools.flight_recorder.components.utils import ( + add_stack_id_in_entries, align_trace_from_beginning, check_current_entry_match, check_no_missing_dump_files, @@ -391,6 +392,9 @@ def build_db( # Ensure version is consistent across all ranks. check_version(version_by_ranks, version) entries = align_trace_from_beginning(entries) + stack_id_trace_map: dict[str, int] = {} + if args.just_print_entries: + entries, stack_id_trace_map = add_stack_id_in_entries(entries) # flattened database groups, _groups, memberships, _memberships, _pg_guids = build_groups_memberships( @@ -402,7 +406,9 @@ def build_db( check_no_missing_dump_files(entries, memberships) if args.just_print_entries: - just_print_entries(entries, _groups, _memberships, _pg_guids, args) + just_print_entries( + entries, _groups, _memberships, _pg_guids, args, stack_id_trace_map + ) sys.exit(0) tracebacks, collectives, nccl_calls = build_collectives( diff --git a/tools/flight_recorder/components/config_manager.py b/tools/flight_recorder/components/config_manager.py index ea9b0cf3918c..abd7f5372133 100644 --- a/tools/flight_recorder/components/config_manager.py +++ b/tools/flight_recorder/components/config_manager.py @@ -67,6 +67,7 @@ class JobConfig: ) self.parser.add_argument("-j", "--just_print_entries", action="store_true") self.parser.add_argument("-v", "--verbose", action="store_true") + self.parser.add_argument("--print_stack_trace", action="store_true") def parse_args( self: "JobConfig", args: Optional[Sequence[str]] diff --git a/tools/flight_recorder/components/types.py b/tools/flight_recorder/components/types.py index 597ee8e3ceda..ded30fb077cd 100644 --- a/tools/flight_recorder/components/types.py +++ b/tools/flight_recorder/components/types.py @@ -417,6 +417,7 @@ class Op: else: self.input_sizes, self.output_sizes = None, None self.collective_seq_id = event["collective_seq_id"] + self.stack_id = event.get("stack_id", -1) self.p2p_seq_id = event["p2p_seq_id"] self.input_dtypes = event["input_dtypes"] self.output_dtypes = event["output_dtypes"] @@ -456,6 +457,7 @@ class Op: f"pg_name={self.pg_name}", f"pg_description={self.pg_desc}", f"pg_size={self.pg_size}", + f"stack_id={self.stack_id}", f"state={self.state}", ) return f"{self.type}(%s)" % ", ".join(s for s in verbose_info if s) diff --git a/tools/flight_recorder/components/utils.py b/tools/flight_recorder/components/utils.py index 73ec2a13d3be..b68266c79b2c 100644 --- a/tools/flight_recorder/components/utils.py +++ b/tools/flight_recorder/components/utils.py @@ -616,6 +616,7 @@ def just_print_entries( _memberships: dict[str, set[Any]], _pg_guids: dict[tuple[str, int], str], args: argparse.Namespace, + stack_id_trace_map: dict[str, int], ) -> None: rows = [] ranks = sorted(all_entries.keys()) @@ -650,6 +651,17 @@ def just_print_entries( logger.info(tabulate(rows, headers=headers)) + if stack_id_trace_map and args.print_stack_trace: + headers = ["stack_id", "frame_stack"] + rows = [] + + for frame, stack_id in sorted( + stack_id_trace_map.items(), key=lambda item: item[1] + ): + rows.append([str(stack_id), frame]) + + logger.info(tabulate(rows, headers=headers)) + def check_no_missing_dump_files( entries: dict[int, Any], memberships: list[Membership] @@ -677,6 +689,27 @@ def get_version_detail(version: str) -> tuple[int, int]: return major, minor +def add_stack_id_in_entries( + entries: dict[int, list[dict[str, Any]]], +) -> tuple[dict[int, list[dict[str, Any]]], dict[str, int]]: + stack_id = 0 + stack_id_trace_map = {} + for rank in entries: + for dump in entries[rank]: + if dump.get("frames", []): + frames = str(dump["frames"]) + if frames not in stack_id_trace_map: + stack_id_trace_map[frames] = stack_id + dump["stack_id"] = stack_id + stack_id += 1 + else: + dump["stack_id"] = stack_id_trace_map[frames] + else: + dump["stack_id"] = -1 + + return entries, stack_id_trace_map + + def align_trace_from_beginning( entries: dict[int, list[dict[str, Any]]], ) -> dict[int, list[dict[str, Any]]]: