mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Preferring dash over underscore in command-line options. Add `--command-arg-name` to the argument parser. The old arguments with underscores `--command_arg_name` are kept for backward compatibility.
Both dashes and underscores are used in the PyTorch codebase. Some argument parsers only have dashes or only have underscores in arguments. For example, the `torchrun` utility for distributed training only accepts underscore arguments (e.g., `--master_port`). The dashes are more common in other command-line tools. And it looks to be the default choice in the Python standard library:
`argparse.BooleanOptionalAction`: 4a9dff0e5a/Lib/argparse.py (L893-L895)
```python
class BooleanOptionalAction(Action):
def __init__(...):
if option_string.startswith('--'):
option_string = '--no-' + option_string[2:]
_option_strings.append(option_string)
```
It adds `--no-argname`, not `--no_argname`. Also typing `_` need to press the shift or the caps-lock key than `-`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94505
Approved by: https://github.com/ezyang, https://github.com/seemethere
271 lines
10 KiB
Python
271 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
import sys
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
import torch.distributed.elastic.rendezvous.registry as rdzv_registry
|
|
from torch.distributed.elastic import events, metrics
|
|
from torch.distributed.elastic.agent.server.api import WorkerSpec
|
|
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
|
|
from torch.distributed.elastic.multiprocessing import SignalException, Std
|
|
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
|
|
from torch.distributed.elastic.rendezvous import RendezvousParameters
|
|
from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
|
|
from torch.distributed.elastic.utils.logging import get_logger
|
|
|
|
__all__ = ['LaunchConfig', 'elastic_launch', 'launch_agent']
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
@dataclass
|
|
class LaunchConfig:
|
|
"""
|
|
Creates a rendezvous config.
|
|
|
|
Args:
|
|
min_nodes: Minimum amount of nodes that the user function will
|
|
be launched on. Elastic agent ensures that the user
|
|
function start only when the min_nodes amount enters
|
|
the rendezvous.
|
|
max_nodes: Maximum amount of nodes that the user function
|
|
will be launched on.
|
|
nproc_per_node: On each node the elastic agent will launch
|
|
this amount of workers that will execute user
|
|
defined function.
|
|
rdzv_backend: rdzv_backend to use in the rendezvous (zeus-adapter, etcd).
|
|
rdzv_endpoint: The endpoint of the rdzv sync. storage.
|
|
rdzv_configs: Key, value pair that specifies rendezvous specific configuration.
|
|
rdzv_timeout: Legacy argument that specifies timeout for the rendezvous. It is going
|
|
to be removed in future versions, see the note below. The default timeout is 900 seconds.
|
|
run_id: The unique run id of the job (if not passed a unique one will be
|
|
deduced from run environment - flow workflow id in flow - or auto generated).
|
|
role: User defined role of the worker (defaults to "trainer").
|
|
max_restarts: The maximum amount of restarts that elastic agent will conduct
|
|
on workers before failure.
|
|
monitor_interval: The interval in seconds that is used by the elastic_agent
|
|
as a period of monitoring workers.
|
|
start_method: The method is used by the elastic agent to start the
|
|
workers (spawn, fork, forkserver).
|
|
log_dir: base log directory where log files are written. If not set,
|
|
one is created in a tmp dir but NOT removed on exit.
|
|
redirects: configuration to redirect stdout/stderr to log files.
|
|
Pass a single ``Std`` enum to redirect all workers,
|
|
or a mapping keyed by local_rank to selectively redirect.
|
|
tee: configuration to "tee" stdout/stderr to console + log file.
|
|
metrics_cfg: configuration to initialize metrics.
|
|
local_addr: address of the local node if any. If not set, a lookup on the local
|
|
machine's FQDN will be performed.
|
|
..note:
|
|
`rdzv_timeout` is a legacy argument that will be removed in future.
|
|
Set the timeout via `rdzv_configs['timeout']`
|
|
|
|
"""
|
|
|
|
min_nodes: int
|
|
max_nodes: int
|
|
nproc_per_node: int
|
|
run_id: str = ""
|
|
role: str = "default_role"
|
|
rdzv_endpoint: str = ""
|
|
rdzv_backend: str = "etcd"
|
|
rdzv_configs: Dict[str, Any] = field(default_factory=dict)
|
|
rdzv_timeout: int = -1
|
|
max_restarts: int = 3
|
|
monitor_interval: float = 30
|
|
start_method: str = "spawn"
|
|
log_dir: Optional[str] = None
|
|
redirects: Union[Std, Dict[int, Std]] = Std.NONE
|
|
tee: Union[Std, Dict[int, Std]] = Std.NONE
|
|
metrics_cfg: Dict[str, str] = field(default_factory=dict)
|
|
local_addr: Optional[str] = None
|
|
|
|
def __post_init__(self):
|
|
default_timeout = 900
|
|
if self.rdzv_timeout != -1:
|
|
self.rdzv_configs["timeout"] = self.rdzv_timeout
|
|
elif "timeout" not in self.rdzv_configs:
|
|
self.rdzv_configs["timeout"] = default_timeout
|
|
|
|
|
|
class elastic_launch:
|
|
"""
|
|
Launches an torchelastic agent on the container that invoked the entrypoint.
|
|
|
|
1. Pass the ``entrypoint`` arguments as non ``kwargs`` (e.g. no named parameters)/
|
|
``entrypoint`` can be a function or a command.
|
|
2. The return value is a map of each worker's output mapped
|
|
by their respective global rank.
|
|
|
|
Usage
|
|
|
|
::
|
|
|
|
def worker_fn(foo):
|
|
# ...
|
|
|
|
def main():
|
|
# entrypoint is a function.
|
|
outputs = elastic_launch(LaunchConfig, worker_fn)(foo)
|
|
# return rank 0's output
|
|
return outputs[0]
|
|
|
|
# entrypoint is a command and ``script.py`` is the python module.
|
|
outputs = elastic_launch(LaunchConfig, "script.py")(args)
|
|
outputs = elastic_launch(LaunchConfig, "python")("script.py")
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: LaunchConfig,
|
|
entrypoint: Union[Callable, str, None],
|
|
):
|
|
self._config = config
|
|
self._entrypoint = entrypoint
|
|
|
|
def __call__(self, *args):
|
|
return launch_agent(self._config, self._entrypoint, list(args))
|
|
|
|
|
|
def _get_entrypoint_name(
|
|
entrypoint: Union[Callable, str, None], args: List[Any]
|
|
) -> str:
|
|
"""Retrive entrypoint name with the rule:
|
|
1. If entrypoint is a function, use ``entrypont.__qualname__``.
|
|
2. If entrypoint is a string, check its value:
|
|
2.1 if entrypoint equals to ``sys.executable`` (like "python"), use the first element from ``args``
|
|
which does not start with hifen letter (for example, "-u" will be skipped).
|
|
2.2 otherwise, use ``entrypoint`` value.
|
|
3. Otherwise, return empty string.
|
|
"""
|
|
if isinstance(entrypoint, Callable): # type: ignore[arg-type]
|
|
return entrypoint.__name__ # type: ignore[union-attr]
|
|
elif isinstance(entrypoint, str):
|
|
if entrypoint == sys.executable:
|
|
return next((arg for arg in args if arg[0] != "-"), "")
|
|
else:
|
|
return entrypoint
|
|
else:
|
|
return ""
|
|
|
|
|
|
def _get_addr_and_port(
|
|
rdzv_parameters: RendezvousParameters,
|
|
) -> Tuple[Optional[str], Optional[int]]:
|
|
if rdzv_parameters.backend != "static":
|
|
return (None, None)
|
|
endpoint = rdzv_parameters.endpoint
|
|
endpoint = endpoint.strip()
|
|
if not endpoint:
|
|
raise ValueError(
|
|
"Endpoint is missing in endpoint. Try to add --master-addr and --master-port"
|
|
)
|
|
master_addr, master_port = parse_rendezvous_endpoint(endpoint, default_port=-1)
|
|
if master_port == -1:
|
|
raise ValueError(
|
|
f"port is missing in endpoint: {endpoint}. Try to specify --master-port"
|
|
)
|
|
return (master_addr, master_port)
|
|
|
|
|
|
def launch_agent(
|
|
config: LaunchConfig,
|
|
entrypoint: Union[Callable, str, None],
|
|
args: List[Any],
|
|
) -> Dict[int, Any]:
|
|
if not config.run_id:
|
|
run_id = str(uuid.uuid4().int)
|
|
logger.warning(f"config has no run_id, generated a random run_id: {run_id}")
|
|
config.run_id = run_id
|
|
|
|
entrypoint_name = _get_entrypoint_name(entrypoint, args)
|
|
|
|
logger.info(
|
|
f"Starting elastic_operator with launch configs:\n"
|
|
f" entrypoint : {entrypoint_name}\n"
|
|
f" min_nodes : {config.min_nodes}\n"
|
|
f" max_nodes : {config.max_nodes}\n"
|
|
f" nproc_per_node : {config.nproc_per_node}\n"
|
|
f" run_id : {config.run_id}\n"
|
|
f" rdzv_backend : {config.rdzv_backend}\n"
|
|
f" rdzv_endpoint : {config.rdzv_endpoint}\n"
|
|
f" rdzv_configs : {config.rdzv_configs}\n"
|
|
f" max_restarts : {config.max_restarts}\n"
|
|
f" monitor_interval : {config.monitor_interval}\n"
|
|
f" log_dir : {config.log_dir}\n"
|
|
f" metrics_cfg : {config.metrics_cfg}\n"
|
|
)
|
|
|
|
rdzv_parameters = RendezvousParameters(
|
|
backend=config.rdzv_backend,
|
|
endpoint=config.rdzv_endpoint,
|
|
run_id=config.run_id,
|
|
min_nodes=config.min_nodes,
|
|
max_nodes=config.max_nodes,
|
|
local_addr=config.local_addr,
|
|
**config.rdzv_configs,
|
|
)
|
|
|
|
master_addr, master_port = _get_addr_and_port(rdzv_parameters)
|
|
|
|
spec = WorkerSpec(
|
|
role=config.role,
|
|
local_world_size=config.nproc_per_node,
|
|
entrypoint=entrypoint,
|
|
args=tuple(args),
|
|
rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters),
|
|
max_restarts=config.max_restarts,
|
|
monitor_interval=config.monitor_interval,
|
|
redirects=config.redirects,
|
|
tee=config.tee,
|
|
master_addr=master_addr,
|
|
master_port=master_port,
|
|
local_addr=config.local_addr,
|
|
)
|
|
|
|
agent = LocalElasticAgent(
|
|
spec=spec, start_method=config.start_method, log_dir=config.log_dir
|
|
)
|
|
|
|
shutdown_rdzv = True
|
|
try:
|
|
metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg))
|
|
|
|
result = agent.run()
|
|
# records that agent.run() has succeeded NOT that workers have succeeded
|
|
events.record(agent.get_event_succeeded())
|
|
|
|
if result.is_failed():
|
|
# ChildFailedError is treated specially by @record
|
|
# if the error files for the failed children exist
|
|
# @record will copy the first error (root cause)
|
|
# to the error file of the launcher process.
|
|
raise ChildFailedError(
|
|
name=entrypoint_name,
|
|
failures=result.failures,
|
|
)
|
|
|
|
return result.return_values
|
|
except ChildFailedError:
|
|
raise
|
|
except SignalException:
|
|
# when the agent dies with a signal do NOT shutdown the rdzv_handler
|
|
# since this closes the rendezvous on this rdzv_id permanently and
|
|
# prevents any additional scaling events
|
|
shutdown_rdzv = False
|
|
events.record(agent.get_event_failed())
|
|
raise
|
|
except Exception:
|
|
events.record(agent.get_event_failed())
|
|
raise
|
|
finally:
|
|
if shutdown_rdzv:
|
|
spec.rdzv_handler.shutdown()
|