Revert "[Flight Recorder] Add more basic analysis to the script (#133412)"

This reverts commit fcc2fc1a70c35628939611b496b209fa0a1d19bf.

Reverted https://github.com/pytorch/pytorch/pull/133412 on behalf of https://github.com/atalman due to New test: distributed/flight_recorder/test_fr_analysis is constantly failing ([comment](https://github.com/pytorch/pytorch/pull/133412#issuecomment-2293506539))
This commit is contained in:
PyTorch MergeBot
2024-08-16 13:26:25 +00:00
parent b444343087
commit e1b9b89d94
2 changed files with 46 additions and 282 deletions

View File

@ -1,82 +0,0 @@
# Owner(s): ["oncall: distributed"]
from tools.flight_recorder.fr_trace import match_one_event, MatchState
from torch.testing._internal.common_utils import run_tests, TestCase
def create_one_event(
collectcive_name,
pg_info,
input_sizes,
output_sizes,
state="scheduled",
collective_seq_id=0,
p2p_seq_id=0,
):
return {
"profiling_name": f"nccl:{collectcive_name}",
"state": state,
"process_group": pg_info,
"input_sizes": input_sizes,
"output_sizes": output_sizes,
"collective_seq_id": str(collective_seq_id),
"p2p_seq_id": str(p2p_seq_id),
}
class FlightRecorderEventTest(TestCase):
def test_match_one_event(self):
e1 = create_one_event(
"all_reduce", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
)
membership = {"0": {0, 1}}
self.assertEqual(match_one_event(e1, e1, membership), MatchState.FULLY_MATCHED)
e2 = create_one_event(
"all_gather", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
)
self.assertEqual(
match_one_event(e1, e2, membership), MatchState.COLLECTIVE_TYPE_MISMATCH
)
e3 = create_one_event(
"alltoall", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
)
e4 = create_one_event(
"alltoall", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
)
self.assertEqual(match_one_event(e3, e4, membership), MatchState.UNDECIDED)
e5 = create_one_event(
"all_reduce", ("0", "default"), [[5, 4]], [[4, 4]], "scheduled", 1, 1
)
self.assertEqual(
match_one_event(e1, e5, membership), MatchState.SIZE_OR_SYNTAX_MISMATCH
)
e6 = create_one_event(
"all_reduce", ("0", "default"), [[4, 4]], [[5, 4]], "scheduled", 1, 2
)
self.assertEqual(
match_one_event(e1, e6, membership), MatchState.SIZE_OR_SYNTAX_MISMATCH
)
e7 = create_one_event(
"all_reduce", ("0", "default"), [[4, 4]], [[5, 4]], "scheduled", 2
)
self.assertEqual(
match_one_event(e7, e7, membership), MatchState.SIZE_OR_SYNTAX_MISMATCH
)
e9 = create_one_event(
"all_reduce", ("0", "default"), [[4, 4]], [[4, 4]], "completed", 1
)
self.assertEqual(
match_one_event(e1, e9, membership), MatchState.COLLECTIVE_STATE_MISMATCH
)
if __name__ == "__main__":
run_tests()

View File

@ -28,12 +28,9 @@ python fr_trace.py -d <dump dir containing trace files> [-o <output file>]
"""
import argparse
import ast
import math
import os
import pickle
import sys
from enum import Enum
from typing import ( # type: ignore[attr-defined]
_eval_type,
Any,
@ -41,7 +38,6 @@ from typing import ( # type: ignore[attr-defined]
Generic,
List,
NamedTuple,
Set,
Tuple,
Type,
TypeVar,
@ -155,9 +151,6 @@ TODO
"""
Collective Matching logic
NOTE: For now, these collectives need to be supported by NCCL,
https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/overview.html.
"""
COLLECTIVES = {
"broadcast",
@ -169,8 +162,6 @@ COLLECTIVES = {
"_reduce_scatter_base",
"gather",
"scatter",
"alltoall_base",
"alltoall",
}
P2P = {
@ -179,50 +170,6 @@ P2P = {
}
class MatchState(Enum):
"""
Enum representing the possible states of matching for collective operations.
- FULLY_MATCHED: Indicates that all aspects of the collective operations match.
- COLLECTIVE_TYPE_MISMATCH: The types of the collective operations differ.
- SIZE_OR_SYNTAX_MISMATCH: There is a mismatch in input/output sizes or violation of collective syntax.
- COLLECTIVE_STATE_MISMATCH:
The states of the collective not same, such as one finished while another just started or scheduled.
- UNDECIDED:
The match status is ambiguous or cannot be determined, e.g., we might need to check all ranks for alltoall_base.
"""
FULLY_MATCHED = 1
COLLECTIVE_TYPE_MISMATCH = 2
SIZE_OR_SYNTAX_MISMATCH = 3
COLLECTIVE_STATE_MISMATCH = 4
UNDECIDED = 5
def check_size_even_expand(list1: List[Any], list2: List[Any], size: int) -> bool:
ratio = None
for a, b in zip(list1, list2):
current_ratio = int(a) / int(b)
if current_ratio == 1:
continue
if current_ratio != size:
return False
elif ratio is None:
ratio = current_ratio
else:
return False
return True
def check_size_alltoall(alltoall_cases: List[Dict[str, Any]]) -> Tuple[bool, int, int]:
input_numel = 0
output_numel = 0
for e in alltoall_cases:
input_numel += math.prod(e["input_sizes"][0])
output_numel += math.prod(e["output_sizes"][0])
return input_numel == output_numel, input_numel, output_numel
class Op:
"""Parses relevant info about operation out of 'event' dict
@ -232,7 +179,7 @@ class Op:
nccl:recv 3<-0
"""
def __init__(self, event: Dict[Any, Any], memberships: Dict[str, Set[Any]]):
def __init__(self, event: Dict[Any, Any], memberships: Dict[str, List[Membership]]):
profiling_name = event["profiling_name"]
nccl, name = profiling_name.split(":")
assert nccl == "nccl", f"name formatting error? {nccl} != 'nccl'"
@ -257,7 +204,6 @@ class Op:
self._src, self._dst = -1, -1
pg_name, pg_desc = event["process_group"]
self._init_global_src_dst(memberships[pg_name])
self.pg_size = len(memberships[pg_name])
if type in P2P | COLLECTIVES:
self.input_sizes = event["input_sizes"]
@ -267,7 +213,7 @@ class Op:
self.collective_seq_id = event["collective_seq_id"]
self.p2p_seq_id = event["p2p_seq_id"]
def _init_global_src_dst(self, pg_ranks: Set[Any]) -> None:
def _init_global_src_dst(self, pg_ranks: List[Membership]) -> None:
pg_ranks = sorted(pg_ranks)
self._src_g = pg_ranks[self._src] if self._src is not None else None
self._dst_g = pg_ranks[self._dst] if self._dst is not None else None
@ -289,79 +235,42 @@ class Op:
)
return f"{self.type}(input_sizes={self.input_sizes}, {self.state})"
def match(self, other: "Op") -> MatchState:
def match(self, other) -> bool: # type: ignore[no-untyped-def]
# TODO: I think this can validly not match,
# e.g. if one PG was used for p2p ops between only some of the peers?
# if self.seq_id != other.seq_id:
# return False
if self.type == "send":
# TODO: We need more states for p2p ops.
return (
MatchState.FULLY_MATCHED
if (
other.type == "recv"
and self.src == other.src
and self.dst == other.dst
and self.input_sizes == other.output_sizes
)
else MatchState.SIZE_OR_SYNTAX_MISMATCH
return bool(
other.type == "recv"
and self.src == other.src
and self.dst == other.dst
and self.input_sizes == other.output_sizes
)
elif self.type == "recv":
return (
MatchState.FULLY_MATCHED
if (
other.type == "send"
and self.src == other.src
and self.dst == other.dst
and self.output_sizes == other.input_sizes
)
else MatchState.SIZE_OR_SYNTAX_MISMATCH
return bool(
other.type == "send"
and self.src == other.src
and self.dst == other.dst
and self.output_sizes == other.input_sizes
)
elif self.type in COLLECTIVES:
if self.type != other.type:
return MatchState.COLLECTIVE_TYPE_MISMATCH
if self.type in ["alltoall", "alltoall_base"]:
return MatchState.UNDECIDED
if self.type != "scatter" and self.input_sizes != other.input_sizes:
return MatchState.SIZE_OR_SYNTAX_MISMATCH
if self.type != "gather" and self.output_sizes != other.output_sizes:
return MatchState.SIZE_OR_SYNTAX_MISMATCH
if self.type == "all_reduce" and self.input_sizes != other.output_sizes:
return MatchState.SIZE_OR_SYNTAX_MISMATCH
# TODO: need to consider uneven sharding for all-gather.
# TODO: need to consider all_gather_into_tensor_coalesced (coalesced related)
if self.type in [
"all_gather",
"all_gather_base",
] and not check_size_even_expand(
other.output_sizes, self.input_sizes, self.pg_size
):
return MatchState.SIZE_OR_SYNTAX_MISMATCH
if self.type in [
"reduce_scatter",
"_reduce_scatter_base",
] and not check_size_even_expand(
other.input_sizes, self.output_sizes, self.pg_size
):
return MatchState.SIZE_OR_SYNTAX_MISMATCH
# TODO: need to add more checks for gather and scatter.
if self.state != other.state:
return MatchState.COLLECTIVE_STATE_MISMATCH
elif self.type == "coalesced":
return (
MatchState.FULLY_MATCHED
if (other.type == "coalesced")
else MatchState.SIZE_OR_SYNTAX_MISMATCH
return bool(
self.type == other.type and self.input_sizes == other.input_sizes
)
return MatchState.FULLY_MATCHED
# TODO(whc) - output sizes dont have to match for e.g. gather, not sure if they ever have to match?
# and self.output_sizes == other.output_sizes)
elif self.type == "coalesced":
return bool(other.type == "coalesced")
return False
def match_one_event(
event_a: Dict[Any, Any],
event_b: Dict[Any, Any],
memberships: Dict[str, Set[Any]],
) -> MatchState:
memberships: Dict[str, List[Membership]],
) -> bool:
op_a = Op(event_a, memberships)
op_b = Op(event_b, memberships)
return op_a.match(op_b)
@ -371,7 +280,7 @@ def match_coalesced_groups(
all_rank_events: Dict[Any, Any],
group_size: int,
groups: Dict[str, Group],
memberships: Dict[str, Set[Any]],
memberships: Dict[str, List[Membership]],
) -> bool:
"""
all_rank_events: {
@ -450,7 +359,7 @@ def match_coalesced_groups(
dst_global_rank = sorted(memberships[op.pg_name])[op.dst]
peer_ops = all_ops[dst_global_rank]
for i, other in enumerate(peer_ops):
if op.match(other) == MatchState.FULLY_MATCHED:
if op.match(other):
match_idx = i
break
elif op.dst == other.src:
@ -479,7 +388,7 @@ Flat DB builder
def build_groups_memberships(
pg_config: Any,
) -> Tuple[List[Group], Dict[Any, Group], List[Membership], Dict[str, Set[Any]]]:
) -> Tuple[List[Group], Dict[Any, Group], List[Membership], Dict[str, Any]]:
"""
pg_config: {
global_rank: {
@ -508,7 +417,7 @@ def build_groups_memberships(
for global_rank in pg_config:
for pg_id in pg_config[global_rank]:
desc = pg_config[global_rank][pg_id]["desc"]
ranks = ast.literal_eval(pg_config[global_rank][pg_id]["ranks"])
ranks = pg_config[global_rank][pg_id]["ranks"]
if isinstance(ranks, str):
# TODO Bug in FR data format? ranks is '[0, 1,...]'
ranks = eval(ranks)
@ -518,7 +427,6 @@ def build_groups_memberships(
for rank in ranks:
memberships.append(Membership(group_id=pg_id, global_rank=rank))
_groups[pg_id] = groups[-1]
# TODO: make ranks int no matter what input (because it can be json or pickled string)
_memberships[pg_id] = set(ranks)
else:
# validation across ranks
@ -578,7 +486,7 @@ def find_coalesced_group(
def just_print_entries(
all_entries: Dict[int, List[Dict[str, Any]]],
_groups: Dict[str, Group],
_memberships: Dict[str, Set[Any]],
_memberships: Dict[str, List[Membership]],
) -> None:
rows = []
ranks = sorted(all_entries.keys())
@ -601,9 +509,9 @@ def just_print_entries(
def build_collectives(
all_entries: Dict[str, List[Dict[str, Any]]],
all_entries: Dict[int, List[Dict[str, Any]]],
_groups: Dict[str, Group],
_memberships: Dict[str, Set[Any]],
_memberships: Dict[str, List[Membership]],
) -> Tuple[List[Traceback], List[Collective], List[NCCLCall]]:
"""
groups, memberships are the non-flat dicts that are indexable
@ -664,11 +572,8 @@ def build_collectives(
entries = all_entries[first_rank]
pg_name, desc = entries[0]["process_group"]
profiling_name = entries[0]["profiling_name"]
collective_seq_id = entries[0]["collective_seq_id"]
expected_ranks = set(_memberships[pg_name])
candidate_ranks = {first_rank}
candidate_idx = {}
found_ranks = set()
found_ranks = {first_rank}
found_idx = {}
if find_coalesced_group(pg_name, entries):
@ -721,81 +626,26 @@ def build_collectives(
)
)
nccl_calls.extend(reversed(reversed_calls))
else:
has_undecided_case = False
errors = Set()
for o in expected_ranks.intersection(set(other_ranks)):
for i, e in enumerate(all_entries[o]): # type: ignore[index]
# step over ops from other PGs
# only check match state when seq_id matches
if (
e["process_group"] == (pg_name, desc)
and e["collective_seq_id"] == collective_seq_id
):
match_state = match_one_event(entries[0], e, _memberships)
if e["process_group"] == (pg_name, desc):
if (
match_state
in [MatchState.FULLY_MATCHED, MatchState.UNDECIDED]
match_one_event(entries[0], e, _memberships)
and mismatch[pg_name] == 0
):
found_ranks.add(o)
found_idx[o] = i
has_undecided_case = match_state == MatchState.UNDECIDED
else:
candidate_ranks.add(o)
candidate_idx[o] = i
errors.add(match_state)
# we found a mismatch. what do we do with that?
mismatch[pg_name] += 1
print(
f"Mismatched collective on rank {o} for group {pg_name}:{desc} collective {profiling_name}"
)
break
# case one: not every rank join the collective or in the flight recorder.
if (candidate_ranks | found_ranks) != expected_ranks:
mismatch[pg_name] += 1
print(
f"Not all ranks joining collective for group {pg_name}:{desc} collective {profiling_name}",
f"Missing ranks are {expected_ranks - (candidate_ranks | found_ranks)}",
)
elif len(candidate_ranks) == 1:
# case two: alltoall or alltoall_base case.
if has_undecided_case:
alltoall_cases = [entries[0]] + [
all_entries[o][found_idx[o]] for o in found_ranks
]
check_result, input_numel, output_numel = check_size_alltoall(
alltoall_cases
)
if not check_result:
mismatch[pg_name] += 1
print(
f"Input/output mismatch in the collective for group {pg_name}:{desc} collective {profiling_name}",
f"input_numel {input_numel} output_numel{output_numel}",
)
candidate_ranks.update(found_ranks)
candidate_idx.update(found_idx)
found_idx.clear()
found_ranks.clear()
else:
found_ranks.update(candidate_ranks)
found_idx.update(candidate_idx)
candidate_idx.clear()
candidate_ranks.clear()
# case three: all joined and everything matches on all ranks.
else:
found_ranks.update(candidate_ranks)
found_idx.update(candidate_idx)
candidate_idx.clear()
candidate_ranks.clear()
# case four: mismatch cases due to not same type, size mismatch or state mismatch.
else:
error_msg = ", ".join(error.name for error in errors)
print(
f"Collective errors for group {pg_name}:{desc} collective {profiling_name}",
f"Found errors: {error_msg}",
)
candidate_ranks.update(found_ranks)
candidate_idx.update(found_idx)
found_idx.clear()
found_ranks.clear()
# at this point there are 3 possibilities
# 1. we found a match on all the ranks that are members of the group
# -> we create a Collective and remove the individual entries from their original lists
@ -819,19 +669,15 @@ def build_collectives(
# TODO should there be a way to mark 'mismatches'?
else:
print("appending a non-matching collective")
# TODO: figure out a better for mismatch.
# Also, shall we add seq Id as well?
for r in candidate_ranks:
i = candidate_idx[r] if r != first_rank else 0
nccl_calls.append(
build_nccl_call(
all_entries[r].pop(i), # type: ignore[index]
id=len(nccl_calls),
collective_id=None,
group_id=pg_name,
global_rank=r,
)
nccl_calls.append(
build_nccl_call(
all_entries[first_rank].pop(0),
id=len(nccl_calls),
collective_id=None,
group_id=pg_name,
global_rank=r,
)
)
if mismatch[pg_name] > MISMATCH_TAIL:
print(f"Too many mismatches for process_group {pg_name}:{desc}, aborting")