mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
D53335860 and D56435815 added an option to torch elastic allowing users to choose a TCPStore backend type to use via 1) explicit argument passing in user code when instantiating `MastRendezvousHandler` 2) pass `--use_libuv` command line argument to `torchrun`. The motivation was to offer a quick way to roll back to non-libuv TCPStore backend since we were making libuv the default in `c10d` code. Now we think that it's better to have torch elastic to not realize the TCPStore backend type but rely on `c10d`'s mechanism to decide which backend to use for torch elastic as well. In this sense, the TCPStore backend type used by torch elastic will be identical to that in pytorch. PyTorch TCPStore uses the environment variable `USE_LIBUV` to determine the backend type: when `USE_LIBUV="0"`, the non-libuv backend will be used. when `USE_LIBUV="1"`, the libuv backend will be used. And this is the default option. Differential Revision: [D58259590](https://our.internmc.facebook.com/intern/diff/D58259590/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134882 Approved by: https://github.com/shuqiangzhang
129 lines
3.6 KiB
Python
129 lines
3.6 KiB
Python
#!/usr/bin/env python3
|
|
# mypy: allow-untyped-defs
|
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import datetime
|
|
import logging
|
|
from typing import cast, Optional
|
|
|
|
from torch.distributed import PrefixStore, Store, TCPStore
|
|
from torch.distributed.elastic.rendezvous import (
|
|
RendezvousHandler,
|
|
RendezvousInfo,
|
|
RendezvousParameters,
|
|
RendezvousStoreInfo,
|
|
)
|
|
from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
|
|
|
|
|
|
__all__ = ["StaticTCPRendezvous", "create_rdzv_handler"]
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_default_timeout_seconds = 600
|
|
|
|
|
|
class StaticTCPRendezvous(RendezvousHandler):
|
|
"""
|
|
Static rendezvous that is a wrapper around the TCPStore.
|
|
|
|
Creates TCPStore based on the input parameters with the
|
|
listener on the agent with group_rank=0
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
master_addr: str,
|
|
master_port: int,
|
|
rank: int,
|
|
world_size: int,
|
|
run_id: str,
|
|
timeout: int,
|
|
):
|
|
self.master_addr = master_addr
|
|
self.master_port = master_port
|
|
self.rank = rank
|
|
self.world_size = world_size
|
|
self.run_id = run_id
|
|
self.timeout = datetime.timedelta(seconds=timeout)
|
|
self._store: Optional[Store] = None
|
|
|
|
def get_backend(self) -> str:
|
|
return "static"
|
|
|
|
@property
|
|
def use_agent_store(self) -> bool:
|
|
return True
|
|
|
|
def next_rendezvous(self) -> RendezvousInfo:
|
|
logger.info("Creating TCPStore as the c10d::Store implementation")
|
|
is_master = self.rank == 0
|
|
if not self._store:
|
|
self._store = TCPStore( # type: ignore[call-arg]
|
|
self.master_addr,
|
|
self.master_port,
|
|
self.world_size,
|
|
is_master,
|
|
self.timeout,
|
|
multi_tenant=True,
|
|
)
|
|
store = PrefixStore(self.run_id, self._store)
|
|
# TCPStore server instance is used by trainer code
|
|
bootstrap_store_info = RendezvousStoreInfo(self.master_addr, self.master_port)
|
|
return RendezvousInfo(
|
|
store,
|
|
self.rank,
|
|
self.world_size,
|
|
bootstrap_store_info,
|
|
)
|
|
|
|
def is_closed(self):
|
|
return False
|
|
|
|
def set_closed(self):
|
|
pass
|
|
|
|
def num_nodes_waiting(self):
|
|
return 0
|
|
|
|
def get_run_id(self) -> str:
|
|
return self.run_id
|
|
|
|
def shutdown(self) -> bool:
|
|
return True
|
|
|
|
|
|
def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler:
|
|
if "rank" not in params.config:
|
|
raise ValueError(
|
|
"rank is absent in RendezvousParameters."
|
|
"Try add --node-rank to the cmd request"
|
|
)
|
|
endpoint = params.endpoint.strip()
|
|
if not endpoint:
|
|
raise ValueError(
|
|
"endpoint is absent in RendezvousParameters"
|
|
"Try add --master-port and --master-addr to the cmd request"
|
|
)
|
|
master_addr, master_port = parse_rendezvous_endpoint(endpoint, -1)
|
|
if master_port == -1:
|
|
raise ValueError(
|
|
f"Port is absent in endpoint: {endpoint}. Try launching with --master-port"
|
|
)
|
|
world_size = params.max_nodes
|
|
rank = cast(int, params.config.get("rank"))
|
|
run_id = params.run_id
|
|
if "timeout" in params.config:
|
|
timeout = int(params.config["timeout"])
|
|
else:
|
|
timeout = _default_timeout_seconds
|
|
|
|
return StaticTCPRendezvous(
|
|
master_addr, master_port, rank, world_size, run_id, timeout
|
|
)
|