mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[fr] Skip the dtype check for some one to all or all to one collective (#163839)
As title, in practice we found that sometimes, the dtype of gather does not match when it comes to output among all ranks, which is a undefined behavior. Same with broadcast and scatter. And they are all completed, so we should not think they are errors, we can skip it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163839 Approved by: https://github.com/VieEeEw
This commit is contained in:
@ -143,6 +143,19 @@ class FlightRecorderEventTest(TestCase):
|
||||
match_one_event(e11, e12, membership, "0").state,
|
||||
MatchState.FULLY_MATCHED,
|
||||
)
|
||||
e13 = create_one_event(
|
||||
"gather",
|
||||
("0", "default"),
|
||||
[[4, 4]],
|
||||
[[4, 4]],
|
||||
"completed",
|
||||
1,
|
||||
output_dtypes="",
|
||||
)
|
||||
self.assertEqual(
|
||||
match_one_event(e11, e13, membership, "0").state,
|
||||
MatchState.FULLY_MATCHED,
|
||||
)
|
||||
|
||||
def test_all_events(self):
|
||||
for collective in sorted(COLLECTIVES):
|
||||
|
@ -469,6 +469,30 @@ class Op:
|
||||
f"{p2p_info}, " if p2p_info else ""
|
||||
)
|
||||
|
||||
def dtype_mismatch(self, other: "Op") -> bool:
|
||||
if (
|
||||
(
|
||||
self.type not in ["scatter", "gather", "broadcast"]
|
||||
and set(self.input_dtypes) != set(self.output_dtypes)
|
||||
and self.input_sizes[0]
|
||||
and self.output_sizes[0]
|
||||
)
|
||||
or (
|
||||
self.type not in ["scatter", "broadcast"]
|
||||
and set(self.input_dtypes) != set(other.input_dtypes)
|
||||
and self.input_sizes[0]
|
||||
and other.input_sizes[0]
|
||||
)
|
||||
or (
|
||||
self.type not in ["gather"]
|
||||
and set(self.output_dtypes) != set(other.output_dtypes)
|
||||
and self.output_sizes[0]
|
||||
and other.output_sizes[0]
|
||||
)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def match(self, other: "Op") -> MatchInfo:
|
||||
# TODO: I think this can validly not match,
|
||||
# e.g. if one PG was used for p2p ops between only some of the peers?
|
||||
@ -510,23 +534,7 @@ class Op:
|
||||
MatchState.COLLECTIVE_STATE_MISMATCH,
|
||||
f"Expected state: '{self.state}' does not match found state: '{other.state}'",
|
||||
)
|
||||
if (
|
||||
(
|
||||
set(self.input_dtypes) != set(self.output_dtypes)
|
||||
and self.input_sizes[0]
|
||||
and self.output_sizes[0]
|
||||
)
|
||||
or (
|
||||
set(self.input_dtypes) != set(other.input_dtypes)
|
||||
and self.input_sizes[0]
|
||||
and other.input_sizes[0]
|
||||
)
|
||||
or (
|
||||
set(self.input_dtypes) != set(other.output_dtypes)
|
||||
and self.input_sizes[0]
|
||||
and other.output_sizes[0]
|
||||
)
|
||||
):
|
||||
if self.dtype_mismatch(other):
|
||||
return MatchInfo(
|
||||
MatchState.COLLECTIVE_DTYPE_MISMATCH,
|
||||
f"Expected dtypes: '{set(self.input_dtypes)}' does not "
|
||||
|
Reference in New Issue
Block a user