mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156082 Approved by: https://github.com/soulitzer ghstack dependencies: #156079
720 lines
26 KiB
Python
720 lines
26 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import argparse
|
|
import math
|
|
from typing import Any
|
|
|
|
from tools.flight_recorder.components.fr_logger import FlightRecorderLogger
|
|
from tools.flight_recorder.components.types import (
|
|
Collective,
|
|
EntryState,
|
|
Group,
|
|
MatchInfo,
|
|
MatchState,
|
|
MatchStateRecord,
|
|
Membership,
|
|
Op,
|
|
P2P,
|
|
)
|
|
|
|
|
|
logger: FlightRecorderLogger = FlightRecorderLogger()
|
|
|
|
|
|
try:
|
|
from tabulate import tabulate
|
|
except ModuleNotFoundError:
|
|
logger.debug("tabulate is not installed. Proceeding without it.")
|
|
|
|
|
|
def format_frame(frame: dict[str, str]) -> str:
|
|
name = frame["name"]
|
|
filename = frame["filename"]
|
|
line = frame["line"]
|
|
return f"{name} at {filename}:{line}"
|
|
|
|
|
|
def format_frames(frames: list[dict[str, str]]) -> str:
|
|
formatted_frames = []
|
|
for frame in frames:
|
|
formatted_frames.append(format_frame(frame))
|
|
return "\n".join(formatted_frames)
|
|
|
|
|
|
def match_one_event(
|
|
event_a: dict[Any, Any],
|
|
event_b: dict[Any, Any],
|
|
memberships: dict[str, set[Any]],
|
|
pg_name: str,
|
|
) -> MatchInfo:
|
|
op_a = Op(event_a, memberships, pg_name)
|
|
op_b = Op(event_b, memberships, pg_name)
|
|
return op_a.match(op_b)
|
|
|
|
|
|
def match_coalesced_groups(
|
|
all_rank_events: dict[Any, Any],
|
|
group_size: int,
|
|
groups: dict[str, Group],
|
|
memberships: dict[str, set[Any]],
|
|
_pg_guids: dict[tuple[str, int], str],
|
|
) -> bool:
|
|
"""
|
|
all_rank_events: {
|
|
rank: [
|
|
(idx, event_dict)
|
|
]
|
|
}
|
|
|
|
Note: it is possible for event dicts in a coalesced group to be asymmetric.
|
|
e.g. the following events lists form a valid coalescing group
|
|
events0 [send:1]
|
|
events1 [recv:0, send:2]
|
|
events2 [recv:1]
|
|
|
|
Rule 1: all ops should find a match
|
|
Rule 2: relative ordering of sends and recvs in one event list can be arbitrary
|
|
e.g.
|
|
events1 [recv:0, send:2] —> okay
|
|
events1 [send:2, recv:0] —> also okay
|
|
Rule 3: sends to the same dest or recvs from the src should be in a consistent order
|
|
e.g.
|
|
rank0 [send:1 (100B), send:1 (1000B)]
|
|
rank1 [recv:0 (1000B), recv:0 (100B)] —> not okay
|
|
"""
|
|
all_ops = {
|
|
rank: [
|
|
Op(e, memberships, _pg_guids[(e["process_group"][0], rank)])
|
|
for i, e in all_rank_events[rank]
|
|
]
|
|
for rank in all_rank_events
|
|
}
|
|
|
|
def visualize_ops(
|
|
match: bool,
|
|
_pg_guids: dict[tuple[str, int], str],
|
|
) -> None:
|
|
all_ops = {
|
|
rank: [
|
|
Op(e, memberships, _pg_guids[(e["process_group"][0], rank)])
|
|
for i, e in all_rank_events[rank]
|
|
]
|
|
for rank in all_rank_events
|
|
}
|
|
|
|
i = 0
|
|
row = []
|
|
progress = True
|
|
table = []
|
|
while progress:
|
|
progress = False
|
|
for r in all_ops:
|
|
if len(all_ops[r]) > i:
|
|
rank, event = all_rank_events[r][i]
|
|
row.append(
|
|
Op(
|
|
event,
|
|
memberships,
|
|
_pg_guids[(event["process_group"][0], rank)],
|
|
)
|
|
)
|
|
progress = True
|
|
else:
|
|
row.append(None) # type: ignore[arg-type]
|
|
table.append(row)
|
|
row = []
|
|
i += 1
|
|
title = "Match" if match else "MISMATCH"
|
|
logger.info("%s \n", title)
|
|
logger.info("%s", tabulate(table)) # type: ignore[operator]
|
|
|
|
# TODO can't verify seq_id bc there might have been valid seq deltas between ranks even within a pg.
|
|
for op_list in all_ops.values():
|
|
if not op_list:
|
|
# print("TODO- not sure if its valid for only some ranks in a PG to participate in a coalesced op?")
|
|
return False
|
|
assert op_list[-1].type == "coalesced"
|
|
op_list.pop(-1)
|
|
|
|
while all_ops:
|
|
first_rank = next(iter(all_ops))
|
|
my_ops = all_ops[first_rank]
|
|
|
|
if len(all_ops[first_rank]) == 0:
|
|
all_ops.pop(first_rank)
|
|
continue
|
|
|
|
# lets match the first collective! we need to know which ranks are involved, and ensure that this same
|
|
# collective is also the first one on those ranks within that group
|
|
op = my_ops[0]
|
|
match_idx = -1
|
|
if op.type in P2P:
|
|
dst_global_rank = sorted(memberships[op.pg_name])[op.dst]
|
|
peer_ops = all_ops[dst_global_rank]
|
|
for i, other in enumerate(peer_ops):
|
|
if op.match(other).state == MatchState.FULLY_MATCHED:
|
|
match_idx = i
|
|
break
|
|
elif op.dst == other.src:
|
|
# Rule 3
|
|
break
|
|
else:
|
|
# Rule 1
|
|
continue
|
|
else:
|
|
raise NotImplementedError("coalesced collective ops")
|
|
if match_idx >= 0:
|
|
my_ops.pop(0)
|
|
peer_ops.pop(match_idx)
|
|
else:
|
|
visualize_ops(False, _pg_guids)
|
|
return False
|
|
|
|
visualize_ops(True, _pg_guids)
|
|
return True
|
|
|
|
|
|
# We enabled the creating FR entry for non-P2P slow path collective ops in v2.7.
|
|
def match_coalesced_groups_with_non_p2p(
|
|
all_rank_events: dict[Any, Any],
|
|
pg_info: tuple[str, str],
|
|
memberships: dict[str, set[Any]],
|
|
_pg_guids: dict[tuple[str, int], str],
|
|
mismatch: dict[str, int],
|
|
dumps_ranks: set[int],
|
|
version: str,
|
|
collectives: list[Collective],
|
|
match_record: MatchStateRecord,
|
|
) -> bool:
|
|
"""
|
|
all_rank_events: {
|
|
rank: [
|
|
(idx, event_dict)
|
|
]
|
|
}
|
|
|
|
Note: it is possible for event dicts in a coalesced group to be asymmetric.
|
|
e.g. the following events lists form a valid coalescing group
|
|
events0 [send:1]
|
|
events1 [recv:0, send:2]
|
|
events2 [recv:1]
|
|
|
|
Rule 1: all ops should find a match
|
|
Rule 2: relative ordering of sends and recvs in one event list can be arbitrary
|
|
e.g.
|
|
events1 [recv:0, send:2] —> okay
|
|
events1 [send:2, recv:0] —> also okay
|
|
Rule 3: sends to the same dest or recvs from the src should be in a consistent order
|
|
e.g.
|
|
rank0 [send:1 (100B), send:1 (1000B)]
|
|
rank1 [recv:0 (1000B), recv:0 (100B)] —> not okay
|
|
"""
|
|
all_ops = {
|
|
rank: [
|
|
Op(e, memberships, _pg_guids[(e["process_group"][0], rank)])
|
|
for _, e in all_rank_events[rank]
|
|
]
|
|
for rank in all_rank_events
|
|
}
|
|
is_p2p = any(op.type in P2P for ops in all_ops.values() for op in ops)
|
|
pg_name = pg_info[0]
|
|
|
|
def visualize_ops(
|
|
match: bool,
|
|
_pg_guids: dict[tuple[str, int], str],
|
|
) -> None:
|
|
all_ops = {
|
|
rank: [
|
|
Op(e, memberships, _pg_guids[(e["process_group"][0], rank)])
|
|
for _, e in all_rank_events[rank]
|
|
]
|
|
for rank in all_rank_events
|
|
}
|
|
|
|
i = 0
|
|
row = []
|
|
progress = True
|
|
table = []
|
|
while progress:
|
|
progress = False
|
|
for r in all_ops:
|
|
if len(all_ops[r]) > i:
|
|
rank, event = all_rank_events[r][i]
|
|
row.append(
|
|
Op(
|
|
event,
|
|
memberships,
|
|
_pg_guids[(event["process_group"][0], rank)],
|
|
)
|
|
)
|
|
progress = True
|
|
else:
|
|
row.append(None) # type: ignore[arg-type]
|
|
table.append(row)
|
|
row = []
|
|
i += 1
|
|
title = "Match" if match else "MISMATCH"
|
|
logger.info("%s \n", title)
|
|
logger.info("%s", tabulate(table)) # type: ignore[operator]
|
|
|
|
# TODO Need to verify no seq_id deltas for P2P ops.
|
|
for rank, op_list in all_ops.items():
|
|
if not op_list:
|
|
logger.error("Rank %s has an empty op list.", rank)
|
|
continue
|
|
if op_list[-1].type == "coalesced" and is_p2p:
|
|
op_list.pop(-1)
|
|
|
|
while all_ops:
|
|
first_rank = next(iter(all_ops))
|
|
my_ops = all_ops[first_rank]
|
|
|
|
if len(all_ops[first_rank]) == 0:
|
|
all_ops.pop(first_rank)
|
|
continue
|
|
|
|
# lets match the first collective! we need to know which ranks are involved, and ensure that this same
|
|
# collective is also the first one on those ranks within that group
|
|
op = my_ops[0]
|
|
match_idx = -1
|
|
if is_p2p:
|
|
dst_global_rank = sorted(memberships[op.pg_name])[op.dst]
|
|
peer_ops = all_ops[dst_global_rank]
|
|
for i, other in enumerate(peer_ops):
|
|
if op.match(other).state == MatchState.FULLY_MATCHED:
|
|
match_idx = i
|
|
break
|
|
elif op.dst == other.src:
|
|
# Rule 3
|
|
break
|
|
else:
|
|
# Rule 1
|
|
continue
|
|
if match_idx >= 0:
|
|
my_ops.pop(0)
|
|
peer_ops.pop(match_idx)
|
|
else:
|
|
visualize_ops(False, _pg_guids)
|
|
return False
|
|
else:
|
|
all_coalesced_entries = {
|
|
rank: [e for _, e in all_rank_events[rank]] for rank in all_rank_events
|
|
}
|
|
current_entry = all_coalesced_entries[first_rank][0]
|
|
my_ops.pop(0)
|
|
|
|
match_record.reset_for_coalesced(
|
|
EntryState(current_entry, match_record.expected_ranks),
|
|
{first_rank},
|
|
)
|
|
|
|
# Iterate through all the ranks and check if there is a mismatch for the current entry.
|
|
check_current_entry_match(
|
|
all_coalesced_entries,
|
|
_pg_guids,
|
|
pg_info,
|
|
current_entry,
|
|
memberships,
|
|
mismatch,
|
|
match_record,
|
|
)
|
|
|
|
# Use heuristics to decide what type of errors and error messages we should print.
|
|
error_analysis(
|
|
all_coalesced_entries,
|
|
match_record,
|
|
dumps_ranks,
|
|
first_rank,
|
|
current_entry,
|
|
mismatch,
|
|
get_version_detail(version),
|
|
pg_info[0],
|
|
)
|
|
|
|
# TODO: For now, we only check the correctness of individual collective within a coalesced one in
|
|
# this script. We need to merge (e.g, input/output sizes) together
|
|
# for downstream consumer.
|
|
|
|
# at this point there are 3 possibilities
|
|
# 1. we found a match on all the ranks that are members of the group
|
|
# -> we create a Collective and remove the individual entries from their original lists
|
|
if (
|
|
match_record.found_ranks == match_record.expected_ranks
|
|
and mismatch[pg_name] == 0
|
|
):
|
|
# Just pop out this collective.
|
|
idx_map = {
|
|
r: match_record.found_idx[r] if r != first_rank else 0
|
|
for r in match_record.found_ranks
|
|
}
|
|
for i, k in idx_map.items():
|
|
all_rank_events[i].pop(k)
|
|
for r in match_record.found_ranks:
|
|
if r != first_rank:
|
|
all_ops[r].pop(0)
|
|
|
|
# 2. we found a partial match but some ranks are missing
|
|
# 3. we found no match
|
|
# -> since its not a complete collective, no entry goes into collectives but we still record a nccl call
|
|
else:
|
|
logger.debug("Non-matching collective inside coalesced group")
|
|
idx_map = {
|
|
r: match_record.candidate_idx[r] if r != first_rank else 0
|
|
for r in match_record.candidate_ranks
|
|
}
|
|
collectives.append(
|
|
match_record.entry_state.to_collective(
|
|
len(collectives),
|
|
errors=match_record.errors,
|
|
idx_map=idx_map,
|
|
all_entries=all_coalesced_entries,
|
|
)
|
|
)
|
|
return False
|
|
|
|
if is_p2p:
|
|
visualize_ops(True, _pg_guids)
|
|
return True
|
|
|
|
|
|
def check_size_alltoall(alltoall_cases: list[dict[str, Any]]) -> tuple[bool, int, int]:
|
|
input_numel = 0
|
|
output_numel = 0
|
|
for e in alltoall_cases:
|
|
input_numel += math.prod(e["input_sizes"][0])
|
|
output_numel += math.prod(e["output_sizes"][0])
|
|
return input_numel != output_numel, input_numel, output_numel
|
|
|
|
|
|
def check_current_entry_match(
|
|
all_entries: dict[int, list[dict[str, Any]]],
|
|
_pg_guids: dict[tuple[str, int], str],
|
|
pg_info: tuple[str, str],
|
|
current_entry: dict[str, Any],
|
|
_memberships: dict[str, set[Any]],
|
|
mismatch: dict[str, int],
|
|
match_record: MatchStateRecord,
|
|
) -> None:
|
|
pg_name, desc = pg_info[0], pg_info[1]
|
|
for o in match_record.expected_ranks.intersection(set(match_record.other_ranks)):
|
|
for i, e in enumerate(all_entries[o]): # type: ignore[index]
|
|
# step over ops from other PGs
|
|
# only check match state when seq_id matches
|
|
if (
|
|
_pg_guids[(e["process_group"][0], o)] == pg_name
|
|
and e["process_group"][1] == desc
|
|
and e["collective_seq_id"] == match_record.entry_state.collective_seq_id
|
|
):
|
|
match_info = match_one_event(current_entry, e, _memberships, pg_name)
|
|
if (
|
|
match_info.state in [MatchState.FULLY_MATCHED, MatchState.UNDECIDED]
|
|
and mismatch[pg_name] == 0
|
|
):
|
|
match_record.found_ranks.add(o)
|
|
match_record.found_idx[o] = i
|
|
match_record.has_undecided_case = (
|
|
match_info.state == MatchState.UNDECIDED
|
|
)
|
|
else:
|
|
match_record.candidate_ranks.add(o)
|
|
match_record.candidate_idx[o] = i
|
|
if match_info.state not in [
|
|
MatchState.FULLY_MATCHED,
|
|
MatchState.UNDECIDED,
|
|
]:
|
|
# Here we assume the current rank is not the source of the error.
|
|
# But it's possible that the current rank is the culprit, then users will
|
|
# see lots of normal ranks reported as culprit.
|
|
# TODO: we need to figure out a better way to handle the case mentioned above.
|
|
match_record.errors.add((o, match_info))
|
|
break
|
|
|
|
|
|
def error_analysis(
|
|
all_entries: dict[int, list[dict[str, Any]]],
|
|
match_record: MatchStateRecord,
|
|
dumps_ranks: set[int],
|
|
first_rank: int,
|
|
current_entry: dict[str, Any],
|
|
mismatch: dict[str, int],
|
|
version: tuple[int, int],
|
|
pg_name: str,
|
|
) -> None:
|
|
major_v, minor_v = version[0], version[1]
|
|
# case one: not every rank join the collective or in the flight recorder.
|
|
if (
|
|
match_record.candidate_ranks | match_record.found_ranks
|
|
) != match_record.expected_ranks and match_record.expected_ranks - (
|
|
match_record.candidate_ranks | match_record.found_ranks
|
|
) <= dumps_ranks:
|
|
mismatch[pg_name] += 1
|
|
logger_msg = "Not all ranks joining collective, sequence number: %s"
|
|
missing_ranks = match_record.expected_ranks - (
|
|
match_record.candidate_ranks | match_record.found_ranks
|
|
)
|
|
match_record.entry_state.log(
|
|
logger, logger_msg, format_frames, missing_ranks=missing_ranks
|
|
)
|
|
match_record.candidate_ranks.update(match_record.found_ranks)
|
|
match_record.candidate_idx.update(match_record.found_idx)
|
|
match_record.found_idx.clear()
|
|
match_record.found_ranks.clear()
|
|
elif (
|
|
len(match_record.candidate_ranks) == 1
|
|
and dumps_ranks == match_record.expected_ranks
|
|
):
|
|
# case two: alltoall or alltoall_base case.
|
|
if match_record.has_undecided_case:
|
|
alltoall_cases = [current_entry] + [
|
|
all_entries[o][match_record.found_idx[o]]
|
|
for o in match_record.found_ranks
|
|
]
|
|
fail_check, total_input_numel, total_output_numel = check_size_alltoall(
|
|
alltoall_cases
|
|
)
|
|
if major_v <= 2 and minor_v <= 3:
|
|
# We don't log the input/output sizes for alltoall before v2.4,
|
|
# so we don't consider the size mismatch as an error for now.
|
|
fail_check = False
|
|
if fail_check:
|
|
# When we see errors in all_to_all, it's hard to tell which rank is the source of the error.
|
|
mismatch[pg_name] += 1
|
|
logger_msg = (
|
|
"Input/output mismatch in the collective sequence number: %s"
|
|
)
|
|
match_record.entry_state.log(
|
|
logger,
|
|
logger_msg,
|
|
format_frames,
|
|
total_numel=(total_input_numel, total_output_numel),
|
|
)
|
|
match_record.candidate_ranks.update(match_record.found_ranks)
|
|
match_record.candidate_idx.update(match_record.found_idx)
|
|
match_record.found_idx.clear()
|
|
match_record.found_ranks.clear()
|
|
match_record.errors.add(
|
|
(first_rank, MatchInfo(MatchState.SIZE_OR_SYNTAX_MISMATCH))
|
|
)
|
|
else:
|
|
match_record.found_ranks.update(match_record.candidate_ranks)
|
|
match_record.found_idx.update(match_record.candidate_idx)
|
|
match_record.candidate_idx.clear()
|
|
match_record.candidate_ranks.clear()
|
|
# case three: all joined and everything matches on all ranks.
|
|
else:
|
|
match_record.found_ranks.update(match_record.candidate_ranks)
|
|
match_record.found_idx.update(match_record.candidate_idx)
|
|
match_record.candidate_idx.clear()
|
|
match_record.candidate_ranks.clear()
|
|
# case four: mismatch cases due to not same type, size mismatch or state mismatch.
|
|
elif len(match_record.errors) > 0:
|
|
mismatch[pg_name] += 1
|
|
logger_msg = "Collective sequence number: %s has errors"
|
|
match_record.entry_state.log(
|
|
logger, logger_msg, format_frames, errors=match_record.errors
|
|
)
|
|
match_record.candidate_ranks.update(match_record.found_ranks)
|
|
match_record.candidate_idx.update(match_record.found_idx)
|
|
match_record.found_idx.clear()
|
|
match_record.found_ranks.clear()
|
|
# partial analysis case when we cannot decide what's wrong with this collective entry.
|
|
else:
|
|
match_record.candidate_ranks.update(match_record.found_ranks)
|
|
match_record.candidate_idx.update(match_record.found_idx)
|
|
match_record.found_idx.clear()
|
|
match_record.found_ranks.clear()
|
|
if match_record.expected_ranks - dumps_ranks:
|
|
mismatch[pg_name] += 1
|
|
logger.info(
|
|
"We cannot decide what's wrong with this collective entry "
|
|
"because we missed FR dumps from ranks (%s) so we don't have enough "
|
|
"information. If you want to debug further use -j to dump all raw trace",
|
|
str(match_record.expected_ranks - dumps_ranks),
|
|
)
|
|
else:
|
|
logger.info(
|
|
"No errors found for this collective entry, There could be some "
|
|
"other reasons why we see collective timeout."
|
|
)
|
|
|
|
|
|
def find_coalesced_group(
|
|
pg_name: str,
|
|
entries: list[dict[str, Any]],
|
|
_pg_guids: dict[tuple[str, int], str],
|
|
rank: int,
|
|
) -> list[tuple[int, dict[str, Any]]]:
|
|
"""Given a list of entries, if the collective_seq_id of the first entry matches that of subsequent ones,
|
|
build an return a list of entries terminating in a 'coalesced' op entry all sharing a collective_seq_id
|
|
"""
|
|
found = []
|
|
collective_seq_id = None
|
|
for i, e in enumerate(entries):
|
|
if _pg_guids[(e["process_group"][0], rank)] != pg_name:
|
|
continue
|
|
elif collective_seq_id is None:
|
|
collective_seq_id = (
|
|
e["p2p_seq_id"] if e["is_p2p"] else e["collective_seq_id"]
|
|
)
|
|
found.append((i, e))
|
|
elif not e["is_p2p"] and e["collective_seq_id"] == collective_seq_id:
|
|
found.append((i, e))
|
|
elif e["is_p2p"] and e["p2p_seq_id"] == collective_seq_id:
|
|
found.append((i, e))
|
|
else:
|
|
break
|
|
|
|
if len(found) > 1:
|
|
assert found[-1][1]["profiling_name"] == "nccl:coalesced"
|
|
return found
|
|
return []
|
|
|
|
|
|
# We enabled the creating FR entry for non-P2P slow path collective ops in v2.7.
|
|
def find_coalesced_group_with_non_p2p(
|
|
pg_name: str,
|
|
entries: list[dict[str, Any]],
|
|
_pg_guids: dict[tuple[str, int], str],
|
|
rank: int,
|
|
) -> list[tuple[int, dict[str, Any]]]:
|
|
"""Given a list of entries, if the collective_seq_id of the first entry matches that of subsequent ones,
|
|
build an return a list of entries terminating in a 'coalesced' op entry all sharing a collective_seq_id
|
|
"""
|
|
found = []
|
|
collective_seq_id = None
|
|
for i, e in enumerate(entries):
|
|
if _pg_guids[(e["process_group"][0], rank)] != pg_name:
|
|
continue
|
|
elif collective_seq_id is None:
|
|
collective_seq_id = (
|
|
e["p2p_seq_id"] if e["is_p2p"] else e["collective_seq_id"]
|
|
)
|
|
found.append((i, e))
|
|
elif not e["is_p2p"] and e["collective_seq_id"] == collective_seq_id:
|
|
found.append((i, e))
|
|
elif e["is_p2p"] and e["p2p_seq_id"] == collective_seq_id:
|
|
found.append((i, e))
|
|
else:
|
|
break
|
|
|
|
if len(found) > 1:
|
|
name = found[-1][1]["profiling_name"]
|
|
if name.startswith("nccl:") and not name.endswith("_coalesced"):
|
|
logger.error("Rank %s does not have a coalesced end.", rank)
|
|
return found
|
|
return []
|
|
|
|
|
|
def just_print_entries(
|
|
all_entries: dict[int, list[dict[str, Any]]],
|
|
_groups: dict[str, Group],
|
|
_memberships: dict[str, set[Any]],
|
|
_pg_guids: dict[tuple[str, int], str],
|
|
args: argparse.Namespace,
|
|
) -> None:
|
|
rows = []
|
|
ranks = sorted(all_entries.keys())
|
|
headers = [
|
|
f"Rank {rank}"
|
|
for rank in ranks
|
|
if args.selected_ranks is None or rank in args.selected_ranks
|
|
]
|
|
progress = True
|
|
while progress:
|
|
progress = False
|
|
row = []
|
|
for rank in ranks:
|
|
if args.selected_ranks is not None and rank not in args.selected_ranks:
|
|
continue
|
|
if len(all_entries[rank]) == 0:
|
|
row.append("")
|
|
else:
|
|
entry = all_entries[rank].pop(0)
|
|
pg_name = _pg_guids[(entry["process_group"][0], rank)]
|
|
if (
|
|
args.pg_filters is None
|
|
or entry["process_group"][1] in args.pg_filters
|
|
or entry["process_group"][0] in args.pg_filters
|
|
):
|
|
row.append(str(Op(entry, _memberships, pg_name)))
|
|
else:
|
|
row.append("")
|
|
progress = True
|
|
if progress:
|
|
rows.append(row)
|
|
|
|
logger.info(tabulate(rows, headers=headers))
|
|
|
|
|
|
def check_no_missing_dump_files(
|
|
entries: dict[int, Any], memberships: list[Membership]
|
|
) -> None:
|
|
all_ranks = set()
|
|
for membership in memberships:
|
|
all_ranks.add(int(membership.global_rank))
|
|
dumps_ranks = {int(key) for key in entries.keys()}
|
|
assert dumps_ranks == all_ranks, (
|
|
f"Missing dump files from ranks {all_ranks - dumps_ranks}"
|
|
)
|
|
|
|
|
|
def check_version(version_by_ranks: dict[str, str], version: str) -> None:
|
|
for rank, v in version_by_ranks.items():
|
|
assert v == version, (
|
|
f"Rank {rank} has different version {v} from the given version {version}"
|
|
)
|
|
|
|
|
|
def get_version_detail(version: str) -> tuple[int, int]:
|
|
version = version.split(".")
|
|
assert len(version) == 2, f"Invalid version {version}"
|
|
major, minor = map(int, version)
|
|
return major, minor
|
|
|
|
|
|
def align_trace_from_beginning(
|
|
entries: dict[int, list[dict[str, Any]]],
|
|
) -> dict[int, list[dict[str, Any]]]:
|
|
"""
|
|
Align the trace entries by record ID for entries.
|
|
This function takes a dictionary of rank names to lists of trace entries as input.
|
|
Each trace entry is a dictionary containing information about a collective operation,
|
|
including its unique identifier (`record_id` is monotonically increasing as we write into the ring buffer).
|
|
The function finds the largest starting point across all ranks by taking the maximum
|
|
`record_id` value of the first entry in each rank. Finally, it filters out any
|
|
entries with `record_id` values less than the maximum starting point.
|
|
The function returns the updated dictionary of sorted and filtered trace entries.
|
|
|
|
Args:
|
|
entries (Dict[str, List[Dict[str, Any]]]): A dictionary of rank names to lists of trace entries.
|
|
|
|
Returns:
|
|
entries (Dict[str, List[Dict[str, Any]]]): Entries sorted by record ID and filtered by the maximum starting point.
|
|
"""
|
|
|
|
maximum_starting_record_id = 0
|
|
for rank in entries:
|
|
# Although this is a ring buffer, we already sort the entries by `record_id` when dumping, we just
|
|
# need to find the largest starting point. For example, if the buffer has the following entries:
|
|
# Rank 0: [0, 1, 2, 3, 4, 5, 6]
|
|
# Rank 1: [1, 2, 3, 4, 5, 6, 7]
|
|
# Rank 2: [2, 3, 4, 5, 6, 7, 8]
|
|
# Rank 3: [0, 1, 2, 3, 4, 5, None]
|
|
# Then we should start from collective 2 not 0 because any collective before,
|
|
# we don't have complete records from all ranks so we need to ignore them.
|
|
first_record_id = entries[rank][0]["record_id"]
|
|
maximum_starting_record_id = max(maximum_starting_record_id, first_record_id)
|
|
|
|
for rank in entries:
|
|
entries[rank] = [
|
|
entry
|
|
for entry in entries[rank]
|
|
if entry["record_id"] >= maximum_starting_record_id
|
|
]
|
|
|
|
return entries
|