mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[FR] Polish the log message for dtype mismatch and don't exit when too many mismatch (#140451)
Summary: 1. We don't want to exit with exceptions when there are so many mismatches. We should just break and return. 2. Polish the message of dtype mismatch. This is because dtype of input/output is actually a list not a string. So we don't want to show a list of ['double'] in the output message. Test Plan: Testing on the case when we see too many collective dtype mismatch {F1958467224} Differential Revision: D65841830 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140451 Approved by: https://github.com/c-p-i-o
This commit is contained in:
committed by
PyTorch MergeBot
parent
cb71bcc542
commit
c61ccaf10e
@ -38,6 +38,7 @@ def create_one_event(
|
||||
"collective_seq_id": str(collective_seq_id),
|
||||
"p2p_seq_id": str(p2p_seq_id),
|
||||
"time_created_ns": 0,
|
||||
"frames": [],
|
||||
}
|
||||
|
||||
|
||||
|
@ -412,7 +412,7 @@ def build_collectives(
|
||||
logger.error(
|
||||
"Too many mismatches for process_group %s: %s aborting", pg_name, desc
|
||||
)
|
||||
sys.exit(-1)
|
||||
break
|
||||
|
||||
return tracebacks, collectives, nccl_calls
|
||||
|
||||
|
@ -371,8 +371,8 @@ class Op:
|
||||
def __init__(
|
||||
self, event: Dict[Any, Any], memberships: Dict[str, Set[Any]], pg_name: str
|
||||
):
|
||||
profiling_name = event["profiling_name"]
|
||||
nccl, name = profiling_name.split(":")
|
||||
self.profiling_name = event["profiling_name"]
|
||||
nccl, name = self.profiling_name.split(":")
|
||||
assert nccl == "nccl", f"name formatting error? {nccl} != 'nccl'"
|
||||
parts = name.split(" ")
|
||||
type = parts[0]
|
||||
@ -403,6 +403,7 @@ class Op:
|
||||
self.input_dtypes = event["input_dtypes"]
|
||||
self.output_dtypes = event["output_dtypes"]
|
||||
self.time_created_ns = event["time_created_ns"]
|
||||
self.collective_frames = event["frames"]
|
||||
self.is_verbose = os.getenv("FR_TRACE_VERBOSE_OUTPUT", "0") == "1"
|
||||
|
||||
def _init_global_src_dst(self, pg_ranks: Set[Any]) -> None:
|
||||
@ -440,11 +441,8 @@ class Op:
|
||||
f"state={self.state}",
|
||||
)
|
||||
return f"{self.type}(%s)" % ", ".join(s for s in verbose_info if s)
|
||||
return (
|
||||
f"{self.type}(%sinput_sizes={self.input_sizes}, state={self.state})"
|
||||
% f"{p2p_info}, "
|
||||
if p2p_info
|
||||
else ""
|
||||
return f"{self.type}(%sinput_sizes={self.input_sizes}, state={self.state})" % (
|
||||
f"{p2p_info}, " if p2p_info else ""
|
||||
)
|
||||
|
||||
def match(self, other: "Op") -> MatchState:
|
||||
@ -487,13 +485,14 @@ class Op:
|
||||
f"Expected state: '{self.state}' does not match found state: '{other.state}'"
|
||||
)
|
||||
if (
|
||||
other.input_dtypes != other.output_dtypes
|
||||
or self.input_dtypes != other.input_dtypes
|
||||
or self.output_dtypes != other.output_dtypes
|
||||
set(self.input_dtypes) != set(self.output_dtypes)
|
||||
or set(self.input_dtypes) != set(other.input_dtypes)
|
||||
or set(self.input_dtypes) != set(other.output_dtypes)
|
||||
):
|
||||
return MatchState.COLLECTIVE_DTYPE_MISMATCH(
|
||||
f"Expected dtypes: '{self.input_dtypes}/{other.input_dtypes}' does not "
|
||||
f"match found dtype: '{self.output_dtypes}/{other.output_dtypes}'",
|
||||
f"Expected dtypes: '{set(self.input_dtypes)}' does not "
|
||||
f"match found dtype: '{set(self.output_dtypes)}/"
|
||||
f"{set(other.input_dtypes)}/{set(other.output_dtypes)}'",
|
||||
)
|
||||
if self.type == "all_to_all":
|
||||
return MatchState.UNDECIDED
|
||||
|
Reference in New Issue
Block a user