mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[fr] Re-order mismatch check in fr analysis script (#164606)
In reality we found the current mismatch order does not match the actual error distribution, so we reorder it a bit as following: 1. We do collective type check first 2. Then size check (excluding all2all) 3. dtype check 4. state check Pull Request resolved: https://github.com/pytorch/pytorch/pull/164606 Approved by: https://github.com/VieEeEw
This commit is contained in:
@ -528,28 +528,19 @@ class Op:
|
||||
MatchState.COLLECTIVE_TYPE_MISMATCH,
|
||||
f"Expected collective type: '{self.type}' does not match found collective type: '{other.type}'",
|
||||
)
|
||||
if self.state != other.state:
|
||||
# MatchState()
|
||||
return MatchInfo(
|
||||
MatchState.COLLECTIVE_STATE_MISMATCH,
|
||||
f"Expected state: '{self.state}' does not match found state: '{other.state}'",
|
||||
)
|
||||
if self.dtype_mismatch(other):
|
||||
return MatchInfo(
|
||||
MatchState.COLLECTIVE_DTYPE_MISMATCH,
|
||||
f"Expected dtypes: '{set(self.input_dtypes)}' does not "
|
||||
f"match found dtype: '{set(self.output_dtypes)}/"
|
||||
f"{set(other.input_dtypes)}/{set(other.output_dtypes)}'",
|
||||
)
|
||||
if self.type == "all_to_all":
|
||||
return MatchInfo(MatchState.UNDECIDED)
|
||||
if self.type != "scatter" and self.input_sizes != other.input_sizes:
|
||||
if (
|
||||
self.type not in ["all_to_all", "scatter"]
|
||||
and self.input_sizes != other.input_sizes
|
||||
):
|
||||
return MatchInfo(
|
||||
MatchState.SIZE_OR_SYNTAX_MISMATCH,
|
||||
f"Expected input sizes: '{self.input_sizes}' does not match found input sizes: "
|
||||
f"'{other.input_sizes}'",
|
||||
)
|
||||
if self.type != "gather" and self.output_sizes != other.output_sizes:
|
||||
if (
|
||||
self.type not in ["all_to_all", "gather"]
|
||||
and self.output_sizes != other.output_sizes
|
||||
):
|
||||
return MatchInfo(
|
||||
MatchState.SIZE_OR_SYNTAX_MISMATCH,
|
||||
f"Expected output sizes: '{self.output_sizes}' does not match found output sizes: "
|
||||
@ -589,6 +580,21 @@ class Op:
|
||||
f"Found input numel '{math.prod(other.input_sizes[0])}' does not match output numel "
|
||||
f"'{math.prod(other.output_sizes[0])} * pg size {self.pg_size}'",
|
||||
)
|
||||
if self.dtype_mismatch(other):
|
||||
return MatchInfo(
|
||||
MatchState.COLLECTIVE_DTYPE_MISMATCH,
|
||||
f"Expected dtypes: '{set(self.input_dtypes)}' does not "
|
||||
f"match found dtype: '{set(self.output_dtypes)}/"
|
||||
f"{set(other.input_dtypes)}/{set(other.output_dtypes)}'",
|
||||
)
|
||||
if self.state != other.state:
|
||||
# MatchState()
|
||||
return MatchInfo(
|
||||
MatchState.COLLECTIVE_STATE_MISMATCH,
|
||||
f"Expected state: '{self.state}' does not match found state: '{other.state}'",
|
||||
)
|
||||
if self.type == "all_to_all":
|
||||
return MatchInfo(MatchState.UNDECIDED)
|
||||
elif self.type in [
|
||||
"coalesced",
|
||||
"ALLGATHER_coalesced",
|
||||
|
Reference in New Issue
Block a user