[torch/elastic] Refactor rendezvous store initialization logic (#58057)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58057

This PR refactors the store initialization logic and moves it to the `create_backend` function for both C10d and etcd backends.
ghstack-source-id: 128671579

Test Plan: Run the existing and revised tests.

Reviewed By: tierex

Differential Revision: D28356587

fbshipit-source-id: caf9416ab811eefe4834268d8a11a48f2236ed5b
This commit is contained in:
Can Balioglu
2021-05-11 13:45:01 -07:00
committed by Facebook GitHub Bot
parent b58a7c95aa
commit 1d4d9ffca0
6 changed files with 62 additions and 64 deletions

View File

@ -4,8 +4,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from base64 import b64encode
from datetime import timedelta
from typing import ClassVar
from typing import ClassVar, cast
from unittest import TestCase
from torch.distributed import TCPStore
@ -55,18 +56,23 @@ class CreateBackendTest(TestCase):
self._expected_read_timeout = timedelta(seconds=10)
def test_create_backend_returns_backend(self) -> None:
backend = create_backend(self._params)
self.assertIsInstance(backend.store, self._expected_store_type)
backend, store = create_backend(self._params)
self.assertEqual(backend.name, "c10d")
self.assertEqual(backend.key, "torch.rendezvous." + self._params.run_id)
store = backend.store
self.assertIsInstance(store, self._expected_store_type)
self.assertEqual(store.host, self._expected_endpoint_host) # type: ignore[attr-defined]
self.assertEqual(store.port, self._expected_endpoint_port) # type: ignore[attr-defined]
self.assertEqual(store.timeout, self._expected_read_timeout) # type: ignore[attr-defined]
tcp_store = cast(TCPStore, store)
self.assertEqual(tcp_store.host, self._expected_endpoint_host) # type: ignore[attr-defined]
self.assertEqual(tcp_store.port, self._expected_endpoint_port) # type: ignore[attr-defined]
self.assertEqual(tcp_store.timeout, self._expected_read_timeout) # type: ignore[attr-defined]
backend.set_state(b"dummy_state")
state = store.get("torch.rendezvous." + self._params.run_id)
self.assertEqual(state, b64encode(b"dummy_state"))
def test_create_backend_returns_backend_if_is_host_is_false(self) -> None:
store = TCPStore( # type: ignore[call-arg] # noqa: F841

View File

@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import codecs
import copy
import os
import pickle
@ -12,6 +11,7 @@ import socket
import threading
import time
from abc import ABC, abstractmethod
from base64 import b64encode
from datetime import datetime, timedelta
from typing import Callable, Optional, Tuple, cast
from unittest import TestCase
@ -153,7 +153,7 @@ class RendezvousStateTest(TestCase):
bits = pickle.dumps(state)
base64_bits = codecs.encode(bits, "base64")
base64_bits = b64encode(bits)
self.assertLessEqual(len(base64_bits), max_byte_size)

View File

@ -5,11 +5,11 @@
# LICENSE file in the root directory of this source tree.
import subprocess
from typing import ClassVar
from base64 import b64encode
from typing import ClassVar, cast
from unittest import TestCase
from etcd import EtcdKeyNotFound
from etcd import EtcdKeyNotFound # type: ignore[import]
from torch.distributed.elastic.rendezvous import RendezvousConnectionError, RendezvousParameters
from torch.distributed.elastic.rendezvous.etcd_rendezvous_backend import (
@ -17,6 +17,7 @@ from torch.distributed.elastic.rendezvous.etcd_rendezvous_backend import (
create_backend,
)
from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
from torch.distributed.elastic.rendezvous.etcd_store import EtcdStore
from rendezvous_backend_test import RendezvousBackendTestMixin
@ -71,19 +72,33 @@ class CreateBackendTest(TestCase):
read_timeout="10",
)
self._expected_protocol = "http"
self._expected_read_timeout = 10
def test_create_backend_returns_backend(self) -> None:
backend = create_backend(self._params)
backend, store = create_backend(self._params)
self.assertEqual(backend.name, "etcd-v2")
self.assertEqual(backend.key, "/torch/elastic/rendezvous/" + self._params.run_id)
self.assertEqual(backend.ttl, 7200)
self.assertEqual(backend.client.host, self._server.get_host())
self.assertEqual(backend.client.port, self._server.get_port())
self.assertEqual(backend.client.protocol, self._expected_protocol)
self.assertEqual(backend.client.read_timeout, self._expected_read_timeout)
self.assertIsInstance(store, EtcdStore)
etcd_store = cast(EtcdStore, store)
self.assertEqual(etcd_store.client.read_timeout, self._expected_read_timeout) # type: ignore[attr-defined]
client = self._server.get_client()
backend.set_state(b"dummy_state")
result = client.get("/torch/elastic/rendezvous/" + self._params.run_id)
self.assertEqual(result.value, b64encode(b"dummy_state").decode())
self.assertLessEqual(result.ttl, 7200)
store.set("dummy_key", "dummy_value")
result = client.get("/torch/elastic/store/" + b64encode(b"dummy_key").decode())
self.assertEqual(result.value, b64encode(b"dummy_value").decode())
def test_create_backend_returns_backend_if_protocol_is_not_specified(self) -> None:
del self._params.config["protocol"]

View File

@ -5,9 +5,9 @@
# LICENSE file in the root directory of this source tree.
import binascii
import codecs
import logging
import os
from base64 import b64decode, b64encode
from datetime import timedelta
from typing import Any, Optional, Tuple, cast
@ -58,17 +58,6 @@ class C10dRendezvousBackend(RendezvousBackend):
"""See base class."""
return "c10d"
@property
def store(self) -> Store:
"""Gets the :py:class:`torch.distributed.Store` instance used to
communicate with the C10d store."""
return self._store
@property
def key(self) -> str:
"""Gets the key under which the rendezvous state is stored."""
return self._key
def get_state(self) -> Optional[Tuple[bytes, Token]]:
"""See base class."""
base64_state: bytes = self._call_store("get", self._key)
@ -79,7 +68,7 @@ class C10dRendezvousBackend(RendezvousBackend):
self, state: bytes, token: Optional[Token] = None
) -> Optional[Tuple[bytes, Token, bool]]:
"""See base class."""
base64_state_str: str = codecs.encode(state, "base64").decode()
base64_state_str: str = b64encode(state).decode()
if token:
# Shortcut if we know for sure that the token is not valid.
@ -122,7 +111,7 @@ class C10dRendezvousBackend(RendezvousBackend):
return None
try:
state = codecs.decode(base64_state, "base64")
state = b64decode(base64_state)
except binascii.Error as exc:
raise RendezvousStateError(
"The state object is corrupt. See inner exception for details."
@ -178,7 +167,7 @@ def _create_tcp_store(params: RendezvousParameters) -> TCPStore:
return store
def create_backend(params: RendezvousParameters) -> C10dRendezvousBackend:
def create_backend(params: RendezvousParameters) -> Tuple[C10dRendezvousBackend, Store]:
"""Creates a new :py:class:`C10dRendezvousBackend` from the specified
parameters.
@ -214,4 +203,4 @@ def create_backend(params: RendezvousParameters) -> C10dRendezvousBackend:
store = _create_tcp_store(params)
return C10dRendezvousBackend(store, params.run_id)
return C10dRendezvousBackend(store, params.run_id), store

View File

@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.
import binascii
import codecs
from base64 import b64decode, b64encode
from typing import Optional, Tuple, cast
import urllib3.exceptions # type: ignore[import]
@ -17,9 +17,11 @@ from etcd import (
EtcdKeyNotFound,
EtcdResult,
)
from torch.distributed import Store
from .api import RendezvousConnectionError, RendezvousParameters, RendezvousStateError
from .dynamic_rendezvous import RendezvousBackend, Token
from .etcd_store import EtcdStore
from .utils import parse_rendezvous_endpoint
@ -70,21 +72,6 @@ class EtcdRendezvousBackend(RendezvousBackend):
"""See base class."""
return "etcd-v2"
@property
def client(self) -> EtcdClient:
"""Gets the ``etcd.Client`` instance used to communicate with etcd."""
return self._client
@property
def key(self) -> str:
"""Gets the key under which the rendezvous state is stored."""
return self._key
@property
def ttl(self) -> int:
"""Gets the TTL of the rendezvous state."""
return self._ttl
def get_state(self) -> Optional[Tuple[bytes, Token]]:
"""See base class."""
try:
@ -102,7 +89,7 @@ class EtcdRendezvousBackend(RendezvousBackend):
self, state: bytes, token: Optional[Token] = None
) -> Optional[Tuple[bytes, Token, bool]]:
"""See base class."""
base64_state = codecs.encode(state, "base64").decode()
base64_state = b64encode(state).decode()
kwargs = {}
@ -145,7 +132,7 @@ class EtcdRendezvousBackend(RendezvousBackend):
base64_state = result.value.encode()
try:
state = codecs.decode(base64_state, "base64")
state = b64decode(base64_state)
except binascii.Error as exc:
raise RendezvousStateError(
"The state object is corrupt. See inner exception for details."
@ -195,7 +182,7 @@ def _create_etcd_client(params: RendezvousParameters) -> EtcdClient:
) from exc
def create_backend(params: RendezvousParameters) -> EtcdRendezvousBackend:
def create_backend(params: RendezvousParameters) -> Tuple[EtcdRendezvousBackend, Store]:
"""Creates a new :py:class:`EtcdRendezvousBackend` from the specified
parameters.
@ -220,4 +207,8 @@ def create_backend(params: RendezvousParameters) -> EtcdRendezvousBackend:
"""
client = _create_etcd_client(params)
return EtcdRendezvousBackend(client, params.run_id, key_prefix="/torch/elastic/rendezvous")
backend = EtcdRendezvousBackend(client, params.run_id, key_prefix="/torch/elastic/rendezvous")
store = EtcdStore(client, "/torch/elastic/store")
return backend, store

View File

@ -23,11 +23,8 @@ def _create_etcd_handler(params: RendezvousParameters) -> RendezvousHandler:
def _create_etcd_v2_handler(params: RendezvousParameters) -> RendezvousHandler:
from .etcd_rendezvous_backend import create_backend
from .etcd_store import EtcdStore
backend = create_backend(params)
store = EtcdStore(backend.client, "/torch/elastic/store")
backend, store = create_backend(params)
return create_handler(store, backend, params)
@ -35,9 +32,9 @@ def _create_etcd_v2_handler(params: RendezvousParameters) -> RendezvousHandler:
def _create_c10d_handler(params: RendezvousParameters) -> RendezvousHandler:
from .c10d_rendezvous_backend import create_backend
backend = create_backend(params)
backend, store = create_backend(params)
return create_handler(backend.store, backend, params)
return create_handler(store, backend, params)
def _register_default_handlers() -> None: