[17/n][torch/elastic] Make torchelastic launcher compatible with the caffe2.distributed.launch (#55687)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55687

The diff makes sure that users can transfer the following parameters:
* master_addr
* master_port
* node_rank
* use_env

The diff implement StaticTCPRendezvous that creates a store with listener on agent rank #0

The diff modifies caffe2/rendezvous: If the worker process launched with torchelastic agent, the worker processes will create a PrefixStore("worker/") from TCPStore without listener.

The diff adds macros functionality to torch/distributed/ealstic/utils that helps to resolve local_rank parameter.

Test Plan: buck test mode/dev-nosan //pytorch/elastic/torchelastic/distributed/test:launch_test

Reviewed By: cbalioglu, wilson100hong

Differential Revision: D27643206

fbshipit-source-id: 540fb26feac322cc3ec0a989fe53324755ccc4ea
This commit is contained in:
Aliaksandr Ivanou
2021-04-14 19:31:42 -07:00
committed by Facebook GitHub Bot
parent c5f9e043e9
commit 8f663170bd
17 changed files with 487 additions and 55 deletions

View File

@ -9,7 +9,7 @@ import os
import sys
from datetime import timedelta
from typing import Optional, Dict, Union
from torch._C._distributed_c10d import FileStore, TCPStore
from torch.distributed import FileStore, TCPStore, PrefixStore
from .constants import default_pg_timeout
_rendezvous_handlers = {}
@ -149,6 +149,13 @@ def _env_rendezvous_handler(url: str, timeout: timedelta = default_pg_timeout, *
def _env_error(var):
return _error("environment variable %s expected, but not set" % var)
def _get_env_or_raise(env_var: str) -> str:
env_val = os.environ.get(env_var, None)
if not env_val:
raise _env_error(env_var)
else:
return env_val
result = urlparse(url)
query: Dict[str, Union[int, str]]
# mypy doesn't allow dict() to accept List of values (#257)
@ -161,34 +168,33 @@ def _env_rendezvous_handler(url: str, timeout: timedelta = default_pg_timeout, *
if "rank" in query:
rank = int(query["rank"])
else:
rank = os.environ.get("RANK", None)
if rank is None:
raise _env_error("RANK")
rank = int(_get_env_or_raise("RANK"))
if "world_size" in query:
world_size = int(query["world_size"])
else:
world_size = os.environ.get("WORLD_SIZE", None)
if world_size is None:
raise _env_error("WORLD_SIZE")
world_size = int(_get_env_or_raise("WORLD_SIZE"))
master_addr = os.environ.get("MASTER_ADDR", None)
if master_addr is None:
raise _env_error("MASTER_ADDR")
master_addr = _get_env_or_raise("MASTER_ADDR")
master_port = int(_get_env_or_raise("MASTER_PORT"))
master_port = os.environ.get("MASTER_PORT", None)
if master_port is None:
raise _env_error("MASTER_PORT")
# Converting before creating the store
rank = int(rank)
world_size = int(world_size)
master_port = int(master_port)
use_torchelastic_store = os.environ.get("TORCHELASTIC_USE_AGENT_STORE", None)
# Now start the TCP store daemon on the rank 0
start_daemon = rank == 0
store = TCPStore(master_addr, master_port, world_size, start_daemon, timeout)
yield (store, rank, world_size)
if use_torchelastic_store == str(True):
worker_process_prefix = "/worker"
# When TORCHELASTIC_USE_AGENT_STORE is set up, the worker process is assumed
# to be invoked by the torchelastic agent. Torchelastic agent creates a tcp daemon thread
# on the GROUP_RANK=0, as a result all user worker processes should create store with: daemon=False
tcp_store = TCPStore(master_addr, master_port, world_size, False, timeout)
# Each if-else condition returns due to: https://github.com/python/mypy/issues/1191
yield (PrefixStore(worker_process_prefix, tcp_store), rank, world_size)
else:
# Start the TCP store daemon on the rank 0
start_daemon = rank == 0
store = TCPStore(master_addr, master_port, world_size, start_daemon, timeout)
# Each if-else condition returns due to: https://github.com/python/mypy/issues/1191
yield (store, rank, world_size)
# If this configuration is invalidated, there is nothing we can do about it
raise RuntimeError("Unable to perform rerendezvous using env:// method")