Files
pytorch/tools/flight_recorder/components/builder.py
fduwjj bf5c7bf06d [FR] Fix the bug in FR script (e.g., checking all ranks dump check) (#134383)
We somehow convert the rank to string which makes the ranks check fail. This fix now convert them all to int.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134383
Approved by: https://github.com/c-p-i-o
2024-08-26 08:21:14 +00:00

405 lines
16 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import ast
import sys
from typing import Any, Dict, List, Set, Tuple # type: ignore[attr-defined]
from tools.flight_recorder.components.types import (
Collective,
Database,
Group,
MatchState,
Membership,
NCCLCall,
Op,
Traceback,
)
from tools.flight_recorder.components.utils import (
check_no_missing_dump_files,
check_size_alltoall,
check_trace_from_beginning,
check_version,
find_coalesced_group,
just_print_entries,
match_coalesced_groups,
match_one_event,
)
try:
from tabulate import tabulate
except ModuleNotFoundError:
print("tabulate is not installed. Proceeding without it.")
# Define a no-op tabulate function
def tabulate(data: Any, headers: Any = None) -> Any: # type: ignore[misc]
return data
"""
Flat DB builder
"""
def build_groups_memberships(
pg_config: Any,
) -> Tuple[List[Group], Dict[Any, Group], List[Membership], Dict[str, Set[Any]]]:
"""
pg_config: {
global_rank: {
(pg_guid, desc, ranks)
}
}
`pg_guid` is a system generated id, but depending on the mode of PG creation it could be a globally incrementing int
or a hash of the ranks. See `_process_group_name` in distributed_c10d.py.
`desc` is provided by the user (optionally) and should be 'meaningful' (e.g. TP/PP/DP group)
`ranks` is a list of the 'global ranks' that are members of the PG.
(pg_guid, desc, ranks) tuples are appended lazily to the flight buffer when `getNCCLComm` is called on a PG and
the `enabled_` flag is true for that PG.
- the order of calling (init_process_group, new_group, etc) does not affect the order of the tuples in the list
Returns:
`groups`: a groups table where each row is a Group namedtuple.
`_groups`: a dict that is indexed by pg_guid with Group namedtuple as value.
`memberships`: a membership table where each row is a Membership namedtuple.
`_memberships`: a dict that is indexed by pg_guid with set of ranks (int) as value.
"""
# flat lists for return
groups = []
memberships = []
# dicts for faster cross-rank validation
_groups = {}
_memberships = {}
for global_rank in pg_config:
for pg_guid in pg_config[global_rank]:
desc = pg_config[global_rank][pg_guid]["desc"]
ranks = ast.literal_eval(pg_config[global_rank][pg_guid]["ranks"])
if isinstance(ranks, str):
# TODO Bug in FR data format? ranks is '[0, 1,...]'
ranks = eval(ranks)
if pg_guid not in _groups:
groups.append(Group(id=pg_guid, desc=desc, size=len(ranks)))
for rank in ranks:
memberships.append(Membership(group_id=pg_guid, global_rank=rank))
_groups[pg_guid] = groups[-1]
_memberships[pg_guid] = set(ranks)
else:
# validation across ranks
assert (
_groups[pg_guid].desc == desc
), f"mismatch in desc {_groups[pg_guid].desc} vs {desc}"
assert _memberships[pg_guid] == set(
ranks
), f"mismatch in membership {_memberships[pg_guid]} vs {set(ranks)}"
return groups, _groups, memberships, _memberships
def build_nccl_call(
entry: Dict[Any, Any],
id: int,
collective_id: Any,
group_id: int,
global_rank: Any,
) -> NCCLCall:
return NCCLCall(
id=id,
collective_id=collective_id,
group_id=group_id, # type: ignore[arg-type]
global_rank=global_rank,
traceback_id=0, # type: ignore[arg-type]
collective_type=entry["profiling_name"],
sizes=entry["input_sizes"],
)
def build_collectives(
all_entries: Dict[str, List[Dict[str, Any]]],
_groups: Dict[str, Group],
_memberships: Dict[str, Set[Any]],
) -> Tuple[List[Traceback], List[Collective], List[NCCLCall]]:
"""
groups, memberships are the non-flat dicts that are indexable
all_entries is a raw dict from the original dumps:
all_entries: {
global_rank: [
{
record_id: ordered id of the event in the trace buffer
pg_id: ProcessGroupNCCL::uid_
*note: `pg_id` corresponds to nothing in groups table
process_group: (pg_name, desc)
*note: `pg_name`, `desc` corresponds to `pg_id`, `desc` in groups table
collective_seq_id: ordered id for collective operations and coalesced group operations
p2p_seq_id: ordered id for point-to-point operations
op_id: ordered id including individual ops inside coalescing group
profiling_name: descriptive name of the operation
'time_created_ns',
'input_sizes',
'output_sizes',
'state',
'time_discovered_started_ns',
'time_discovered_completed_ns',
'retired',
'frames',
}
]
}
"""
tracebacks: List[Traceback] = []
collectives: List[Collective] = []
nccl_calls: List[NCCLCall] = []
# once we find one mismatch, we stop pairing up collectives since the pairing is possibly incorrect
# instead, just record the remaining ops as NCCLCalls
mismatch = {_groups[g].id: 0 for g in _groups}
MISMATCH_TAIL = 10
"""
- it doesn't matter what order I put collectives/ncclops into their table. we can later on re-sort it by start time
- there could be multiple options for the "first" collective to pair up (rank 0,1 might do a bcast while rank 2,3 do a bcast)
- within a group, the first collective must be the same on all ranks in the group, then it can be marked as a
collective and removed
"""
while all_entries:
# we greedily match collectives, starting arbitrarily with the trace from the first rank
# later, if we exhaust the first rank, we continue with the next 'first rank'
rank_iter = iter(all_entries)
first_rank = next(rank_iter)
other_ranks = list(rank_iter)
if len(all_entries[first_rank]) == 0:
all_entries.pop(first_rank)
continue
# lets match the first collective! we need to know which ranks are involved, and ensure that this same
# collective is also the first one on those ranks within that group
entries = all_entries[first_rank]
pg_name, desc = entries[0]["process_group"]
profiling_name = entries[0]["profiling_name"]
collective_seq_id = entries[0]["collective_seq_id"]
expected_ranks = set(_memberships[pg_name])
candidate_ranks = {first_rank}
candidate_idx = {}
found_ranks = set()
found_idx = {}
if find_coalesced_group(pg_name, entries):
expected_ranks.add(first_rank)
done_ranks = set()
all_coalesced_entries = {}
while expected_ranks:
curr = expected_ranks.pop()
done_ranks.add(curr)
grp = (
find_coalesced_group(pg_name, all_entries[curr]) # type: ignore[index]
if curr in all_entries # type: ignore[comparison-overlap]
else []
)
all_coalesced_entries[curr] = grp
for index, entry in grp:
op = Op(entry, _memberships)
peer = None
if op.type == "send":
assert op._src_g == curr, (op._src_g, curr)
peer = op._dst_g
elif op.type == "recv":
assert op._dst_g == curr, (op._dst_g, curr)
peer = op._src_g
if peer and peer not in done_ranks:
expected_ranks.add(peer)
match = match_coalesced_groups(
all_coalesced_entries,
group_size=_groups[pg_name].size,
groups=_groups,
memberships=_memberships,
)
if match and mismatch[pg_name] == 0:
collectives.append(Collective(id=len(collectives), group_id=pg_name))
else:
mismatch[pg_name] += 1
for r in all_coalesced_entries:
reversed_calls = []
for i, _ in reversed(all_coalesced_entries[r]):
reversed_calls.append(
build_nccl_call(
all_entries[r].pop(i), # type: ignore[index]
id=len(nccl_calls),
collective_id=collectives[-1].id if match else None,
group_id=pg_name,
global_rank=r,
)
)
nccl_calls.extend(reversed(reversed_calls))
else:
has_undecided_case = False
errors = set()
for o in expected_ranks.intersection(set(other_ranks)):
for i, e in enumerate(all_entries[o]): # type: ignore[index]
# step over ops from other PGs
# only check match state when seq_id matches
if (
e["process_group"] == (pg_name, desc)
and e["collective_seq_id"] == collective_seq_id
):
match_state = match_one_event(entries[0], e, _memberships)
if (
match_state
in [MatchState.FULLY_MATCHED, MatchState.UNDECIDED]
and mismatch[pg_name] == 0
):
found_ranks.add(o)
found_idx[o] = i
has_undecided_case = match_state == MatchState.UNDECIDED
else:
candidate_ranks.add(o)
candidate_idx[o] = i
errors.add(match_state)
break
# case one: not every rank join the collective or in the flight recorder.
if (candidate_ranks | found_ranks) != expected_ranks:
mismatch[pg_name] += 1
print(
f"Not all ranks joining collective for group {pg_name}:{desc} collective {profiling_name}",
f"Missing ranks are {expected_ranks - (candidate_ranks | found_ranks)}",
)
elif len(candidate_ranks) == 1:
# case two: alltoall or alltoall_base case.
if has_undecided_case:
alltoall_cases = [entries[0]] + [
all_entries[o][found_idx[o]] for o in found_ranks
]
check_result, input_numel, output_numel = check_size_alltoall(
alltoall_cases
)
if not check_result:
mismatch[pg_name] += 1
print(
f"Input/output mismatch in the collective for group {pg_name}:{desc} collective {profiling_name}",
f"input_numel {input_numel} output_numel{output_numel}",
)
candidate_ranks.update(found_ranks)
candidate_idx.update(found_idx)
found_idx.clear()
found_ranks.clear()
else:
found_ranks.update(candidate_ranks)
found_idx.update(candidate_idx)
candidate_idx.clear()
candidate_ranks.clear()
# case three: all joined and everything matches on all ranks.
else:
found_ranks.update(candidate_ranks)
found_idx.update(candidate_idx)
candidate_idx.clear()
candidate_ranks.clear()
# case four: mismatch cases due to not same type, size mismatch or state mismatch.
else:
error_msg = ", ".join(error.name for error in errors)
print(
f"Collective errors for group {pg_name}:{desc} collective {profiling_name}",
f"Found errors: {error_msg}",
)
candidate_ranks.update(found_ranks)
candidate_idx.update(found_idx)
found_idx.clear()
found_ranks.clear()
# at this point there are 3 possibilities
# 1. we found a match on all the ranks that are members of the group
# -> we create a Collective and remove the individual entries from their original lists
if found_ranks == expected_ranks and mismatch[pg_name] == 0:
collectives.append(Collective(id=len(collectives), group_id=pg_name))
for r in found_ranks:
i = found_idx[r] if r != first_rank else 0
nccl_calls.append(
build_nccl_call(
all_entries[r].pop(i), # type: ignore[index]
id=len(nccl_calls),
collective_id=collectives[-1].id,
group_id=pg_name,
global_rank=r,
)
)
# 2. we found a partial match but some ranks are missing
# 3. we found no match
# -> since its not a complete collective, no entry goes into collectives but we still record a nccl call
# TODO should there be a way to mark 'mismatches'?
else:
print("appending a non-matching collective")
# TODO: figure out a better for mismatch.
# Also, shall we add seq Id as well?
for r in candidate_ranks:
i = candidate_idx[r] if r != first_rank else 0
nccl_calls.append(
build_nccl_call(
all_entries[r].pop(i), # type: ignore[index]
id=len(nccl_calls),
collective_id=None,
group_id=pg_name,
global_rank=r,
)
)
if mismatch[pg_name] > MISMATCH_TAIL:
print(f"Too many mismatches for process_group {pg_name}:{desc}, aborting")
sys.exit(-1)
return tracebacks, collectives, nccl_calls
def build_db(details: Dict[str, Dict[str, Any]], args: argparse.Namespace) -> Database:
# temporary state used for building database
entries = {}
pg_config = {}
version = {}
for dump in details.values():
rank = dump["rank"]
entries[rank] = dump["entries"]
version[rank] = dump["version"]
pg_config[rank] = dump["pg_config"]
check_version(version)
check_trace_from_beginning(entries)
# flattened database
groups, _groups, memberships, _memberships = build_groups_memberships(pg_config)
print("built groups, memberships")
check_no_missing_dump_files(entries, memberships)
if args.just_print_entries:
just_print_entries(entries, _groups, _memberships)
sys.exit(0)
tracebacks, collectives, nccl_calls = build_collectives(
entries, _groups, _memberships
)
print("built collectives, nccl_calls")
if args.verbose:
print("Groups\n", tabulate(groups, headers=Group._fields))
print("Memberships\n", tabulate(memberships, headers=Membership._fields))
print("Collectives\n", tabulate(collectives, headers=Collective._fields))
print("NCCLCalls\n", tabulate(nccl_calls, headers=NCCLCall._fields))
db = Database(
tracebacks=tracebacks,
collectives=collectives,
ncclcalls=nccl_calls,
groups=groups,
memberships=memberships,
)
return db