[Flight Recorder][WP] Added mismatch tail as an arg (#162991)

Summary:
Mismatch tail is used as a fixed variable and there are cases that there are more than 10 mismatches FR gives up producing results (e.g. https://fburl.com/ai_infra/7gjl5ucb). This diff added the mismatch tail in the parsed args so make this configuarble.

Also tho the variable name is `mismatch_tail`(last 10) it is used as `mismatch_head` (the first 10). Updated it to be `num_mismatch_to_print`

Test Plan:
`buck2 run @//mode/opt //caffe2/fb/flight_recorder:fr_trace -- --mast_job_id aps-ctx_fm_pipeline_change-1c8ea38a94 --mast_job_version 0 --mast_job_attempt 2 --bucket tlcm_log_blob --world_size 128 --dump_file_name_offset 0 --allow-incomplete-ranks --num_mismatch_to_print 20 1>out 2>err`
Confirm no error and output 20 mismatches.

Rollback Plan:

Differential Revision: D82335995

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162991
Approved by: https://github.com/fduwjj
This commit is contained in:
Phillip Liu
2025-09-16 04:46:05 +00:00
committed by PyTorch MergeBot
parent 6c0fd747af
commit 2c45628813
2 changed files with 9 additions and 3 deletions

View File

@ -134,6 +134,7 @@ def build_collectives(
_memberships: dict[str, set[Any]],
_pg_guids: dict[tuple[str, int], str],
version: str,
mismatch_cap: int = 10,
) -> tuple[list[Traceback], list[Collective], list[NCCLCall]]:
"""
groups, memberships are the non-flat dicts that are indexable
@ -171,7 +172,6 @@ def build_collectives(
# once we find one mismatch, we stop pairing up collectives since the pairing is possibly incorrect
# instead, just record the remaining ops as NCCLCalls
mismatch = {_groups[g].id: 0 for g in _groups}
MISMATCH_TAIL = 10
# For best effort partial analysis.
dumps_ranks = {int(key) for key in all_entries.keys()}
@ -365,7 +365,7 @@ def build_collectives(
)
)
if mismatch[pg_name] > MISMATCH_TAIL:
if mismatch[pg_name] > mismatch_cap:
logger.error(
"Too many mismatches for process_group %s: %s aborting", pg_name, desc
)
@ -412,7 +412,7 @@ def build_db(
check_no_missing_dump_files(entries, memberships)
tracebacks, collectives, nccl_calls = build_collectives(
entries, _groups, _memberships, _pg_guids, version
entries, _groups, _memberships, _pg_guids, version, args.mismatch_cap
)
logger.debug("built collectives, nccl_calls")
if args.verbose:

View File

@ -68,6 +68,12 @@ class JobConfig:
self.parser.add_argument("-j", "--just_print_entries", action="store_true")
self.parser.add_argument("-v", "--verbose", action="store_true")
self.parser.add_argument("--print_stack_trace", action="store_true")
self.parser.add_argument(
"--mismatch_cap",
type=int,
default=10,
help="Maximum number of mismatches we print (from earliest).",
)
def parse_args(
self: "JobConfig", args: Optional[Sequence[str]]