[FR] Polish the log message for dtype mismatch and don't exit when too many mismatch (#140451)

Summary:
1. We don't want to exit with exceptions when there are so many mismatches. We should just break and return.
2. Polish the message of dtype mismatch. This is because dtype of input/output is actually a list not a string. So we don't want to show a list of ['double'] in the output message.

Test Plan:
Testing on the case when we see too many collective dtype mismatch

 {F1958467224}

Differential Revision: D65841830

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140451
Approved by: https://github.com/c-p-i-o
This commit is contained in:
Junjie Wang (PyTorch)
2024-11-13 07:24:53 +00:00
committed by PyTorch MergeBot
parent cb71bcc542
commit c61ccaf10e
3 changed files with 13 additions and 13 deletions

View File

@ -38,6 +38,7 @@ def create_one_event(
"collective_seq_id": str(collective_seq_id),
"p2p_seq_id": str(p2p_seq_id),
"time_created_ns": 0,
"frames": [],
}

View File

@ -412,7 +412,7 @@ def build_collectives(
logger.error(
"Too many mismatches for process_group %s: %s aborting", pg_name, desc
)
sys.exit(-1)
break
return tracebacks, collectives, nccl_calls

View File

@ -371,8 +371,8 @@ class Op:
def __init__(
self, event: Dict[Any, Any], memberships: Dict[str, Set[Any]], pg_name: str
):
profiling_name = event["profiling_name"]
nccl, name = profiling_name.split(":")
self.profiling_name = event["profiling_name"]
nccl, name = self.profiling_name.split(":")
assert nccl == "nccl", f"name formatting error? {nccl} != 'nccl'"
parts = name.split(" ")
type = parts[0]
@ -403,6 +403,7 @@ class Op:
self.input_dtypes = event["input_dtypes"]
self.output_dtypes = event["output_dtypes"]
self.time_created_ns = event["time_created_ns"]
self.collective_frames = event["frames"]
self.is_verbose = os.getenv("FR_TRACE_VERBOSE_OUTPUT", "0") == "1"
def _init_global_src_dst(self, pg_ranks: Set[Any]) -> None:
@ -440,11 +441,8 @@ class Op:
f"state={self.state}",
)
return f"{self.type}(%s)" % ", ".join(s for s in verbose_info if s)
return (
f"{self.type}(%sinput_sizes={self.input_sizes}, state={self.state})"
% f"{p2p_info}, "
if p2p_info
else ""
return f"{self.type}(%sinput_sizes={self.input_sizes}, state={self.state})" % (
f"{p2p_info}, " if p2p_info else ""
)
def match(self, other: "Op") -> MatchState:
@ -487,13 +485,14 @@ class Op:
f"Expected state: '{self.state}' does not match found state: '{other.state}'"
)
if (
other.input_dtypes != other.output_dtypes
or self.input_dtypes != other.input_dtypes
or self.output_dtypes != other.output_dtypes
set(self.input_dtypes) != set(self.output_dtypes)
or set(self.input_dtypes) != set(other.input_dtypes)
or set(self.input_dtypes) != set(other.output_dtypes)
):
return MatchState.COLLECTIVE_DTYPE_MISMATCH(
f"Expected dtypes: '{self.input_dtypes}/{other.input_dtypes}' does not "
f"match found dtype: '{self.output_dtypes}/{other.output_dtypes}'",
f"Expected dtypes: '{set(self.input_dtypes)}' does not "
f"match found dtype: '{set(self.output_dtypes)}/"
f"{set(other.input_dtypes)}/{set(other.output_dtypes)}'",
)
if self.type == "all_to_all":
return MatchState.UNDECIDED