Files
pytorch/torch/distributed/rendezvous.py
Kiuk Chung 9d95d48567 (torch.distributed) Add torch.distributed.is_torchelastic_launched() util method + make init_method=tcp:// compatible with torchelastic (#63910)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63910

Addresses the current issue that `init_method=tcp://` is not compatible with `torch.distributed.run` and `torch.distributed.launch`. When running with a training script that initializes the process group with `init_method=tcp://localhost:$port` as such:

```
$ python -u -m torch.distributed.run --max_restarts 0 --nproc_per_node 1 --nnodes 1 --master_addr $(hostname) --master_port 6000 ~/tmp/test.py
```

An `Address in use` error is raised since the training script tries to create a TCPStore on port 6000, which is already taken since the elastic agent is already running a TCPStore on that port.

For details see: https://github.com/pytorch/pytorch/issues/63874.

This change does a couple of things:

1. Adds `is_torchelastic_launched()` check function that users can use in the training scripts to see whether the script is launched via torchelastic.
1. Update the `torch.distributed` docs page to include the new `is_torchelastic_launched()` function.
1. Makes `init_method=tcp://` torchelastic compatible by modifying `_tcp_rendezvous_handler` in `torch.distributed.rendezvous` (this is NOT the elastic rendezvous, it is the old rendezvous module which is slotted for deprecation in future releases) to check `is_torchelastic_launched()` AND `torchelastic_use_agent_store()` and if so, only create TCPStore clients (no daemons, not even for rank 0).
1. Adds a bunch of unittests to cover the different code paths

NOTE: the issue mentions that we should fail-fast with an assertion on `init_method!=env://` when `is_torchelastic_launched()` is `True`. There are three registered init_methods in pytorch: env://, tcp://, file://. Since this diff makes tcp:// compatible with torchelastic and I've validated that file is compatible with torchelastic. There is no need to add assertions. I did update the docs to point out that env:// is the RECOMMENDED init_method. We should probably deprecate the other init_methods in the future but this is out of scope for this issue.

Test Plan: Unittests.

Reviewed By: cbalioglu

Differential Revision: D30529984

fbshipit-source-id: 267aea6d4dad73eb14a2680ac921f210ff547cc5
2021-08-25 22:57:43 -07:00

240 lines
8.4 KiB
Python

try:
from urllib.parse import urlparse, urlunparse
except ImportError:
raise ImportError(
"urllib cannot be found, urlparse from python2 is no longer supported."
)
import numbers
import os
import sys
from datetime import timedelta
from typing import Dict, Optional, Union
import torch._six as six
from torch.distributed import FileStore, PrefixStore, Store, TCPStore
from .constants import default_pg_timeout
_rendezvous_handlers = {}
def register_rendezvous_handler(scheme, handler):
"""Registers a new rendezvous handler.
Before we can run collective algorithms, participating processes
need to find each other and exchange information to be able to
communicate. We call this process rendezvous.
The outcome of the rendezvous process is a triplet containing a
shared key/value store, the rank of the process, and the total
number of participating processes.
If none of the bundled rendezvous methods apply to your execution
environment you can opt to register your own rendezvous handler.
Pick a unique name and use the URL scheme to identify it when
calling the `rendezvous()` function.
Args:
scheme (str): URL scheme to identify your rendezvous handler.
handler (function): Handler that is invoked when the
`rendezvous()` function is called with a URL that uses
the corresponding scheme. It must be a generator function
that yields the triplet.
"""
global _rendezvous_handlers
if scheme in _rendezvous_handlers:
raise RuntimeError(
"Rendezvous handler for {}:// already registered".format(scheme)
)
_rendezvous_handlers[scheme] = handler
def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs):
if not isinstance(url, six.string_classes):
raise RuntimeError("`url` must be a string. {}: {}".format(type(url), url))
if not isinstance(rank, numbers.Integral):
raise RuntimeError("`rank` must be an integer. {}".format(rank))
if not isinstance(world_size, numbers.Integral):
raise RuntimeError("`world_size` must be an integer. {}".format(world_size))
# Append node-specific arguments.
result = urlparse(url)
if rank != -1 or world_size != -1:
query_dict: Dict[str, Union[int, str]] = dict(
pair.split("=") for pair in filter(None, result.query.split("&"))
)
assert (
"rank" not in query_dict and "world_size" not in query_dict
), "The url: {url} has node-specific arguments(rank, world_size) already.".format(
url=url
)
if rank != -1:
query_dict["rank"] = rank
if world_size != -1:
query_dict["world_size"] = world_size
result = result._replace(
query="{}".format(
"&".join(["{}={}".format(k, v) for k, v in query_dict.items()])
)
)
url = urlunparse(result)
if result.scheme not in _rendezvous_handlers:
raise RuntimeError("No rendezvous handler for {}://".format(result.scheme))
return _rendezvous_handlers[result.scheme](url, **kwargs)
def _rendezvous_error(msg):
return ValueError("Error initializing torch.distributed using " + msg)
def _file_rendezvous_handler(url: str, **kwargs):
def _error(msg):
return _rendezvous_error("file:// rendezvous: " + msg)
result = urlparse(url)
path = result.path
if sys.platform == "win32":
import urllib.request
full_path = result.netloc + result.path
path = urllib.request.url2pathname(full_path)
if path:
# Normalizing an empty string produces ".", which is not expected.
path = os.path.normpath(path)
if not path:
raise _error("path missing")
query: Dict[str, str]
# mypy doesn't allow dict() to accept List of values (#257)
query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) # type: ignore[misc, arg-type]
if "rank" not in query:
raise _error("rank parameter missing")
if "world_size" not in query:
raise _error("world size parameter missing")
rank = int(query["rank"])
world_size = int(query["world_size"])
store = FileStore(path, world_size)
yield (store, rank, world_size)
# If this configuration is invalidated, there is nothing we can do about it
raise RuntimeError("Unable to perform rerendezvous using file:// method")
def _torchelastic_use_agent_store() -> bool:
return os.environ.get("TORCHELASTIC_USE_AGENT_STORE", None) == str(True)
def _create_c10d_store(hostname, port, rank, world_size, timeout) -> Store:
"""
Smartly creates a c10d Store object on ``rank`` based on whether
we need to re-use agent store. The TCPStore server is assumed to be hosted
on ``hostname:port``.
If ``torchelastic_use_agent_store()`` is ``True``, then it is assumed that
the agent leader (node rank 0) hosts the TCPStore server (for which the
endpoint is specified by the given ``hostname:port``). Hence
ALL ranks will create and return a TCPStore client (e.g. ``start_daemon=False``).
If ``torchelastic_use_agent_store()`` is ``False``, then rank 0 will host
the TCPStore (with multi-tenancy) and it is assumed that rank 0's hostname
and port are correctly passed via ``hostname`` and ``port``. All
non-zero ranks will create and return a TCPStore client.
"""
if _torchelastic_use_agent_store():
attempt = os.environ["TORCHELASTIC_RESTART_COUNT"]
tcp_store = TCPStore(hostname, port, world_size, False, timeout)
return PrefixStore(f"/worker/attempt_{attempt}", tcp_store)
else:
start_daemon = rank == 0
return TCPStore(
hostname, port, world_size, start_daemon, timeout, multi_tenant=True
)
def _tcp_rendezvous_handler(
url: str, timeout: timedelta = default_pg_timeout, **kwargs
):
def _error(msg):
return _rendezvous_error("tcp:// rendezvous: " + msg)
result = urlparse(url)
if not result.port:
raise _error("port number missing")
query: Dict[str, Union[int, str]]
# mypy doesn't allow dict() to accept List of values (#257)
query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) # type: ignore[misc, arg-type]
if "rank" not in query:
raise _error("rank parameter missing")
if "world_size" not in query:
raise _error("world size parameter missing")
rank = int(query["rank"])
world_size = int(query["world_size"])
assert result.hostname is not None
store = _create_c10d_store(result.hostname, result.port, rank, world_size, timeout)
yield (store, rank, world_size)
# If this configuration is invalidated, there is nothing we can do about it
raise RuntimeError("Unable to perform re-rendezvous using tcp:// method")
def _env_rendezvous_handler(
url: str, timeout: timedelta = default_pg_timeout, **kwargs
):
def _error(msg):
return _rendezvous_error("env:// rendezvous: " + msg)
def _env_error(var):
return _error("environment variable %s expected, but not set" % var)
def _get_env_or_raise(env_var: str) -> str:
env_val = os.environ.get(env_var, None)
if not env_val:
raise _env_error(env_var)
else:
return env_val
result = urlparse(url)
query: Dict[str, Union[int, str]]
# mypy doesn't allow dict() to accept List of values (#257)
query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) # type: ignore[misc, arg-type]
rank: Optional[Union[str, int]]
world_size: Optional[Union[str, int]]
master_port: Optional[Union[str, int]]
if "rank" in query:
rank = int(query["rank"])
else:
rank = int(_get_env_or_raise("RANK"))
if "world_size" in query:
world_size = int(query["world_size"])
else:
world_size = int(_get_env_or_raise("WORLD_SIZE"))
master_addr = _get_env_or_raise("MASTER_ADDR")
master_port = int(_get_env_or_raise("MASTER_PORT"))
store = _create_c10d_store(master_addr, master_port, rank, world_size, timeout)
yield (store, rank, world_size)
# If this configuration is invalidated, there is nothing we can do about it
raise RuntimeError("Unable to perform re-rendezvous using env:// method")
register_rendezvous_handler("tcp", _tcp_rendezvous_handler)
register_rendezvous_handler("env", _env_rendezvous_handler)
register_rendezvous_handler("file", _file_rendezvous_handler)