mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
PEP585 update - torch/distributed (#145164)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145164 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
c6986ca2e1
commit
00ffeca1b1
@ -398,7 +398,7 @@ import sys
|
||||
import uuid
|
||||
from argparse import ArgumentParser, REMAINDER
|
||||
from importlib import metadata
|
||||
from typing import Callable, List, Optional, Set, Type, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed.argparse_util import check_env, env
|
||||
@ -736,7 +736,7 @@ def get_use_env(args) -> bool:
|
||||
return args.use_env
|
||||
|
||||
|
||||
def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]:
|
||||
def _get_logs_specs_class(logs_specs_name: Optional[str]) -> type[LogsSpecs]:
|
||||
"""
|
||||
Attemps to load `torchrun.logs_spec` entrypoint with key of `logs_specs_name` param.
|
||||
Provides plugin mechanism to provide custom implementation of LogsSpecs.
|
||||
@ -770,7 +770,7 @@ def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]:
|
||||
return logs_specs_cls
|
||||
|
||||
|
||||
def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], List[str]]:
|
||||
def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str]]:
|
||||
# If ``args`` not passed, defaults to ``sys.argv[:1]``
|
||||
min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes)
|
||||
assert 0 < min_nodes <= max_nodes
|
||||
@ -810,7 +810,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], List[str
|
||||
|
||||
rdzv_endpoint = get_rdzv_endpoint(args)
|
||||
|
||||
ranks: Optional[Set[int]] = None
|
||||
ranks: Optional[set[int]] = None
|
||||
if args.local_ranks_filter:
|
||||
try:
|
||||
ranks = set(map(int, args.local_ranks_filter.split(",")))
|
||||
@ -820,7 +820,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], List[str
|
||||
"--local_ranks_filter must be a comma-separated list of integers e.g. --local_ranks_filter=0,1,2"
|
||||
) from e
|
||||
|
||||
logs_specs_cls: Type[LogsSpecs] = _get_logs_specs_class(args.logs_specs)
|
||||
logs_specs_cls: type[LogsSpecs] = _get_logs_specs_class(args.logs_specs)
|
||||
logs_specs = logs_specs_cls(
|
||||
log_dir=args.log_dir,
|
||||
redirects=Std.from_str(args.redirects),
|
||||
|
Reference in New Issue
Block a user