mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[1/n][torch/elastic][upstream] Move torchelastic/rendezvous to torch/distributed/rendezvous (#53172)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53172 Pull Request resolved: https://github.com/pytorch/elastic/pull/141 Upstreams two modules to torch: 1. `torchelastic.rendezvous` 2. `torchelastic.utils` These modules were chosen as `[1/n]` since they are the leaf modules in torchelastic. ==== NOTES: ==== 1. I'm disabling etcd_rendezvous and etcd_server tests in CIRCLECI for the moment since I need to edit the test dockers to contain the etcd server binary (there's 4-5 test dockers - one for each platform so this is going to take some time for me to set up the environments and test) - T85992919. 2. I've fixed all lint errors on python files but there are ones on the cpp files on the ZeusRendezvous. I took a look at them, and I don't want to fix the linter errors right now for 2 major reasons: 1. Some of them are more than formatting changes (e.g. std::move vs pass by value) and I don't want to introduce bundled changes with the move 1. The old rendezvous code (the one we forked from in caffe2/fb) has the same problems and I think its better for us to deal with this when we deprecate caffe2/fb/rendezvous in favor of the one in torchelastic -T86012579. Test Plan: ``` buck test mode/dev-nosan //caffe2/torch/distributed/elastic/utils/test/... buck test mode/dev-nosan //caffe2/torch/distributed/elastic/utils/data/test/... buck test mode/dev-nosan //caffe2/torch/distributed/elastic/rendezvous/test/... buck test mode/dev-nosan //caffe2/torch/distributed/elastic/rendezvous/fb/... buck test mode/dev-nosan //pytorch/elastic/torchelastic/... ``` \+ Sandcastle Reviewed By: H-Huang Differential Revision: D26718746 fbshipit-source-id: 67cc0350c3d847221cb3c3038f98f47915362f51
This commit is contained in:
committed by
Facebook GitHub Bot
parent
14fa47631b
commit
ba75cedfc5
7
test/distributed/elastic/rendezvous/__init__.py
Normal file
7
test/distributed/elastic/rendezvous/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
85
test/distributed/elastic/rendezvous/api_test.py
Normal file
85
test/distributed/elastic/rendezvous/api_test.py
Normal file
@ -0,0 +1,85 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import unittest
|
||||
from typing import Tuple
|
||||
|
||||
from torch.distributed.elastic.rendezvous import (
|
||||
RendezvousHandler,
|
||||
RendezvousHandlerFactory,
|
||||
RendezvousParameters,
|
||||
)
|
||||
|
||||
|
||||
def create_mock_rdzv_handler(ignored: RendezvousParameters) -> RendezvousHandler:
|
||||
return MockRendezvousHandler()
|
||||
|
||||
|
||||
class MockRendezvousHandler(RendezvousHandler):
|
||||
def next_rendezvous(
|
||||
self,
|
||||
# pyre-ignore[11]: Annotation `Store` is not defined as a type.
|
||||
) -> Tuple["torch.distributed.Store", int, int]: # noqa F821
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_backend(self) -> str:
|
||||
return "mock"
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
return False
|
||||
|
||||
def set_closed(self):
|
||||
pass
|
||||
|
||||
def num_nodes_waiting(self) -> int:
|
||||
return -1
|
||||
|
||||
def get_run_id(self) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
class RendezvousHandlerFactoryTest(unittest.TestCase):
|
||||
def test_double_registration(self):
|
||||
factory = RendezvousHandlerFactory()
|
||||
factory.register("mock", create_mock_rdzv_handler)
|
||||
with self.assertRaises(ValueError):
|
||||
factory.register("mock", create_mock_rdzv_handler)
|
||||
|
||||
def test_no_factory_method_found(self):
|
||||
factory = RendezvousHandlerFactory()
|
||||
rdzv_params = RendezvousParameters(
|
||||
backend="mock", endpoint="", run_id="foobar", min_nodes=1, max_nodes=2
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
factory.create_handler(rdzv_params)
|
||||
|
||||
def test_create_handler(self):
|
||||
rdzv_params = RendezvousParameters(
|
||||
backend="mock", endpoint="", run_id="foobar", min_nodes=1, max_nodes=2
|
||||
)
|
||||
|
||||
factory = RendezvousHandlerFactory()
|
||||
factory.register("mock", create_mock_rdzv_handler)
|
||||
mock_rdzv_handler = factory.create_handler(rdzv_params)
|
||||
self.assertTrue(isinstance(mock_rdzv_handler, MockRendezvousHandler))
|
||||
|
||||
|
||||
class RendezvousParametersTest(unittest.TestCase):
|
||||
def test_get_or_default(self):
|
||||
|
||||
params = RendezvousParameters(
|
||||
backend="foobar",
|
||||
endpoint="localhost",
|
||||
run_id="1234",
|
||||
min_nodes=1,
|
||||
max_nodes=1,
|
||||
timeout1=10,
|
||||
)
|
||||
|
||||
self.assertEqual(10, params.get("timeout1", 20))
|
||||
self.assertEqual(60, params.get("timeout2", 60))
|
78
test/distributed/elastic/rendezvous/etcd_rendezvous_test.py
Normal file
78
test/distributed/elastic/rendezvous/etcd_rendezvous_test.py
Normal file
@ -0,0 +1,78 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import os
|
||||
import unittest
|
||||
import uuid
|
||||
|
||||
from torch.distributed.elastic.rendezvous import RendezvousParameters
|
||||
from torch.distributed.elastic.rendezvous.etcd_rendezvous import create_rdzv_handler
|
||||
from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
|
||||
|
||||
|
||||
@unittest.skipIf(os.getenv("CIRCLECI"), "T85992919 temporarily disabling in circle ci")
|
||||
class EtcdRendezvousTest(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# start a standalone, single process etcd server to use for all tests
|
||||
cls._etcd_server = EtcdServer()
|
||||
cls._etcd_server.start()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
# stop the standalone etcd server
|
||||
cls._etcd_server.stop()
|
||||
|
||||
def test_etcd_rdzv_basic_params(self):
|
||||
"""
|
||||
Check that we can create the handler with a minimum set of
|
||||
params
|
||||
"""
|
||||
rdzv_params = RendezvousParameters(
|
||||
backend="etcd",
|
||||
endpoint=f"{self._etcd_server.get_endpoint()}",
|
||||
run_id=f"{uuid.uuid4()}",
|
||||
min_nodes=1,
|
||||
max_nodes=1,
|
||||
)
|
||||
etcd_rdzv = create_rdzv_handler(rdzv_params)
|
||||
self.assertIsNotNone(etcd_rdzv)
|
||||
|
||||
def test_etcd_rdzv_additional_params(self):
|
||||
run_id = str(uuid.uuid4())
|
||||
rdzv_params = RendezvousParameters(
|
||||
backend="etcd",
|
||||
endpoint=f"{self._etcd_server.get_endpoint()}",
|
||||
run_id=run_id,
|
||||
min_nodes=1,
|
||||
max_nodes=1,
|
||||
timeout=60,
|
||||
last_call_timeout=30,
|
||||
protocol="http",
|
||||
)
|
||||
|
||||
etcd_rdzv = create_rdzv_handler(rdzv_params)
|
||||
|
||||
self.assertIsNotNone(etcd_rdzv)
|
||||
self.assertEqual(run_id, etcd_rdzv.get_run_id())
|
||||
|
||||
def test_get_backend(self):
|
||||
run_id = str(uuid.uuid4())
|
||||
rdzv_params = RendezvousParameters(
|
||||
backend="etcd",
|
||||
endpoint=f"{self._etcd_server.get_endpoint()}",
|
||||
run_id=run_id,
|
||||
min_nodes=1,
|
||||
max_nodes=1,
|
||||
timeout=60,
|
||||
last_call_timeout=30,
|
||||
protocol="http",
|
||||
)
|
||||
|
||||
etcd_rdzv = create_rdzv_handler(rdzv_params)
|
||||
|
||||
self.assertEqual("etcd", etcd_rdzv.get_backend())
|
55
test/distributed/elastic/rendezvous/etcd_server_test.py
Normal file
55
test/distributed/elastic/rendezvous/etcd_server_test.py
Normal file
@ -0,0 +1,55 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import etcd
|
||||
from torch.distributed.elastic.rendezvous.etcd_rendezvous import (
|
||||
EtcdRendezvous,
|
||||
EtcdRendezvousHandler,
|
||||
)
|
||||
from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
|
||||
|
||||
|
||||
@unittest.skipIf(os.getenv("CIRCLECI"), "T85992919 temporarily disabling in circle ci")
|
||||
class EtcdServerTest(unittest.TestCase):
|
||||
def test_etcd_server_start_stop(self):
|
||||
server = EtcdServer()
|
||||
server.start()
|
||||
|
||||
try:
|
||||
port = server.get_port()
|
||||
host = server.get_host()
|
||||
|
||||
self.assertGreater(port, 0)
|
||||
self.assertEqual("localhost", host)
|
||||
self.assertEqual(f"{host}:{port}", server.get_endpoint())
|
||||
self.assertIsNotNone(server.get_client().version)
|
||||
finally:
|
||||
server.stop()
|
||||
|
||||
def test_etcd_server_with_rendezvous(self):
|
||||
server = EtcdServer()
|
||||
server.start()
|
||||
|
||||
client = etcd.Client(server.get_host(), server.get_port())
|
||||
|
||||
rdzv = EtcdRendezvous(
|
||||
client=client,
|
||||
prefix="test",
|
||||
run_id=1,
|
||||
num_min_workers=1,
|
||||
num_max_workers=1,
|
||||
timeout=60,
|
||||
last_call_timeout=30,
|
||||
)
|
||||
rdzv_handler = EtcdRendezvousHandler(rdzv)
|
||||
store, rank, world_size = rdzv_handler.next_rendezvous()
|
||||
self.assertIsNotNone(store)
|
||||
self.assertEqual(0, rank)
|
||||
self.assertEqual(1, world_size)
|
7
test/distributed/elastic/utils/__init__.py
Normal file
7
test/distributed/elastic/utils/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
7
test/distributed/elastic/utils/data/__init__.py
Normal file
7
test/distributed/elastic/utils/data/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
46
test/distributed/elastic/utils/data/cycling_iterator_test.py
Normal file
46
test/distributed/elastic/utils/data/cycling_iterator_test.py
Normal file
@ -0,0 +1,46 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import unittest
|
||||
|
||||
from torch.distributed.elastic.utils.data import CyclingIterator
|
||||
|
||||
|
||||
class CyclingIteratorTest(unittest.TestCase):
|
||||
def generator(self, epoch, stride, max_epochs):
|
||||
# generate an continuously incrementing list each epoch
|
||||
# e.g. [0,1,2] [3,4,5] [6,7,8] ...
|
||||
return iter([stride * epoch + i for i in range(0, stride)])
|
||||
|
||||
def test_cycling_iterator(self):
|
||||
stride = 3
|
||||
max_epochs = 90
|
||||
|
||||
def generator_fn(epoch):
|
||||
return self.generator(epoch, stride, max_epochs)
|
||||
|
||||
it = CyclingIterator(n=max_epochs, generator_fn=generator_fn)
|
||||
for i in range(0, stride * max_epochs):
|
||||
self.assertEqual(i, next(it))
|
||||
|
||||
with self.assertRaises(StopIteration):
|
||||
next(it)
|
||||
|
||||
def test_cycling_iterator_start_epoch(self):
|
||||
stride = 3
|
||||
max_epochs = 2
|
||||
start_epoch = 1
|
||||
|
||||
def generator_fn(epoch):
|
||||
return self.generator(epoch, stride, max_epochs)
|
||||
|
||||
it = CyclingIterator(max_epochs, generator_fn, start_epoch)
|
||||
for i in range(stride * start_epoch, stride * max_epochs):
|
||||
self.assertEqual(i, next(it))
|
||||
|
||||
with self.assertRaises(StopIteration):
|
||||
next(it)
|
127
test/distributed/elastic/utils/distributed_test.py
Normal file
127
test/distributed/elastic/utils/distributed_test.py
Normal file
@ -0,0 +1,127 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import socket
|
||||
import unittest
|
||||
from contextlib import closing
|
||||
|
||||
from torch.distributed.elastic.utils.distributed import (
|
||||
create_c10d_store,
|
||||
get_free_port,
|
||||
get_socket_with_port,
|
||||
)
|
||||
|
||||
|
||||
def _create_c10d_store_mp(is_server, server_addr, port, world_size):
|
||||
store = create_c10d_store(is_server, server_addr, port, world_size, timeout=2)
|
||||
if store is None:
|
||||
raise AssertionError()
|
||||
|
||||
store.set(f"test_key/{os.getpid()}", "test_value".encode("UTF-8"))
|
||||
|
||||
|
||||
class DistributedUtilTest(unittest.TestCase):
|
||||
def test_create_store_single_server(self):
|
||||
store = create_c10d_store(is_server=True, server_addr=socket.gethostname())
|
||||
self.assertIsNotNone(store)
|
||||
|
||||
def test_create_store_no_port_multi(self):
|
||||
with self.assertRaises(ValueError):
|
||||
create_c10d_store(
|
||||
is_server=True, server_addr=socket.gethostname(), world_size=2
|
||||
)
|
||||
|
||||
def test_create_store_multi(self):
|
||||
world_size = 3
|
||||
server_port = get_free_port()
|
||||
localhost = socket.gethostname()
|
||||
worker0 = mp.Process(
|
||||
target=_create_c10d_store_mp,
|
||||
args=(False, localhost, server_port, world_size),
|
||||
)
|
||||
worker1 = mp.Process(
|
||||
target=_create_c10d_store_mp,
|
||||
args=(False, localhost, server_port, world_size),
|
||||
)
|
||||
|
||||
worker0.start()
|
||||
worker1.start()
|
||||
|
||||
# start the server on the main process
|
||||
store = create_c10d_store(
|
||||
is_server=True,
|
||||
server_addr=localhost,
|
||||
server_port=server_port,
|
||||
world_size=world_size,
|
||||
timeout=2,
|
||||
)
|
||||
|
||||
worker0.join()
|
||||
worker1.join()
|
||||
|
||||
# check test_key/pid == "test_value"
|
||||
self.assertEqual(
|
||||
"test_value", store.get(f"test_key/{worker0.pid}").decode("UTF-8")
|
||||
)
|
||||
self.assertEqual(
|
||||
"test_value", store.get(f"test_key/{worker1.pid}").decode("UTF-8")
|
||||
)
|
||||
|
||||
self.assertEqual(0, worker0.exitcode)
|
||||
self.assertEqual(0, worker1.exitcode)
|
||||
|
||||
def test_create_store_timeout_on_server(self):
|
||||
with self.assertRaises(TimeoutError):
|
||||
port = get_free_port()
|
||||
create_c10d_store(
|
||||
is_server=True,
|
||||
server_addr=socket.gethostname(),
|
||||
server_port=port,
|
||||
world_size=2,
|
||||
timeout=1,
|
||||
)
|
||||
|
||||
def test_create_store_timeout_on_worker(self):
|
||||
with self.assertRaises(TimeoutError):
|
||||
port = get_free_port()
|
||||
create_c10d_store(
|
||||
is_server=False,
|
||||
server_addr=socket.gethostname(),
|
||||
server_port=port,
|
||||
world_size=2,
|
||||
timeout=1,
|
||||
)
|
||||
|
||||
def test_port_already_in_use_on_server(self):
|
||||
sock = get_socket_with_port()
|
||||
with closing(sock):
|
||||
# try to create a store on the same port without releasing the socket
|
||||
# should raise a IOError
|
||||
port = sock.getsockname()[1]
|
||||
with self.assertRaises(IOError):
|
||||
create_c10d_store(
|
||||
is_server=True,
|
||||
server_addr=socket.gethostname(),
|
||||
server_port=port,
|
||||
timeout=1,
|
||||
)
|
||||
|
||||
def test_port_already_in_use_on_worker(self):
|
||||
sock = get_socket_with_port()
|
||||
with closing(sock):
|
||||
port = sock.getsockname()[1]
|
||||
# on the worker port conflict shouldn't matter, it should just timeout
|
||||
# since we never created a server
|
||||
with self.assertRaises(IOError):
|
||||
create_c10d_store(
|
||||
is_server=False,
|
||||
server_addr=socket.gethostname(),
|
||||
server_port=port,
|
||||
timeout=1,
|
||||
)
|
32
test/distributed/elastic/utils/logging_test.py
Normal file
32
test/distributed/elastic/utils/logging_test.py
Normal file
@ -0,0 +1,32 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch.distributed.elastic.utils.logging as logging
|
||||
|
||||
|
||||
log = logging.get_logger()
|
||||
|
||||
|
||||
class LoggingTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.clazz_log = logging.get_logger()
|
||||
|
||||
def test_logger_name(self):
|
||||
local_log = logging.get_logger()
|
||||
name_override_log = logging.get_logger("foobar")
|
||||
|
||||
self.assertEqual("logging_test", log.name)
|
||||
self.assertEqual("logging_test", self.clazz_log.name)
|
||||
self.assertEqual("logging_test", local_log.name)
|
||||
self.assertEqual("foobar", name_override_log.name)
|
||||
|
||||
def test_derive_module_name(self):
|
||||
module_name = logging._derive_module_name(depth=1)
|
||||
self.assertEqual("logging_test", module_name)
|
70
test/distributed/elastic/utils/util_test.py
Normal file
70
test/distributed/elastic/utils/util_test.py
Normal file
@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch.distributed.elastic.utils.store as store_util
|
||||
from torch.distributed.elastic.utils.logging import get_logger
|
||||
|
||||
|
||||
class TestStore:
|
||||
def get(self, key: str):
|
||||
return f"retrieved:{key}"
|
||||
|
||||
|
||||
class StoreUtilTest(unittest.TestCase):
|
||||
def test_get_data(self):
|
||||
store = TestStore()
|
||||
data = store_util.get_all(store, "test/store", 10)
|
||||
for idx in range(0, 10):
|
||||
self.assertEqual(f"retrieved:test/store{idx}", data[idx])
|
||||
|
||||
def test_synchronize(self):
|
||||
class DummyStore:
|
||||
def __init__(self):
|
||||
self._data = {
|
||||
"torchelastic/test0": "data0".encode(encoding="UTF-8"),
|
||||
"torchelastic/test1": "data1".encode(encoding="UTF-8"),
|
||||
"torchelastic/test2": "data2".encode(encoding="UTF-8"),
|
||||
}
|
||||
|
||||
def set(self, key, value):
|
||||
self._data[key] = value
|
||||
|
||||
def get(self, key):
|
||||
return self._data[key]
|
||||
|
||||
def set_timeout(self, timeout):
|
||||
pass
|
||||
|
||||
data = "data0".encode(encoding="UTF-8")
|
||||
store = DummyStore()
|
||||
res = store_util.synchronize(store, data, 0, 3, key_prefix="torchelastic/test")
|
||||
self.assertEqual(3, len(res))
|
||||
for idx, res_data in enumerate(res):
|
||||
actual_str = res_data.decode(encoding="UTF-8")
|
||||
self.assertEqual(f"data{idx}", actual_str)
|
||||
|
||||
|
||||
class UtilTest(unittest.TestCase):
|
||||
def test_get_logger_different(self):
|
||||
logger1 = get_logger("name1")
|
||||
logger2 = get_logger("name2")
|
||||
self.assertNotEqual(logger1.name, logger2.name)
|
||||
|
||||
def test_get_logger(self):
|
||||
logger1 = get_logger()
|
||||
self.assertEqual(__name__, logger1.name)
|
||||
|
||||
def test_get_logger_none(self):
|
||||
logger1 = get_logger(None)
|
||||
self.assertEqual(__name__, logger1.name)
|
||||
|
||||
def test_get_logger_custom_name(self):
|
||||
logger1 = get_logger("test.module")
|
||||
self.assertEqual("test.module", logger1.name)
|
7
torch/distributed/elastic/__init__.py
Normal file
7
torch/distributed/elastic/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
#!/usr/bin/env/python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
112
torch/distributed/elastic/rendezvous/__init__.py
Normal file
112
torch/distributed/elastic/rendezvous/__init__.py
Normal file
@ -0,0 +1,112 @@
|
||||
#!/usr/bin/env/python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
""" Rendezvous
|
||||
|
||||
In the context of torchelastic we use the term ``rendezvous`` to refer to
|
||||
a particular functionality that combines a **distributed
|
||||
synchronization** primitive with **peer discovery**.
|
||||
|
||||
It is used by torchelastic to gather participants of a training job
|
||||
(i.e. workers) such that they all agree on the same list of participants
|
||||
and everyone’s roles, as well as make a consistent collective decision
|
||||
on when training can begin/resume.
|
||||
|
||||
Torchelastic Rendezvous provides the following critical functionalities:
|
||||
|
||||
**Barrier**:
|
||||
|
||||
Workers performing rendezvous will all block until the rendezvous is
|
||||
considered complete - this happens when at least ``min`` total number of
|
||||
workers have joined the rendezvous barrier (for the same job). This also
|
||||
implies the barrier is not necessarily of fixed size.
|
||||
|
||||
There’s an additional small waiting time after reaching ``min`` number
|
||||
of workers - this is used to ensure the rendezvous is not completed “too
|
||||
quickly” (which could potentially exclude additional workers attempting
|
||||
to join at approximately the same time).
|
||||
|
||||
If ``max`` number of workers is gathered at the barrier, the rendezvous
|
||||
is completed immediately.
|
||||
|
||||
There’s also an overall timeout which causes the rendezvous to fail if
|
||||
``min`` number of workers is never reached – this is meant to be a
|
||||
simple fail-safe to help release partially allocated job resources, in
|
||||
case there’s a problem with the resource manger, and is meant to be
|
||||
interpreted as non-retryable.
|
||||
|
||||
**Exclusivity**:
|
||||
|
||||
A simple distributed barrier would not be sufficient, as we also need to
|
||||
ensure that only one group of workers exists at any given time (for a
|
||||
given job). In other words, new workers (i.e. joining late) should not
|
||||
be able to form a parallel independent group of workers for the same
|
||||
job.
|
||||
|
||||
Torchelastic rendezvous ensures that if a group of workers has already
|
||||
completed a rendezvous (and hence might already be training), then
|
||||
additional “late” workers attempting to rendezvous will only announce
|
||||
themselves as waiting, and will have to wait until the (previously
|
||||
completed) existing rendezvous is destroyed first.
|
||||
|
||||
**Consistency**:
|
||||
|
||||
|
||||
When a rendezvous is completed, all its members will agree on the job
|
||||
membership and everyone’s role in it. This role is represented using an
|
||||
integer, called rank, that is between between 0 and world size.
|
||||
|
||||
Note that ranks are *not stable*, in the sense that the same worker
|
||||
process can be assigned a different rank in the next (re-)rendezvous.
|
||||
|
||||
**Fault-tolerance**:
|
||||
|
||||
Torchelastic rendezvous is designed to tolerate worker failures during
|
||||
the rendezvous process. Should a process crash (or lose network
|
||||
connectivity, etc), between joining the rendezvous and it being
|
||||
completed, then a re-rendezvous with remaining healthy workers will
|
||||
happen automatically.
|
||||
|
||||
A worker can also fail *after* it has completed (or *has been
|
||||
observered* by other workers to have completed) the rendezvous - this
|
||||
scenario will be handled by the torchelastic ``train_loop`` instead
|
||||
(where it will also trigger a re-rendezvous).
|
||||
|
||||
**Shared key-value store**:
|
||||
|
||||
When the rendezvous is completed, a shared key-value store is created
|
||||
and returned. This store implements a ``torch.distributed.Store`` API
|
||||
(see `distributed communication
|
||||
docs <https://pytorch.org/docs/stable/distributed.html>`__).
|
||||
|
||||
This store is only shared by the members of the completed rendezvous. It
|
||||
is intended to be used by torchelastic to exchange information necessary
|
||||
to initialize job control and data-planes.
|
||||
|
||||
**Waiting workers and rendezvous closing**:
|
||||
|
||||
Torchelastic rendezvous handler object provides additional
|
||||
functionalities, which are technically not part of the rendezvous
|
||||
process:
|
||||
|
||||
1. Querying how many workers arrived late at the barrier, who
|
||||
can participate in *next* rendezvous.
|
||||
|
||||
2. Setting the rendezvous *closed* to signal all workers not
|
||||
to participate in next rendezvous.
|
||||
"""
|
||||
|
||||
from .api import ( # noqa: F401
|
||||
RendezvousClosedException,
|
||||
RendezvousException,
|
||||
RendezvousHandler,
|
||||
RendezvousHandlerFactory,
|
||||
RendezvousNonRetryableError,
|
||||
RendezvousParameters,
|
||||
RendezvousTimeoutException,
|
||||
)
|
302
torch/distributed/elastic/rendezvous/api.py
Normal file
302
torch/distributed/elastic/rendezvous/api.py
Normal file
@ -0,0 +1,302 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import abc
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
from torch.distributed import Store
|
||||
|
||||
|
||||
class RendezvousException(Exception):
|
||||
"""
|
||||
Represents the base type for rendezvous exceptions.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RendezvousClosedException(RendezvousException):
|
||||
"""
|
||||
Raised when a rendezvous is closed.
|
||||
|
||||
This is used to signal completion to nodes that arrive late.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RendezvousTimeoutException(RendezvousException):
|
||||
"""
|
||||
Raised to signal that a rendezvous did not succeed within the allocated
|
||||
time.
|
||||
|
||||
This is a non-retryable type of failure.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RendezvousNonRetryableError(RendezvousException):
|
||||
"""
|
||||
Raised when a failure occured that should not be retried within the same
|
||||
worker process.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RendezvousHandler(abc.ABC):
|
||||
"""
|
||||
Main rendezvous interface.
|
||||
|
||||
.. note:: torchelastic users normally **do not** need to implement their
|
||||
own ``RendezvousHandler``. An implementation based on
|
||||
`etcd <https://etcd.io/>`__ is already provided, and is recommended
|
||||
for most users, provided they can deploy it in their environment.
|
||||
|
||||
.. warning:: torchelastic is currently considered experimental,
|
||||
so the APIs may change!
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_backend(self) -> str:
|
||||
"""
|
||||
Return the string representation of the rendezvous handler.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def next_rendezvous(
|
||||
self,
|
||||
) -> Tuple[Store, int, int]:
|
||||
"""
|
||||
Main entry-point into the rendezvous barrier.
|
||||
Blocks until the rendezvous is complete (and the current
|
||||
process is included in the formed worker group), or a timeout occurs, or
|
||||
rendezvous was marked closed.
|
||||
|
||||
Returns: a tuple of (``c10d Store``, ``rank``, ``world size``)
|
||||
|
||||
Raises:
|
||||
RendezvousClosedException - if rendezvous for the current
|
||||
job is closed.
|
||||
RendezvousTimeoutException - on timeout
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_closed(self) -> bool:
|
||||
"""
|
||||
Checks whether rendezvous for current job has been closed,
|
||||
which means all future attempts to re-rendezvous (within same job) will
|
||||
fail.
|
||||
|
||||
.. note:: ``is_closed`` and ``set_closed`` have semantics of eventual
|
||||
propagation, and should not be used for synchronization.
|
||||
The intention here is that if at least one worker decides
|
||||
the job is finished, it will close the rendezvous, and
|
||||
other workers will soon observe this and stop
|
||||
training/rendezvous-ing as well.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_closed(self):
|
||||
"""
|
||||
Used to mark the rendezvous (for current job) as closed.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def num_nodes_waiting(self) -> int:
|
||||
"""
|
||||
Returns number of workers who *arrived late* at
|
||||
the rendezvous barrier, hence weren’t included in the current worker
|
||||
group.
|
||||
|
||||
Callers should periodically call this method to check whether
|
||||
new members are waiting to join the job and if so admit them by
|
||||
calling ``next_rendezvous()`` (re-rendezvous).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_run_id(self) -> str:
|
||||
"""
|
||||
Returns the run_id of this rendezvous handler. The run_id is a user-defined
|
||||
id that uniquely identifies an instance of a distributed application.
|
||||
It typically maps to a job id and is used to allow workers to join the
|
||||
correct distributed application.
|
||||
"""
|
||||
pass
|
||||
|
||||
def shutdown(self) -> bool:
|
||||
"""
|
||||
Closes all resources that were open for rendezvous run.
|
||||
|
||||
Usage:
|
||||
|
||||
::
|
||||
|
||||
def main():
|
||||
rdzv_handler = ...
|
||||
try:
|
||||
rank, world_size, store = rdzv_handler.next_rendezvous()
|
||||
finally:
|
||||
rdzv_handler.shutdown()
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class RendezvousParameters:
|
||||
"""
|
||||
The data object holding parameters to construct a ``RendezvousHandler``.
|
||||
"""
|
||||
|
||||
# Default timeout for the rendezvous.
|
||||
_DEFAULT_TIMEOUT: int = 600 # 10 minutes
|
||||
|
||||
# Additional waiting time after reaching the minimum number of nodes
|
||||
# in case the rendezvous is elastic (min != max).
|
||||
_DEFAULT_LAST_CALL_TIMEOUT: int = 30 # 30 seconds
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backend: str,
|
||||
endpoint: str,
|
||||
run_id: str,
|
||||
min_nodes: int,
|
||||
max_nodes: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
backend: The backend that is used to register the rendezvous.
|
||||
endpoint: The endpoint of the rendezvous. Usually it is a string in the format
|
||||
<hostname>:<port>.
|
||||
run_id: The id of the rendezvous.
|
||||
min_nodes: The minimum number of nodes required to complete the rendezvous.
|
||||
max_nodes: The maximum number of nodes that are allowed to join the rendezvous.
|
||||
**kwargs: Additional parameters for the specified backend.
|
||||
"""
|
||||
if backend is None:
|
||||
raise ValueError("The backend cannot be None.")
|
||||
|
||||
if min_nodes < 1:
|
||||
raise ValueError("The minimum number of nodes must be greater than zero.")
|
||||
if max_nodes < min_nodes:
|
||||
raise ValueError(
|
||||
"The maximum number of nodes must be greater than"
|
||||
" or equal to the minimum number of nodes."
|
||||
)
|
||||
|
||||
self.backend = backend
|
||||
self.endpoint = endpoint
|
||||
self.run_id = run_id
|
||||
self.min_nodes = min_nodes
|
||||
self.max_nodes = max_nodes
|
||||
self.config = kwargs
|
||||
|
||||
@property
|
||||
def timeout(self):
|
||||
"""
|
||||
Gets the timeout for the rendezvous.
|
||||
"""
|
||||
return self.get_as_int("timeout", self._DEFAULT_TIMEOUT)
|
||||
|
||||
@property
|
||||
def last_call_timeout(self):
|
||||
"""
|
||||
Gets additional waiting time after reaching the minimum number of nodes
|
||||
in case the rendezvous is elastic (min != max).
|
||||
"""
|
||||
return self.get_as_int("last_call_timeout", self._DEFAULT_LAST_CALL_TIMEOUT)
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Returns the value for ``key`` if ``key`` exists, else ``default``.
|
||||
"""
|
||||
return self.config.get(key, default)
|
||||
|
||||
def get_as_bool(self, key: str, default: Optional[bool] = None) -> Optional[bool]:
|
||||
"""
|
||||
Returns the value for ``key`` as a ``bool`` if ``key`` exists.
|
||||
"""
|
||||
val = self.get(key, default)
|
||||
if val is None:
|
||||
return val
|
||||
if isinstance(val, int) or isinstance(val, bool):
|
||||
return True if val else False
|
||||
if isinstance(val, str):
|
||||
return val.lower() in ["1", "true", "t", "yes", "y"]
|
||||
raise ValueError(
|
||||
f"The '{key}' rendezvous config does not represent a valid boolean value."
|
||||
)
|
||||
|
||||
def get_as_int(self, key: str, default: Optional[int] = None) -> Optional[int]:
|
||||
"""
|
||||
Returns the value for ``key`` as an ``int`` if ``key`` exists.
|
||||
"""
|
||||
val = self.get(key, default)
|
||||
if val is None:
|
||||
return val
|
||||
try:
|
||||
return int(val)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"The '{key}' rendezvous config does not represent a valid integer value."
|
||||
)
|
||||
|
||||
|
||||
RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler]
|
||||
|
||||
|
||||
class RendezvousHandlerFactory:
|
||||
"""
|
||||
Creates ``RendezvousHandler`` instances for supported rendezvous backends.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._registry: Dict[str, RendezvousHandlerCreator] = {}
|
||||
|
||||
def register(self, backend: str, creator: RendezvousHandlerCreator):
|
||||
"""
|
||||
Registers a new rendezvous backend.
|
||||
"""
|
||||
try:
|
||||
current_creator = self._registry[backend]
|
||||
except KeyError:
|
||||
current_creator = None # type: ignore[assignment]
|
||||
|
||||
if current_creator is not None:
|
||||
raise ValueError(
|
||||
f"The rendezvous backend '{backend}' cannot be registered with"
|
||||
f" '{creator.__module__}.{creator.__name__}' as it is already"
|
||||
f" registered with '{current_creator.__module__}.{current_creator.__name__}'."
|
||||
)
|
||||
|
||||
self._registry[backend] = creator
|
||||
|
||||
def create_handler(self, params: RendezvousParameters) -> RendezvousHandler:
|
||||
"""
|
||||
Creates a new ``RendezvousHandler`` instance for the specified backend.
|
||||
"""
|
||||
try:
|
||||
creator = self._registry[params.backend]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"The rendezvous backend '{params.backend}' is not registered. Did you forget to call {self.register.__name__}?"
|
||||
)
|
||||
|
||||
handler = creator(params)
|
||||
|
||||
# Do some sanity check.
|
||||
if handler.get_backend() != params.backend:
|
||||
raise RuntimeError(
|
||||
f"The rendezvous handler backend '{handler.get_backend()}' does not match the requested backend '{params.backend}'."
|
||||
)
|
||||
|
||||
return handler
|
1259
torch/distributed/elastic/rendezvous/etcd_rendezvous.py
Normal file
1259
torch/distributed/elastic/rendezvous/etcd_rendezvous.py
Normal file
File diff suppressed because it is too large
Load Diff
249
torch/distributed/elastic/rendezvous/etcd_server.py
Normal file
249
torch/distributed/elastic/rendezvous/etcd_server.py
Normal file
@ -0,0 +1,249 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import atexit
|
||||
import logging
|
||||
import os
|
||||
import shlex
|
||||
import shutil
|
||||
import socket
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import etcd # type: ignore[import]
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def find_free_port():
|
||||
"""
|
||||
Finds a free port and binds a temporary socket to it so that
|
||||
the port can be "reserved" until used.
|
||||
|
||||
.. note:: the returned socket must be closed before using the port,
|
||||
otherwise a ``address already in use`` error will happen.
|
||||
The socket should be held and closed as close to the
|
||||
consumer of the port as possible since otherwise, there
|
||||
is a greater chance of race-condition where a different
|
||||
process may see the port as being free and take it.
|
||||
|
||||
Returns: a socket binded to the reserved free port
|
||||
|
||||
Usage::
|
||||
|
||||
sock = find_free_port()
|
||||
port = sock.getsockname()[1]
|
||||
sock.close()
|
||||
use_port(port)
|
||||
"""
|
||||
addrs = socket.getaddrinfo(
|
||||
host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
|
||||
)
|
||||
|
||||
for addr in addrs:
|
||||
family, type, proto, _, _ = addr
|
||||
try:
|
||||
s = socket.socket(family, type, proto)
|
||||
s.bind(("localhost", 0))
|
||||
s.listen(0)
|
||||
return s
|
||||
except OSError as e:
|
||||
s.close()
|
||||
print(f"Socket creation attempt failed: {e}")
|
||||
raise RuntimeError("Failed to create a socket")
|
||||
|
||||
|
||||
def stop_etcd(subprocess, data_dir: Optional[str] = None):
|
||||
if subprocess and subprocess.poll() is None:
|
||||
log.info("stopping etcd server")
|
||||
subprocess.terminate()
|
||||
subprocess.wait()
|
||||
|
||||
if data_dir:
|
||||
log.info(f"deleting etcd data dir: {data_dir}")
|
||||
shutil.rmtree(data_dir, ignore_errors=True)
|
||||
|
||||
|
||||
class EtcdServer:
|
||||
"""
|
||||
.. note:: tested on etcd server v3.4.3
|
||||
|
||||
Starts and stops a local standalone etcd server on a random free
|
||||
port. Useful for single node, multi-worker launches or testing,
|
||||
where a sidecar etcd server is more convenient than having to
|
||||
separately setup an etcd server.
|
||||
|
||||
This class registers a termination handler to shutdown the etcd
|
||||
subprocess on exit. This termination handler is NOT a substitute for
|
||||
calling the ``stop()`` method.
|
||||
|
||||
The following fallback mechanism is used to find the etcd binary:
|
||||
|
||||
1. Uses env var TORCHELASTIC_ETCD_BINARY_PATH
|
||||
2. Uses ``<this file root>/bin/etcd`` if one exists
|
||||
3. Uses ``etcd`` from ``PATH``
|
||||
|
||||
Usage
|
||||
::
|
||||
|
||||
server = EtcdServer("/usr/bin/etcd", 2379, "/tmp/default.etcd")
|
||||
server.start()
|
||||
client = server.get_client()
|
||||
# use client
|
||||
server.stop()
|
||||
|
||||
Args:
|
||||
etcd_binary_path: path of etcd server binary (see above for fallback path)
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: Optional[str] = None):
|
||||
self._port = -1
|
||||
self._host = "localhost"
|
||||
|
||||
root = os.path.dirname(__file__)
|
||||
default_etcd_bin = os.path.join(root, "bin/etcd")
|
||||
self._etcd_binary_path = os.environ.get(
|
||||
"TORCHELASTIC_ETCD_BINARY_PATH", default_etcd_bin
|
||||
)
|
||||
if not os.path.isfile(self._etcd_binary_path):
|
||||
self._etcd_binary_path = "etcd"
|
||||
|
||||
self._base_data_dir = (
|
||||
data_dir if data_dir else tempfile.mkdtemp(prefix="torchelastic_etcd_data")
|
||||
)
|
||||
self._etcd_cmd = None
|
||||
self._etcd_proc: Optional[subprocess.Popen] = None
|
||||
|
||||
def _get_etcd_server_process(self) -> subprocess.Popen:
|
||||
if not self._etcd_proc:
|
||||
raise RuntimeError(
|
||||
"No etcd server process started. Call etcd_server.start() first"
|
||||
)
|
||||
else:
|
||||
return self._etcd_proc
|
||||
|
||||
def get_port(self) -> int:
|
||||
"""
|
||||
Returns:
|
||||
the port the server is running on.
|
||||
"""
|
||||
return self._port
|
||||
|
||||
def get_host(self) -> str:
|
||||
"""
|
||||
Returns:
|
||||
the host the server is running on.
|
||||
"""
|
||||
return self._host
|
||||
|
||||
def get_endpoint(self) -> str:
|
||||
"""
|
||||
Returns:
|
||||
the etcd server endpoint (host:port)
|
||||
"""
|
||||
return f"{self._host}:{self._port}"
|
||||
|
||||
def start(self, timeout: int = 60, num_retries: int = 3) -> None:
|
||||
"""
|
||||
Starts the server, and waits for it to be ready. When this function
|
||||
returns the sever is ready to take requests.
|
||||
|
||||
Args:
|
||||
timeout: time (in seconds) to wait for the server to be ready
|
||||
before giving up.
|
||||
num_retries: number of retries to start the server. Each retry
|
||||
will wait for max ``timeout`` before considering it as failed.
|
||||
|
||||
Raises:
|
||||
TimeoutError: if the server is not ready within the specified timeout
|
||||
"""
|
||||
curr_retries = 0
|
||||
while True:
|
||||
try:
|
||||
data_dir = os.path.join(self._base_data_dir, str(curr_retries))
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
return self._start(data_dir, timeout)
|
||||
except Exception as e:
|
||||
curr_retries += 1
|
||||
stop_etcd(self._etcd_proc)
|
||||
log.warning(
|
||||
f"Failed to start etcd server, got error: {str(e)}, retrying"
|
||||
)
|
||||
if curr_retries >= num_retries:
|
||||
shutil.rmtree(self._base_data_dir, ignore_errors=True)
|
||||
raise
|
||||
atexit.register(stop_etcd, self._etcd_proc, self._base_data_dir)
|
||||
|
||||
def _start(self, data_dir: str, timeout: int = 60) -> None:
|
||||
sock = find_free_port()
|
||||
sock_peer = find_free_port()
|
||||
self._port = sock.getsockname()[1]
|
||||
peer_port = sock_peer.getsockname()[1]
|
||||
|
||||
etcd_cmd = shlex.split(
|
||||
" ".join(
|
||||
[
|
||||
self._etcd_binary_path,
|
||||
"--enable-v2",
|
||||
"--data-dir",
|
||||
data_dir,
|
||||
"--listen-client-urls",
|
||||
f"http://{self._host}:{self._port}",
|
||||
"--advertise-client-urls",
|
||||
f"http://{self._host}:{self._port}",
|
||||
"--listen-peer-urls",
|
||||
f"http://{self._host}:{peer_port}",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
log.info(f"Starting etcd server: [{etcd_cmd}]")
|
||||
|
||||
sock.close()
|
||||
sock_peer.close()
|
||||
self._etcd_proc = subprocess.Popen(etcd_cmd, close_fds=True)
|
||||
self._wait_for_ready(timeout)
|
||||
|
||||
def get_client(self) -> etcd.Client:
|
||||
"""
|
||||
Returns:
|
||||
An etcd client object that can be used to make requests to
|
||||
this server.
|
||||
"""
|
||||
return etcd.Client(
|
||||
host=self._host, port=self._port, version_prefix="/v2", read_timeout=10
|
||||
)
|
||||
|
||||
def _wait_for_ready(self, timeout: int = 60) -> None:
|
||||
client = etcd.Client(
|
||||
host=f"{self._host}", port=self._port, version_prefix="/v2", read_timeout=5
|
||||
)
|
||||
max_time = time.time() + timeout
|
||||
|
||||
while time.time() < max_time:
|
||||
if self._get_etcd_server_process().poll() is not None:
|
||||
# etcd server process finished
|
||||
exitcode = self._get_etcd_server_process().returncode
|
||||
raise RuntimeError(
|
||||
f"Etcd server process exited with the code: {exitcode}"
|
||||
)
|
||||
try:
|
||||
log.info(f"etcd server ready. version: {client.version}")
|
||||
return
|
||||
except Exception:
|
||||
time.sleep(1)
|
||||
raise TimeoutError("Timed out waiting for etcd server to be ready!")
|
||||
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
Stops the server and cleans up auto generated resources (e.g. data dir)
|
||||
"""
|
||||
log.info("EtcdServer stop method called")
|
||||
stop_etcd(self._etcd_proc, self._base_data_dir)
|
19
torch/distributed/elastic/rendezvous/registry.py
Normal file
19
torch/distributed/elastic/rendezvous/registry.py
Normal file
@ -0,0 +1,19 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from . import etcd_rendezvous
|
||||
from .api import (
|
||||
RendezvousHandler,
|
||||
RendezvousHandlerFactory,
|
||||
RendezvousParameters,
|
||||
)
|
||||
|
||||
_factory = RendezvousHandlerFactory()
|
||||
_factory.register("etcd", etcd_rendezvous.create_rdzv_handler)
|
||||
|
||||
|
||||
def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler:
|
||||
return _factory.create_handler(params)
|
60
torch/distributed/elastic/rendezvous/utils.py
Normal file
60
torch/distributed/elastic/rendezvous/utils.py
Normal file
@ -0,0 +1,60 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import re
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
|
||||
def _parse_rendezvous_config(config_str: str) -> Dict[str, str]:
|
||||
"""
|
||||
Extracts key-value pairs from a configuration string that has the format
|
||||
<key1>=<value1>,...,<keyN>=<valueN>.
|
||||
"""
|
||||
config: Dict[str, str] = {}
|
||||
|
||||
if not config_str:
|
||||
return config
|
||||
|
||||
key_values = config_str.split(",")
|
||||
for kv in key_values:
|
||||
key, *values = kv.split("=", 1)
|
||||
if not values:
|
||||
raise ValueError(f"The '{key}' rendezvous config has no value specified.")
|
||||
config[key] = values[0]
|
||||
return config
|
||||
|
||||
|
||||
def _parse_hostname_and_port(
|
||||
endpoint: Optional[str], default_port: int
|
||||
) -> Tuple[str, int]:
|
||||
"""
|
||||
Extracts the hostname and the port number from an endpoint string that has
|
||||
the format <hostname>:<port>.
|
||||
|
||||
If no hostname can be found, defaults to the loopback address 127.0.0.1.
|
||||
"""
|
||||
if not endpoint:
|
||||
return ("127.0.0.1", default_port)
|
||||
|
||||
hostname, *rest = endpoint.rsplit(":", 1)
|
||||
if len(rest) == 1:
|
||||
if re.match(r"^[0-9]{1,5}$", rest[0]):
|
||||
port = int(rest[0])
|
||||
else:
|
||||
port = 0
|
||||
if port <= 80 or port >= 2 ** 16:
|
||||
raise ValueError(
|
||||
f"The rendezvous endpoint '{endpoint}' has an invalid port number '{rest[0]}'."
|
||||
)
|
||||
else:
|
||||
port = default_port
|
||||
|
||||
if not re.match(r"^[\w\.:-]+$", hostname):
|
||||
raise ValueError(
|
||||
f"The rendezvous enpoint '{endpoint}' has an invalid hostname '{hostname}'."
|
||||
)
|
||||
|
||||
return hostname, port
|
9
torch/distributed/elastic/utils/__init__.py
Normal file
9
torch/distributed/elastic/utils/__init__.py
Normal file
@ -0,0 +1,9 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .api import get_env_variable_or_raise # noqa F401
|
24
torch/distributed/elastic/utils/api.py
Normal file
24
torch/distributed/elastic/utils/api.py
Normal file
@ -0,0 +1,24 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def get_env_variable_or_raise(env_name: str) -> str:
|
||||
r"""
|
||||
Tries to retrieve environment variable. Raises ``ValueError``
|
||||
if no environment variable found.
|
||||
|
||||
Args:
|
||||
env_name (str): Name of the env variable
|
||||
"""
|
||||
value = os.environ.get(env_name, None)
|
||||
if value is None:
|
||||
msg = f"Environment variable {env_name} expected, but not set"
|
||||
raise ValueError(msg)
|
||||
return value
|
10
torch/distributed/elastic/utils/data/__init__.py
Normal file
10
torch/distributed/elastic/utils/data/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .cycling_iterator import CyclingIterator # noqa F401
|
||||
from .elastic_distributed_sampler import ElasticDistributedSampler # noqa F401
|
43
torch/distributed/elastic/utils/data/cycling_iterator.py
Normal file
43
torch/distributed/elastic/utils/data/cycling_iterator.py
Normal file
@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
class CyclingIterator:
|
||||
"""
|
||||
An iterator decorator that cycles through the
|
||||
underlying iterator "n" times. Useful to "unroll"
|
||||
the dataset across multiple training epochs.
|
||||
|
||||
The generator function is called as ``generator_fn(epoch)``
|
||||
to obtain the underlying iterator, where ``epoch`` is a
|
||||
number less than or equal to ``n`` representing the ``k``th cycle
|
||||
|
||||
For example if ``generator_fn`` always returns ``[1,2,3]``
|
||||
then ``CyclingIterator(n=2, generator_fn)`` will iterate through
|
||||
``[1,2,3,1,2,3]``
|
||||
"""
|
||||
|
||||
def __init__(self, n: int, generator_fn, start_epoch=0):
|
||||
self._n = n
|
||||
self._epoch = start_epoch
|
||||
self._generator_fn = generator_fn
|
||||
self._iter = generator_fn(self._epoch)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
return next(self._iter)
|
||||
except StopIteration as eod: # eod == end of data
|
||||
if self._epoch < self._n - 1:
|
||||
self._epoch += 1
|
||||
self._iter = self._generator_fn(self._epoch)
|
||||
return self.__next__()
|
||||
else:
|
||||
raise eod
|
@ -0,0 +1,72 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
|
||||
class ElasticDistributedSampler(DistributedSampler):
|
||||
"""
|
||||
Sampler that restricts data loading to a subset of
|
||||
the dataset for elastic training.
|
||||
|
||||
It is especially useful in conjunction with
|
||||
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
|
||||
process can pass a DistributedSampler instance as a DataLoader sampler,
|
||||
and load a subset of the original dataset that is exclusive to it.
|
||||
|
||||
.. note::
|
||||
Dataset is assumed to be of constant size.
|
||||
|
||||
Args:
|
||||
dataset: Dataset used for sampling.
|
||||
num_replicas (optional): Number of processes participating in
|
||||
distributed training.
|
||||
rank (optional): Rank of the current process within num_replicas.
|
||||
start_index (optional): Which index of the dataset to start sampling from
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, num_replicas=None, rank=None, start_index=0):
|
||||
super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank)
|
||||
if start_index >= len(dataset):
|
||||
raise ValueError(
|
||||
"Start index {} should be less than dataset size {}".format(
|
||||
start_index, len(dataset)
|
||||
)
|
||||
)
|
||||
|
||||
self.start_index = start_index
|
||||
self.num_samples = int(
|
||||
math.ceil(float(len(self.dataset) - self.start_index) / self.num_replicas) # type: ignore[arg-type]
|
||||
)
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
|
||||
def __iter__(self):
|
||||
# deterministically shuffle based on epoch
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
indices = (
|
||||
torch.randperm(len(self.dataset) - self.start_index, generator=g) # type: ignore[arg-type]
|
||||
.add(self.start_index)
|
||||
.tolist()
|
||||
)
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices += indices[: (self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank : self.total_size : self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
144
torch/distributed/elastic/utils/distributed.py
Normal file
144
torch/distributed/elastic/utils/distributed.py
Normal file
@ -0,0 +1,144 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import datetime
|
||||
import socket
|
||||
from contextlib import closing
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.elastic.utils.logging import get_logger
|
||||
|
||||
|
||||
log = get_logger()
|
||||
|
||||
_ADDRESS_IN_USE = "Address already in use"
|
||||
_CONNECT_TIMEOUT = "connect() timed out."
|
||||
_SOCKET_TIMEOUT = "Socket Timeout"
|
||||
|
||||
_MEMBER_CHECKIN = "_tcp_store/num_members"
|
||||
_LAST_MEMBER_CHECKIN = "_tcp_store/last_member"
|
||||
|
||||
|
||||
def create_c10d_store(
|
||||
is_server: bool,
|
||||
server_addr: str,
|
||||
server_port: int = -1,
|
||||
world_size: int = 1,
|
||||
timeout: float = (60 * 10), # 10 min
|
||||
retries=3,
|
||||
):
|
||||
if server_port == -1 and world_size > 1:
|
||||
raise ValueError(
|
||||
f"server_port must be specified when world_size > 1, got server_port={server_port}, world_size={world_size}"
|
||||
)
|
||||
|
||||
if server_port != -1:
|
||||
log.info(f"sever_port: {server_port}, specified, ignoring retries")
|
||||
|
||||
# only retry when server_port is NOT static
|
||||
attempt = retries if server_port == -1 else 1
|
||||
while True:
|
||||
if server_port != -1:
|
||||
port = server_port
|
||||
else:
|
||||
port = get_free_port()
|
||||
|
||||
log.info(
|
||||
f"Creating c10d store on {server_addr}:{port}\n"
|
||||
f" world_size : {world_size}\n"
|
||||
f" is_server : {is_server}\n"
|
||||
f" timeout(sec): {timeout}\n"
|
||||
)
|
||||
|
||||
try:
|
||||
store = dist.TCPStore(
|
||||
host_name=server_addr,
|
||||
port=port,
|
||||
world_size=world_size,
|
||||
is_master=is_server,
|
||||
timeout=datetime.timedelta(seconds=timeout),
|
||||
)
|
||||
_check_full_rank(store, world_size)
|
||||
log.info("Successfully created c10d store")
|
||||
return store
|
||||
except RuntimeError as e:
|
||||
# this is brittle, but the underlying exception type is not properly pybinded
|
||||
# so we parse the error msg for now, interestingly this is how torch itself
|
||||
# detects timeouts and port conflicts in their own unittests
|
||||
# see - caffe2/torch/testing/_internal/common_utils.py
|
||||
# TODO properly map the exceptions in pybind (c10d/init.cpp)
|
||||
if str(e) == _CONNECT_TIMEOUT and not is_server:
|
||||
raise TimeoutError(
|
||||
f"timed out waiting for tcp store's server: {server_addr}:{port}"
|
||||
) from e
|
||||
elif str(e) == _ADDRESS_IN_USE: # this will only happen on the server
|
||||
if attempt < retries:
|
||||
log.warning(
|
||||
f"port: {port} already in use, attempt: [{attempt}/{retries}]"
|
||||
)
|
||||
attempt += 1
|
||||
else:
|
||||
raise IOError(
|
||||
f"on {server_addr}, port: {port} already in use"
|
||||
) from e
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def _check_full_rank(store, world_size):
|
||||
idx = store.add(_MEMBER_CHECKIN, 1)
|
||||
if idx == world_size:
|
||||
store.set(_LAST_MEMBER_CHECKIN, "<val_ignored>")
|
||||
|
||||
try:
|
||||
store.get(_LAST_MEMBER_CHECKIN)
|
||||
except RuntimeError as e:
|
||||
if str(e) == _SOCKET_TIMEOUT:
|
||||
raise TimeoutError(
|
||||
f"timed out waiting for all {world_size} members to join"
|
||||
) from e
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def get_free_port():
|
||||
sock = get_socket_with_port()
|
||||
with closing(sock):
|
||||
return sock.getsockname()[1]
|
||||
|
||||
|
||||
def get_socket_with_port() -> socket.socket:
|
||||
"""
|
||||
Returns a free port on localhost that is "reserved" by binding a temporary
|
||||
socket on it. Close the socket before passing the port to the entity
|
||||
that requires it. Usage example
|
||||
|
||||
::
|
||||
|
||||
sock = _get_socket_with_port()
|
||||
with closing(sock):
|
||||
port = sock.getsockname()[1]
|
||||
sock.close()
|
||||
# there is still a race-condition that some other process
|
||||
# may grab this port before func() runs
|
||||
func(port)
|
||||
"""
|
||||
|
||||
addrs = socket.getaddrinfo(
|
||||
host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
|
||||
)
|
||||
for addr in addrs:
|
||||
family, type, proto, _, _ = addr
|
||||
s = socket.socket(family, type, proto)
|
||||
try:
|
||||
s.bind(("localhost", 0))
|
||||
s.listen(0)
|
||||
return s
|
||||
except OSError as e:
|
||||
s.close()
|
||||
log.info("Socket creation attempt failed.", exc_info=e)
|
||||
raise RuntimeError("Failed to create a socket")
|
67
torch/distributed/elastic/utils/logging.py
Normal file
67
torch/distributed/elastic/utils/logging.py
Normal file
@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def get_logger(name: Optional[str] = None):
|
||||
"""
|
||||
Util function to set up a simple logger that writes
|
||||
into stderr. The loglevel is fetched from the LOGLEVEL
|
||||
env. variable or INFO as default. The function will use the
|
||||
module name of the caller if no name is provided.
|
||||
|
||||
Args:
|
||||
name: Name of the logger. If no name provided, the name will
|
||||
be derived from the call stack.
|
||||
"""
|
||||
|
||||
# Derive the name of the caller, if none provided
|
||||
# Use depth=2 since this function takes up one level in the call stack
|
||||
return _setup_logger(name or _derive_module_name(depth=2))
|
||||
|
||||
|
||||
def _setup_logger(name: Optional[str] = None):
|
||||
log = logging.getLogger(name)
|
||||
log.setLevel(os.environ.get("LOGLEVEL", "INFO"))
|
||||
return log
|
||||
|
||||
|
||||
def _derive_module_name(depth: int = 1) -> Optional[str]:
|
||||
"""
|
||||
Derives the name of the caller module from the stack frames.
|
||||
|
||||
Args:
|
||||
depth: The position of the frame in the stack.
|
||||
"""
|
||||
try:
|
||||
stack = inspect.stack()
|
||||
assert depth < len(stack)
|
||||
# FrameInfo is just a named tuple: (frame, filename, lineno, function, code_context, index)
|
||||
frame_info = stack[depth]
|
||||
filename = frame_info[1]
|
||||
|
||||
module = inspect.getmodule(frame_info[0])
|
||||
if module:
|
||||
module_name = module.__name__
|
||||
else:
|
||||
# inspect.getmodule(frame_info[0]) does NOT work (returns None) in
|
||||
# binaries built with @mode/opt
|
||||
# return the filename (minus the .py extension) as modulename
|
||||
module_name = os.path.splitext(os.path.basename(filename))[0]
|
||||
return module_name
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
f"Error deriving logger module name, using <None>. Exception: {e}",
|
||||
RuntimeWarning,
|
||||
)
|
||||
return None
|
74
torch/distributed/elastic/utils/store.py
Normal file
74
torch/distributed/elastic/utils/store.py
Normal file
@ -0,0 +1,74 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
from datetime import timedelta
|
||||
from typing import List
|
||||
|
||||
|
||||
def get_all(store, prefix: str, size: int):
|
||||
r"""
|
||||
Given a store and a prefix, the method goes through the array of keys
|
||||
of the following format: ``{prefix}{idx}``, where idx is in a range
|
||||
from 0 to size, and tries to retrieve the data.
|
||||
|
||||
Usage
|
||||
|
||||
::
|
||||
|
||||
values = get_all(store, 'torchelastic/data', 3)
|
||||
value1 = values[0] # retrieves the data for key torchelastic/data0
|
||||
value2 = values[1] # retrieves the data for key torchelastic/data1
|
||||
value3 = values[2] # retrieves the data for key torchelastic/data2
|
||||
|
||||
"""
|
||||
data_arr = []
|
||||
for idx in range(size):
|
||||
data = store.get(f"{prefix}{idx}")
|
||||
data_arr.append(data)
|
||||
return data_arr
|
||||
|
||||
|
||||
def synchronize(
|
||||
store,
|
||||
data: bytes,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
key_prefix: str,
|
||||
barrier_timeout: float = 300,
|
||||
) -> List[bytes]:
|
||||
"""
|
||||
Synchronizes ``world_size`` agents between each other using the underlying c10d store.
|
||||
The ``data`` will be available on each of the agents.
|
||||
|
||||
Note: The data on the path is not deleted, as a result there can be stale data if
|
||||
you use the same key_prefix twice.
|
||||
"""
|
||||
warnings.warn(
|
||||
"This is an experimental API and will be changed in future.", FutureWarning
|
||||
)
|
||||
store.set_timeout(timedelta(seconds=barrier_timeout))
|
||||
store.set(f"{key_prefix}{rank}", data)
|
||||
agent_data = get_all(store, key_prefix, world_size)
|
||||
return agent_data
|
||||
|
||||
|
||||
def barrier(
|
||||
store, rank: int, world_size: int, key_prefix: str, barrier_timeout: float = 300
|
||||
) -> None:
|
||||
"""
|
||||
A global lock between agents.
|
||||
|
||||
Note: Since the data is not removed from the store, the barrier can be used
|
||||
once per unique ``key_prefix``.
|
||||
"""
|
||||
warnings.warn(
|
||||
"This is an experimental API and will be changed in future.", FutureWarning
|
||||
)
|
||||
data = f"{rank}".encode(encoding="UTF-8")
|
||||
synchronize(store, data, rank, world_size, key_prefix, barrier_timeout)
|
Reference in New Issue
Block a user