PEP585 update - torch/distributed/elastic torch/distributed/checkpoint (#145163)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145163
Approved by: https://github.com/Skylion007
This commit is contained in:
Aaron Orenstein
2025-01-18 14:58:05 -08:00
committed by PyTorch MergeBot
parent c64e657632
commit 316808e4e9
47 changed files with 311 additions and 344 deletions

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import functools
import time
from typing import Any, Callable, Dict, List, TypeVar
from typing import Any, Callable, TypeVar
from typing_extensions import ParamSpec
from uuid import uuid4
@ -9,7 +9,7 @@ import torch.distributed.c10d_logger as c10d_logger
from torch.distributed.checkpoint.logging_handlers import DCP_LOGGER_NAME
__all__: List[str] = []
__all__: list[str] = []
global _dcp_logger
_dcp_logger = c10d_logger._get_or_create_logger(DCP_LOGGER_NAME)
@ -18,7 +18,7 @@ _T = TypeVar("_T")
_P = ParamSpec("_P")
def _msg_dict_from_dcp_method_args(*args, **kwargs) -> Dict[str, Any]:
def _msg_dict_from_dcp_method_args(*args, **kwargs) -> dict[str, Any]:
"""
Extracts log data from dcp method args
"""
@ -52,7 +52,7 @@ def _msg_dict_from_dcp_method_args(*args, **kwargs) -> Dict[str, Any]:
return msg_dict
def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]:
def _get_msg_dict(func_name, *args, **kwargs) -> dict[str, Any]:
msg_dict = _msg_dict_from_dcp_method_args(*args, **kwargs)
msg_dict.update(c10d_logger._get_msg_dict(func_name, *args, **kwargs))