mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][Easy] enable UFMT for torch/distributed/
(#128870)
Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128870 Approved by: https://github.com/fegin ghstack dependencies: #128868, #128869
This commit is contained in:
committed by
PyTorch MergeBot
parent
3b798df853
commit
a0e1e20c41
@ -10,7 +10,7 @@ import numbers
|
||||
import os
|
||||
import sys
|
||||
from datetime import timedelta
|
||||
from typing import Dict, Optional, Callable, Iterator, Tuple
|
||||
from typing import Callable, Dict, Iterator, Optional, Tuple
|
||||
|
||||
from torch.distributed import FileStore, PrefixStore, Store, TCPStore
|
||||
|
||||
@ -21,6 +21,7 @@ _rendezvous_handlers: Dict[str, Callable[..., Iterator[Tuple[Store, int, int]]]]
|
||||
|
||||
__all__ = ["register_rendezvous_handler", "rendezvous"]
|
||||
|
||||
|
||||
def register_rendezvous_handler(scheme, handler):
|
||||
"""
|
||||
Register a new rendezvous handler.
|
||||
@ -47,16 +48,17 @@ def register_rendezvous_handler(scheme, handler):
|
||||
"""
|
||||
global _rendezvous_handlers
|
||||
if scheme in _rendezvous_handlers:
|
||||
raise RuntimeError(
|
||||
f"Rendezvous handler for {scheme}:// already registered"
|
||||
)
|
||||
raise RuntimeError(f"Rendezvous handler for {scheme}:// already registered")
|
||||
_rendezvous_handlers[scheme] = handler
|
||||
|
||||
|
||||
# Query will have format "rank=0&world_size=1" and is
|
||||
# converted into {"rank": 0, "world_size": 1}
|
||||
def _query_to_dict(query: str) -> Dict[str, str]:
|
||||
return {pair[0]: pair[1] for pair in (pair.split("=") for pair in filter(None, query.split("&")))}
|
||||
return {
|
||||
pair[0]: pair[1]
|
||||
for pair in (pair.split("=") for pair in filter(None, query.split("&")))
|
||||
}
|
||||
|
||||
|
||||
def _get_use_libuv_from_query_dict(query_dict: Dict[str, str]) -> bool:
|
||||
@ -152,7 +154,9 @@ def _torchelastic_use_agent_store() -> bool:
|
||||
return os.environ.get("TORCHELASTIC_USE_AGENT_STORE", None) == str(True)
|
||||
|
||||
|
||||
def _create_c10d_store(hostname, port, rank, world_size, timeout, use_libuv=True) -> Store:
|
||||
def _create_c10d_store(
|
||||
hostname, port, rank, world_size, timeout, use_libuv=True
|
||||
) -> Store:
|
||||
"""
|
||||
Smartly creates a c10d Store object on ``rank`` based on whether we need to re-use agent store.
|
||||
|
||||
@ -183,7 +187,13 @@ def _create_c10d_store(hostname, port, rank, world_size, timeout, use_libuv=True
|
||||
else:
|
||||
start_daemon = rank == 0
|
||||
return TCPStore(
|
||||
hostname, port, world_size, start_daemon, timeout, multi_tenant=True, use_libuv=use_libuv
|
||||
hostname,
|
||||
port,
|
||||
world_size,
|
||||
start_daemon,
|
||||
timeout,
|
||||
multi_tenant=True,
|
||||
use_libuv=use_libuv,
|
||||
)
|
||||
|
||||
|
||||
@ -208,7 +218,9 @@ def _tcp_rendezvous_handler(
|
||||
|
||||
assert result.hostname is not None
|
||||
|
||||
store = _create_c10d_store(result.hostname, result.port, rank, world_size, timeout, use_libuv)
|
||||
store = _create_c10d_store(
|
||||
result.hostname, result.port, rank, world_size, timeout, use_libuv
|
||||
)
|
||||
|
||||
yield (store, rank, world_size)
|
||||
|
||||
@ -250,12 +262,13 @@ def _env_rendezvous_handler(
|
||||
else:
|
||||
world_size = int(_get_env_or_raise("WORLD_SIZE"))
|
||||
|
||||
|
||||
master_addr = _get_env_or_raise("MASTER_ADDR")
|
||||
master_port = int(_get_env_or_raise("MASTER_PORT"))
|
||||
use_libuv = _get_use_libuv_from_query_dict(query_dict)
|
||||
|
||||
store = _create_c10d_store(master_addr, master_port, rank, world_size, timeout, use_libuv)
|
||||
store = _create_c10d_store(
|
||||
master_addr, master_port, rank, world_size, timeout, use_libuv
|
||||
)
|
||||
|
||||
yield (store, rank, world_size)
|
||||
|
||||
|
Reference in New Issue
Block a user