mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[FR] Fix the bug in FR script (e.g., checking all ranks dump check) (#134383)
We somehow convert the rank to string which makes the ranks check fail. This fix now convert them all to int. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134383 Approved by: https://github.com/c-p-i-o
This commit is contained in:
@ -53,10 +53,10 @@ class FlightRecorderEventTest(TestCase):
|
||||
)
|
||||
|
||||
e3 = create_one_event(
|
||||
"alltoall", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
|
||||
"all_to_all", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
|
||||
)
|
||||
e4 = create_one_event(
|
||||
"alltoall", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
|
||||
"all_to_all", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
|
||||
)
|
||||
self.assertEqual(match_one_event(e3, e4, membership), MatchState.UNDECIDED)
|
||||
|
||||
|
@ -244,7 +244,7 @@ def build_collectives(
|
||||
nccl_calls.extend(reversed(reversed_calls))
|
||||
else:
|
||||
has_undecided_case = False
|
||||
errors = Set()
|
||||
errors = set()
|
||||
for o in expected_ranks.intersection(set(other_ranks)):
|
||||
for i, e in enumerate(all_entries[o]): # type: ignore[index]
|
||||
# step over ops from other PGs
|
||||
|
@ -138,8 +138,7 @@ COLLECTIVES = {
|
||||
"_reduce_scatter_base",
|
||||
"gather",
|
||||
"scatter",
|
||||
"alltoall_base",
|
||||
"alltoall",
|
||||
"all_to_all",
|
||||
}
|
||||
|
||||
P2P = {
|
||||
@ -158,7 +157,7 @@ class MatchState(Enum):
|
||||
- COLLECTIVE_STATE_MISMATCH:
|
||||
The states of the collective not same, such as one finished while another just started or scheduled.
|
||||
- UNDECIDED:
|
||||
The match status is ambiguous or cannot be determined, e.g., we might need to check all ranks for alltoall_base.
|
||||
The match status is ambiguous or cannot be determined, e.g., we might need to check all ranks for all_to_all.
|
||||
"""
|
||||
|
||||
FULLY_MATCHED = 1
|
||||
@ -171,6 +170,8 @@ class MatchState(Enum):
|
||||
def check_size_evenly_broadcasting(
|
||||
list1: List[Any], list2: List[Any], size: int
|
||||
) -> bool:
|
||||
if len(list1) != len(list2):
|
||||
return False
|
||||
ratio = None
|
||||
for a, b in zip(list1, list2):
|
||||
current_ratio = int(a) / int(b)
|
||||
@ -283,7 +284,7 @@ class Op:
|
||||
elif self.type in COLLECTIVES:
|
||||
if self.type != other.type:
|
||||
return MatchState.COLLECTIVE_TYPE_MISMATCH
|
||||
if self.type in ["alltoall", "alltoall_base"]:
|
||||
if self.type == "all_to_all":
|
||||
return MatchState.UNDECIDED
|
||||
if self.type != "scatter" and self.input_sizes != other.input_sizes:
|
||||
return MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||
@ -297,14 +298,14 @@ class Op:
|
||||
"all_gather",
|
||||
"all_gather_base",
|
||||
] and not check_size_evenly_broadcasting(
|
||||
other.output_sizes, self.input_sizes, self.pg_size
|
||||
other.output_sizes[0], self.input_sizes[0], self.pg_size
|
||||
):
|
||||
return MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||
if self.type in [
|
||||
"reduce_scatter",
|
||||
"_reduce_scatter_base",
|
||||
] and not check_size_evenly_broadcasting(
|
||||
other.input_sizes, self.output_sizes, self.pg_size
|
||||
other.input_sizes[0], self.output_sizes[0], self.pg_size
|
||||
):
|
||||
return MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||
# TODO: need to add more checks for gather and scatter.
|
||||
|
@ -202,8 +202,8 @@ def check_no_missing_dump_files(
|
||||
) -> None:
|
||||
all_ranks = set()
|
||||
for membership in memberships:
|
||||
all_ranks.add(str(membership.global_rank))
|
||||
dumps_ranks = set(entries.keys())
|
||||
all_ranks.add(int(membership.global_rank))
|
||||
dumps_ranks = {int(key) for key in entries.keys()}
|
||||
assert (
|
||||
dumps_ranks == all_ranks
|
||||
), f"Missing dump files from ranks {all_ranks - dumps_ranks}"
|
||||
|
Reference in New Issue
Block a user