[Torchelastic][Logging] Pluggable logsspecs using python entrypoints and option to specify one by name. (#120942)

Summary:
Expose an option to users to specify name of the LogsSpec implementation to use.
- Has to be defined in entrypoints under `torchrun.logs_specs` group.
- Must implement LogsSpec defined in prior PR/diff.

Test Plan: unit test+local tests

Reviewed By: ezyang

Differential Revision: D54180838

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120942
Approved by: https://github.com/ezyang
This commit is contained in:
Kurman Karabukaev
2024-03-02 08:07:52 +00:00
committed by PyTorch MergeBot
parent f351a71dbb
commit b0cfa96e82
6 changed files with 116 additions and 17 deletions

View File

@ -375,12 +375,13 @@ import logging
import os
import sys
import uuid
import importlib.metadata as metadata
from argparse import REMAINDER, ArgumentParser
from typing import Callable, List, Tuple, Union, Optional, Set
from typing import Callable, List, Tuple, Type, Union, Optional, Set
import torch
from torch.distributed.argparse_util import check_env, env
from torch.distributed.elastic.multiprocessing import Std, DefaultLogsSpecs
from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, LogsSpecs, Std
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.elastic.rendezvous.utils import _parse_rendezvous_config
from torch.distributed.elastic.utils import macros
@ -602,6 +603,15 @@ def get_args_parser() -> ArgumentParser:
"machine's FQDN.",
)
parser.add_argument(
"--logs-specs",
"--logs_specs",
default=None,
type=str,
help="torchrun.logs_specs group entrypoint name, value must be type of LogsSpecs. "
"Can be used to override custom logging behavior.",
)
#
# Positional arguments.
#
@ -699,6 +709,36 @@ def get_use_env(args) -> bool:
return args.use_env
def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]:
"""
Attemps to load `torchrun.logs_spec` entrypoint with key of `logs_specs_name` param.
Provides plugin mechanism to provide custom implementation of LogsSpecs.
Returns `DefaultLogsSpecs` when logs_spec_name is None.
Raises ValueError when entrypoint for `logs_spec_name` can't be found in entrypoints.
"""
logs_specs_cls = None
if logs_specs_name is not None:
eps = metadata.entry_points()
if hasattr(eps, "select"): # >= 3.10
group = eps.select(group="torchrun.logs_specs")
if group.select(name=logs_specs_name):
logs_specs_cls = group[logs_specs_name].load()
elif specs := eps.get("torchrun.logs_specs"): # < 3.10
if entrypoint_list := [ep for ep in specs if ep.name == logs_specs_name]:
logs_specs_cls = entrypoint_list[0].load()
if logs_specs_cls is None:
raise ValueError(f"Could not find entrypoint under 'torchrun.logs_specs[{logs_specs_name}]' key")
logging.info("Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls))
else:
logs_specs_cls = DefaultLogsSpecs
return logs_specs_cls
def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str]]:
# If ``args`` not passed, defaults to ``sys.argv[:1]``
min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes)
@ -745,7 +785,8 @@ def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str
"--local_ranks_filter must be a comma-separated list of integers e.g. --local_ranks_filter=0,1,2"
) from e
logs_specs = DefaultLogsSpecs(
logs_specs_cls: Type[LogsSpecs] = _get_logs_specs_class(args.logs_specs)
logs_specs = logs_specs_cls(
log_dir=args.log_dir,
redirects=Std.from_str(args.redirects),
tee=Std.from_str(args.tee),