[Torchelastic][Logging] Pluggable logsspecs using python entrypoints and option to specify one by name. (#120942)

Summary:
Expose an option to users to specify name of the LogsSpec implementation to use.
- Has to be defined in entrypoints under `torchrun.logs_specs` group.
- Must implement LogsSpec defined in prior PR/diff.

Test Plan: unit test+local tests

Reviewed By: ezyang

Differential Revision: D54180838

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120942
Approved by: https://github.com/ezyang
This commit is contained in:
Kurman Karabukaev
2024-03-02 08:07:52 +00:00
committed by PyTorch MergeBot
parent f351a71dbb
commit b0cfa96e82
6 changed files with 116 additions and 17 deletions

View File

@ -1055,7 +1055,10 @@ def configure_extension_build():
"convert-caffe2-to-onnx = caffe2.python.onnx.bin.conversion:caffe2_to_onnx",
"convert-onnx-to-caffe2 = caffe2.python.onnx.bin.conversion:onnx_to_caffe2",
"torchrun = torch.distributed.run:main",
]
],
"torchrun.logs_specs": [
"default = torch.distributed.elastic.multiprocessing:DefaultLogsSpecs",
],
}
return extensions, cmdclass, packages, entry_points, extra_install_requires

View File

@ -17,17 +17,18 @@ import unittest
import uuid
from contextlib import closing
from unittest import mock
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch
import torch.distributed.run as launch
from torch.distributed.elastic.agent.server.api import RunResult, WorkerState
from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
from torch.distributed.elastic.utils import get_socket_with_port
from torch.distributed.elastic.utils.distributed import get_free_port
from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN,
skip_but_pass_in_sandcastle_if,
TEST_WITH_DEV_DBG_ASAN,
)
@ -504,6 +505,55 @@ class ElasticLaunchTest(unittest.TestCase):
is_torchelastic_launched = fp.readline()
self.assertEqual("True", is_torchelastic_launched)
@patch("torch.distributed.run.metadata")
@skip_but_pass_in_sandcastle_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
def test_is_torchelastic_launched_with_logs_spec_defined(self, metadata_mock):
# mock the entrypoint API to avoid version issues.
entrypoints = MagicMock()
metadata_mock.entry_points.return_value = entrypoints
group = MagicMock()
entrypoints.select.return_value = group
ep = MagicMock()
ep.load.return_value = DefaultLogsSpecs
group.select.return_value = (ep)
group.__getitem__.return_value = ep
out_file = f"{os.path.join(self.test_dir, 'out')}"
if os.path.exists(out_file):
os.remove(out_file)
launch.main(
[
"--run-path",
"--nnodes=1",
"--nproc-per-node=1",
"--monitor-interval=1",
"--logs_specs=default",
path("bin/test_script_is_torchelastic_launched.py"),
f"--out-file={out_file}",
]
)
with open(out_file) as fp:
is_torchelastic_launched = fp.readline()
self.assertEqual("True", is_torchelastic_launched)
@skip_but_pass_in_sandcastle_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
def test_logs_logs_spec_entrypoint_must_be_defined(self):
with self.assertRaises(ValueError):
launch.main(
[
"--run-path",
"--nnodes=1",
"--nproc-per-node=1",
"--monitor-interval=1",
"--logs_specs=DOESNOT_EXIST",
path("bin/test_script_is_torchelastic_launched.py"),
]
)
def test_is_not_torchelastic_launched(self):
# launch test script without torchelastic and validate that
# torch.distributed.is_torchelastic_launched() returns False

View File

@ -68,6 +68,7 @@ from typing import Callable, Dict, Optional, Tuple, Union, Set
from torch.distributed.elastic.multiprocessing.api import ( # noqa: F401
_validate_full_rank,
DefaultLogsSpecs,
LogsDest,
LogsSpecs,
MultiprocessContext,
PContext,
@ -88,6 +89,7 @@ __all__ = [
"RunProcsResult",
"SignalException",
"Std",
"LogsDest",
"LogsSpecs",
"DefaultLogsSpecs",
"SubprocessContext",

View File

@ -184,10 +184,18 @@ class LogsDest:
class LogsSpecs(ABC):
"""
Defines logs processing and redirection for each worker process.
Args:
log_dir: base directory where logs will be written
redirects: specifies which streams to redirect to files.
tee: specifies which streams to duplicate to stdout/stderr
log_dir:
Base directory where logs will be written.
redirects:
Streams to redirect to files. Pass a single ``Std``
enum to redirect for all workers, or a mapping keyed
by local_rank to selectively redirect.
tee:
Streams to duplicate to stdout/stderr.
Pass a single ``Std`` enum to duplicate streams for all workers,
or a mapping keyed by local_rank to selectively duplicate.
"""
def __init__(
@ -220,7 +228,8 @@ class LogsSpecs(ABC):
class DefaultLogsSpecs(LogsSpecs):
"""
Default LogsSpecs implementation:
- `log_dir` will be created if it doesn't exist and it is not set to os.devnull
- `log_dir` will be created if it doesn't exist and it is not set to `os.devnull`
- Generates nested folders for each attempt and rank.
"""
def __init__(

View File

@ -54,12 +54,6 @@ class LaunchConfig:
as a period of monitoring workers.
start_method: The method is used by the elastic agent to start the
workers (spawn, fork, forkserver).
log_dir: base log directory where log files are written. If not set,
one is created in a tmp dir but NOT removed on exit.
redirects: configuration to redirect stdout/stderr to log files.
Pass a single ``Std`` enum to redirect all workers,
or a mapping keyed by local_rank to selectively redirect.
tee: configuration to "tee" stdout/stderr to console + log file.
metrics_cfg: configuration to initialize metrics.
local_addr: address of the local node if any. If not set, a lookup on the local
machine's FQDN will be performed.
@ -248,9 +242,9 @@ def launch_agent(
agent = LocalElasticAgent(
spec=spec,
logs_specs=config.logs_specs, # type: ignore[arg-type]
start_method=config.start_method,
log_line_prefix_template=config.log_line_prefix_template,
logs_specs=config.logs_specs, # type: ignore[arg-type]
)
shutdown_rdzv = True

View File

@ -375,12 +375,13 @@ import logging
import os
import sys
import uuid
import importlib.metadata as metadata
from argparse import REMAINDER, ArgumentParser
from typing import Callable, List, Tuple, Union, Optional, Set
from typing import Callable, List, Tuple, Type, Union, Optional, Set
import torch
from torch.distributed.argparse_util import check_env, env
from torch.distributed.elastic.multiprocessing import Std, DefaultLogsSpecs
from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, LogsSpecs, Std
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.elastic.rendezvous.utils import _parse_rendezvous_config
from torch.distributed.elastic.utils import macros
@ -602,6 +603,15 @@ def get_args_parser() -> ArgumentParser:
"machine's FQDN.",
)
parser.add_argument(
"--logs-specs",
"--logs_specs",
default=None,
type=str,
help="torchrun.logs_specs group entrypoint name, value must be type of LogsSpecs. "
"Can be used to override custom logging behavior.",
)
#
# Positional arguments.
#
@ -699,6 +709,36 @@ def get_use_env(args) -> bool:
return args.use_env
def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]:
"""
Attemps to load `torchrun.logs_spec` entrypoint with key of `logs_specs_name` param.
Provides plugin mechanism to provide custom implementation of LogsSpecs.
Returns `DefaultLogsSpecs` when logs_spec_name is None.
Raises ValueError when entrypoint for `logs_spec_name` can't be found in entrypoints.
"""
logs_specs_cls = None
if logs_specs_name is not None:
eps = metadata.entry_points()
if hasattr(eps, "select"): # >= 3.10
group = eps.select(group="torchrun.logs_specs")
if group.select(name=logs_specs_name):
logs_specs_cls = group[logs_specs_name].load()
elif specs := eps.get("torchrun.logs_specs"): # < 3.10
if entrypoint_list := [ep for ep in specs if ep.name == logs_specs_name]:
logs_specs_cls = entrypoint_list[0].load()
if logs_specs_cls is None:
raise ValueError(f"Could not find entrypoint under 'torchrun.logs_specs[{logs_specs_name}]' key")
logging.info("Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls))
else:
logs_specs_cls = DefaultLogsSpecs
return logs_specs_cls
def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str]]:
# If ``args`` not passed, defaults to ``sys.argv[:1]``
min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes)
@ -745,7 +785,8 @@ def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str
"--local_ranks_filter must be a comma-separated list of integers e.g. --local_ranks_filter=0,1,2"
) from e
logs_specs = DefaultLogsSpecs(
logs_specs_cls: Type[LogsSpecs] = _get_logs_specs_class(args.logs_specs)
logs_specs = logs_specs_cls(
log_dir=args.log_dir,
redirects=Std.from_str(args.redirects),
tee=Std.from_str(args.tee),