c10d/TCPStore: better logs on remote shutdown (#153586)

This makes it more obvious what's going on when TCPStore shuts down while waiting on a remote key and also shows the remote address.

Test plan:

```
[W514 18:33:36.536327028 TCPStore.cpp:138] [c10d] recvValueWithTimeout failed on SocketImpl(fd=3, addr=[localhost]:34658, remote=[localhost]:1234): Failed to recv, got 0 bytes. Connection was likely closed. Did the remote server shutdown or crash?
```

```py
import os
rank = int(os.environ["RANK"])

import time
from torch import distributed as dist

store = dist.TCPStore(
    host_name="localhost",
    port=1234,
    is_master=(rank == 0),
    wait_for_workers=False,
)

time.sleep(1)

print("starting")

if rank != 0:
    store.get("foo")
else:
    time.sleep(1)

print("done")
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153586
Approved by: https://github.com/XilunWu
This commit is contained in:
Tristan Rice
2025-05-15 20:02:48 +00:00
committed by PyTorch MergeBot
parent 064f4c18f9
commit f7ecc091a0
2 changed files with 20 additions and 4 deletions

View File

@ -128,9 +128,17 @@ class TCPClient {
}
template <typename T>
std::optional<T> receiveValueWithTimeout(std::chrono::milliseconds timeout) {
if (!socket_.waitForInput(timeout))
if (!socket_.waitForInput(timeout)) {
return {};
return tcputil::recvValue<T>(socket_.handle());
}
try {
return tcputil::recvValue<T>(socket_.handle());
} catch (const std::exception& e) {
C10D_WARNING(
"recvValueWithTimeout failed on {}: {}", socket_.repr(), e.what());
throw;
}
}
void setTimeout(std::chrono::milliseconds value);

View File

@ -653,7 +653,11 @@ void sendBytes(
SYSCHECK_ERR_RETURN_NEG1(
bytesSent = ::send(socket, currentBytes, bytesToSend, flags))
if (bytesSent == 0) {
C10_THROW_ERROR(DistNetworkError, "failed to send, sent 0 bytes");
C10_THROW_ERROR(
DistNetworkError,
"Failed to send, sent 0 bytes. "
"Connection was likely closed. "
"Did the remote server shutdown or crash?");
}
bytesToSend -= bytesSent;
@ -675,7 +679,11 @@ void recvBytes(int socket, T* buffer, size_t length) {
SYSCHECK_ERR_RETURN_NEG1(
bytesReceived = recv(socket, currentBytes, bytesToReceive, 0))
if (bytesReceived == 0) {
C10_THROW_ERROR(DistNetworkError, "failed to recv, got 0 bytes");
C10_THROW_ERROR(
DistNetworkError,
"Failed to recv, got 0 bytes. "
"Connection was likely closed. "
"Did the remote server shutdown or crash?");
}
bytesToReceive -= bytesReceived;