mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Context in RFC: https://github.com/pytorch/pytorch/issues/86620 .rst file will be finalized in subsequent PRs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88698 Approved by: https://github.com/wanchaol
50 lines
1.5 KiB
Python
50 lines
1.5 KiB
Python
from typing import Dict, Tuple, Any
|
|
import traceback as tb
|
|
|
|
WRAPPED_EXCEPTION = Tuple[BaseException, tb.StackSummary]
|
|
|
|
__all__ = ["CheckpointException"]
|
|
|
|
|
|
def _wrap_exception(exc: BaseException) -> WRAPPED_EXCEPTION:
|
|
return (exc, tb.extract_tb(exc.__traceback__))
|
|
|
|
|
|
def _is_wrapped_exception(obj: Any) -> bool:
|
|
if not isinstance(obj, tuple):
|
|
return False
|
|
if len(obj) != 2:
|
|
return False
|
|
return isinstance(obj[0], BaseException) and isinstance(
|
|
obj[1], tb.StackSummary
|
|
)
|
|
|
|
|
|
class CheckpointException(BaseException):
|
|
"""
|
|
Exception raised if failure was detected as part of a checkpoint load or save.
|
|
"""
|
|
|
|
def __init__(self, msg: str, failures: Dict[int, WRAPPED_EXCEPTION]):
|
|
super().__init__(msg, failures)
|
|
self._failures = failures
|
|
|
|
@property
|
|
def failures(self) -> Dict[int, WRAPPED_EXCEPTION]:
|
|
"""
|
|
Returns:
|
|
Dict of failed nodes and their associated exception.
|
|
Keys are node ranks and values are exceptions
|
|
"""
|
|
return self._failures
|
|
|
|
def __str__(self):
|
|
str = f"CheckpointException ranks:{self._failures.keys()}\n"
|
|
for rank, exc_pair in self._failures.items():
|
|
exc, trace = exc_pair
|
|
str += f"Traceback (most recent call last): (RANK {rank})\n"
|
|
if trace is not None:
|
|
str += "".join(tb.format_list(trace))
|
|
str += "".join(tb.format_exception_only(type(exc), value=exc))
|
|
return str
|