[1/n][torch/elastic][upstream] Move torchelastic/rendezvous to torch/distributed/rendezvous (#53172)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53172

Pull Request resolved: https://github.com/pytorch/elastic/pull/141

Upstreams two modules to torch:

1. `torchelastic.rendezvous`
2. `torchelastic.utils`

These modules were chosen as `[1/n]` since they are the leaf modules in torchelastic.

==== NOTES: ====
1. I'm disabling etcd_rendezvous and etcd_server tests in CIRCLECI for the moment since I need to edit the test dockers to contain the etcd server binary (there's 4-5 test dockers - one for each platform so this is going to take some time for me to set up the environments and test) - T85992919.

2. I've fixed all lint errors on python files but there are ones on the cpp files on the ZeusRendezvous. I took a look at them, and I don't want to fix the linter errors right now for 2 major reasons:
     1. Some of them are more than formatting changes (e.g. std::move vs pass by value) and I don't want to introduce bundled changes with the move
     1. The old rendezvous code (the one we forked from in caffe2/fb) has the same problems and I think its better for us to deal with this when we deprecate caffe2/fb/rendezvous in favor of the one in torchelastic -T86012579.

Test Plan:
```
buck test mode/dev-nosan //caffe2/torch/distributed/elastic/utils/test/...
buck test mode/dev-nosan //caffe2/torch/distributed/elastic/utils/data/test/...
buck test mode/dev-nosan //caffe2/torch/distributed/elastic/rendezvous/test/...
buck test mode/dev-nosan //caffe2/torch/distributed/elastic/rendezvous/fb/...
buck test mode/dev-nosan //pytorch/elastic/torchelastic/...
```
\+ Sandcastle

Reviewed By: H-Huang

Differential Revision: D26718746

fbshipit-source-id: 67cc0350c3d847221cb3c3038f98f47915362f51
This commit is contained in:
Kiuk Chung
2021-03-05 11:24:25 -08:00
committed by Facebook GitHub Bot
parent 14fa47631b
commit ba75cedfc5
25 changed files with 2965 additions and 0 deletions

View File

@ -0,0 +1,7 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,85 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from typing import Tuple
from torch.distributed.elastic.rendezvous import (
RendezvousHandler,
RendezvousHandlerFactory,
RendezvousParameters,
)
def create_mock_rdzv_handler(ignored: RendezvousParameters) -> RendezvousHandler:
return MockRendezvousHandler()
class MockRendezvousHandler(RendezvousHandler):
def next_rendezvous(
self,
# pyre-ignore[11]: Annotation `Store` is not defined as a type.
) -> Tuple["torch.distributed.Store", int, int]: # noqa F821
raise NotImplementedError()
def get_backend(self) -> str:
return "mock"
def is_closed(self) -> bool:
return False
def set_closed(self):
pass
def num_nodes_waiting(self) -> int:
return -1
def get_run_id(self) -> str:
return ""
class RendezvousHandlerFactoryTest(unittest.TestCase):
def test_double_registration(self):
factory = RendezvousHandlerFactory()
factory.register("mock", create_mock_rdzv_handler)
with self.assertRaises(ValueError):
factory.register("mock", create_mock_rdzv_handler)
def test_no_factory_method_found(self):
factory = RendezvousHandlerFactory()
rdzv_params = RendezvousParameters(
backend="mock", endpoint="", run_id="foobar", min_nodes=1, max_nodes=2
)
with self.assertRaises(ValueError):
factory.create_handler(rdzv_params)
def test_create_handler(self):
rdzv_params = RendezvousParameters(
backend="mock", endpoint="", run_id="foobar", min_nodes=1, max_nodes=2
)
factory = RendezvousHandlerFactory()
factory.register("mock", create_mock_rdzv_handler)
mock_rdzv_handler = factory.create_handler(rdzv_params)
self.assertTrue(isinstance(mock_rdzv_handler, MockRendezvousHandler))
class RendezvousParametersTest(unittest.TestCase):
def test_get_or_default(self):
params = RendezvousParameters(
backend="foobar",
endpoint="localhost",
run_id="1234",
min_nodes=1,
max_nodes=1,
timeout1=10,
)
self.assertEqual(10, params.get("timeout1", 20))
self.assertEqual(60, params.get("timeout2", 60))

View File

@ -0,0 +1,78 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
import uuid
from torch.distributed.elastic.rendezvous import RendezvousParameters
from torch.distributed.elastic.rendezvous.etcd_rendezvous import create_rdzv_handler
from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
@unittest.skipIf(os.getenv("CIRCLECI"), "T85992919 temporarily disabling in circle ci")
class EtcdRendezvousTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
# start a standalone, single process etcd server to use for all tests
cls._etcd_server = EtcdServer()
cls._etcd_server.start()
@classmethod
def tearDownClass(cls):
# stop the standalone etcd server
cls._etcd_server.stop()
def test_etcd_rdzv_basic_params(self):
"""
Check that we can create the handler with a minimum set of
params
"""
rdzv_params = RendezvousParameters(
backend="etcd",
endpoint=f"{self._etcd_server.get_endpoint()}",
run_id=f"{uuid.uuid4()}",
min_nodes=1,
max_nodes=1,
)
etcd_rdzv = create_rdzv_handler(rdzv_params)
self.assertIsNotNone(etcd_rdzv)
def test_etcd_rdzv_additional_params(self):
run_id = str(uuid.uuid4())
rdzv_params = RendezvousParameters(
backend="etcd",
endpoint=f"{self._etcd_server.get_endpoint()}",
run_id=run_id,
min_nodes=1,
max_nodes=1,
timeout=60,
last_call_timeout=30,
protocol="http",
)
etcd_rdzv = create_rdzv_handler(rdzv_params)
self.assertIsNotNone(etcd_rdzv)
self.assertEqual(run_id, etcd_rdzv.get_run_id())
def test_get_backend(self):
run_id = str(uuid.uuid4())
rdzv_params = RendezvousParameters(
backend="etcd",
endpoint=f"{self._etcd_server.get_endpoint()}",
run_id=run_id,
min_nodes=1,
max_nodes=1,
timeout=60,
last_call_timeout=30,
protocol="http",
)
etcd_rdzv = create_rdzv_handler(rdzv_params)
self.assertEqual("etcd", etcd_rdzv.get_backend())

View File

@ -0,0 +1,55 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
import etcd
from torch.distributed.elastic.rendezvous.etcd_rendezvous import (
EtcdRendezvous,
EtcdRendezvousHandler,
)
from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
@unittest.skipIf(os.getenv("CIRCLECI"), "T85992919 temporarily disabling in circle ci")
class EtcdServerTest(unittest.TestCase):
def test_etcd_server_start_stop(self):
server = EtcdServer()
server.start()
try:
port = server.get_port()
host = server.get_host()
self.assertGreater(port, 0)
self.assertEqual("localhost", host)
self.assertEqual(f"{host}:{port}", server.get_endpoint())
self.assertIsNotNone(server.get_client().version)
finally:
server.stop()
def test_etcd_server_with_rendezvous(self):
server = EtcdServer()
server.start()
client = etcd.Client(server.get_host(), server.get_port())
rdzv = EtcdRendezvous(
client=client,
prefix="test",
run_id=1,
num_min_workers=1,
num_max_workers=1,
timeout=60,
last_call_timeout=30,
)
rdzv_handler = EtcdRendezvousHandler(rdzv)
store, rank, world_size = rdzv_handler.next_rendezvous()
self.assertIsNotNone(store)
self.assertEqual(0, rank)
self.assertEqual(1, world_size)

View File

@ -0,0 +1,7 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,7 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,46 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from torch.distributed.elastic.utils.data import CyclingIterator
class CyclingIteratorTest(unittest.TestCase):
def generator(self, epoch, stride, max_epochs):
# generate an continuously incrementing list each epoch
# e.g. [0,1,2] [3,4,5] [6,7,8] ...
return iter([stride * epoch + i for i in range(0, stride)])
def test_cycling_iterator(self):
stride = 3
max_epochs = 90
def generator_fn(epoch):
return self.generator(epoch, stride, max_epochs)
it = CyclingIterator(n=max_epochs, generator_fn=generator_fn)
for i in range(0, stride * max_epochs):
self.assertEqual(i, next(it))
with self.assertRaises(StopIteration):
next(it)
def test_cycling_iterator_start_epoch(self):
stride = 3
max_epochs = 2
start_epoch = 1
def generator_fn(epoch):
return self.generator(epoch, stride, max_epochs)
it = CyclingIterator(max_epochs, generator_fn, start_epoch)
for i in range(stride * start_epoch, stride * max_epochs):
self.assertEqual(i, next(it))
with self.assertRaises(StopIteration):
next(it)

View File

@ -0,0 +1,127 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import multiprocessing as mp
import os
import socket
import unittest
from contextlib import closing
from torch.distributed.elastic.utils.distributed import (
create_c10d_store,
get_free_port,
get_socket_with_port,
)
def _create_c10d_store_mp(is_server, server_addr, port, world_size):
store = create_c10d_store(is_server, server_addr, port, world_size, timeout=2)
if store is None:
raise AssertionError()
store.set(f"test_key/{os.getpid()}", "test_value".encode("UTF-8"))
class DistributedUtilTest(unittest.TestCase):
def test_create_store_single_server(self):
store = create_c10d_store(is_server=True, server_addr=socket.gethostname())
self.assertIsNotNone(store)
def test_create_store_no_port_multi(self):
with self.assertRaises(ValueError):
create_c10d_store(
is_server=True, server_addr=socket.gethostname(), world_size=2
)
def test_create_store_multi(self):
world_size = 3
server_port = get_free_port()
localhost = socket.gethostname()
worker0 = mp.Process(
target=_create_c10d_store_mp,
args=(False, localhost, server_port, world_size),
)
worker1 = mp.Process(
target=_create_c10d_store_mp,
args=(False, localhost, server_port, world_size),
)
worker0.start()
worker1.start()
# start the server on the main process
store = create_c10d_store(
is_server=True,
server_addr=localhost,
server_port=server_port,
world_size=world_size,
timeout=2,
)
worker0.join()
worker1.join()
# check test_key/pid == "test_value"
self.assertEqual(
"test_value", store.get(f"test_key/{worker0.pid}").decode("UTF-8")
)
self.assertEqual(
"test_value", store.get(f"test_key/{worker1.pid}").decode("UTF-8")
)
self.assertEqual(0, worker0.exitcode)
self.assertEqual(0, worker1.exitcode)
def test_create_store_timeout_on_server(self):
with self.assertRaises(TimeoutError):
port = get_free_port()
create_c10d_store(
is_server=True,
server_addr=socket.gethostname(),
server_port=port,
world_size=2,
timeout=1,
)
def test_create_store_timeout_on_worker(self):
with self.assertRaises(TimeoutError):
port = get_free_port()
create_c10d_store(
is_server=False,
server_addr=socket.gethostname(),
server_port=port,
world_size=2,
timeout=1,
)
def test_port_already_in_use_on_server(self):
sock = get_socket_with_port()
with closing(sock):
# try to create a store on the same port without releasing the socket
# should raise a IOError
port = sock.getsockname()[1]
with self.assertRaises(IOError):
create_c10d_store(
is_server=True,
server_addr=socket.gethostname(),
server_port=port,
timeout=1,
)
def test_port_already_in_use_on_worker(self):
sock = get_socket_with_port()
with closing(sock):
port = sock.getsockname()[1]
# on the worker port conflict shouldn't matter, it should just timeout
# since we never created a server
with self.assertRaises(IOError):
create_c10d_store(
is_server=False,
server_addr=socket.gethostname(),
server_port=port,
timeout=1,
)

View File

@ -0,0 +1,32 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch.distributed.elastic.utils.logging as logging
log = logging.get_logger()
class LoggingTest(unittest.TestCase):
def setUp(self):
self.clazz_log = logging.get_logger()
def test_logger_name(self):
local_log = logging.get_logger()
name_override_log = logging.get_logger("foobar")
self.assertEqual("logging_test", log.name)
self.assertEqual("logging_test", self.clazz_log.name)
self.assertEqual("logging_test", local_log.name)
self.assertEqual("foobar", name_override_log.name)
def test_derive_module_name(self):
module_name = logging._derive_module_name(depth=1)
self.assertEqual("logging_test", module_name)

View File

@ -0,0 +1,70 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch.distributed.elastic.utils.store as store_util
from torch.distributed.elastic.utils.logging import get_logger
class TestStore:
def get(self, key: str):
return f"retrieved:{key}"
class StoreUtilTest(unittest.TestCase):
def test_get_data(self):
store = TestStore()
data = store_util.get_all(store, "test/store", 10)
for idx in range(0, 10):
self.assertEqual(f"retrieved:test/store{idx}", data[idx])
def test_synchronize(self):
class DummyStore:
def __init__(self):
self._data = {
"torchelastic/test0": "data0".encode(encoding="UTF-8"),
"torchelastic/test1": "data1".encode(encoding="UTF-8"),
"torchelastic/test2": "data2".encode(encoding="UTF-8"),
}
def set(self, key, value):
self._data[key] = value
def get(self, key):
return self._data[key]
def set_timeout(self, timeout):
pass
data = "data0".encode(encoding="UTF-8")
store = DummyStore()
res = store_util.synchronize(store, data, 0, 3, key_prefix="torchelastic/test")
self.assertEqual(3, len(res))
for idx, res_data in enumerate(res):
actual_str = res_data.decode(encoding="UTF-8")
self.assertEqual(f"data{idx}", actual_str)
class UtilTest(unittest.TestCase):
def test_get_logger_different(self):
logger1 = get_logger("name1")
logger2 = get_logger("name2")
self.assertNotEqual(logger1.name, logger2.name)
def test_get_logger(self):
logger1 = get_logger()
self.assertEqual(__name__, logger1.name)
def test_get_logger_none(self):
logger1 = get_logger(None)
self.assertEqual(__name__, logger1.name)
def test_get_logger_custom_name(self):
logger1 = get_logger("test.module")
self.assertEqual("test.module", logger1.name)

View File

@ -0,0 +1,7 @@
#!/usr/bin/env/python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,112 @@
#!/usr/bin/env/python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
""" Rendezvous
In the context of torchelastic we use the term ``rendezvous`` to refer to
a particular functionality that combines a **distributed
synchronization** primitive with **peer discovery**.
It is used by torchelastic to gather participants of a training job
(i.e. workers) such that they all agree on the same list of participants
and everyones roles, as well as make a consistent collective decision
on when training can begin/resume.
Torchelastic Rendezvous provides the following critical functionalities:
**Barrier**:
Workers performing rendezvous will all block until the rendezvous is
considered complete - this happens when at least ``min`` total number of
workers have joined the rendezvous barrier (for the same job). This also
implies the barrier is not necessarily of fixed size.
Theres an additional small waiting time after reaching ``min`` number
of workers - this is used to ensure the rendezvous is not completed “too
quickly” (which could potentially exclude additional workers attempting
to join at approximately the same time).
If ``max`` number of workers is gathered at the barrier, the rendezvous
is completed immediately.
Theres also an overall timeout which causes the rendezvous to fail if
``min`` number of workers is never reached this is meant to be a
simple fail-safe to help release partially allocated job resources, in
case theres a problem with the resource manger, and is meant to be
interpreted as non-retryable.
**Exclusivity**:
A simple distributed barrier would not be sufficient, as we also need to
ensure that only one group of workers exists at any given time (for a
given job). In other words, new workers (i.e. joining late) should not
be able to form a parallel independent group of workers for the same
job.
Torchelastic rendezvous ensures that if a group of workers has already
completed a rendezvous (and hence might already be training), then
additional “late” workers attempting to rendezvous will only announce
themselves as waiting, and will have to wait until the (previously
completed) existing rendezvous is destroyed first.
**Consistency**:
When a rendezvous is completed, all its members will agree on the job
membership and everyones role in it. This role is represented using an
integer, called rank, that is between between 0 and world size.
Note that ranks are *not stable*, in the sense that the same worker
process can be assigned a different rank in the next (re-)rendezvous.
**Fault-tolerance**:
Torchelastic rendezvous is designed to tolerate worker failures during
the rendezvous process. Should a process crash (or lose network
connectivity, etc), between joining the rendezvous and it being
completed, then a re-rendezvous with remaining healthy workers will
happen automatically.
A worker can also fail *after* it has completed (or *has been
observered* by other workers to have completed) the rendezvous - this
scenario will be handled by the torchelastic ``train_loop`` instead
(where it will also trigger a re-rendezvous).
**Shared key-value store**:
When the rendezvous is completed, a shared key-value store is created
and returned. This store implements a ``torch.distributed.Store`` API
(see `distributed communication
docs <https://pytorch.org/docs/stable/distributed.html>`__).
This store is only shared by the members of the completed rendezvous. It
is intended to be used by torchelastic to exchange information necessary
to initialize job control and data-planes.
**Waiting workers and rendezvous closing**:
Torchelastic rendezvous handler object provides additional
functionalities, which are technically not part of the rendezvous
process:
1. Querying how many workers arrived late at the barrier, who
can participate in *next* rendezvous.
2. Setting the rendezvous *closed* to signal all workers not
to participate in next rendezvous.
"""
from .api import ( # noqa: F401
RendezvousClosedException,
RendezvousException,
RendezvousHandler,
RendezvousHandlerFactory,
RendezvousNonRetryableError,
RendezvousParameters,
RendezvousTimeoutException,
)

View File

@ -0,0 +1,302 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import abc
from typing import Any, Callable, Dict, Optional, Tuple
from torch.distributed import Store
class RendezvousException(Exception):
"""
Represents the base type for rendezvous exceptions.
"""
pass
class RendezvousClosedException(RendezvousException):
"""
Raised when a rendezvous is closed.
This is used to signal completion to nodes that arrive late.
"""
pass
class RendezvousTimeoutException(RendezvousException):
"""
Raised to signal that a rendezvous did not succeed within the allocated
time.
This is a non-retryable type of failure.
"""
pass
class RendezvousNonRetryableError(RendezvousException):
"""
Raised when a failure occured that should not be retried within the same
worker process.
"""
pass
class RendezvousHandler(abc.ABC):
"""
Main rendezvous interface.
.. note:: torchelastic users normally **do not** need to implement their
own ``RendezvousHandler``. An implementation based on
`etcd <https://etcd.io/>`__ is already provided, and is recommended
for most users, provided they can deploy it in their environment.
.. warning:: torchelastic is currently considered experimental,
so the APIs may change!
"""
@abc.abstractmethod
def get_backend(self) -> str:
"""
Return the string representation of the rendezvous handler.
"""
pass
@abc.abstractmethod
def next_rendezvous(
self,
) -> Tuple[Store, int, int]:
"""
Main entry-point into the rendezvous barrier.
Blocks until the rendezvous is complete (and the current
process is included in the formed worker group), or a timeout occurs, or
rendezvous was marked closed.
Returns: a tuple of (``c10d Store``, ``rank``, ``world size``)
Raises:
RendezvousClosedException - if rendezvous for the current
job is closed.
RendezvousTimeoutException - on timeout
"""
pass
@abc.abstractmethod
def is_closed(self) -> bool:
"""
Checks whether rendezvous for current job has been closed,
which means all future attempts to re-rendezvous (within same job) will
fail.
.. note:: ``is_closed`` and ``set_closed`` have semantics of eventual
propagation, and should not be used for synchronization.
The intention here is that if at least one worker decides
the job is finished, it will close the rendezvous, and
other workers will soon observe this and stop
training/rendezvous-ing as well.
"""
pass
@abc.abstractmethod
def set_closed(self):
"""
Used to mark the rendezvous (for current job) as closed.
"""
pass
@abc.abstractmethod
def num_nodes_waiting(self) -> int:
"""
Returns number of workers who *arrived late* at
the rendezvous barrier, hence werent included in the current worker
group.
Callers should periodically call this method to check whether
new members are waiting to join the job and if so admit them by
calling ``next_rendezvous()`` (re-rendezvous).
"""
pass
@abc.abstractmethod
def get_run_id(self) -> str:
"""
Returns the run_id of this rendezvous handler. The run_id is a user-defined
id that uniquely identifies an instance of a distributed application.
It typically maps to a job id and is used to allow workers to join the
correct distributed application.
"""
pass
def shutdown(self) -> bool:
"""
Closes all resources that were open for rendezvous run.
Usage:
::
def main():
rdzv_handler = ...
try:
rank, world_size, store = rdzv_handler.next_rendezvous()
finally:
rdzv_handler.shutdown()
"""
pass
class RendezvousParameters:
"""
The data object holding parameters to construct a ``RendezvousHandler``.
"""
# Default timeout for the rendezvous.
_DEFAULT_TIMEOUT: int = 600 # 10 minutes
# Additional waiting time after reaching the minimum number of nodes
# in case the rendezvous is elastic (min != max).
_DEFAULT_LAST_CALL_TIMEOUT: int = 30 # 30 seconds
def __init__(
self,
backend: str,
endpoint: str,
run_id: str,
min_nodes: int,
max_nodes: int,
**kwargs,
):
"""
Args:
backend: The backend that is used to register the rendezvous.
endpoint: The endpoint of the rendezvous. Usually it is a string in the format
<hostname>:<port>.
run_id: The id of the rendezvous.
min_nodes: The minimum number of nodes required to complete the rendezvous.
max_nodes: The maximum number of nodes that are allowed to join the rendezvous.
**kwargs: Additional parameters for the specified backend.
"""
if backend is None:
raise ValueError("The backend cannot be None.")
if min_nodes < 1:
raise ValueError("The minimum number of nodes must be greater than zero.")
if max_nodes < min_nodes:
raise ValueError(
"The maximum number of nodes must be greater than"
" or equal to the minimum number of nodes."
)
self.backend = backend
self.endpoint = endpoint
self.run_id = run_id
self.min_nodes = min_nodes
self.max_nodes = max_nodes
self.config = kwargs
@property
def timeout(self):
"""
Gets the timeout for the rendezvous.
"""
return self.get_as_int("timeout", self._DEFAULT_TIMEOUT)
@property
def last_call_timeout(self):
"""
Gets additional waiting time after reaching the minimum number of nodes
in case the rendezvous is elastic (min != max).
"""
return self.get_as_int("last_call_timeout", self._DEFAULT_LAST_CALL_TIMEOUT)
def get(self, key: str, default: Any = None) -> Any:
"""
Returns the value for ``key`` if ``key`` exists, else ``default``.
"""
return self.config.get(key, default)
def get_as_bool(self, key: str, default: Optional[bool] = None) -> Optional[bool]:
"""
Returns the value for ``key`` as a ``bool`` if ``key`` exists.
"""
val = self.get(key, default)
if val is None:
return val
if isinstance(val, int) or isinstance(val, bool):
return True if val else False
if isinstance(val, str):
return val.lower() in ["1", "true", "t", "yes", "y"]
raise ValueError(
f"The '{key}' rendezvous config does not represent a valid boolean value."
)
def get_as_int(self, key: str, default: Optional[int] = None) -> Optional[int]:
"""
Returns the value for ``key`` as an ``int`` if ``key`` exists.
"""
val = self.get(key, default)
if val is None:
return val
try:
return int(val)
except ValueError:
raise ValueError(
f"The '{key}' rendezvous config does not represent a valid integer value."
)
RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler]
class RendezvousHandlerFactory:
"""
Creates ``RendezvousHandler`` instances for supported rendezvous backends.
"""
def __init__(self):
self._registry: Dict[str, RendezvousHandlerCreator] = {}
def register(self, backend: str, creator: RendezvousHandlerCreator):
"""
Registers a new rendezvous backend.
"""
try:
current_creator = self._registry[backend]
except KeyError:
current_creator = None # type: ignore[assignment]
if current_creator is not None:
raise ValueError(
f"The rendezvous backend '{backend}' cannot be registered with"
f" '{creator.__module__}.{creator.__name__}' as it is already"
f" registered with '{current_creator.__module__}.{current_creator.__name__}'."
)
self._registry[backend] = creator
def create_handler(self, params: RendezvousParameters) -> RendezvousHandler:
"""
Creates a new ``RendezvousHandler`` instance for the specified backend.
"""
try:
creator = self._registry[params.backend]
except KeyError:
raise ValueError(
f"The rendezvous backend '{params.backend}' is not registered. Did you forget to call {self.register.__name__}?"
)
handler = creator(params)
# Do some sanity check.
if handler.get_backend() != params.backend:
raise RuntimeError(
f"The rendezvous handler backend '{handler.get_backend()}' does not match the requested backend '{params.backend}'."
)
return handler

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,249 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import atexit
import logging
import os
import shlex
import shutil
import socket
import subprocess
import tempfile
import time
from typing import Optional
import etcd # type: ignore[import]
log = logging.getLogger(__name__)
def find_free_port():
"""
Finds a free port and binds a temporary socket to it so that
the port can be "reserved" until used.
.. note:: the returned socket must be closed before using the port,
otherwise a ``address already in use`` error will happen.
The socket should be held and closed as close to the
consumer of the port as possible since otherwise, there
is a greater chance of race-condition where a different
process may see the port as being free and take it.
Returns: a socket binded to the reserved free port
Usage::
sock = find_free_port()
port = sock.getsockname()[1]
sock.close()
use_port(port)
"""
addrs = socket.getaddrinfo(
host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
)
for addr in addrs:
family, type, proto, _, _ = addr
try:
s = socket.socket(family, type, proto)
s.bind(("localhost", 0))
s.listen(0)
return s
except OSError as e:
s.close()
print(f"Socket creation attempt failed: {e}")
raise RuntimeError("Failed to create a socket")
def stop_etcd(subprocess, data_dir: Optional[str] = None):
if subprocess and subprocess.poll() is None:
log.info("stopping etcd server")
subprocess.terminate()
subprocess.wait()
if data_dir:
log.info(f"deleting etcd data dir: {data_dir}")
shutil.rmtree(data_dir, ignore_errors=True)
class EtcdServer:
"""
.. note:: tested on etcd server v3.4.3
Starts and stops a local standalone etcd server on a random free
port. Useful for single node, multi-worker launches or testing,
where a sidecar etcd server is more convenient than having to
separately setup an etcd server.
This class registers a termination handler to shutdown the etcd
subprocess on exit. This termination handler is NOT a substitute for
calling the ``stop()`` method.
The following fallback mechanism is used to find the etcd binary:
1. Uses env var TORCHELASTIC_ETCD_BINARY_PATH
2. Uses ``<this file root>/bin/etcd`` if one exists
3. Uses ``etcd`` from ``PATH``
Usage
::
server = EtcdServer("/usr/bin/etcd", 2379, "/tmp/default.etcd")
server.start()
client = server.get_client()
# use client
server.stop()
Args:
etcd_binary_path: path of etcd server binary (see above for fallback path)
"""
def __init__(self, data_dir: Optional[str] = None):
self._port = -1
self._host = "localhost"
root = os.path.dirname(__file__)
default_etcd_bin = os.path.join(root, "bin/etcd")
self._etcd_binary_path = os.environ.get(
"TORCHELASTIC_ETCD_BINARY_PATH", default_etcd_bin
)
if not os.path.isfile(self._etcd_binary_path):
self._etcd_binary_path = "etcd"
self._base_data_dir = (
data_dir if data_dir else tempfile.mkdtemp(prefix="torchelastic_etcd_data")
)
self._etcd_cmd = None
self._etcd_proc: Optional[subprocess.Popen] = None
def _get_etcd_server_process(self) -> subprocess.Popen:
if not self._etcd_proc:
raise RuntimeError(
"No etcd server process started. Call etcd_server.start() first"
)
else:
return self._etcd_proc
def get_port(self) -> int:
"""
Returns:
the port the server is running on.
"""
return self._port
def get_host(self) -> str:
"""
Returns:
the host the server is running on.
"""
return self._host
def get_endpoint(self) -> str:
"""
Returns:
the etcd server endpoint (host:port)
"""
return f"{self._host}:{self._port}"
def start(self, timeout: int = 60, num_retries: int = 3) -> None:
"""
Starts the server, and waits for it to be ready. When this function
returns the sever is ready to take requests.
Args:
timeout: time (in seconds) to wait for the server to be ready
before giving up.
num_retries: number of retries to start the server. Each retry
will wait for max ``timeout`` before considering it as failed.
Raises:
TimeoutError: if the server is not ready within the specified timeout
"""
curr_retries = 0
while True:
try:
data_dir = os.path.join(self._base_data_dir, str(curr_retries))
os.makedirs(data_dir, exist_ok=True)
return self._start(data_dir, timeout)
except Exception as e:
curr_retries += 1
stop_etcd(self._etcd_proc)
log.warning(
f"Failed to start etcd server, got error: {str(e)}, retrying"
)
if curr_retries >= num_retries:
shutil.rmtree(self._base_data_dir, ignore_errors=True)
raise
atexit.register(stop_etcd, self._etcd_proc, self._base_data_dir)
def _start(self, data_dir: str, timeout: int = 60) -> None:
sock = find_free_port()
sock_peer = find_free_port()
self._port = sock.getsockname()[1]
peer_port = sock_peer.getsockname()[1]
etcd_cmd = shlex.split(
" ".join(
[
self._etcd_binary_path,
"--enable-v2",
"--data-dir",
data_dir,
"--listen-client-urls",
f"http://{self._host}:{self._port}",
"--advertise-client-urls",
f"http://{self._host}:{self._port}",
"--listen-peer-urls",
f"http://{self._host}:{peer_port}",
]
)
)
log.info(f"Starting etcd server: [{etcd_cmd}]")
sock.close()
sock_peer.close()
self._etcd_proc = subprocess.Popen(etcd_cmd, close_fds=True)
self._wait_for_ready(timeout)
def get_client(self) -> etcd.Client:
"""
Returns:
An etcd client object that can be used to make requests to
this server.
"""
return etcd.Client(
host=self._host, port=self._port, version_prefix="/v2", read_timeout=10
)
def _wait_for_ready(self, timeout: int = 60) -> None:
client = etcd.Client(
host=f"{self._host}", port=self._port, version_prefix="/v2", read_timeout=5
)
max_time = time.time() + timeout
while time.time() < max_time:
if self._get_etcd_server_process().poll() is not None:
# etcd server process finished
exitcode = self._get_etcd_server_process().returncode
raise RuntimeError(
f"Etcd server process exited with the code: {exitcode}"
)
try:
log.info(f"etcd server ready. version: {client.version}")
return
except Exception:
time.sleep(1)
raise TimeoutError("Timed out waiting for etcd server to be ready!")
def stop(self) -> None:
"""
Stops the server and cleans up auto generated resources (e.g. data dir)
"""
log.info("EtcdServer stop method called")
stop_etcd(self._etcd_proc, self._base_data_dir)

View File

@ -0,0 +1,19 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from . import etcd_rendezvous
from .api import (
RendezvousHandler,
RendezvousHandlerFactory,
RendezvousParameters,
)
_factory = RendezvousHandlerFactory()
_factory.register("etcd", etcd_rendezvous.create_rdzv_handler)
def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler:
return _factory.create_handler(params)

View File

@ -0,0 +1,60 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import re
from typing import Dict, Optional, Tuple
def _parse_rendezvous_config(config_str: str) -> Dict[str, str]:
"""
Extracts key-value pairs from a configuration string that has the format
<key1>=<value1>,...,<keyN>=<valueN>.
"""
config: Dict[str, str] = {}
if not config_str:
return config
key_values = config_str.split(",")
for kv in key_values:
key, *values = kv.split("=", 1)
if not values:
raise ValueError(f"The '{key}' rendezvous config has no value specified.")
config[key] = values[0]
return config
def _parse_hostname_and_port(
endpoint: Optional[str], default_port: int
) -> Tuple[str, int]:
"""
Extracts the hostname and the port number from an endpoint string that has
the format <hostname>:<port>.
If no hostname can be found, defaults to the loopback address 127.0.0.1.
"""
if not endpoint:
return ("127.0.0.1", default_port)
hostname, *rest = endpoint.rsplit(":", 1)
if len(rest) == 1:
if re.match(r"^[0-9]{1,5}$", rest[0]):
port = int(rest[0])
else:
port = 0
if port <= 80 or port >= 2 ** 16:
raise ValueError(
f"The rendezvous endpoint '{endpoint}' has an invalid port number '{rest[0]}'."
)
else:
port = default_port
if not re.match(r"^[\w\.:-]+$", hostname):
raise ValueError(
f"The rendezvous enpoint '{endpoint}' has an invalid hostname '{hostname}'."
)
return hostname, port

View File

@ -0,0 +1,9 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .api import get_env_variable_or_raise # noqa F401

View File

@ -0,0 +1,24 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
def get_env_variable_or_raise(env_name: str) -> str:
r"""
Tries to retrieve environment variable. Raises ``ValueError``
if no environment variable found.
Args:
env_name (str): Name of the env variable
"""
value = os.environ.get(env_name, None)
if value is None:
msg = f"Environment variable {env_name} expected, but not set"
raise ValueError(msg)
return value

View File

@ -0,0 +1,10 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .cycling_iterator import CyclingIterator # noqa F401
from .elastic_distributed_sampler import ElasticDistributedSampler # noqa F401

View File

@ -0,0 +1,43 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
class CyclingIterator:
"""
An iterator decorator that cycles through the
underlying iterator "n" times. Useful to "unroll"
the dataset across multiple training epochs.
The generator function is called as ``generator_fn(epoch)``
to obtain the underlying iterator, where ``epoch`` is a
number less than or equal to ``n`` representing the ``k``th cycle
For example if ``generator_fn`` always returns ``[1,2,3]``
then ``CyclingIterator(n=2, generator_fn)`` will iterate through
``[1,2,3,1,2,3]``
"""
def __init__(self, n: int, generator_fn, start_epoch=0):
self._n = n
self._epoch = start_epoch
self._generator_fn = generator_fn
self._iter = generator_fn(self._epoch)
def __iter__(self):
return self
def __next__(self):
try:
return next(self._iter)
except StopIteration as eod: # eod == end of data
if self._epoch < self._n - 1:
self._epoch += 1
self._iter = self._generator_fn(self._epoch)
return self.__next__()
else:
raise eod

View File

@ -0,0 +1,72 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
from torch.utils.data.distributed import DistributedSampler
class ElasticDistributedSampler(DistributedSampler):
"""
Sampler that restricts data loading to a subset of
the dataset for elastic training.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSampler instance as a DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Args:
dataset: Dataset used for sampling.
num_replicas (optional): Number of processes participating in
distributed training.
rank (optional): Rank of the current process within num_replicas.
start_index (optional): Which index of the dataset to start sampling from
"""
def __init__(self, dataset, num_replicas=None, rank=None, start_index=0):
super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank)
if start_index >= len(dataset):
raise ValueError(
"Start index {} should be less than dataset size {}".format(
start_index, len(dataset)
)
)
self.start_index = start_index
self.num_samples = int(
math.ceil(float(len(self.dataset) - self.start_index) / self.num_replicas) # type: ignore[arg-type]
)
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = (
torch.randperm(len(self.dataset) - self.start_index, generator=g) # type: ignore[arg-type]
.add(self.start_index)
.tolist()
)
# add extra samples to make it evenly divisible
indices += indices[: (self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples

View File

@ -0,0 +1,144 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import datetime
import socket
from contextlib import closing
import torch.distributed as dist
from torch.distributed.elastic.utils.logging import get_logger
log = get_logger()
_ADDRESS_IN_USE = "Address already in use"
_CONNECT_TIMEOUT = "connect() timed out."
_SOCKET_TIMEOUT = "Socket Timeout"
_MEMBER_CHECKIN = "_tcp_store/num_members"
_LAST_MEMBER_CHECKIN = "_tcp_store/last_member"
def create_c10d_store(
is_server: bool,
server_addr: str,
server_port: int = -1,
world_size: int = 1,
timeout: float = (60 * 10), # 10 min
retries=3,
):
if server_port == -1 and world_size > 1:
raise ValueError(
f"server_port must be specified when world_size > 1, got server_port={server_port}, world_size={world_size}"
)
if server_port != -1:
log.info(f"sever_port: {server_port}, specified, ignoring retries")
# only retry when server_port is NOT static
attempt = retries if server_port == -1 else 1
while True:
if server_port != -1:
port = server_port
else:
port = get_free_port()
log.info(
f"Creating c10d store on {server_addr}:{port}\n"
f" world_size : {world_size}\n"
f" is_server : {is_server}\n"
f" timeout(sec): {timeout}\n"
)
try:
store = dist.TCPStore(
host_name=server_addr,
port=port,
world_size=world_size,
is_master=is_server,
timeout=datetime.timedelta(seconds=timeout),
)
_check_full_rank(store, world_size)
log.info("Successfully created c10d store")
return store
except RuntimeError as e:
# this is brittle, but the underlying exception type is not properly pybinded
# so we parse the error msg for now, interestingly this is how torch itself
# detects timeouts and port conflicts in their own unittests
# see - caffe2/torch/testing/_internal/common_utils.py
# TODO properly map the exceptions in pybind (c10d/init.cpp)
if str(e) == _CONNECT_TIMEOUT and not is_server:
raise TimeoutError(
f"timed out waiting for tcp store's server: {server_addr}:{port}"
) from e
elif str(e) == _ADDRESS_IN_USE: # this will only happen on the server
if attempt < retries:
log.warning(
f"port: {port} already in use, attempt: [{attempt}/{retries}]"
)
attempt += 1
else:
raise IOError(
f"on {server_addr}, port: {port} already in use"
) from e
else:
raise
def _check_full_rank(store, world_size):
idx = store.add(_MEMBER_CHECKIN, 1)
if idx == world_size:
store.set(_LAST_MEMBER_CHECKIN, "<val_ignored>")
try:
store.get(_LAST_MEMBER_CHECKIN)
except RuntimeError as e:
if str(e) == _SOCKET_TIMEOUT:
raise TimeoutError(
f"timed out waiting for all {world_size} members to join"
) from e
else:
raise
def get_free_port():
sock = get_socket_with_port()
with closing(sock):
return sock.getsockname()[1]
def get_socket_with_port() -> socket.socket:
"""
Returns a free port on localhost that is "reserved" by binding a temporary
socket on it. Close the socket before passing the port to the entity
that requires it. Usage example
::
sock = _get_socket_with_port()
with closing(sock):
port = sock.getsockname()[1]
sock.close()
# there is still a race-condition that some other process
# may grab this port before func() runs
func(port)
"""
addrs = socket.getaddrinfo(
host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
)
for addr in addrs:
family, type, proto, _, _ = addr
s = socket.socket(family, type, proto)
try:
s.bind(("localhost", 0))
s.listen(0)
return s
except OSError as e:
s.close()
log.info("Socket creation attempt failed.", exc_info=e)
raise RuntimeError("Failed to create a socket")

View File

@ -0,0 +1,67 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import inspect
import logging
import os
import warnings
from typing import Optional
def get_logger(name: Optional[str] = None):
"""
Util function to set up a simple logger that writes
into stderr. The loglevel is fetched from the LOGLEVEL
env. variable or INFO as default. The function will use the
module name of the caller if no name is provided.
Args:
name: Name of the logger. If no name provided, the name will
be derived from the call stack.
"""
# Derive the name of the caller, if none provided
# Use depth=2 since this function takes up one level in the call stack
return _setup_logger(name or _derive_module_name(depth=2))
def _setup_logger(name: Optional[str] = None):
log = logging.getLogger(name)
log.setLevel(os.environ.get("LOGLEVEL", "INFO"))
return log
def _derive_module_name(depth: int = 1) -> Optional[str]:
"""
Derives the name of the caller module from the stack frames.
Args:
depth: The position of the frame in the stack.
"""
try:
stack = inspect.stack()
assert depth < len(stack)
# FrameInfo is just a named tuple: (frame, filename, lineno, function, code_context, index)
frame_info = stack[depth]
filename = frame_info[1]
module = inspect.getmodule(frame_info[0])
if module:
module_name = module.__name__
else:
# inspect.getmodule(frame_info[0]) does NOT work (returns None) in
# binaries built with @mode/opt
# return the filename (minus the .py extension) as modulename
module_name = os.path.splitext(os.path.basename(filename))[0]
return module_name
except Exception as e:
warnings.warn(
f"Error deriving logger module name, using <None>. Exception: {e}",
RuntimeWarning,
)
return None

View File

@ -0,0 +1,74 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from datetime import timedelta
from typing import List
def get_all(store, prefix: str, size: int):
r"""
Given a store and a prefix, the method goes through the array of keys
of the following format: ``{prefix}{idx}``, where idx is in a range
from 0 to size, and tries to retrieve the data.
Usage
::
values = get_all(store, 'torchelastic/data', 3)
value1 = values[0] # retrieves the data for key torchelastic/data0
value2 = values[1] # retrieves the data for key torchelastic/data1
value3 = values[2] # retrieves the data for key torchelastic/data2
"""
data_arr = []
for idx in range(size):
data = store.get(f"{prefix}{idx}")
data_arr.append(data)
return data_arr
def synchronize(
store,
data: bytes,
rank: int,
world_size: int,
key_prefix: str,
barrier_timeout: float = 300,
) -> List[bytes]:
"""
Synchronizes ``world_size`` agents between each other using the underlying c10d store.
The ``data`` will be available on each of the agents.
Note: The data on the path is not deleted, as a result there can be stale data if
you use the same key_prefix twice.
"""
warnings.warn(
"This is an experimental API and will be changed in future.", FutureWarning
)
store.set_timeout(timedelta(seconds=barrier_timeout))
store.set(f"{key_prefix}{rank}", data)
agent_data = get_all(store, key_prefix, world_size)
return agent_data
def barrier(
store, rank: int, world_size: int, key_prefix: str, barrier_timeout: float = 300
) -> None:
"""
A global lock between agents.
Note: Since the data is not removed from the store, the barrier can be used
once per unique ``key_prefix``.
"""
warnings.warn(
"This is an experimental API and will be changed in future.", FutureWarning
)
data = f"{rank}".encode(encoding="UTF-8")
synchronize(store, data, rank, world_size, key_prefix, barrier_timeout)