diff --git a/test/distributed/test_store.py b/test/distributed/test_store.py index ea45e54c6b72..6897aa91fd97 100644 --- a/test/distributed/test_store.py +++ b/test/distributed/test_store.py @@ -312,6 +312,30 @@ class TCPStoreTest(TestCase, StoreTestBase): self.assertEqual(store1.libuvBackend, self._use_libuv) self.assertEqual(store2.libuvBackend, self._use_libuv) + def test_repr(self) -> None: + # server + store1 = self._create_store() + self.assertRegex( + repr(store1), + r"TCPStore\(" + r"client=TCPClient\(SocketImpl\(fd=\d+, addr=\[?localhost\]?:\d+, remote=\[?localhost\]?:\d+\)\), " + r"server=TCPServer\(port=\d+\)\)", + ) + + # client + store2 = dist.TCPStore( + store1.host, + store1.port, + world_size=2, + is_master=False, + ) + self.assertRegex( + repr(store2), + r"TCPStore\(" + r"client=TCPClient\(SocketImpl\(fd=\d+, addr=\[?localhost\]?:\d+, remote=\[?localhost\]?:\d+\)\), " + r"server=\)", + ) + @skip_if_win32() @retry_on_connect_failures def test_init_pg_and_rpc_with_same_socket(self): diff --git a/torch/csrc/distributed/c10d/TCPStore.cpp b/torch/csrc/distributed/c10d/TCPStore.cpp index 3e657f48279f..4bd3b28ec1e2 100644 --- a/torch/csrc/distributed/c10d/TCPStore.cpp +++ b/torch/csrc/distributed/c10d/TCPStore.cpp @@ -94,6 +94,10 @@ class TCPServer { std::unique_ptr&& daemon) : port_{port}, daemon_{std::move(daemon)} {} + std::string repr() const { + return fmt::format("TCPServer(port={})", port_); + } + private: std::uint16_t port_; std::unique_ptr daemon_; @@ -157,9 +161,9 @@ class TCPClient { const TCPStoreOptions& opts, std::shared_ptr backoff); - void sendRaw(uint8_t* data, size_t lenght) { + void sendRaw(uint8_t* data, size_t length) { try { - tcputil::sendBytes(socket_.handle(), data, lenght); + tcputil::sendBytes(socket_.handle(), data, length); } catch (const std::exception& e) { C10D_WARNING("sendBytes failed on {}: {}", socket_.repr(), e.what()); throw; @@ -195,6 +199,10 @@ class TCPClient { explicit TCPClient(Socket&& socket) : socket_{std::move(socket)} {} + std::string repr() const { + return fmt::format("TCPClient({})", socket_.repr()); + } + private: Socket socket_; }; @@ -709,4 +717,10 @@ TCPStore::collectClientCounters() const noexcept { return res; } +std::string TCPStore::repr() const { + auto clientRepr = client_ ? client_->repr() : ""; + auto serverRepr = server_ ? server_->repr() : ""; + return fmt::format("TCPStore(client={}, server={})", clientRepr, serverRepr); +} + } // namespace c10d diff --git a/torch/csrc/distributed/c10d/TCPStore.hpp b/torch/csrc/distributed/c10d/TCPStore.hpp index 1bd6e7a0b5e6..015d134e983f 100644 --- a/torch/csrc/distributed/c10d/TCPStore.hpp +++ b/torch/csrc/distributed/c10d/TCPStore.hpp @@ -140,6 +140,8 @@ class TORCH_API TCPStore : public Store { // note(xilunwu): this function is only for internal testing void _splitSet(const std::string& key, const std::vector& data); + std::string repr() const; + private: int64_t incrementValueBy(const std::string& key, int64_t delta); diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 0fc1c6f7d7e4..bcb1adf199cf 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1552,7 +1552,12 @@ Example:: .def_property_readonly( "libuvBackend", &::c10d::TCPStore::isLibUvBackend, - R"(Returns True if it's using the libuv backend.)"); + R"(Returns True if it's using the libuv backend.)") + .def( + "__repr__", + &::c10d::TCPStore::repr, + R"(Returns a string representation of the TCPStore.)", + py::call_guard()); intrusive_ptr_class_<::c10d::PrefixStore>( module, diff --git a/torch/csrc/distributed/c10d/socket.cpp b/torch/csrc/distributed/c10d/socket.cpp index d3b9d1324eee..f155f8252842 100644 --- a/torch/csrc/distributed/c10d/socket.cpp +++ b/torch/csrc/distributed/c10d/socket.cpp @@ -141,10 +141,9 @@ class SocketImpl { static constexpr Handle invalid_socket = -1; #endif - explicit SocketImpl( - Handle hnd, - std::optional<::addrinfo> remote = std::nullopt) noexcept - : hnd_{hnd}, remote_(remote) {} + explicit SocketImpl(Handle hnd) noexcept : hnd_{hnd} {} + + explicit SocketImpl(Handle hnd, const ::addrinfo& remote); SocketImpl(const SocketImpl& other) = delete; @@ -182,7 +181,7 @@ class SocketImpl { return hnd_; } - const std::optional<::addrinfo>& remote() const noexcept { + const std::optional& remote() const noexcept { return remote_; } @@ -192,7 +191,7 @@ class SocketImpl { bool setSocketFlag(int level, int optname, bool value) noexcept; Handle hnd_; - const std::optional<::addrinfo> remote_; + const std::optional remote_; }; } // namespace c10d::detail @@ -278,7 +277,7 @@ struct formatter { addr.ai_addrlen = addr_len; auto remote = socket.remote(); - std::string remoteStr = remote ? fmt::format("{}", *remote) : "none"; + std::string remoteStr = remote ? *remote : "none"; return fmt::format_to( ctx.out(), @@ -293,6 +292,9 @@ struct formatter { namespace c10d::detail { +SocketImpl::SocketImpl(Handle hnd, const ::addrinfo& remote) + : hnd_{hnd}, remote_{fmt::format("{}", remote)} {} + SocketImpl::~SocketImpl() { #ifdef _WIN32 ::closesocket(hnd_);