[FR] Fix the bug in FR script (e.g., checking all ranks dump check) (#134383)

We somehow convert the rank to string which makes the ranks check fail. This fix now convert them all to int.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134383
Approved by: https://github.com/c-p-i-o
This commit is contained in:
fduwjj
2024-08-25 22:50:11 -07:00
committed by PyTorch MergeBot
parent 92c4771853
commit bf5c7bf06d
4 changed files with 12 additions and 11 deletions

View File

@ -53,10 +53,10 @@ class FlightRecorderEventTest(TestCase):
) )
e3 = create_one_event( e3 = create_one_event(
"alltoall", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1 "all_to_all", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
) )
e4 = create_one_event( e4 = create_one_event(
"alltoall", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1 "all_to_all", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
) )
self.assertEqual(match_one_event(e3, e4, membership), MatchState.UNDECIDED) self.assertEqual(match_one_event(e3, e4, membership), MatchState.UNDECIDED)

View File

@ -244,7 +244,7 @@ def build_collectives(
nccl_calls.extend(reversed(reversed_calls)) nccl_calls.extend(reversed(reversed_calls))
else: else:
has_undecided_case = False has_undecided_case = False
errors = Set() errors = set()
for o in expected_ranks.intersection(set(other_ranks)): for o in expected_ranks.intersection(set(other_ranks)):
for i, e in enumerate(all_entries[o]): # type: ignore[index] for i, e in enumerate(all_entries[o]): # type: ignore[index]
# step over ops from other PGs # step over ops from other PGs

View File

@ -138,8 +138,7 @@ COLLECTIVES = {
"_reduce_scatter_base", "_reduce_scatter_base",
"gather", "gather",
"scatter", "scatter",
"alltoall_base", "all_to_all",
"alltoall",
} }
P2P = { P2P = {
@ -158,7 +157,7 @@ class MatchState(Enum):
- COLLECTIVE_STATE_MISMATCH: - COLLECTIVE_STATE_MISMATCH:
The states of the collective not same, such as one finished while another just started or scheduled. The states of the collective not same, such as one finished while another just started or scheduled.
- UNDECIDED: - UNDECIDED:
The match status is ambiguous or cannot be determined, e.g., we might need to check all ranks for alltoall_base. The match status is ambiguous or cannot be determined, e.g., we might need to check all ranks for all_to_all.
""" """
FULLY_MATCHED = 1 FULLY_MATCHED = 1
@ -171,6 +170,8 @@ class MatchState(Enum):
def check_size_evenly_broadcasting( def check_size_evenly_broadcasting(
list1: List[Any], list2: List[Any], size: int list1: List[Any], list2: List[Any], size: int
) -> bool: ) -> bool:
if len(list1) != len(list2):
return False
ratio = None ratio = None
for a, b in zip(list1, list2): for a, b in zip(list1, list2):
current_ratio = int(a) / int(b) current_ratio = int(a) / int(b)
@ -283,7 +284,7 @@ class Op:
elif self.type in COLLECTIVES: elif self.type in COLLECTIVES:
if self.type != other.type: if self.type != other.type:
return MatchState.COLLECTIVE_TYPE_MISMATCH return MatchState.COLLECTIVE_TYPE_MISMATCH
if self.type in ["alltoall", "alltoall_base"]: if self.type == "all_to_all":
return MatchState.UNDECIDED return MatchState.UNDECIDED
if self.type != "scatter" and self.input_sizes != other.input_sizes: if self.type != "scatter" and self.input_sizes != other.input_sizes:
return MatchState.SIZE_OR_SYNTAX_MISMATCH return MatchState.SIZE_OR_SYNTAX_MISMATCH
@ -297,14 +298,14 @@ class Op:
"all_gather", "all_gather",
"all_gather_base", "all_gather_base",
] and not check_size_evenly_broadcasting( ] and not check_size_evenly_broadcasting(
other.output_sizes, self.input_sizes, self.pg_size other.output_sizes[0], self.input_sizes[0], self.pg_size
): ):
return MatchState.SIZE_OR_SYNTAX_MISMATCH return MatchState.SIZE_OR_SYNTAX_MISMATCH
if self.type in [ if self.type in [
"reduce_scatter", "reduce_scatter",
"_reduce_scatter_base", "_reduce_scatter_base",
] and not check_size_evenly_broadcasting( ] and not check_size_evenly_broadcasting(
other.input_sizes, self.output_sizes, self.pg_size other.input_sizes[0], self.output_sizes[0], self.pg_size
): ):
return MatchState.SIZE_OR_SYNTAX_MISMATCH return MatchState.SIZE_OR_SYNTAX_MISMATCH
# TODO: need to add more checks for gather and scatter. # TODO: need to add more checks for gather and scatter.

View File

@ -202,8 +202,8 @@ def check_no_missing_dump_files(
) -> None: ) -> None:
all_ranks = set() all_ranks = set()
for membership in memberships: for membership in memberships:
all_ranks.add(str(membership.global_rank)) all_ranks.add(int(membership.global_rank))
dumps_ranks = set(entries.keys()) dumps_ranks = {int(key) for key in entries.keys()}
assert ( assert (
dumps_ranks == all_ranks dumps_ranks == all_ranks
), f"Missing dump files from ranks {all_ranks - dumps_ranks}" ), f"Missing dump files from ranks {all_ranks - dumps_ranks}"