mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[FR] Fix the bug in FR script (e.g., checking all ranks dump check) (#134383)
We somehow convert the rank to string which makes the ranks check fail. This fix now convert them all to int. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134383 Approved by: https://github.com/c-p-i-o
This commit is contained in:
@ -53,10 +53,10 @@ class FlightRecorderEventTest(TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
e3 = create_one_event(
|
e3 = create_one_event(
|
||||||
"alltoall", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
|
"all_to_all", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
|
||||||
)
|
)
|
||||||
e4 = create_one_event(
|
e4 = create_one_event(
|
||||||
"alltoall", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
|
"all_to_all", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
|
||||||
)
|
)
|
||||||
self.assertEqual(match_one_event(e3, e4, membership), MatchState.UNDECIDED)
|
self.assertEqual(match_one_event(e3, e4, membership), MatchState.UNDECIDED)
|
||||||
|
|
||||||
|
@ -244,7 +244,7 @@ def build_collectives(
|
|||||||
nccl_calls.extend(reversed(reversed_calls))
|
nccl_calls.extend(reversed(reversed_calls))
|
||||||
else:
|
else:
|
||||||
has_undecided_case = False
|
has_undecided_case = False
|
||||||
errors = Set()
|
errors = set()
|
||||||
for o in expected_ranks.intersection(set(other_ranks)):
|
for o in expected_ranks.intersection(set(other_ranks)):
|
||||||
for i, e in enumerate(all_entries[o]): # type: ignore[index]
|
for i, e in enumerate(all_entries[o]): # type: ignore[index]
|
||||||
# step over ops from other PGs
|
# step over ops from other PGs
|
||||||
|
@ -138,8 +138,7 @@ COLLECTIVES = {
|
|||||||
"_reduce_scatter_base",
|
"_reduce_scatter_base",
|
||||||
"gather",
|
"gather",
|
||||||
"scatter",
|
"scatter",
|
||||||
"alltoall_base",
|
"all_to_all",
|
||||||
"alltoall",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
P2P = {
|
P2P = {
|
||||||
@ -158,7 +157,7 @@ class MatchState(Enum):
|
|||||||
- COLLECTIVE_STATE_MISMATCH:
|
- COLLECTIVE_STATE_MISMATCH:
|
||||||
The states of the collective not same, such as one finished while another just started or scheduled.
|
The states of the collective not same, such as one finished while another just started or scheduled.
|
||||||
- UNDECIDED:
|
- UNDECIDED:
|
||||||
The match status is ambiguous or cannot be determined, e.g., we might need to check all ranks for alltoall_base.
|
The match status is ambiguous or cannot be determined, e.g., we might need to check all ranks for all_to_all.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
FULLY_MATCHED = 1
|
FULLY_MATCHED = 1
|
||||||
@ -171,6 +170,8 @@ class MatchState(Enum):
|
|||||||
def check_size_evenly_broadcasting(
|
def check_size_evenly_broadcasting(
|
||||||
list1: List[Any], list2: List[Any], size: int
|
list1: List[Any], list2: List[Any], size: int
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
if len(list1) != len(list2):
|
||||||
|
return False
|
||||||
ratio = None
|
ratio = None
|
||||||
for a, b in zip(list1, list2):
|
for a, b in zip(list1, list2):
|
||||||
current_ratio = int(a) / int(b)
|
current_ratio = int(a) / int(b)
|
||||||
@ -283,7 +284,7 @@ class Op:
|
|||||||
elif self.type in COLLECTIVES:
|
elif self.type in COLLECTIVES:
|
||||||
if self.type != other.type:
|
if self.type != other.type:
|
||||||
return MatchState.COLLECTIVE_TYPE_MISMATCH
|
return MatchState.COLLECTIVE_TYPE_MISMATCH
|
||||||
if self.type in ["alltoall", "alltoall_base"]:
|
if self.type == "all_to_all":
|
||||||
return MatchState.UNDECIDED
|
return MatchState.UNDECIDED
|
||||||
if self.type != "scatter" and self.input_sizes != other.input_sizes:
|
if self.type != "scatter" and self.input_sizes != other.input_sizes:
|
||||||
return MatchState.SIZE_OR_SYNTAX_MISMATCH
|
return MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||||
@ -297,14 +298,14 @@ class Op:
|
|||||||
"all_gather",
|
"all_gather",
|
||||||
"all_gather_base",
|
"all_gather_base",
|
||||||
] and not check_size_evenly_broadcasting(
|
] and not check_size_evenly_broadcasting(
|
||||||
other.output_sizes, self.input_sizes, self.pg_size
|
other.output_sizes[0], self.input_sizes[0], self.pg_size
|
||||||
):
|
):
|
||||||
return MatchState.SIZE_OR_SYNTAX_MISMATCH
|
return MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||||
if self.type in [
|
if self.type in [
|
||||||
"reduce_scatter",
|
"reduce_scatter",
|
||||||
"_reduce_scatter_base",
|
"_reduce_scatter_base",
|
||||||
] and not check_size_evenly_broadcasting(
|
] and not check_size_evenly_broadcasting(
|
||||||
other.input_sizes, self.output_sizes, self.pg_size
|
other.input_sizes[0], self.output_sizes[0], self.pg_size
|
||||||
):
|
):
|
||||||
return MatchState.SIZE_OR_SYNTAX_MISMATCH
|
return MatchState.SIZE_OR_SYNTAX_MISMATCH
|
||||||
# TODO: need to add more checks for gather and scatter.
|
# TODO: need to add more checks for gather and scatter.
|
||||||
|
@ -202,8 +202,8 @@ def check_no_missing_dump_files(
|
|||||||
) -> None:
|
) -> None:
|
||||||
all_ranks = set()
|
all_ranks = set()
|
||||||
for membership in memberships:
|
for membership in memberships:
|
||||||
all_ranks.add(str(membership.global_rank))
|
all_ranks.add(int(membership.global_rank))
|
||||||
dumps_ranks = set(entries.keys())
|
dumps_ranks = {int(key) for key in entries.keys()}
|
||||||
assert (
|
assert (
|
||||||
dumps_ranks == all_ranks
|
dumps_ranks == all_ranks
|
||||||
), f"Missing dump files from ranks {all_ranks - dumps_ranks}"
|
), f"Missing dump files from ranks {all_ranks - dumps_ranks}"
|
||||||
|
Reference in New Issue
Block a user