mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Torchelastic][Logging] Pluggable logsspecs using python entrypoints and option to specify one by name. (#120942)
Summary: Expose an option to users to specify name of the LogsSpec implementation to use. - Has to be defined in entrypoints under `torchrun.logs_specs` group. - Must implement LogsSpec defined in prior PR/diff. Test Plan: unit test+local tests Reviewed By: ezyang Differential Revision: D54180838 Pull Request resolved: https://github.com/pytorch/pytorch/pull/120942 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
f351a71dbb
commit
b0cfa96e82
5
setup.py
5
setup.py
@ -1055,7 +1055,10 @@ def configure_extension_build():
|
||||
"convert-caffe2-to-onnx = caffe2.python.onnx.bin.conversion:caffe2_to_onnx",
|
||||
"convert-onnx-to-caffe2 = caffe2.python.onnx.bin.conversion:onnx_to_caffe2",
|
||||
"torchrun = torch.distributed.run:main",
|
||||
]
|
||||
],
|
||||
"torchrun.logs_specs": [
|
||||
"default = torch.distributed.elastic.multiprocessing:DefaultLogsSpecs",
|
||||
],
|
||||
}
|
||||
|
||||
return extensions, cmdclass, packages, entry_points, extra_install_requires
|
||||
|
@ -17,17 +17,18 @@ import unittest
|
||||
import uuid
|
||||
from contextlib import closing
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import torch.distributed.run as launch
|
||||
from torch.distributed.elastic.agent.server.api import RunResult, WorkerState
|
||||
from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs
|
||||
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
|
||||
from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
|
||||
from torch.distributed.elastic.utils import get_socket_with_port
|
||||
from torch.distributed.elastic.utils.distributed import get_free_port
|
||||
from torch.testing._internal.common_utils import (
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
)
|
||||
|
||||
|
||||
@ -504,6 +505,55 @@ class ElasticLaunchTest(unittest.TestCase):
|
||||
is_torchelastic_launched = fp.readline()
|
||||
self.assertEqual("True", is_torchelastic_launched)
|
||||
|
||||
@patch("torch.distributed.run.metadata")
|
||||
@skip_but_pass_in_sandcastle_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
|
||||
def test_is_torchelastic_launched_with_logs_spec_defined(self, metadata_mock):
|
||||
# mock the entrypoint API to avoid version issues.
|
||||
entrypoints = MagicMock()
|
||||
metadata_mock.entry_points.return_value = entrypoints
|
||||
|
||||
group = MagicMock()
|
||||
entrypoints.select.return_value = group
|
||||
|
||||
ep = MagicMock()
|
||||
ep.load.return_value = DefaultLogsSpecs
|
||||
|
||||
group.select.return_value = (ep)
|
||||
group.__getitem__.return_value = ep
|
||||
|
||||
out_file = f"{os.path.join(self.test_dir, 'out')}"
|
||||
if os.path.exists(out_file):
|
||||
os.remove(out_file)
|
||||
launch.main(
|
||||
[
|
||||
"--run-path",
|
||||
"--nnodes=1",
|
||||
"--nproc-per-node=1",
|
||||
"--monitor-interval=1",
|
||||
"--logs_specs=default",
|
||||
path("bin/test_script_is_torchelastic_launched.py"),
|
||||
f"--out-file={out_file}",
|
||||
]
|
||||
)
|
||||
|
||||
with open(out_file) as fp:
|
||||
is_torchelastic_launched = fp.readline()
|
||||
self.assertEqual("True", is_torchelastic_launched)
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
|
||||
def test_logs_logs_spec_entrypoint_must_be_defined(self):
|
||||
with self.assertRaises(ValueError):
|
||||
launch.main(
|
||||
[
|
||||
"--run-path",
|
||||
"--nnodes=1",
|
||||
"--nproc-per-node=1",
|
||||
"--monitor-interval=1",
|
||||
"--logs_specs=DOESNOT_EXIST",
|
||||
path("bin/test_script_is_torchelastic_launched.py"),
|
||||
]
|
||||
)
|
||||
|
||||
def test_is_not_torchelastic_launched(self):
|
||||
# launch test script without torchelastic and validate that
|
||||
# torch.distributed.is_torchelastic_launched() returns False
|
||||
|
@ -68,6 +68,7 @@ from typing import Callable, Dict, Optional, Tuple, Union, Set
|
||||
from torch.distributed.elastic.multiprocessing.api import ( # noqa: F401
|
||||
_validate_full_rank,
|
||||
DefaultLogsSpecs,
|
||||
LogsDest,
|
||||
LogsSpecs,
|
||||
MultiprocessContext,
|
||||
PContext,
|
||||
@ -88,6 +89,7 @@ __all__ = [
|
||||
"RunProcsResult",
|
||||
"SignalException",
|
||||
"Std",
|
||||
"LogsDest",
|
||||
"LogsSpecs",
|
||||
"DefaultLogsSpecs",
|
||||
"SubprocessContext",
|
||||
|
@ -184,10 +184,18 @@ class LogsDest:
|
||||
class LogsSpecs(ABC):
|
||||
"""
|
||||
Defines logs processing and redirection for each worker process.
|
||||
|
||||
Args:
|
||||
log_dir: base directory where logs will be written
|
||||
redirects: specifies which streams to redirect to files.
|
||||
tee: specifies which streams to duplicate to stdout/stderr
|
||||
log_dir:
|
||||
Base directory where logs will be written.
|
||||
redirects:
|
||||
Streams to redirect to files. Pass a single ``Std``
|
||||
enum to redirect for all workers, or a mapping keyed
|
||||
by local_rank to selectively redirect.
|
||||
tee:
|
||||
Streams to duplicate to stdout/stderr.
|
||||
Pass a single ``Std`` enum to duplicate streams for all workers,
|
||||
or a mapping keyed by local_rank to selectively duplicate.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -220,7 +228,8 @@ class LogsSpecs(ABC):
|
||||
class DefaultLogsSpecs(LogsSpecs):
|
||||
"""
|
||||
Default LogsSpecs implementation:
|
||||
- `log_dir` will be created if it doesn't exist and it is not set to os.devnull
|
||||
|
||||
- `log_dir` will be created if it doesn't exist and it is not set to `os.devnull`
|
||||
- Generates nested folders for each attempt and rank.
|
||||
"""
|
||||
def __init__(
|
||||
|
@ -54,12 +54,6 @@ class LaunchConfig:
|
||||
as a period of monitoring workers.
|
||||
start_method: The method is used by the elastic agent to start the
|
||||
workers (spawn, fork, forkserver).
|
||||
log_dir: base log directory where log files are written. If not set,
|
||||
one is created in a tmp dir but NOT removed on exit.
|
||||
redirects: configuration to redirect stdout/stderr to log files.
|
||||
Pass a single ``Std`` enum to redirect all workers,
|
||||
or a mapping keyed by local_rank to selectively redirect.
|
||||
tee: configuration to "tee" stdout/stderr to console + log file.
|
||||
metrics_cfg: configuration to initialize metrics.
|
||||
local_addr: address of the local node if any. If not set, a lookup on the local
|
||||
machine's FQDN will be performed.
|
||||
@ -248,9 +242,9 @@ def launch_agent(
|
||||
|
||||
agent = LocalElasticAgent(
|
||||
spec=spec,
|
||||
logs_specs=config.logs_specs, # type: ignore[arg-type]
|
||||
start_method=config.start_method,
|
||||
log_line_prefix_template=config.log_line_prefix_template,
|
||||
logs_specs=config.logs_specs, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
shutdown_rdzv = True
|
||||
|
@ -375,12 +375,13 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
import importlib.metadata as metadata
|
||||
from argparse import REMAINDER, ArgumentParser
|
||||
from typing import Callable, List, Tuple, Union, Optional, Set
|
||||
from typing import Callable, List, Tuple, Type, Union, Optional, Set
|
||||
|
||||
import torch
|
||||
from torch.distributed.argparse_util import check_env, env
|
||||
from torch.distributed.elastic.multiprocessing import Std, DefaultLogsSpecs
|
||||
from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, LogsSpecs, Std
|
||||
from torch.distributed.elastic.multiprocessing.errors import record
|
||||
from torch.distributed.elastic.rendezvous.utils import _parse_rendezvous_config
|
||||
from torch.distributed.elastic.utils import macros
|
||||
@ -602,6 +603,15 @@ def get_args_parser() -> ArgumentParser:
|
||||
"machine's FQDN.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--logs-specs",
|
||||
"--logs_specs",
|
||||
default=None,
|
||||
type=str,
|
||||
help="torchrun.logs_specs group entrypoint name, value must be type of LogsSpecs. "
|
||||
"Can be used to override custom logging behavior.",
|
||||
)
|
||||
|
||||
#
|
||||
# Positional arguments.
|
||||
#
|
||||
@ -699,6 +709,36 @@ def get_use_env(args) -> bool:
|
||||
return args.use_env
|
||||
|
||||
|
||||
def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]:
|
||||
"""
|
||||
Attemps to load `torchrun.logs_spec` entrypoint with key of `logs_specs_name` param.
|
||||
Provides plugin mechanism to provide custom implementation of LogsSpecs.
|
||||
|
||||
Returns `DefaultLogsSpecs` when logs_spec_name is None.
|
||||
Raises ValueError when entrypoint for `logs_spec_name` can't be found in entrypoints.
|
||||
"""
|
||||
logs_specs_cls = None
|
||||
if logs_specs_name is not None:
|
||||
eps = metadata.entry_points()
|
||||
if hasattr(eps, "select"): # >= 3.10
|
||||
group = eps.select(group="torchrun.logs_specs")
|
||||
if group.select(name=logs_specs_name):
|
||||
logs_specs_cls = group[logs_specs_name].load()
|
||||
|
||||
elif specs := eps.get("torchrun.logs_specs"): # < 3.10
|
||||
if entrypoint_list := [ep for ep in specs if ep.name == logs_specs_name]:
|
||||
logs_specs_cls = entrypoint_list[0].load()
|
||||
|
||||
if logs_specs_cls is None:
|
||||
raise ValueError(f"Could not find entrypoint under 'torchrun.logs_specs[{logs_specs_name}]' key")
|
||||
|
||||
logging.info("Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls))
|
||||
else:
|
||||
logs_specs_cls = DefaultLogsSpecs
|
||||
|
||||
return logs_specs_cls
|
||||
|
||||
|
||||
def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str]]:
|
||||
# If ``args`` not passed, defaults to ``sys.argv[:1]``
|
||||
min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes)
|
||||
@ -745,7 +785,8 @@ def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str
|
||||
"--local_ranks_filter must be a comma-separated list of integers e.g. --local_ranks_filter=0,1,2"
|
||||
) from e
|
||||
|
||||
logs_specs = DefaultLogsSpecs(
|
||||
logs_specs_cls: Type[LogsSpecs] = _get_logs_specs_class(args.logs_specs)
|
||||
logs_specs = logs_specs_cls(
|
||||
log_dir=args.log_dir,
|
||||
redirects=Std.from_str(args.redirects),
|
||||
tee=Std.from_str(args.tee),
|
||||
|
Reference in New Issue
Block a user