TCPStore: improve connect and retry logic (#129261)

We've been facing issues where TCPStore can successfully connect but then fail in the validate() function due to resets from listen backlog queue overflow when combined with reset enabled as well as long init times.

This PR does a few things:
* Retry that connect and validate up to the specified timeout.
* Use exponential backoff for the retry logic with jitter instead of a fixed 1s sleep.
* Eliminate the `sleep(std::chrono::milliseconds(numWorkers))` on init which can add significant delays to startup. This is no longer necessary per @XilunWu https://github.com/pytorch/pytorch/pull/116141

Test plan:

```
python test/distributed/test_store.py -v
./build/bin/BackoffTest
```

Will do internal testing with some large scale jobs to ensure TCPStore works correctly.

At 4k scale: 4x improvement

```
tristanr@devvm4382 ~/pt_tests [SIGABRT]> time TORCH_SHOW_CPP_STACKTRACES=1 python tcpstore_large_test.py                                                                                                   (pytorch-3.10)
started 0
init 0
set 0
joined all

________________________________________________________
Executed in    1.98 secs    fish           external
   usr time    0.93 secs   91.00 micros    0.93 secs
   sys time    1.98 secs  954.00 micros    1.97 secs

tristanr@devvm4382 ~/pt_tests> conda activate torchdrive-3.10                                                                                                                                              (pytorch-3.10)
tristanr@devvm4382 ~/pt_tests> time TORCH_SHOW_CPP_STACKTRACES=1 python tcpstore_large_test.py                                                                                                          (torchdrive-3.10)
started 0
init 0
set 0
joined all

________________________________________________________
Executed in    8.20 secs    fish           external
   usr time    2.15 secs    0.00 micros    2.15 secs
   sys time    2.76 secs  843.00 micros    2.76 secs
```

```py
import time
import os
import threading
from multiprocessing import Pool

WORLD_SIZE = 10000

import torch.distributed as dist

def run(rank):
    should_log = rank % (WORLD_SIZE // 10) == 0
    if should_log:
        print(f"started {rank}")
    store = dist.TCPStore(
        host_name="devvm4382.nao0.facebook.com",
        port=29500,
        world_size=WORLD_SIZE,
        is_master=rank == 0,
        use_libuv=True,
    )
    if should_log:
        print(f"init {rank}")
    store.set(f"key{rank}", "1234")
    if should_log:
        print(f"set {rank}")
    del store

def noop(rank):
    pass

print("starting pool")
with Pool(WORLD_SIZE) as pool:
    pool.map(noop, range(WORLD_SIZE), 1)
    print("pool hot")
    start = time.time()
    pool.map(run, range(WORLD_SIZE), 1)
    print("run finished", time.time()-start)
```

```
tristanr@devvm4382 ~/pt_tests> python tcpstore_large_test.py                                                                                                                                (pytorch-3.10)
starting pool
pool hot
started 0
[W624 16:58:09.086081750 TCPStore.cpp:343] [c10d] Starting store with 10000 workers but somaxconn is 4096.This might cause instability during bootstrap, consider increasing it.
started 1000
init 1000
set 1000
started 2000
init 2000
set 2000
started 3000
init 3000
set 3000
started 4000
init 4000
set 4000
started 5000
init 5000
set 5000
started 6000
init 6000
set 6000
started 7000
init 7000
set 7000
started 8000
init 8000
set 8000
started 9000
init 9000
set 9000
init 0
set 0
run finished 0.705092191696167
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129261
Approved by: https://github.com/rsdcastro, https://github.com/wconstab, https://github.com/kurman, https://github.com/XilunWu, https://github.com/c-p-i-o
This commit is contained in:
Tristan Rice
2024-06-25 19:24:22 +00:00
committed by PyTorch MergeBot
parent 816e8a3f21
commit 0298560ca2
9 changed files with 280 additions and 32 deletions

View File

@ -488,6 +488,7 @@ libtorch_core_sources = sorted(
# These files are the only ones that are supported on Windows.
libtorch_distributed_base_sources = [
"torch/csrc/distributed/c10d/Backend.cpp",
"torch/csrc/distributed/c10d/Backoff.cpp",
"torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp",
"torch/csrc/distributed/c10d/FileStore.cpp",
"torch/csrc/distributed/c10d/Functional.cpp",

View File

@ -0,0 +1,68 @@
#include <c10/util/irange.h>
#include "StoreTestCommon.hpp"
#include <iostream>
#include <thread>
#include <torch/csrc/distributed/c10d/Backoff.hpp>
TEST(BackoffTest, exponentialBackoffDefaults) {
c10d::ExponentialBackoffWithJitter backoff;
EXPECT_EQ(backoff.initialInterval, std::chrono::milliseconds(500));
EXPECT_EQ(backoff.maxInterval, std::chrono::milliseconds(60000));
EXPECT_EQ(backoff.multiplier, 1.5);
EXPECT_EQ(backoff.randomizationFactor, 0.5);
}
TEST(BackoffTest, exponentialBackoff) {
c10d::ExponentialBackoffWithJitter backoff;
backoff.randomizationFactor = 0.0;
backoff.multiplier = 2.0;
backoff.maxInterval = std::chrono::milliseconds(5000);
EXPECT_EQ(backoff.nextBackoff(), std::chrono::milliseconds(500));
EXPECT_EQ(backoff.nextBackoff(), std::chrono::milliseconds(1000));
EXPECT_EQ(backoff.nextBackoff(), std::chrono::milliseconds(2000));
EXPECT_EQ(backoff.nextBackoff(), std::chrono::milliseconds(4000));
EXPECT_EQ(backoff.nextBackoff(), std::chrono::milliseconds(5000));
EXPECT_EQ(backoff.nextBackoff(), std::chrono::milliseconds(5000));
backoff.reset();
EXPECT_EQ(backoff.nextBackoff(), std::chrono::milliseconds(500));
EXPECT_EQ(backoff.nextBackoff(), std::chrono::milliseconds(1000));
}
TEST(BackoffTest, expontentialBackoffRandomization) {
c10d::ExponentialBackoffWithJitter backoff;
backoff.initialInterval = std::chrono::milliseconds(1000);
backoff.randomizationFactor = 0.5;
backoff.multiplier = 1.0;
backoff.maxInterval = std::chrono::milliseconds(5000);
for (int i = 0; i < 100; i++) {
auto backoffDur = backoff.nextBackoff();
EXPECT_GE(backoffDur, std::chrono::milliseconds(500));
EXPECT_LE(backoffDur, std::chrono::milliseconds(1500));
}
}
TEST(BackoffTest, fixedBackoff) {
c10d::FixedBackoff backoff{std::chrono::milliseconds(1000)};
EXPECT_EQ(backoff.nextBackoff(), std::chrono::milliseconds(1000));
EXPECT_EQ(backoff.nextBackoff(), std::chrono::milliseconds(1000));
backoff.reset();
EXPECT_EQ(backoff.nextBackoff(), std::chrono::milliseconds(1000));
}
TEST(BackoffTest, sleep) {
std::chrono::milliseconds sleepTime{10};
c10d::FixedBackoff backoff{sleepTime};
EXPECT_EQ(backoff.nextBackoff(), sleepTime);
auto start = std::chrono::high_resolution_clock::now();
backoff.sleepBackoff();
auto dur = std::chrono::high_resolution_clock::now() - start;
EXPECT_GE(dur, sleepTime);
}

View File

@ -16,6 +16,7 @@ function(c10d_add_test test_src)
add_test(NAME ${test_name} COMMAND $<TARGET_FILE:${test_name}>)
endfunction()
c10d_add_test(BackoffTest.cpp torch_cpu gtest_main)
c10d_add_test(FileStoreTest.cpp torch_cpu gtest_main)
c10d_add_test(TCPStoreTest.cpp torch_cpu gtest_main)
if(INSTALL_TEST)

View File

@ -0,0 +1,77 @@
#include <torch/csrc/distributed/c10d/Backoff.hpp>
#include <exception>
#include <stdexcept>
namespace c10d {
namespace {
constexpr std::chrono::milliseconds kZeroInterval{0};
int32_t randSeed() {
std::random_device rd;
return rd();
}
} // namespace
ExponentialBackoffWithJitter::ExponentialBackoffWithJitter()
: gen_(randSeed()) {}
std::chrono::milliseconds ExponentialBackoffWithJitter::nextBackoff() {
if (initialInterval == kZeroInterval) {
throw std::out_of_range(
"ExponentialBackoffWithJitter requires non-zero initial interval");
}
if (initialInterval > maxInterval) {
throw std::out_of_range(
"ExponentialBackoffWithJitter requires initialInterval <= maxInterval");
}
if (randomizationFactor >= 1 || randomizationFactor < 0) {
throw std::out_of_range(
"ExponentialBackoffWithJitter requires randomization factor (0,1]");
}
if (multiplier < 1.0) {
throw std::out_of_range(
"ExponentialBackoffWithJitter requires multiplier >=1");
}
// detect initial setup
if (currentInterval_ == kZeroInterval) {
currentInterval_ = initialInterval;
}
// sample current interval
std::chrono::milliseconds randomization{static_cast<int64_t>(
randomizationFactor * static_cast<double>(currentInterval_.count()))};
std::chrono::milliseconds minSampleInterval =
currentInterval_ - randomization;
std::chrono::milliseconds maxSampleInterval =
currentInterval_ + randomization;
std::uniform_int_distribution<> dist(
minSampleInterval.count(), maxSampleInterval.count());
std::chrono::milliseconds backoffInterval{dist(gen_)};
// update current interval
currentInterval_ = std::chrono::milliseconds(static_cast<int64_t>(
static_cast<double>(currentInterval_.count()) * multiplier));
if (currentInterval_ > maxInterval) {
currentInterval_ = maxInterval;
}
return backoffInterval;
}
void ExponentialBackoffWithJitter::reset() {
currentInterval_ = kZeroInterval;
}
FixedBackoff::FixedBackoff(std::chrono::milliseconds interval)
: interval_(interval) {}
std::chrono::milliseconds FixedBackoff::nextBackoff() {
return interval_;
}
void FixedBackoff::reset() {}
} // namespace c10d

View File

@ -0,0 +1,52 @@
#pragma once
#include <chrono>
#include <random>
#include <thread>
#include <c10/macros/Macros.h>
namespace c10d {
class TORCH_API Backoff {
public:
virtual ~Backoff() = default;
virtual std::chrono::milliseconds nextBackoff() = 0;
virtual void reset() = 0;
void sleepBackoff() {
std::this_thread::sleep_for(nextBackoff());
}
};
class TORCH_API ExponentialBackoffWithJitter : public Backoff {
public:
ExponentialBackoffWithJitter();
std::chrono::milliseconds nextBackoff() override;
void reset() override;
public:
std::chrono::milliseconds initialInterval{500};
double randomizationFactor{0.5};
double multiplier{1.5};
std::chrono::milliseconds maxInterval{60000};
private:
std::mt19937 gen_;
std::chrono::milliseconds currentInterval_{0};
};
class TORCH_API FixedBackoff : public Backoff {
public:
FixedBackoff(std::chrono::milliseconds interval);
std::chrono::milliseconds nextBackoff() override;
void reset() override;
private:
std::chrono::milliseconds interval_;
};
} // namespace c10d

View File

@ -1,5 +1,6 @@
#include <c10/util/irange.h>
#include <fmt/format.h>
#include <torch/csrc/distributed/c10d/Backoff.hpp>
#include <torch/csrc/distributed/c10d/TCPStore.hpp>
#include <torch/csrc/distributed/c10d/TCPStoreBackend.hpp>
#include <torch/csrc/distributed/c10d/logging.h>
@ -152,7 +153,8 @@ class TCPClient {
public:
static std::unique_ptr<TCPClient> connect(
const SocketAddress& addr,
const TCPStoreOptions& opts);
const TCPStoreOptions& opts,
std::shared_ptr<Backoff> backoff);
void sendRaw(uint8_t* data, size_t lenght) {
try {
@ -198,10 +200,14 @@ class TCPClient {
std::unique_ptr<TCPClient> TCPClient::connect(
const SocketAddress& addr,
const TCPStoreOptions& opts) {
const TCPStoreOptions& opts,
std::shared_ptr<Backoff> backoff) {
auto timeout = std::chrono::duration_cast<std::chrono::seconds>(opts.timeout);
Socket socket = Socket::connect(
addr.host, addr.port, SocketOptions{}.connect_timeout(timeout));
addr.host,
addr.port,
SocketOptions{}.connect_timeout(timeout).connect_backoff(
std::move(backoff)));
return std::make_unique<TCPClient>(std::move(socket));
}
@ -350,23 +356,51 @@ TCPStore::TCPStore(std::string host, const TCPStoreOptions& opts)
addr_.port = opts.port;
}
if (numWorkers_.has_value()) {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> distrib(1, *numWorkers_);
// TODO (xilunwu): this wait logic may be removed after fixing read_offset
// stagger connecting to the store when there are too many ranks to
// avoid causing a DDoS
std::this_thread::sleep_for(std::chrono::milliseconds(distrib(gen)));
}
// Try connecting several times -- if the server listen backlog is full it may
// fail on the first send in validate.
auto deadline = std::chrono::steady_clock::now() + opts.timeout;
auto backoff = std::make_shared<ExponentialBackoffWithJitter>();
client_ = detail::TCPClient::connect(addr_, opts);
auto retry = 0;
do {
try {
client_ = detail::TCPClient::connect(addr_, opts, backoff);
// TCP connection established
C10D_DEBUG("TCP client connected to host {}:{}", addr_.host, addr_.port);
// client's first query for validation
validate();
// success
break;
} catch (const c10::DistNetworkError& ex) {
if (deadline < std::chrono::steady_clock::now()) {
C10D_ERROR(
"TCP client failed to connect/validate to host {}:{} - timed out (try={}, timeout={}ms): {}",
addr_.host,
addr_.port,
retry,
opts.timeout.count(),
ex.what());
throw;
}
auto delayDuration = backoff->nextBackoff();
C10D_WARNING(
"TCP client failed to connect/validate to host {}:{} - retrying (try={}, timeout={}ms, delay={}ms): {}",
addr_.host,
addr_.port,
retry,
opts.timeout.count(),
delayDuration.count(),
ex.what());
std::this_thread::sleep_for(delayDuration);
retry += 1;
}
} while (true);
if (opts.waitWorkers) {
waitForWorkers();
}

View File

@ -68,6 +68,8 @@ struct TCPStoreOptions {
class TORCH_API TCPStore : public Store {
public:
static constexpr std::chrono::milliseconds kConnectRetryDelay{1000};
explicit TCPStore(std::string host, const TCPStoreOptions& opts = {});
[[deprecated("Use TCPStore(host, opts) instead.")]] explicit TCPStore(

View File

@ -97,12 +97,14 @@ inline void setSocketError(int val) noexcept {
#endif
// Suspends the current thread for the specified duration.
void delay(std::chrono::seconds d) {
void delay(std::chrono::milliseconds d) {
#ifdef _WIN32
std::this_thread::sleep_for(d);
#else
::timespec req{};
req.tv_sec = d.count();
auto ms = d.count();
req.tv_sec = ms / 1000;
req.tv_nsec = (ms % 1000) * 1000000;
// The C++ Standard does not specify whether `sleep_for()` should be signal-
// aware; therefore, we use the `nanosleep()` syscall.
@ -720,8 +722,6 @@ class SocketConnectOp {
using Duration = std::chrono::steady_clock::duration;
using TimePoint = std::chrono::time_point<std::chrono::steady_clock>;
static const std::chrono::seconds delay_duration_;
enum class ConnectResult : uint8_t { Success, Error, Retry };
public:
@ -759,8 +759,6 @@ class SocketConnectOp {
std::unique_ptr<SocketImpl> socket_{};
};
const std::chrono::seconds SocketConnectOp::delay_duration_{1};
SocketConnectOp::SocketConnectOp(
const std::string& host,
std::uint16_t port,
@ -816,8 +814,6 @@ bool SocketConnectOp::tryConnect(int family) {
deadline_ = Clock::now() + opts_->connect_timeout();
std::size_t retry_attempt = 1;
bool retry; // NOLINT(cppcoreguidelines-init-variables)
do {
retry = false;
@ -859,21 +855,24 @@ bool SocketConnectOp::tryConnect(int family) {
}
if (retry) {
if (Clock::now() < deadline_ - delay_duration_) {
auto connectBackoff = opts_->connect_backoff();
auto delayDuration = connectBackoff->nextBackoff();
if (Clock::now() < deadline_ - delayDuration) {
// Prevent our log output to be too noisy, warn only every 30 seconds.
if (retry_attempt == 30) {
static auto lastLog = std::chrono::steady_clock::now();
auto now = std::chrono::steady_clock::now();
if ((now - lastLog) >= std::chrono::seconds(30)) {
C10D_INFO(
"No socket on ({}, {}) is listening yet, will retry.",
host_,
port_);
retry_attempt = 0;
lastLog = now;
}
// Wait one second to avoid choking the server.
delay(delay_duration_);
retry_attempt++;
// Wait to avoid choking the server.
delay(delayDuration);
} else {
throwTimeoutError();
}

View File

@ -13,6 +13,7 @@
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <torch/csrc/distributed/c10d/Backoff.hpp>
#include <torch/csrc/distributed/c10d/exception.h>
namespace c10d {
@ -40,9 +41,22 @@ class SocketOptions {
return connect_timeout_;
}
// Sets the backoff policy to use for socket connect ops.
SocketOptions& connect_backoff(std::shared_ptr<Backoff> value) noexcept {
connect_backoff_ = std::move(value);
return *this;
}
const std::shared_ptr<Backoff>& connect_backoff() const noexcept {
return connect_backoff_;
}
private:
bool prefer_ipv6_ = true;
std::chrono::seconds connect_timeout_{30};
std::shared_ptr<Backoff> connect_backoff_{
std::make_shared<FixedBackoff>(std::chrono::milliseconds(1000))};
};
class SocketImpl;