[fr] Re-order mismatch check in fr analysis script (#164606)

In reality we found the current mismatch order does not match the actual error distribution, so we reorder it a bit as following:
1. We do collective type check first
2. Then size check (excluding all2all)
3. dtype check
4. state check

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164606
Approved by: https://github.com/VieEeEw
This commit is contained in:
fduwjj
2025-10-03 14:55:52 -07:00
committed by PyTorch MergeBot
parent f3afbcf340
commit 86c789849e

View File

@ -528,28 +528,19 @@ class Op:
MatchState.COLLECTIVE_TYPE_MISMATCH,
f"Expected collective type: '{self.type}' does not match found collective type: '{other.type}'",
)
if self.state != other.state:
# MatchState()
return MatchInfo(
MatchState.COLLECTIVE_STATE_MISMATCH,
f"Expected state: '{self.state}' does not match found state: '{other.state}'",
)
if self.dtype_mismatch(other):
return MatchInfo(
MatchState.COLLECTIVE_DTYPE_MISMATCH,
f"Expected dtypes: '{set(self.input_dtypes)}' does not "
f"match found dtype: '{set(self.output_dtypes)}/"
f"{set(other.input_dtypes)}/{set(other.output_dtypes)}'",
)
if self.type == "all_to_all":
return MatchInfo(MatchState.UNDECIDED)
if self.type != "scatter" and self.input_sizes != other.input_sizes:
if (
self.type not in ["all_to_all", "scatter"]
and self.input_sizes != other.input_sizes
):
return MatchInfo(
MatchState.SIZE_OR_SYNTAX_MISMATCH,
f"Expected input sizes: '{self.input_sizes}' does not match found input sizes: "
f"'{other.input_sizes}'",
)
if self.type != "gather" and self.output_sizes != other.output_sizes:
if (
self.type not in ["all_to_all", "gather"]
and self.output_sizes != other.output_sizes
):
return MatchInfo(
MatchState.SIZE_OR_SYNTAX_MISMATCH,
f"Expected output sizes: '{self.output_sizes}' does not match found output sizes: "
@ -589,6 +580,21 @@ class Op:
f"Found input numel '{math.prod(other.input_sizes[0])}' does not match output numel "
f"'{math.prod(other.output_sizes[0])} * pg size {self.pg_size}'",
)
if self.dtype_mismatch(other):
return MatchInfo(
MatchState.COLLECTIVE_DTYPE_MISMATCH,
f"Expected dtypes: '{set(self.input_dtypes)}' does not "
f"match found dtype: '{set(self.output_dtypes)}/"
f"{set(other.input_dtypes)}/{set(other.output_dtypes)}'",
)
if self.state != other.state:
# MatchState()
return MatchInfo(
MatchState.COLLECTIVE_STATE_MISMATCH,
f"Expected state: '{self.state}' does not match found state: '{other.state}'",
)
if self.type == "all_to_all":
return MatchInfo(MatchState.UNDECIDED)
elif self.type in [
"coalesced",
"ALLGATHER_coalesced",