mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Flight Recorder] Add more basic analysis to the script (#133412)
This is the first step to make sure we have a basic function of analyzer for FR in production. - We want to use this script to find out abnormalities in collectives and report it to users. - We also fixed some type errors. - [Ongoing] Also we will add more unit tests to this script and make it modularized so that we can better maintain it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133412 Approved by: https://github.com/c-p-i-o
This commit is contained in:
82
test/distributed/flight_recorder/test_fr_analysis.py
Normal file
82
test/distributed/flight_recorder/test_fr_analysis.py
Normal file
@ -0,0 +1,82 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
|
||||
from tools.flight_recorder.fr_trace import match_one_event, MatchState
|
||||
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
def create_one_event(
|
||||
collectcive_name,
|
||||
pg_info,
|
||||
input_sizes,
|
||||
output_sizes,
|
||||
state="scheduled",
|
||||
collective_seq_id=0,
|
||||
p2p_seq_id=0,
|
||||
):
|
||||
return {
|
||||
"profiling_name": f"nccl:{collectcive_name}",
|
||||
"state": state,
|
||||
"process_group": pg_info,
|
||||
"input_sizes": input_sizes,
|
||||
"output_sizes": output_sizes,
|
||||
"collective_seq_id": str(collective_seq_id),
|
||||
"p2p_seq_id": str(p2p_seq_id),
|
||||
}
|
||||
|
||||
|
||||
class FlightRecorderEventTest(TestCase):
|
||||
def test_match_one_event(self):
|
||||
e1 = create_one_event(
|
||||
"all_reduce", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
|
||||
)
|
||||
membership = {"0": {0, 1}}
|
||||
self.assertEqual(match_one_event(e1, e1, membership), MatchState.FULLY_MATCHED)
|
||||
|
||||
e2 = create_one_event(
|
||||
"all_gather", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
|
||||
)
|
||||
self.assertEqual(
|
||||
match_one_event(e1, e2, membership), MatchState.COLLECTIVE_TYPE_MISMATCH
|
||||
)
|
||||
|
||||
e3 = create_one_event(
|
||||
"alltoall", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
|
||||
)
|
||||
e4 = create_one_event(
|
||||
"alltoall", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
|
||||
)
|
||||
self.assertEqual(match_one_event(e3, e4, membership), MatchState.UNDECIDED)
|
||||
|
||||
e5 = create_one_event(
|
||||
"all_reduce", ("0", "default"), [[5, 4]], [[4, 4]], "scheduled", 1, 1
|
||||
)
|
||||
self.assertEqual(
|
||||
match_one_event(e1, e5, membership), MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||
)
|
||||
|
||||
e6 = create_one_event(
|
||||
"all_reduce", ("0", "default"), [[4, 4]], [[5, 4]], "scheduled", 1, 2
|
||||
)
|
||||
self.assertEqual(
|
||||
match_one_event(e1, e6, membership), MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||
)
|
||||
|
||||
e7 = create_one_event(
|
||||
"all_reduce", ("0", "default"), [[4, 4]], [[5, 4]], "scheduled", 2
|
||||
)
|
||||
self.assertEqual(
|
||||
match_one_event(e7, e7, membership), MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||
)
|
||||
|
||||
e9 = create_one_event(
|
||||
"all_reduce", ("0", "default"), [[4, 4]], [[4, 4]], "completed", 1
|
||||
)
|
||||
self.assertEqual(
|
||||
match_one_event(e1, e9, membership), MatchState.COLLECTIVE_STATE_MISMATCH
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -28,9 +28,12 @@ python fr_trace.py -d <dump dir containing trace files> [-o <output file>]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
from enum import Enum
|
||||
from typing import ( # type: ignore[attr-defined]
|
||||
_eval_type,
|
||||
Any,
|
||||
@ -38,6 +41,7 @@ from typing import ( # type: ignore[attr-defined]
|
||||
Generic,
|
||||
List,
|
||||
NamedTuple,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
@ -151,6 +155,9 @@ TODO
|
||||
|
||||
"""
|
||||
Collective Matching logic
|
||||
|
||||
NOTE: For now, these collectives need to be supported by NCCL,
|
||||
https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/overview.html.
|
||||
"""
|
||||
COLLECTIVES = {
|
||||
"broadcast",
|
||||
@ -162,6 +169,8 @@ COLLECTIVES = {
|
||||
"_reduce_scatter_base",
|
||||
"gather",
|
||||
"scatter",
|
||||
"alltoall_base",
|
||||
"alltoall",
|
||||
}
|
||||
|
||||
P2P = {
|
||||
@ -170,6 +179,50 @@ P2P = {
|
||||
}
|
||||
|
||||
|
||||
class MatchState(Enum):
|
||||
"""
|
||||
Enum representing the possible states of matching for collective operations.
|
||||
|
||||
- FULLY_MATCHED: Indicates that all aspects of the collective operations match.
|
||||
- COLLECTIVE_TYPE_MISMATCH: The types of the collective operations differ.
|
||||
- SIZE_OR_SYNTAX_MISMATCH: There is a mismatch in input/output sizes or violation of collective syntax.
|
||||
- COLLECTIVE_STATE_MISMATCH:
|
||||
The states of the collective not same, such as one finished while another just started or scheduled.
|
||||
- UNDECIDED:
|
||||
The match status is ambiguous or cannot be determined, e.g., we might need to check all ranks for alltoall_base.
|
||||
"""
|
||||
|
||||
FULLY_MATCHED = 1
|
||||
COLLECTIVE_TYPE_MISMATCH = 2
|
||||
SIZE_OR_SYNTAX_MISMATCH = 3
|
||||
COLLECTIVE_STATE_MISMATCH = 4
|
||||
UNDECIDED = 5
|
||||
|
||||
|
||||
def check_size_even_expand(list1: List[Any], list2: List[Any], size: int) -> bool:
|
||||
ratio = None
|
||||
for a, b in zip(list1, list2):
|
||||
current_ratio = int(a) / int(b)
|
||||
if current_ratio == 1:
|
||||
continue
|
||||
if current_ratio != size:
|
||||
return False
|
||||
elif ratio is None:
|
||||
ratio = current_ratio
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_size_alltoall(alltoall_cases: List[Dict[str, Any]]) -> Tuple[bool, int, int]:
|
||||
input_numel = 0
|
||||
output_numel = 0
|
||||
for e in alltoall_cases:
|
||||
input_numel += math.prod(e["input_sizes"][0])
|
||||
output_numel += math.prod(e["output_sizes"][0])
|
||||
return input_numel == output_numel, input_numel, output_numel
|
||||
|
||||
|
||||
class Op:
|
||||
"""Parses relevant info about operation out of 'event' dict
|
||||
|
||||
@ -179,7 +232,7 @@ class Op:
|
||||
nccl:recv 3<-0
|
||||
"""
|
||||
|
||||
def __init__(self, event: Dict[Any, Any], memberships: Dict[str, List[Membership]]):
|
||||
def __init__(self, event: Dict[Any, Any], memberships: Dict[str, Set[Any]]):
|
||||
profiling_name = event["profiling_name"]
|
||||
nccl, name = profiling_name.split(":")
|
||||
assert nccl == "nccl", f"name formatting error? {nccl} != 'nccl'"
|
||||
@ -204,6 +257,7 @@ class Op:
|
||||
self._src, self._dst = -1, -1
|
||||
pg_name, pg_desc = event["process_group"]
|
||||
self._init_global_src_dst(memberships[pg_name])
|
||||
self.pg_size = len(memberships[pg_name])
|
||||
|
||||
if type in P2P | COLLECTIVES:
|
||||
self.input_sizes = event["input_sizes"]
|
||||
@ -213,7 +267,7 @@ class Op:
|
||||
self.collective_seq_id = event["collective_seq_id"]
|
||||
self.p2p_seq_id = event["p2p_seq_id"]
|
||||
|
||||
def _init_global_src_dst(self, pg_ranks: List[Membership]) -> None:
|
||||
def _init_global_src_dst(self, pg_ranks: Set[Any]) -> None:
|
||||
pg_ranks = sorted(pg_ranks)
|
||||
self._src_g = pg_ranks[self._src] if self._src is not None else None
|
||||
self._dst_g = pg_ranks[self._dst] if self._dst is not None else None
|
||||
@ -235,42 +289,79 @@ class Op:
|
||||
)
|
||||
return f"{self.type}(input_sizes={self.input_sizes}, {self.state})"
|
||||
|
||||
def match(self, other) -> bool: # type: ignore[no-untyped-def]
|
||||
def match(self, other: "Op") -> MatchState:
|
||||
# TODO: I think this can validly not match,
|
||||
# e.g. if one PG was used for p2p ops between only some of the peers?
|
||||
# if self.seq_id != other.seq_id:
|
||||
# return False
|
||||
|
||||
if self.type == "send":
|
||||
return bool(
|
||||
other.type == "recv"
|
||||
and self.src == other.src
|
||||
and self.dst == other.dst
|
||||
and self.input_sizes == other.output_sizes
|
||||
# TODO: We need more states for p2p ops.
|
||||
return (
|
||||
MatchState.FULLY_MATCHED
|
||||
if (
|
||||
other.type == "recv"
|
||||
and self.src == other.src
|
||||
and self.dst == other.dst
|
||||
and self.input_sizes == other.output_sizes
|
||||
)
|
||||
else MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||
)
|
||||
elif self.type == "recv":
|
||||
return bool(
|
||||
other.type == "send"
|
||||
and self.src == other.src
|
||||
and self.dst == other.dst
|
||||
and self.output_sizes == other.input_sizes
|
||||
return (
|
||||
MatchState.FULLY_MATCHED
|
||||
if (
|
||||
other.type == "send"
|
||||
and self.src == other.src
|
||||
and self.dst == other.dst
|
||||
and self.output_sizes == other.input_sizes
|
||||
)
|
||||
else MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||
)
|
||||
elif self.type in COLLECTIVES:
|
||||
return bool(
|
||||
self.type == other.type and self.input_sizes == other.input_sizes
|
||||
)
|
||||
# TODO(whc) - output sizes dont have to match for e.g. gather, not sure if they ever have to match?
|
||||
# and self.output_sizes == other.output_sizes)
|
||||
if self.type != other.type:
|
||||
return MatchState.COLLECTIVE_TYPE_MISMATCH
|
||||
if self.type in ["alltoall", "alltoall_base"]:
|
||||
return MatchState.UNDECIDED
|
||||
if self.type != "scatter" and self.input_sizes != other.input_sizes:
|
||||
return MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||
if self.type != "gather" and self.output_sizes != other.output_sizes:
|
||||
return MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||
if self.type == "all_reduce" and self.input_sizes != other.output_sizes:
|
||||
return MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||
# TODO: need to consider uneven sharding for all-gather.
|
||||
# TODO: need to consider all_gather_into_tensor_coalesced (coalesced related)
|
||||
if self.type in [
|
||||
"all_gather",
|
||||
"all_gather_base",
|
||||
] and not check_size_even_expand(
|
||||
other.output_sizes, self.input_sizes, self.pg_size
|
||||
):
|
||||
return MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||
if self.type in [
|
||||
"reduce_scatter",
|
||||
"_reduce_scatter_base",
|
||||
] and not check_size_even_expand(
|
||||
other.input_sizes, self.output_sizes, self.pg_size
|
||||
):
|
||||
return MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||
# TODO: need to add more checks for gather and scatter.
|
||||
if self.state != other.state:
|
||||
return MatchState.COLLECTIVE_STATE_MISMATCH
|
||||
elif self.type == "coalesced":
|
||||
return bool(other.type == "coalesced")
|
||||
return False
|
||||
return (
|
||||
MatchState.FULLY_MATCHED
|
||||
if (other.type == "coalesced")
|
||||
else MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||
)
|
||||
return MatchState.FULLY_MATCHED
|
||||
|
||||
|
||||
def match_one_event(
|
||||
event_a: Dict[Any, Any],
|
||||
event_b: Dict[Any, Any],
|
||||
memberships: Dict[str, List[Membership]],
|
||||
) -> bool:
|
||||
memberships: Dict[str, Set[Any]],
|
||||
) -> MatchState:
|
||||
op_a = Op(event_a, memberships)
|
||||
op_b = Op(event_b, memberships)
|
||||
return op_a.match(op_b)
|
||||
@ -280,7 +371,7 @@ def match_coalesced_groups(
|
||||
all_rank_events: Dict[Any, Any],
|
||||
group_size: int,
|
||||
groups: Dict[str, Group],
|
||||
memberships: Dict[str, List[Membership]],
|
||||
memberships: Dict[str, Set[Any]],
|
||||
) -> bool:
|
||||
"""
|
||||
all_rank_events: {
|
||||
@ -359,7 +450,7 @@ def match_coalesced_groups(
|
||||
dst_global_rank = sorted(memberships[op.pg_name])[op.dst]
|
||||
peer_ops = all_ops[dst_global_rank]
|
||||
for i, other in enumerate(peer_ops):
|
||||
if op.match(other):
|
||||
if op.match(other) == MatchState.FULLY_MATCHED:
|
||||
match_idx = i
|
||||
break
|
||||
elif op.dst == other.src:
|
||||
@ -388,7 +479,7 @@ Flat DB builder
|
||||
|
||||
def build_groups_memberships(
|
||||
pg_config: Any,
|
||||
) -> Tuple[List[Group], Dict[Any, Group], List[Membership], Dict[str, Any]]:
|
||||
) -> Tuple[List[Group], Dict[Any, Group], List[Membership], Dict[str, Set[Any]]]:
|
||||
"""
|
||||
pg_config: {
|
||||
global_rank: {
|
||||
@ -417,7 +508,7 @@ def build_groups_memberships(
|
||||
for global_rank in pg_config:
|
||||
for pg_id in pg_config[global_rank]:
|
||||
desc = pg_config[global_rank][pg_id]["desc"]
|
||||
ranks = pg_config[global_rank][pg_id]["ranks"]
|
||||
ranks = ast.literal_eval(pg_config[global_rank][pg_id]["ranks"])
|
||||
if isinstance(ranks, str):
|
||||
# TODO Bug in FR data format? ranks is '[0, 1,...]'
|
||||
ranks = eval(ranks)
|
||||
@ -427,6 +518,7 @@ def build_groups_memberships(
|
||||
for rank in ranks:
|
||||
memberships.append(Membership(group_id=pg_id, global_rank=rank))
|
||||
_groups[pg_id] = groups[-1]
|
||||
# TODO: make ranks int no matter what input (because it can be json or pickled string)
|
||||
_memberships[pg_id] = set(ranks)
|
||||
else:
|
||||
# validation across ranks
|
||||
@ -486,7 +578,7 @@ def find_coalesced_group(
|
||||
def just_print_entries(
|
||||
all_entries: Dict[int, List[Dict[str, Any]]],
|
||||
_groups: Dict[str, Group],
|
||||
_memberships: Dict[str, List[Membership]],
|
||||
_memberships: Dict[str, Set[Any]],
|
||||
) -> None:
|
||||
rows = []
|
||||
ranks = sorted(all_entries.keys())
|
||||
@ -509,9 +601,9 @@ def just_print_entries(
|
||||
|
||||
|
||||
def build_collectives(
|
||||
all_entries: Dict[int, List[Dict[str, Any]]],
|
||||
all_entries: Dict[str, List[Dict[str, Any]]],
|
||||
_groups: Dict[str, Group],
|
||||
_memberships: Dict[str, List[Membership]],
|
||||
_memberships: Dict[str, Set[Any]],
|
||||
) -> Tuple[List[Traceback], List[Collective], List[NCCLCall]]:
|
||||
"""
|
||||
groups, memberships are the non-flat dicts that are indexable
|
||||
@ -572,8 +664,11 @@ def build_collectives(
|
||||
entries = all_entries[first_rank]
|
||||
pg_name, desc = entries[0]["process_group"]
|
||||
profiling_name = entries[0]["profiling_name"]
|
||||
collective_seq_id = entries[0]["collective_seq_id"]
|
||||
expected_ranks = set(_memberships[pg_name])
|
||||
found_ranks = {first_rank}
|
||||
candidate_ranks = {first_rank}
|
||||
candidate_idx = {}
|
||||
found_ranks = set()
|
||||
found_idx = {}
|
||||
|
||||
if find_coalesced_group(pg_name, entries):
|
||||
@ -626,26 +721,81 @@ def build_collectives(
|
||||
)
|
||||
)
|
||||
nccl_calls.extend(reversed(reversed_calls))
|
||||
|
||||
else:
|
||||
has_undecided_case = False
|
||||
errors = Set()
|
||||
for o in expected_ranks.intersection(set(other_ranks)):
|
||||
for i, e in enumerate(all_entries[o]): # type: ignore[index]
|
||||
# step over ops from other PGs
|
||||
if e["process_group"] == (pg_name, desc):
|
||||
# only check match state when seq_id matches
|
||||
if (
|
||||
e["process_group"] == (pg_name, desc)
|
||||
and e["collective_seq_id"] == collective_seq_id
|
||||
):
|
||||
match_state = match_one_event(entries[0], e, _memberships)
|
||||
if (
|
||||
match_one_event(entries[0], e, _memberships)
|
||||
match_state
|
||||
in [MatchState.FULLY_MATCHED, MatchState.UNDECIDED]
|
||||
and mismatch[pg_name] == 0
|
||||
):
|
||||
found_ranks.add(o)
|
||||
found_idx[o] = i
|
||||
has_undecided_case = match_state == MatchState.UNDECIDED
|
||||
else:
|
||||
# we found a mismatch. what do we do with that?
|
||||
mismatch[pg_name] += 1
|
||||
print(
|
||||
f"Mismatched collective on rank {o} for group {pg_name}:{desc} collective {profiling_name}"
|
||||
)
|
||||
candidate_ranks.add(o)
|
||||
candidate_idx[o] = i
|
||||
errors.add(match_state)
|
||||
break
|
||||
|
||||
# case one: not every rank join the collective or in the flight recorder.
|
||||
if (candidate_ranks | found_ranks) != expected_ranks:
|
||||
mismatch[pg_name] += 1
|
||||
print(
|
||||
f"Not all ranks joining collective for group {pg_name}:{desc} collective {profiling_name}",
|
||||
f"Missing ranks are {expected_ranks - (candidate_ranks | found_ranks)}",
|
||||
)
|
||||
elif len(candidate_ranks) == 1:
|
||||
# case two: alltoall or alltoall_base case.
|
||||
if has_undecided_case:
|
||||
alltoall_cases = [entries[0]] + [
|
||||
all_entries[o][found_idx[o]] for o in found_ranks
|
||||
]
|
||||
check_result, input_numel, output_numel = check_size_alltoall(
|
||||
alltoall_cases
|
||||
)
|
||||
if not check_result:
|
||||
mismatch[pg_name] += 1
|
||||
print(
|
||||
f"Input/output mismatch in the collective for group {pg_name}:{desc} collective {profiling_name}",
|
||||
f"input_numel {input_numel} output_numel{output_numel}",
|
||||
)
|
||||
candidate_ranks.update(found_ranks)
|
||||
candidate_idx.update(found_idx)
|
||||
found_idx.clear()
|
||||
found_ranks.clear()
|
||||
else:
|
||||
found_ranks.update(candidate_ranks)
|
||||
found_idx.update(candidate_idx)
|
||||
candidate_idx.clear()
|
||||
candidate_ranks.clear()
|
||||
# case three: all joined and everything matches on all ranks.
|
||||
else:
|
||||
found_ranks.update(candidate_ranks)
|
||||
found_idx.update(candidate_idx)
|
||||
candidate_idx.clear()
|
||||
candidate_ranks.clear()
|
||||
# case four: mismatch cases due to not same type, size mismatch or state mismatch.
|
||||
else:
|
||||
error_msg = ", ".join(error.name for error in errors)
|
||||
print(
|
||||
f"Collective errors for group {pg_name}:{desc} collective {profiling_name}",
|
||||
f"Found errors: {error_msg}",
|
||||
)
|
||||
candidate_ranks.update(found_ranks)
|
||||
candidate_idx.update(found_idx)
|
||||
found_idx.clear()
|
||||
found_ranks.clear()
|
||||
|
||||
# at this point there are 3 possibilities
|
||||
# 1. we found a match on all the ranks that are members of the group
|
||||
# -> we create a Collective and remove the individual entries from their original lists
|
||||
@ -669,15 +819,19 @@ def build_collectives(
|
||||
# TODO should there be a way to mark 'mismatches'?
|
||||
else:
|
||||
print("appending a non-matching collective")
|
||||
nccl_calls.append(
|
||||
build_nccl_call(
|
||||
all_entries[first_rank].pop(0),
|
||||
id=len(nccl_calls),
|
||||
collective_id=None,
|
||||
group_id=pg_name,
|
||||
global_rank=r,
|
||||
# TODO: figure out a better for mismatch.
|
||||
# Also, shall we add seq Id as well?
|
||||
for r in candidate_ranks:
|
||||
i = candidate_idx[r] if r != first_rank else 0
|
||||
nccl_calls.append(
|
||||
build_nccl_call(
|
||||
all_entries[r].pop(i), # type: ignore[index]
|
||||
id=len(nccl_calls),
|
||||
collective_id=None,
|
||||
group_id=pg_name,
|
||||
global_rank=r,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if mismatch[pg_name] > MISMATCH_TAIL:
|
||||
print(f"Too many mismatches for process_group {pg_name}:{desc}, aborting")
|
||||
|
Reference in New Issue
Block a user