[FR] Fix the bug in FR script (e.g., checking all ranks dump check) (#134383)

We somehow convert the rank to string which makes the ranks check fail. This fix now convert them all to int.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134383
Approved by: https://github.com/c-p-i-o
This commit is contained in:
fduwjj
2024-08-25 22:50:11 -07:00
committed by PyTorch MergeBot
parent 92c4771853
commit bf5c7bf06d
4 changed files with 12 additions and 11 deletions

View File

@ -53,10 +53,10 @@ class FlightRecorderEventTest(TestCase):
)
e3 = create_one_event(
"alltoall", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
"all_to_all", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
)
e4 = create_one_event(
"alltoall", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
"all_to_all", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
)
self.assertEqual(match_one_event(e3, e4, membership), MatchState.UNDECIDED)

View File

@ -244,7 +244,7 @@ def build_collectives(
nccl_calls.extend(reversed(reversed_calls))
else:
has_undecided_case = False
errors = Set()
errors = set()
for o in expected_ranks.intersection(set(other_ranks)):
for i, e in enumerate(all_entries[o]): # type: ignore[index]
# step over ops from other PGs

View File

@ -138,8 +138,7 @@ COLLECTIVES = {
"_reduce_scatter_base",
"gather",
"scatter",
"alltoall_base",
"alltoall",
"all_to_all",
}
P2P = {
@ -158,7 +157,7 @@ class MatchState(Enum):
- COLLECTIVE_STATE_MISMATCH:
The states of the collective not same, such as one finished while another just started or scheduled.
- UNDECIDED:
The match status is ambiguous or cannot be determined, e.g., we might need to check all ranks for alltoall_base.
The match status is ambiguous or cannot be determined, e.g., we might need to check all ranks for all_to_all.
"""
FULLY_MATCHED = 1
@ -171,6 +170,8 @@ class MatchState(Enum):
def check_size_evenly_broadcasting(
list1: List[Any], list2: List[Any], size: int
) -> bool:
if len(list1) != len(list2):
return False
ratio = None
for a, b in zip(list1, list2):
current_ratio = int(a) / int(b)
@ -283,7 +284,7 @@ class Op:
elif self.type in COLLECTIVES:
if self.type != other.type:
return MatchState.COLLECTIVE_TYPE_MISMATCH
if self.type in ["alltoall", "alltoall_base"]:
if self.type == "all_to_all":
return MatchState.UNDECIDED
if self.type != "scatter" and self.input_sizes != other.input_sizes:
return MatchState.SIZE_OR_SYNTAX_MISMATCH
@ -297,14 +298,14 @@ class Op:
"all_gather",
"all_gather_base",
] and not check_size_evenly_broadcasting(
other.output_sizes, self.input_sizes, self.pg_size
other.output_sizes[0], self.input_sizes[0], self.pg_size
):
return MatchState.SIZE_OR_SYNTAX_MISMATCH
if self.type in [
"reduce_scatter",
"_reduce_scatter_base",
] and not check_size_evenly_broadcasting(
other.input_sizes, self.output_sizes, self.pg_size
other.input_sizes[0], self.output_sizes[0], self.pg_size
):
return MatchState.SIZE_OR_SYNTAX_MISMATCH
# TODO: need to add more checks for gather and scatter.

View File

@ -202,8 +202,8 @@ def check_no_missing_dump_files(
) -> None:
all_ranks = set()
for membership in memberships:
all_ranks.add(str(membership.global_rank))
dumps_ranks = set(entries.keys())
all_ranks.add(int(membership.global_rank))
dumps_ranks = {int(key) for key in entries.keys()}
assert (
dumps_ranks == all_ranks
), f"Missing dump files from ranks {all_ranks - dumps_ranks}"