mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47534 Test Plan: Imported from OSS Reviewed By: walterddr Differential Revision: D24952497 Pulled By: xuzhao9 fbshipit-source-id: 063bfd0707198436fcfd9431f72f9a392bc0017e
This commit is contained in:
committed by
Facebook GitHub Bot
parent
7f66fa62ca
commit
49f0e5dfeb
@ -1,12 +1,14 @@
|
||||
try:
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
except ImportError:
|
||||
from urlparse import urlparse, urlunparse
|
||||
raise ImportError("urllib cannot be found, urlparse from python2 is no longer supported.")
|
||||
|
||||
import torch._six as six
|
||||
import numbers
|
||||
import os
|
||||
import sys
|
||||
from datetime import timedelta
|
||||
from typing import Optional, Dict, Union
|
||||
from torch._C._distributed_c10d import FileStore
|
||||
from .constants import default_pg_timeout
|
||||
|
||||
@ -47,7 +49,7 @@ def register_rendezvous_handler(scheme, handler):
|
||||
_rendezvous_handlers[scheme] = handler
|
||||
|
||||
|
||||
def rendezvous(url, rank=-1, world_size=-1, **kwargs):
|
||||
def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs):
|
||||
if not isinstance(url, six.string_classes):
|
||||
raise RuntimeError("`url` must be a string. {}: {}".format(type(url), url))
|
||||
|
||||
@ -60,8 +62,9 @@ def rendezvous(url, rank=-1, world_size=-1, **kwargs):
|
||||
# Append node-specific arguments.
|
||||
result = urlparse(url)
|
||||
if rank != -1 or world_size != -1:
|
||||
query_dict = dict(
|
||||
pair.split("=") for pair in filter(None, result.query.split("&"))
|
||||
query_dict: Dict[str, Union[int, str]] = dict(
|
||||
# mypy doesn't allow dict() to accept List of values (#257)
|
||||
pair.split("=") for pair in filter(None, result.query.split("&")) # type: ignore[arg-type, misc]
|
||||
)
|
||||
assert (
|
||||
"rank" not in query_dict and "world_size" not in query_dict
|
||||
@ -87,7 +90,7 @@ def _rendezvous_error(msg):
|
||||
return ValueError("Error initializing torch.distributed using " + msg)
|
||||
|
||||
|
||||
def _file_rendezvous_handler(url, **kwargs):
|
||||
def _file_rendezvous_handler(url: str, **kwargs):
|
||||
def _error(msg):
|
||||
return _rendezvous_error("file:// rendezvous: " + msg)
|
||||
|
||||
@ -99,7 +102,9 @@ def _file_rendezvous_handler(url, **kwargs):
|
||||
|
||||
if not path:
|
||||
raise _error("path missing")
|
||||
query = dict(pair.split("=") for pair in filter(None, result.query.split("&")))
|
||||
query: Dict[str, str]
|
||||
# mypy doesn't allow dict() to accept List of values (#257)
|
||||
query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) # type: ignore[misc, arg-type]
|
||||
if "rank" not in query:
|
||||
raise _error("rank parameter missing")
|
||||
if "world_size" not in query:
|
||||
@ -114,14 +119,16 @@ def _file_rendezvous_handler(url, **kwargs):
|
||||
raise RuntimeError("Unable to perform rerendezvous using file:// method")
|
||||
|
||||
|
||||
def _tcp_rendezvous_handler(url, timeout=default_pg_timeout, **kwargs):
|
||||
def _tcp_rendezvous_handler(url: str, timeout: timedelta = default_pg_timeout, **kwargs):
|
||||
def _error(msg):
|
||||
return _rendezvous_error("tcp:// rendezvous: " + msg)
|
||||
|
||||
result = urlparse(url)
|
||||
if not result.port:
|
||||
raise _error("port number missing")
|
||||
query = dict(pair.split("=") for pair in filter(None, result.query.split("&")))
|
||||
query: Dict[str, Union[int, str]]
|
||||
# mypy doesn't allow dict() to accept List of values (#257)
|
||||
query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) # type: ignore[misc, arg-type]
|
||||
if "rank" not in query:
|
||||
raise _error("rank parameter missing")
|
||||
if "world_size" not in query:
|
||||
@ -130,6 +137,7 @@ def _tcp_rendezvous_handler(url, timeout=default_pg_timeout, **kwargs):
|
||||
rank = int(query["rank"])
|
||||
world_size = int(query["world_size"])
|
||||
start_daemon = rank == 0
|
||||
assert result.hostname is not None
|
||||
store = TCPStore(result.hostname, result.port, world_size, start_daemon, timeout)
|
||||
yield (store, rank, world_size)
|
||||
|
||||
@ -137,7 +145,7 @@ def _tcp_rendezvous_handler(url, timeout=default_pg_timeout, **kwargs):
|
||||
raise RuntimeError("Unable to perform rerendezvous using tcp:// method")
|
||||
|
||||
|
||||
def _env_rendezvous_handler(url, timeout=default_pg_timeout, **kwargs):
|
||||
def _env_rendezvous_handler(url: str, timeout: timedelta = default_pg_timeout, **kwargs):
|
||||
def _error(msg):
|
||||
return _rendezvous_error("env:// rendezvous: " + msg)
|
||||
|
||||
@ -145,7 +153,13 @@ def _env_rendezvous_handler(url, timeout=default_pg_timeout, **kwargs):
|
||||
return _error("environment variable %s expected, but not set" % var)
|
||||
|
||||
result = urlparse(url)
|
||||
query = dict(pair.split("=") for pair in filter(None, result.query.split("&")))
|
||||
query: Dict[str, Union[int, str]]
|
||||
# mypy doesn't allow dict() to accept List of values (#257)
|
||||
query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) # type: ignore[misc, arg-type]
|
||||
|
||||
rank: Optional[Union[str, int]]
|
||||
world_size: Optional[Union[str, int]]
|
||||
master_port: Optional[Union[str, int]]
|
||||
|
||||
if "rank" in query:
|
||||
rank = int(query["rank"])
|
||||
|
Reference in New Issue
Block a user