PEP585 update - torch/distributed (#145164)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145164
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-20 14:50:01 -08:00
committed by PyTorch MergeBot
parent c6986ca2e1
commit 00ffeca1b1
79 changed files with 805 additions and 860 deletions

View File

@ -398,7 +398,7 @@ import sys
import uuid
from argparse import ArgumentParser, REMAINDER
from importlib import metadata
from typing import Callable, List, Optional, Set, Type, Union
from typing import Callable, Optional, Union
import torch
from torch.distributed.argparse_util import check_env, env
@ -736,7 +736,7 @@ def get_use_env(args) -> bool:
return args.use_env
def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]:
def _get_logs_specs_class(logs_specs_name: Optional[str]) -> type[LogsSpecs]:
"""
Attemps to load `torchrun.logs_spec` entrypoint with key of `logs_specs_name` param.
Provides plugin mechanism to provide custom implementation of LogsSpecs.
@ -770,7 +770,7 @@ def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]:
return logs_specs_cls
def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], List[str]]:
def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str]]:
# If ``args`` not passed, defaults to ``sys.argv[:1]``
min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes)
assert 0 < min_nodes <= max_nodes
@ -810,7 +810,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], List[str
rdzv_endpoint = get_rdzv_endpoint(args)
ranks: Optional[Set[int]] = None
ranks: Optional[set[int]] = None
if args.local_ranks_filter:
try:
ranks = set(map(int, args.local_ranks_filter.split(",")))
@ -820,7 +820,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], List[str
"--local_ranks_filter must be a comma-separated list of integers e.g. --local_ranks_filter=0,1,2"
) from e
logs_specs_cls: Type[LogsSpecs] = _get_logs_specs_class(args.logs_specs)
logs_specs_cls: type[LogsSpecs] = _get_logs_specs_class(args.logs_specs)
logs_specs = logs_specs_cls(
log_dir=args.log_dir,
redirects=Std.from_str(args.redirects),