mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Flight Recorder][WP] Added mismatch tail as an arg (#162991)
Summary: Mismatch tail is used as a fixed variable and there are cases that there are more than 10 mismatches FR gives up producing results (e.g. https://fburl.com/ai_infra/7gjl5ucb). This diff added the mismatch tail in the parsed args so make this configuarble. Also tho the variable name is `mismatch_tail`(last 10) it is used as `mismatch_head` (the first 10). Updated it to be `num_mismatch_to_print` Test Plan: `buck2 run @//mode/opt //caffe2/fb/flight_recorder:fr_trace -- --mast_job_id aps-ctx_fm_pipeline_change-1c8ea38a94 --mast_job_version 0 --mast_job_attempt 2 --bucket tlcm_log_blob --world_size 128 --dump_file_name_offset 0 --allow-incomplete-ranks --num_mismatch_to_print 20 1>out 2>err` Confirm no error and output 20 mismatches. Rollback Plan: Differential Revision: D82335995 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162991 Approved by: https://github.com/fduwjj
This commit is contained in:
committed by
PyTorch MergeBot
parent
6c0fd747af
commit
2c45628813
@ -134,6 +134,7 @@ def build_collectives(
|
||||
_memberships: dict[str, set[Any]],
|
||||
_pg_guids: dict[tuple[str, int], str],
|
||||
version: str,
|
||||
mismatch_cap: int = 10,
|
||||
) -> tuple[list[Traceback], list[Collective], list[NCCLCall]]:
|
||||
"""
|
||||
groups, memberships are the non-flat dicts that are indexable
|
||||
@ -171,7 +172,6 @@ def build_collectives(
|
||||
# once we find one mismatch, we stop pairing up collectives since the pairing is possibly incorrect
|
||||
# instead, just record the remaining ops as NCCLCalls
|
||||
mismatch = {_groups[g].id: 0 for g in _groups}
|
||||
MISMATCH_TAIL = 10
|
||||
|
||||
# For best effort partial analysis.
|
||||
dumps_ranks = {int(key) for key in all_entries.keys()}
|
||||
@ -365,7 +365,7 @@ def build_collectives(
|
||||
)
|
||||
)
|
||||
|
||||
if mismatch[pg_name] > MISMATCH_TAIL:
|
||||
if mismatch[pg_name] > mismatch_cap:
|
||||
logger.error(
|
||||
"Too many mismatches for process_group %s: %s aborting", pg_name, desc
|
||||
)
|
||||
@ -412,7 +412,7 @@ def build_db(
|
||||
check_no_missing_dump_files(entries, memberships)
|
||||
|
||||
tracebacks, collectives, nccl_calls = build_collectives(
|
||||
entries, _groups, _memberships, _pg_guids, version
|
||||
entries, _groups, _memberships, _pg_guids, version, args.mismatch_cap
|
||||
)
|
||||
logger.debug("built collectives, nccl_calls")
|
||||
if args.verbose:
|
||||
|
@ -68,6 +68,12 @@ class JobConfig:
|
||||
self.parser.add_argument("-j", "--just_print_entries", action="store_true")
|
||||
self.parser.add_argument("-v", "--verbose", action="store_true")
|
||||
self.parser.add_argument("--print_stack_trace", action="store_true")
|
||||
self.parser.add_argument(
|
||||
"--mismatch_cap",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Maximum number of mismatches we print (from earliest).",
|
||||
)
|
||||
|
||||
def parse_args(
|
||||
self: "JobConfig", args: Optional[Sequence[str]]
|
||||
|
Reference in New Issue
Block a user