[fr] Skip the dtype check for some one to all or all to one collective (#163839)

As title, in practice we found that sometimes, the dtype of gather does not match when it comes to output among all ranks, which is a undefined behavior. Same with broadcast and scatter. And they are all completed, so we should not think they are errors, we can skip it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163839
Approved by: https://github.com/VieEeEw
This commit is contained in:
fduwjj
2025-09-24 21:37:31 -07:00
committed by PyTorch MergeBot
parent e8f5f1b1a2
commit c8e75c48b9
2 changed files with 38 additions and 17 deletions

View File

@ -143,6 +143,19 @@ class FlightRecorderEventTest(TestCase):
match_one_event(e11, e12, membership, "0").state,
MatchState.FULLY_MATCHED,
)
e13 = create_one_event(
"gather",
("0", "default"),
[[4, 4]],
[[4, 4]],
"completed",
1,
output_dtypes="",
)
self.assertEqual(
match_one_event(e11, e13, membership, "0").state,
MatchState.FULLY_MATCHED,
)
def test_all_events(self):
for collective in sorted(COLLECTIVES):

View File

@ -469,6 +469,30 @@ class Op:
f"{p2p_info}, " if p2p_info else ""
)
def dtype_mismatch(self, other: "Op") -> bool:
if (
(
self.type not in ["scatter", "gather", "broadcast"]
and set(self.input_dtypes) != set(self.output_dtypes)
and self.input_sizes[0]
and self.output_sizes[0]
)
or (
self.type not in ["scatter", "broadcast"]
and set(self.input_dtypes) != set(other.input_dtypes)
and self.input_sizes[0]
and other.input_sizes[0]
)
or (
self.type not in ["gather"]
and set(self.output_dtypes) != set(other.output_dtypes)
and self.output_sizes[0]
and other.output_sizes[0]
)
):
return True
return False
def match(self, other: "Op") -> MatchInfo:
# TODO: I think this can validly not match,
# e.g. if one PG was used for p2p ops between only some of the peers?
@ -510,23 +534,7 @@ class Op:
MatchState.COLLECTIVE_STATE_MISMATCH,
f"Expected state: '{self.state}' does not match found state: '{other.state}'",
)
if (
(
set(self.input_dtypes) != set(self.output_dtypes)
and self.input_sizes[0]
and self.output_sizes[0]
)
or (
set(self.input_dtypes) != set(other.input_dtypes)
and self.input_sizes[0]
and other.input_sizes[0]
)
or (
set(self.input_dtypes) != set(other.output_dtypes)
and self.input_sizes[0]
and other.output_sizes[0]
)
):
if self.dtype_mismatch(other):
return MatchInfo(
MatchState.COLLECTIVE_DTYPE_MISMATCH,
f"Expected dtypes: '{set(self.input_dtypes)}' does not "