[BE][Easy] enable UFMT for torch/distributed/ (#128870)

Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128870
Approved by: https://github.com/fegin
ghstack dependencies: #128868, #128869
This commit is contained in:
Xuehai Pan
2024-06-18 23:21:45 +08:00
committed by PyTorch MergeBot
parent 3b798df853
commit a0e1e20c41
36 changed files with 584 additions and 299 deletions

View File

@ -10,7 +10,7 @@ import numbers
import os
import sys
from datetime import timedelta
from typing import Dict, Optional, Callable, Iterator, Tuple
from typing import Callable, Dict, Iterator, Optional, Tuple
from torch.distributed import FileStore, PrefixStore, Store, TCPStore
@ -21,6 +21,7 @@ _rendezvous_handlers: Dict[str, Callable[..., Iterator[Tuple[Store, int, int]]]]
__all__ = ["register_rendezvous_handler", "rendezvous"]
def register_rendezvous_handler(scheme, handler):
"""
Register a new rendezvous handler.
@ -47,16 +48,17 @@ def register_rendezvous_handler(scheme, handler):
"""
global _rendezvous_handlers
if scheme in _rendezvous_handlers:
raise RuntimeError(
f"Rendezvous handler for {scheme}:// already registered"
)
raise RuntimeError(f"Rendezvous handler for {scheme}:// already registered")
_rendezvous_handlers[scheme] = handler
# Query will have format "rank=0&world_size=1" and is
# converted into {"rank": 0, "world_size": 1}
def _query_to_dict(query: str) -> Dict[str, str]:
return {pair[0]: pair[1] for pair in (pair.split("=") for pair in filter(None, query.split("&")))}
return {
pair[0]: pair[1]
for pair in (pair.split("=") for pair in filter(None, query.split("&")))
}
def _get_use_libuv_from_query_dict(query_dict: Dict[str, str]) -> bool:
@ -152,7 +154,9 @@ def _torchelastic_use_agent_store() -> bool:
return os.environ.get("TORCHELASTIC_USE_AGENT_STORE", None) == str(True)
def _create_c10d_store(hostname, port, rank, world_size, timeout, use_libuv=True) -> Store:
def _create_c10d_store(
hostname, port, rank, world_size, timeout, use_libuv=True
) -> Store:
"""
Smartly creates a c10d Store object on ``rank`` based on whether we need to re-use agent store.
@ -183,7 +187,13 @@ def _create_c10d_store(hostname, port, rank, world_size, timeout, use_libuv=True
else:
start_daemon = rank == 0
return TCPStore(
hostname, port, world_size, start_daemon, timeout, multi_tenant=True, use_libuv=use_libuv
hostname,
port,
world_size,
start_daemon,
timeout,
multi_tenant=True,
use_libuv=use_libuv,
)
@ -208,7 +218,9 @@ def _tcp_rendezvous_handler(
assert result.hostname is not None
store = _create_c10d_store(result.hostname, result.port, rank, world_size, timeout, use_libuv)
store = _create_c10d_store(
result.hostname, result.port, rank, world_size, timeout, use_libuv
)
yield (store, rank, world_size)
@ -250,12 +262,13 @@ def _env_rendezvous_handler(
else:
world_size = int(_get_env_or_raise("WORLD_SIZE"))
master_addr = _get_env_or_raise("MASTER_ADDR")
master_port = int(_get_env_or_raise("MASTER_PORT"))
use_libuv = _get_use_libuv_from_query_dict(query_dict)
store = _create_c10d_store(master_addr, master_port, rank, world_size, timeout, use_libuv)
store = _create_c10d_store(
master_addr, master_port, rank, world_size, timeout, use_libuv
)
yield (store, rank, world_size)