mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Dynamic RPC] Allow for optional world_size argument in init_rpc (#73372)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73372 This PR which allows for optional `world_size` argument in init_rpc. This makes changes in rendezvous to allow for `NoneType` for world_size and creates a new code path when initializing TensorPipe agent for init_rpc. The TensorPipe agent is protected by a critical section enforced using the store, so that only one node can create a TPAgent at a time. This PR does not yet enable RPC commands between ranks. Previously: ```python os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '29500' init_rpc("worker0", world_size=1, rank=0) ``` Now (only rank is needed): ```python os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '29500' init_rpc("worker0", rank=0) ``` Test Plan: Imported from OSS Reviewed By: mrshenli Differential Revision: D34621651 Pulled By: H-Huang fbshipit-source-id: 09dbb511d5a00c219a6ce0a35501ff2e388998b0 (cherry picked from commit 834aedc3256167399c323169ef2f0c9b3cf98dff)
This commit is contained in:
committed by
PyTorch MergeBot
parent
09f32eba7a
commit
f76d1c022e
@ -9,7 +9,7 @@ import numbers
|
||||
import os
|
||||
import sys
|
||||
from datetime import timedelta
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import cast, Dict, Iterable, Optional, Tuple, Union
|
||||
|
||||
import torch._six as six
|
||||
from torch.distributed import FileStore, PrefixStore, Store, TCPStore
|
||||
@ -50,6 +50,13 @@ def register_rendezvous_handler(scheme, handler):
|
||||
)
|
||||
_rendezvous_handlers[scheme] = handler
|
||||
|
||||
# Query will have format "rank=0&world_size=1" and is
|
||||
# converted into {"rank": 0, "world_size": 1}
|
||||
def _query_to_dict(query):
|
||||
query_dict: Dict[str, str] = dict(
|
||||
cast(Tuple[str, str], pair.split("=")) for pair in cast(Iterable[str], filter(None, query.split("&")))
|
||||
)
|
||||
return query_dict
|
||||
|
||||
def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs):
|
||||
if not isinstance(url, six.string_classes):
|
||||
@ -64,9 +71,7 @@ def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs):
|
||||
# Append node-specific arguments.
|
||||
result = urlparse(url)
|
||||
if rank != -1 or world_size != -1:
|
||||
query_dict: Dict[str, Union[int, str]] = dict(
|
||||
pair.split("=") for pair in filter(None, result.query.split("&"))
|
||||
)
|
||||
query_dict = _query_to_dict(result.query)
|
||||
assert (
|
||||
"rank" not in query_dict and "world_size" not in query_dict
|
||||
), "The url: {url} has node-specific arguments(rank, world_size) already.".format(
|
||||
@ -88,6 +93,34 @@ def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs):
|
||||
raise RuntimeError("No rendezvous handler for {}://".format(result.scheme))
|
||||
return _rendezvous_handlers[result.scheme](url, **kwargs)
|
||||
|
||||
def _create_store_from_options(backend_options, rank):
|
||||
result = urlparse(backend_options.init_method)
|
||||
|
||||
# If using env initialization, get rank and world_size from env
|
||||
world_size = -1
|
||||
if result.scheme == "env":
|
||||
rank = os.environ.get("RANK", rank)
|
||||
# Here, the world_size has already beeen initialized to -1 in init_rpc
|
||||
# If the world_size env variable is also not present then it is a dynamic group
|
||||
world_size = int(os.environ.get("WORLD_SIZE", world_size))
|
||||
|
||||
query_dict = _query_to_dict(result.query)
|
||||
# if rank is -1 then intentionally exclude rank for the query, error will be thrown later
|
||||
if rank != -1:
|
||||
query_dict["rank"] = rank
|
||||
query_dict["world_size"] = world_size
|
||||
|
||||
result = result._replace(
|
||||
query="{}".format(
|
||||
"&".join(["{}={}".format(k, v) for k, v in query_dict.items()])
|
||||
)
|
||||
)
|
||||
|
||||
url = urlunparse(result)
|
||||
if result.scheme not in _rendezvous_handlers:
|
||||
raise RuntimeError("No handler for {}://".format(result.scheme))
|
||||
store, _, _ = next(_rendezvous_handlers[result.scheme](url))
|
||||
return store
|
||||
|
||||
def _rendezvous_error(msg):
|
||||
return ValueError("Error initializing torch.distributed using " + msg)
|
||||
@ -110,16 +143,14 @@ def _file_rendezvous_handler(url: str, **kwargs):
|
||||
|
||||
if not path:
|
||||
raise _error("path missing")
|
||||
query: Dict[str, str]
|
||||
# mypy doesn't allow dict() to accept List of values (#257)
|
||||
query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) # type: ignore[misc, arg-type]
|
||||
if "rank" not in query:
|
||||
query_dict = _query_to_dict(result.query)
|
||||
if "rank" not in query_dict:
|
||||
raise _error("rank parameter missing")
|
||||
if "world_size" not in query:
|
||||
if "world_size" not in query_dict:
|
||||
raise _error("world size parameter missing")
|
||||
|
||||
rank = int(query["rank"])
|
||||
world_size = int(query["world_size"])
|
||||
rank = int(query_dict["rank"])
|
||||
world_size = int(query_dict["world_size"])
|
||||
store = FileStore(path, world_size)
|
||||
yield (store, rank, world_size)
|
||||
|
||||
@ -171,16 +202,14 @@ def _tcp_rendezvous_handler(
|
||||
result = urlparse(url)
|
||||
if not result.port:
|
||||
raise _error("port number missing")
|
||||
query: Dict[str, Union[int, str]]
|
||||
# mypy doesn't allow dict() to accept List of values (#257)
|
||||
query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) # type: ignore[misc, arg-type]
|
||||
if "rank" not in query:
|
||||
query_dict = _query_to_dict(result.query)
|
||||
if "rank" not in query_dict:
|
||||
raise _error("rank parameter missing")
|
||||
if "world_size" not in query:
|
||||
if "world_size" not in query_dict:
|
||||
raise _error("world size parameter missing")
|
||||
|
||||
rank = int(query["rank"])
|
||||
world_size = int(query["world_size"])
|
||||
rank = int(query_dict["rank"])
|
||||
world_size = int(query_dict["world_size"])
|
||||
assert result.hostname is not None
|
||||
|
||||
store = _create_c10d_store(result.hostname, result.port, rank, world_size, timeout)
|
||||
@ -208,21 +237,19 @@ def _env_rendezvous_handler(
|
||||
return env_val
|
||||
|
||||
result = urlparse(url)
|
||||
query: Dict[str, Union[int, str]]
|
||||
# mypy doesn't allow dict() to accept List of values (#257)
|
||||
query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) # type: ignore[misc, arg-type]
|
||||
query_dict: Dict[str, Union[int, str]] = _query_to_dict(result.query)
|
||||
|
||||
rank: Optional[Union[str, int]]
|
||||
world_size: Optional[Union[str, int]]
|
||||
master_port: Optional[Union[str, int]]
|
||||
|
||||
if "rank" in query:
|
||||
rank = int(query["rank"])
|
||||
if "rank" in query_dict:
|
||||
rank = int(query_dict["rank"])
|
||||
else:
|
||||
rank = int(_get_env_or_raise("RANK"))
|
||||
|
||||
if "world_size" in query:
|
||||
world_size = int(query["world_size"])
|
||||
if "world_size" in query_dict:
|
||||
world_size = int(query_dict["world_size"])
|
||||
else:
|
||||
world_size = int(_get_env_or_raise("WORLD_SIZE"))
|
||||
|
||||
|
Reference in New Issue
Block a user