TCPStore: fix remote address (#131773) (#131913)

Summary:
This fixes corrupt remote address logs caused by dangling pointers to addrinfo_storage inside of addrinfo.

This relands it since it got reverted due to a fmt::format issue internally.

Original Pull Request: https://github.com/pytorch/pytorch/pull/131773
Approved by: https://github.com/kurman

Test Plan:
Enable debug logs and verify addresses are correct

```
TORCH_CPP_LOG_LEVEL=INFO TORCH_DISABLE_SHARE_RDZV_TCP_STORE=1 TORCH_DISTRIBUTED_DEBUG=DETAIL LOGLEVEL=INFO python test/distributed/test_store.py -v
buck2 test @//mode/dev-nosan //caffe2/test/distributed:store
```

Differential Revision: D60296583

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131913
Approved by: https://github.com/kurman, https://github.com/rsdcastro, https://github.com/Skylion007
This commit is contained in:
Tristan Rice
2024-07-30 17:27:33 +00:00
committed by PyTorch MergeBot
parent 3864a2d834
commit 9027db1ab8
5 changed files with 57 additions and 10 deletions

View File

@ -312,6 +312,30 @@ class TCPStoreTest(TestCase, StoreTestBase):
self.assertEqual(store1.libuvBackend, self._use_libuv)
self.assertEqual(store2.libuvBackend, self._use_libuv)
def test_repr(self) -> None:
# server
store1 = self._create_store()
self.assertRegex(
repr(store1),
r"TCPStore\("
r"client=TCPClient\(SocketImpl\(fd=\d+, addr=\[?localhost\]?:\d+, remote=\[?localhost\]?:\d+\)\), "
r"server=TCPServer\(port=\d+\)\)",
)
# client
store2 = dist.TCPStore(
store1.host,
store1.port,
world_size=2,
is_master=False,
)
self.assertRegex(
repr(store2),
r"TCPStore\("
r"client=TCPClient\(SocketImpl\(fd=\d+, addr=\[?localhost\]?:\d+, remote=\[?localhost\]?:\d+\)\), "
r"server=<nullptr>\)",
)
@skip_if_win32()
@retry_on_connect_failures
def test_init_pg_and_rpc_with_same_socket(self):

View File

@ -94,6 +94,10 @@ class TCPServer {
std::unique_ptr<BackgroundThread>&& daemon)
: port_{port}, daemon_{std::move(daemon)} {}
std::string repr() const {
return fmt::format("TCPServer(port={})", port_);
}
private:
std::uint16_t port_;
std::unique_ptr<BackgroundThread> daemon_;
@ -157,9 +161,9 @@ class TCPClient {
const TCPStoreOptions& opts,
std::shared_ptr<Backoff> backoff);
void sendRaw(uint8_t* data, size_t lenght) {
void sendRaw(uint8_t* data, size_t length) {
try {
tcputil::sendBytes(socket_.handle(), data, lenght);
tcputil::sendBytes(socket_.handle(), data, length);
} catch (const std::exception& e) {
C10D_WARNING("sendBytes failed on {}: {}", socket_.repr(), e.what());
throw;
@ -195,6 +199,10 @@ class TCPClient {
explicit TCPClient(Socket&& socket) : socket_{std::move(socket)} {}
std::string repr() const {
return fmt::format("TCPClient({})", socket_.repr());
}
private:
Socket socket_;
};
@ -709,4 +717,10 @@ TCPStore::collectClientCounters() const noexcept {
return res;
}
std::string TCPStore::repr() const {
auto clientRepr = client_ ? client_->repr() : "<nullptr>";
auto serverRepr = server_ ? server_->repr() : "<nullptr>";
return fmt::format("TCPStore(client={}, server={})", clientRepr, serverRepr);
}
} // namespace c10d

View File

@ -140,6 +140,8 @@ class TORCH_API TCPStore : public Store {
// note(xilunwu): this function is only for internal testing
void _splitSet(const std::string& key, const std::vector<uint8_t>& data);
std::string repr() const;
private:
int64_t incrementValueBy(const std::string& key, int64_t delta);

View File

@ -1552,7 +1552,12 @@ Example::
.def_property_readonly(
"libuvBackend",
&::c10d::TCPStore::isLibUvBackend,
R"(Returns True if it's using the libuv backend.)");
R"(Returns True if it's using the libuv backend.)")
.def(
"__repr__",
&::c10d::TCPStore::repr,
R"(Returns a string representation of the TCPStore.)",
py::call_guard<py::gil_scoped_release>());
intrusive_ptr_class_<::c10d::PrefixStore>(
module,

View File

@ -141,10 +141,9 @@ class SocketImpl {
static constexpr Handle invalid_socket = -1;
#endif
explicit SocketImpl(
Handle hnd,
std::optional<::addrinfo> remote = std::nullopt) noexcept
: hnd_{hnd}, remote_(remote) {}
explicit SocketImpl(Handle hnd) noexcept : hnd_{hnd} {}
explicit SocketImpl(Handle hnd, const ::addrinfo& remote);
SocketImpl(const SocketImpl& other) = delete;
@ -182,7 +181,7 @@ class SocketImpl {
return hnd_;
}
const std::optional<::addrinfo>& remote() const noexcept {
const std::optional<std::string>& remote() const noexcept {
return remote_;
}
@ -192,7 +191,7 @@ class SocketImpl {
bool setSocketFlag(int level, int optname, bool value) noexcept;
Handle hnd_;
const std::optional<::addrinfo> remote_;
const std::optional<std::string> remote_;
};
} // namespace c10d::detail
@ -278,7 +277,7 @@ struct formatter<c10d::detail::SocketImpl> {
addr.ai_addrlen = addr_len;
auto remote = socket.remote();
std::string remoteStr = remote ? fmt::format("{}", *remote) : "none";
std::string remoteStr = remote ? *remote : "none";
return fmt::format_to(
ctx.out(),
@ -293,6 +292,9 @@ struct formatter<c10d::detail::SocketImpl> {
namespace c10d::detail {
SocketImpl::SocketImpl(Handle hnd, const ::addrinfo& remote)
: hnd_{hnd}, remote_{fmt::format("{}", remote)} {}
SocketImpl::~SocketImpl() {
#ifdef _WIN32
::closesocket(hnd_);