mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[torch/elastic] Refactor rendezvous store initialization logic (#58057)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58057 This PR refactors the store initialization logic and moves it to the `create_backend` function for both C10d and etcd backends. ghstack-source-id: 128671579 Test Plan: Run the existing and revised tests. Reviewed By: tierex Differential Revision: D28356587 fbshipit-source-id: caf9416ab811eefe4834268d8a11a48f2236ed5b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
b58a7c95aa
commit
1d4d9ffca0
@ -4,8 +4,9 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from base64 import b64encode
|
||||
from datetime import timedelta
|
||||
from typing import ClassVar
|
||||
from typing import ClassVar, cast
|
||||
from unittest import TestCase
|
||||
|
||||
from torch.distributed import TCPStore
|
||||
@ -55,18 +56,23 @@ class CreateBackendTest(TestCase):
|
||||
self._expected_read_timeout = timedelta(seconds=10)
|
||||
|
||||
def test_create_backend_returns_backend(self) -> None:
|
||||
backend = create_backend(self._params)
|
||||
|
||||
self.assertIsInstance(backend.store, self._expected_store_type)
|
||||
backend, store = create_backend(self._params)
|
||||
|
||||
self.assertEqual(backend.name, "c10d")
|
||||
self.assertEqual(backend.key, "torch.rendezvous." + self._params.run_id)
|
||||
|
||||
store = backend.store
|
||||
self.assertIsInstance(store, self._expected_store_type)
|
||||
|
||||
self.assertEqual(store.host, self._expected_endpoint_host) # type: ignore[attr-defined]
|
||||
self.assertEqual(store.port, self._expected_endpoint_port) # type: ignore[attr-defined]
|
||||
self.assertEqual(store.timeout, self._expected_read_timeout) # type: ignore[attr-defined]
|
||||
tcp_store = cast(TCPStore, store)
|
||||
|
||||
self.assertEqual(tcp_store.host, self._expected_endpoint_host) # type: ignore[attr-defined]
|
||||
self.assertEqual(tcp_store.port, self._expected_endpoint_port) # type: ignore[attr-defined]
|
||||
self.assertEqual(tcp_store.timeout, self._expected_read_timeout) # type: ignore[attr-defined]
|
||||
|
||||
backend.set_state(b"dummy_state")
|
||||
|
||||
state = store.get("torch.rendezvous." + self._params.run_id)
|
||||
|
||||
self.assertEqual(state, b64encode(b"dummy_state"))
|
||||
|
||||
def test_create_backend_returns_backend_if_is_host_is_false(self) -> None:
|
||||
store = TCPStore( # type: ignore[call-arg] # noqa: F841
|
||||
|
@ -4,7 +4,6 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import codecs
|
||||
import copy
|
||||
import os
|
||||
import pickle
|
||||
@ -12,6 +11,7 @@ import socket
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from base64 import b64encode
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Callable, Optional, Tuple, cast
|
||||
from unittest import TestCase
|
||||
@ -153,7 +153,7 @@ class RendezvousStateTest(TestCase):
|
||||
|
||||
bits = pickle.dumps(state)
|
||||
|
||||
base64_bits = codecs.encode(bits, "base64")
|
||||
base64_bits = b64encode(bits)
|
||||
|
||||
self.assertLessEqual(len(base64_bits), max_byte_size)
|
||||
|
||||
|
@ -5,11 +5,11 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import subprocess
|
||||
|
||||
from typing import ClassVar
|
||||
from base64 import b64encode
|
||||
from typing import ClassVar, cast
|
||||
from unittest import TestCase
|
||||
|
||||
from etcd import EtcdKeyNotFound
|
||||
from etcd import EtcdKeyNotFound # type: ignore[import]
|
||||
|
||||
from torch.distributed.elastic.rendezvous import RendezvousConnectionError, RendezvousParameters
|
||||
from torch.distributed.elastic.rendezvous.etcd_rendezvous_backend import (
|
||||
@ -17,6 +17,7 @@ from torch.distributed.elastic.rendezvous.etcd_rendezvous_backend import (
|
||||
create_backend,
|
||||
)
|
||||
from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
|
||||
from torch.distributed.elastic.rendezvous.etcd_store import EtcdStore
|
||||
|
||||
from rendezvous_backend_test import RendezvousBackendTestMixin
|
||||
|
||||
@ -71,19 +72,33 @@ class CreateBackendTest(TestCase):
|
||||
read_timeout="10",
|
||||
)
|
||||
|
||||
self._expected_protocol = "http"
|
||||
self._expected_read_timeout = 10
|
||||
|
||||
def test_create_backend_returns_backend(self) -> None:
|
||||
backend = create_backend(self._params)
|
||||
backend, store = create_backend(self._params)
|
||||
|
||||
self.assertEqual(backend.name, "etcd-v2")
|
||||
self.assertEqual(backend.key, "/torch/elastic/rendezvous/" + self._params.run_id)
|
||||
self.assertEqual(backend.ttl, 7200)
|
||||
self.assertEqual(backend.client.host, self._server.get_host())
|
||||
self.assertEqual(backend.client.port, self._server.get_port())
|
||||
self.assertEqual(backend.client.protocol, self._expected_protocol)
|
||||
self.assertEqual(backend.client.read_timeout, self._expected_read_timeout)
|
||||
|
||||
self.assertIsInstance(store, EtcdStore)
|
||||
|
||||
etcd_store = cast(EtcdStore, store)
|
||||
|
||||
self.assertEqual(etcd_store.client.read_timeout, self._expected_read_timeout) # type: ignore[attr-defined]
|
||||
|
||||
client = self._server.get_client()
|
||||
|
||||
backend.set_state(b"dummy_state")
|
||||
|
||||
result = client.get("/torch/elastic/rendezvous/" + self._params.run_id)
|
||||
|
||||
self.assertEqual(result.value, b64encode(b"dummy_state").decode())
|
||||
self.assertLessEqual(result.ttl, 7200)
|
||||
|
||||
store.set("dummy_key", "dummy_value")
|
||||
|
||||
result = client.get("/torch/elastic/store/" + b64encode(b"dummy_key").decode())
|
||||
|
||||
self.assertEqual(result.value, b64encode(b"dummy_value").decode())
|
||||
|
||||
def test_create_backend_returns_backend_if_protocol_is_not_specified(self) -> None:
|
||||
del self._params.config["protocol"]
|
||||
|
@ -5,9 +5,9 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import binascii
|
||||
import codecs
|
||||
import logging
|
||||
import os
|
||||
from base64 import b64decode, b64encode
|
||||
from datetime import timedelta
|
||||
from typing import Any, Optional, Tuple, cast
|
||||
|
||||
@ -58,17 +58,6 @@ class C10dRendezvousBackend(RendezvousBackend):
|
||||
"""See base class."""
|
||||
return "c10d"
|
||||
|
||||
@property
|
||||
def store(self) -> Store:
|
||||
"""Gets the :py:class:`torch.distributed.Store` instance used to
|
||||
communicate with the C10d store."""
|
||||
return self._store
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
"""Gets the key under which the rendezvous state is stored."""
|
||||
return self._key
|
||||
|
||||
def get_state(self) -> Optional[Tuple[bytes, Token]]:
|
||||
"""See base class."""
|
||||
base64_state: bytes = self._call_store("get", self._key)
|
||||
@ -79,7 +68,7 @@ class C10dRendezvousBackend(RendezvousBackend):
|
||||
self, state: bytes, token: Optional[Token] = None
|
||||
) -> Optional[Tuple[bytes, Token, bool]]:
|
||||
"""See base class."""
|
||||
base64_state_str: str = codecs.encode(state, "base64").decode()
|
||||
base64_state_str: str = b64encode(state).decode()
|
||||
|
||||
if token:
|
||||
# Shortcut if we know for sure that the token is not valid.
|
||||
@ -122,7 +111,7 @@ class C10dRendezvousBackend(RendezvousBackend):
|
||||
return None
|
||||
|
||||
try:
|
||||
state = codecs.decode(base64_state, "base64")
|
||||
state = b64decode(base64_state)
|
||||
except binascii.Error as exc:
|
||||
raise RendezvousStateError(
|
||||
"The state object is corrupt. See inner exception for details."
|
||||
@ -178,7 +167,7 @@ def _create_tcp_store(params: RendezvousParameters) -> TCPStore:
|
||||
return store
|
||||
|
||||
|
||||
def create_backend(params: RendezvousParameters) -> C10dRendezvousBackend:
|
||||
def create_backend(params: RendezvousParameters) -> Tuple[C10dRendezvousBackend, Store]:
|
||||
"""Creates a new :py:class:`C10dRendezvousBackend` from the specified
|
||||
parameters.
|
||||
|
||||
@ -214,4 +203,4 @@ def create_backend(params: RendezvousParameters) -> C10dRendezvousBackend:
|
||||
|
||||
store = _create_tcp_store(params)
|
||||
|
||||
return C10dRendezvousBackend(store, params.run_id)
|
||||
return C10dRendezvousBackend(store, params.run_id), store
|
||||
|
@ -5,7 +5,7 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import binascii
|
||||
import codecs
|
||||
from base64 import b64decode, b64encode
|
||||
from typing import Optional, Tuple, cast
|
||||
|
||||
import urllib3.exceptions # type: ignore[import]
|
||||
@ -17,9 +17,11 @@ from etcd import (
|
||||
EtcdKeyNotFound,
|
||||
EtcdResult,
|
||||
)
|
||||
from torch.distributed import Store
|
||||
|
||||
from .api import RendezvousConnectionError, RendezvousParameters, RendezvousStateError
|
||||
from .dynamic_rendezvous import RendezvousBackend, Token
|
||||
from .etcd_store import EtcdStore
|
||||
from .utils import parse_rendezvous_endpoint
|
||||
|
||||
|
||||
@ -70,21 +72,6 @@ class EtcdRendezvousBackend(RendezvousBackend):
|
||||
"""See base class."""
|
||||
return "etcd-v2"
|
||||
|
||||
@property
|
||||
def client(self) -> EtcdClient:
|
||||
"""Gets the ``etcd.Client`` instance used to communicate with etcd."""
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
"""Gets the key under which the rendezvous state is stored."""
|
||||
return self._key
|
||||
|
||||
@property
|
||||
def ttl(self) -> int:
|
||||
"""Gets the TTL of the rendezvous state."""
|
||||
return self._ttl
|
||||
|
||||
def get_state(self) -> Optional[Tuple[bytes, Token]]:
|
||||
"""See base class."""
|
||||
try:
|
||||
@ -102,7 +89,7 @@ class EtcdRendezvousBackend(RendezvousBackend):
|
||||
self, state: bytes, token: Optional[Token] = None
|
||||
) -> Optional[Tuple[bytes, Token, bool]]:
|
||||
"""See base class."""
|
||||
base64_state = codecs.encode(state, "base64").decode()
|
||||
base64_state = b64encode(state).decode()
|
||||
|
||||
kwargs = {}
|
||||
|
||||
@ -145,7 +132,7 @@ class EtcdRendezvousBackend(RendezvousBackend):
|
||||
base64_state = result.value.encode()
|
||||
|
||||
try:
|
||||
state = codecs.decode(base64_state, "base64")
|
||||
state = b64decode(base64_state)
|
||||
except binascii.Error as exc:
|
||||
raise RendezvousStateError(
|
||||
"The state object is corrupt. See inner exception for details."
|
||||
@ -195,7 +182,7 @@ def _create_etcd_client(params: RendezvousParameters) -> EtcdClient:
|
||||
) from exc
|
||||
|
||||
|
||||
def create_backend(params: RendezvousParameters) -> EtcdRendezvousBackend:
|
||||
def create_backend(params: RendezvousParameters) -> Tuple[EtcdRendezvousBackend, Store]:
|
||||
"""Creates a new :py:class:`EtcdRendezvousBackend` from the specified
|
||||
parameters.
|
||||
|
||||
@ -220,4 +207,8 @@ def create_backend(params: RendezvousParameters) -> EtcdRendezvousBackend:
|
||||
"""
|
||||
client = _create_etcd_client(params)
|
||||
|
||||
return EtcdRendezvousBackend(client, params.run_id, key_prefix="/torch/elastic/rendezvous")
|
||||
backend = EtcdRendezvousBackend(client, params.run_id, key_prefix="/torch/elastic/rendezvous")
|
||||
|
||||
store = EtcdStore(client, "/torch/elastic/store")
|
||||
|
||||
return backend, store
|
||||
|
@ -23,11 +23,8 @@ def _create_etcd_handler(params: RendezvousParameters) -> RendezvousHandler:
|
||||
|
||||
def _create_etcd_v2_handler(params: RendezvousParameters) -> RendezvousHandler:
|
||||
from .etcd_rendezvous_backend import create_backend
|
||||
from .etcd_store import EtcdStore
|
||||
|
||||
backend = create_backend(params)
|
||||
|
||||
store = EtcdStore(backend.client, "/torch/elastic/store")
|
||||
backend, store = create_backend(params)
|
||||
|
||||
return create_handler(store, backend, params)
|
||||
|
||||
@ -35,9 +32,9 @@ def _create_etcd_v2_handler(params: RendezvousParameters) -> RendezvousHandler:
|
||||
def _create_c10d_handler(params: RendezvousParameters) -> RendezvousHandler:
|
||||
from .c10d_rendezvous_backend import create_backend
|
||||
|
||||
backend = create_backend(params)
|
||||
backend, store = create_backend(params)
|
||||
|
||||
return create_handler(backend.store, backend, params)
|
||||
return create_handler(store, backend, params)
|
||||
|
||||
|
||||
def _register_default_handlers() -> None:
|
||||
|
Reference in New Issue
Block a user