mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
TCPStore: improve connect and retry logic (#129261)
We've been facing issues where TCPStore can successfully connect but then fail in the validate() function due to resets from listen backlog queue overflow when combined with reset enabled as well as long init times. This PR does a few things: * Retry that connect and validate up to the specified timeout. * Use exponential backoff for the retry logic with jitter instead of a fixed 1s sleep. * Eliminate the `sleep(std::chrono::milliseconds(numWorkers))` on init which can add significant delays to startup. This is no longer necessary per @XilunWu https://github.com/pytorch/pytorch/pull/116141 Test plan: ``` python test/distributed/test_store.py -v ./build/bin/BackoffTest ``` Will do internal testing with some large scale jobs to ensure TCPStore works correctly. At 4k scale: 4x improvement ``` tristanr@devvm4382 ~/pt_tests [SIGABRT]> time TORCH_SHOW_CPP_STACKTRACES=1 python tcpstore_large_test.py (pytorch-3.10) started 0 init 0 set 0 joined all ________________________________________________________ Executed in 1.98 secs fish external usr time 0.93 secs 91.00 micros 0.93 secs sys time 1.98 secs 954.00 micros 1.97 secs tristanr@devvm4382 ~/pt_tests> conda activate torchdrive-3.10 (pytorch-3.10) tristanr@devvm4382 ~/pt_tests> time TORCH_SHOW_CPP_STACKTRACES=1 python tcpstore_large_test.py (torchdrive-3.10) started 0 init 0 set 0 joined all ________________________________________________________ Executed in 8.20 secs fish external usr time 2.15 secs 0.00 micros 2.15 secs sys time 2.76 secs 843.00 micros 2.76 secs ``` ```py import time import os import threading from multiprocessing import Pool WORLD_SIZE = 10000 import torch.distributed as dist def run(rank): should_log = rank % (WORLD_SIZE // 10) == 0 if should_log: print(f"started {rank}") store = dist.TCPStore( host_name="devvm4382.nao0.facebook.com", port=29500, world_size=WORLD_SIZE, is_master=rank == 0, use_libuv=True, ) if should_log: print(f"init {rank}") store.set(f"key{rank}", "1234") if should_log: print(f"set {rank}") del store def noop(rank): pass print("starting pool") with Pool(WORLD_SIZE) as pool: pool.map(noop, range(WORLD_SIZE), 1) print("pool hot") start = time.time() pool.map(run, range(WORLD_SIZE), 1) print("run finished", time.time()-start) ``` ``` tristanr@devvm4382 ~/pt_tests> python tcpstore_large_test.py (pytorch-3.10) starting pool pool hot started 0 [W624 16:58:09.086081750 TCPStore.cpp:343] [c10d] Starting store with 10000 workers but somaxconn is 4096.This might cause instability during bootstrap, consider increasing it. started 1000 init 1000 set 1000 started 2000 init 2000 set 2000 started 3000 init 3000 set 3000 started 4000 init 4000 set 4000 started 5000 init 5000 set 5000 started 6000 init 6000 set 6000 started 7000 init 7000 set 7000 started 8000 init 8000 set 8000 started 9000 init 9000 set 9000 init 0 set 0 run finished 0.705092191696167 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129261 Approved by: https://github.com/rsdcastro, https://github.com/wconstab, https://github.com/kurman, https://github.com/XilunWu, https://github.com/c-p-i-o
This commit is contained in:
committed by
PyTorch MergeBot
parent
816e8a3f21
commit
0298560ca2
@ -488,6 +488,7 @@ libtorch_core_sources = sorted(
|
||||
# These files are the only ones that are supported on Windows.
|
||||
libtorch_distributed_base_sources = [
|
||||
"torch/csrc/distributed/c10d/Backend.cpp",
|
||||
"torch/csrc/distributed/c10d/Backoff.cpp",
|
||||
"torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp",
|
||||
"torch/csrc/distributed/c10d/FileStore.cpp",
|
||||
"torch/csrc/distributed/c10d/Functional.cpp",
|
||||
|
68
test/cpp/c10d/BackoffTest.cpp
Normal file
68
test/cpp/c10d/BackoffTest.cpp
Normal file
@ -0,0 +1,68 @@
|
||||
#include <c10/util/irange.h>
|
||||
#include "StoreTestCommon.hpp"
|
||||
|
||||
#include <iostream>
|
||||
#include <thread>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/Backoff.hpp>
|
||||
|
||||
TEST(BackoffTest, exponentialBackoffDefaults) {
|
||||
c10d::ExponentialBackoffWithJitter backoff;
|
||||
EXPECT_EQ(backoff.initialInterval, std::chrono::milliseconds(500));
|
||||
EXPECT_EQ(backoff.maxInterval, std::chrono::milliseconds(60000));
|
||||
EXPECT_EQ(backoff.multiplier, 1.5);
|
||||
EXPECT_EQ(backoff.randomizationFactor, 0.5);
|
||||
}
|
||||
|
||||
TEST(BackoffTest, exponentialBackoff) {
|
||||
c10d::ExponentialBackoffWithJitter backoff;
|
||||
backoff.randomizationFactor = 0.0;
|
||||
backoff.multiplier = 2.0;
|
||||
backoff.maxInterval = std::chrono::milliseconds(5000);
|
||||
|
||||
EXPECT_EQ(backoff.nextBackoff(), std::chrono::milliseconds(500));
|
||||
EXPECT_EQ(backoff.nextBackoff(), std::chrono::milliseconds(1000));
|
||||
EXPECT_EQ(backoff.nextBackoff(), std::chrono::milliseconds(2000));
|
||||
EXPECT_EQ(backoff.nextBackoff(), std::chrono::milliseconds(4000));
|
||||
EXPECT_EQ(backoff.nextBackoff(), std::chrono::milliseconds(5000));
|
||||
EXPECT_EQ(backoff.nextBackoff(), std::chrono::milliseconds(5000));
|
||||
|
||||
backoff.reset();
|
||||
EXPECT_EQ(backoff.nextBackoff(), std::chrono::milliseconds(500));
|
||||
EXPECT_EQ(backoff.nextBackoff(), std::chrono::milliseconds(1000));
|
||||
}
|
||||
|
||||
TEST(BackoffTest, expontentialBackoffRandomization) {
|
||||
c10d::ExponentialBackoffWithJitter backoff;
|
||||
backoff.initialInterval = std::chrono::milliseconds(1000);
|
||||
backoff.randomizationFactor = 0.5;
|
||||
backoff.multiplier = 1.0;
|
||||
backoff.maxInterval = std::chrono::milliseconds(5000);
|
||||
|
||||
for (int i = 0; i < 100; i++) {
|
||||
auto backoffDur = backoff.nextBackoff();
|
||||
EXPECT_GE(backoffDur, std::chrono::milliseconds(500));
|
||||
EXPECT_LE(backoffDur, std::chrono::milliseconds(1500));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BackoffTest, fixedBackoff) {
|
||||
c10d::FixedBackoff backoff{std::chrono::milliseconds(1000)};
|
||||
|
||||
EXPECT_EQ(backoff.nextBackoff(), std::chrono::milliseconds(1000));
|
||||
EXPECT_EQ(backoff.nextBackoff(), std::chrono::milliseconds(1000));
|
||||
backoff.reset();
|
||||
EXPECT_EQ(backoff.nextBackoff(), std::chrono::milliseconds(1000));
|
||||
}
|
||||
|
||||
TEST(BackoffTest, sleep) {
|
||||
std::chrono::milliseconds sleepTime{10};
|
||||
c10d::FixedBackoff backoff{sleepTime};
|
||||
|
||||
EXPECT_EQ(backoff.nextBackoff(), sleepTime);
|
||||
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
backoff.sleepBackoff();
|
||||
auto dur = std::chrono::high_resolution_clock::now() - start;
|
||||
EXPECT_GE(dur, sleepTime);
|
||||
}
|
@ -16,6 +16,7 @@ function(c10d_add_test test_src)
|
||||
add_test(NAME ${test_name} COMMAND $<TARGET_FILE:${test_name}>)
|
||||
endfunction()
|
||||
|
||||
c10d_add_test(BackoffTest.cpp torch_cpu gtest_main)
|
||||
c10d_add_test(FileStoreTest.cpp torch_cpu gtest_main)
|
||||
c10d_add_test(TCPStoreTest.cpp torch_cpu gtest_main)
|
||||
if(INSTALL_TEST)
|
||||
|
77
torch/csrc/distributed/c10d/Backoff.cpp
Normal file
77
torch/csrc/distributed/c10d/Backoff.cpp
Normal file
@ -0,0 +1,77 @@
|
||||
#include <torch/csrc/distributed/c10d/Backoff.hpp>
|
||||
|
||||
#include <exception>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace c10d {
|
||||
namespace {
|
||||
constexpr std::chrono::milliseconds kZeroInterval{0};
|
||||
|
||||
int32_t randSeed() {
|
||||
std::random_device rd;
|
||||
return rd();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
ExponentialBackoffWithJitter::ExponentialBackoffWithJitter()
|
||||
: gen_(randSeed()) {}
|
||||
|
||||
std::chrono::milliseconds ExponentialBackoffWithJitter::nextBackoff() {
|
||||
if (initialInterval == kZeroInterval) {
|
||||
throw std::out_of_range(
|
||||
"ExponentialBackoffWithJitter requires non-zero initial interval");
|
||||
}
|
||||
if (initialInterval > maxInterval) {
|
||||
throw std::out_of_range(
|
||||
"ExponentialBackoffWithJitter requires initialInterval <= maxInterval");
|
||||
}
|
||||
if (randomizationFactor >= 1 || randomizationFactor < 0) {
|
||||
throw std::out_of_range(
|
||||
"ExponentialBackoffWithJitter requires randomization factor (0,1]");
|
||||
}
|
||||
if (multiplier < 1.0) {
|
||||
throw std::out_of_range(
|
||||
"ExponentialBackoffWithJitter requires multiplier >=1");
|
||||
}
|
||||
|
||||
// detect initial setup
|
||||
if (currentInterval_ == kZeroInterval) {
|
||||
currentInterval_ = initialInterval;
|
||||
}
|
||||
|
||||
// sample current interval
|
||||
std::chrono::milliseconds randomization{static_cast<int64_t>(
|
||||
randomizationFactor * static_cast<double>(currentInterval_.count()))};
|
||||
std::chrono::milliseconds minSampleInterval =
|
||||
currentInterval_ - randomization;
|
||||
std::chrono::milliseconds maxSampleInterval =
|
||||
currentInterval_ + randomization;
|
||||
|
||||
std::uniform_int_distribution<> dist(
|
||||
minSampleInterval.count(), maxSampleInterval.count());
|
||||
std::chrono::milliseconds backoffInterval{dist(gen_)};
|
||||
|
||||
// update current interval
|
||||
currentInterval_ = std::chrono::milliseconds(static_cast<int64_t>(
|
||||
static_cast<double>(currentInterval_.count()) * multiplier));
|
||||
|
||||
if (currentInterval_ > maxInterval) {
|
||||
currentInterval_ = maxInterval;
|
||||
}
|
||||
|
||||
return backoffInterval;
|
||||
}
|
||||
|
||||
void ExponentialBackoffWithJitter::reset() {
|
||||
currentInterval_ = kZeroInterval;
|
||||
}
|
||||
|
||||
FixedBackoff::FixedBackoff(std::chrono::milliseconds interval)
|
||||
: interval_(interval) {}
|
||||
|
||||
std::chrono::milliseconds FixedBackoff::nextBackoff() {
|
||||
return interval_;
|
||||
}
|
||||
|
||||
void FixedBackoff::reset() {}
|
||||
} // namespace c10d
|
52
torch/csrc/distributed/c10d/Backoff.hpp
Normal file
52
torch/csrc/distributed/c10d/Backoff.hpp
Normal file
@ -0,0 +1,52 @@
|
||||
#pragma once
|
||||
|
||||
#include <chrono>
|
||||
#include <random>
|
||||
#include <thread>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
namespace c10d {
|
||||
|
||||
class TORCH_API Backoff {
|
||||
public:
|
||||
virtual ~Backoff() = default;
|
||||
|
||||
virtual std::chrono::milliseconds nextBackoff() = 0;
|
||||
virtual void reset() = 0;
|
||||
|
||||
void sleepBackoff() {
|
||||
std::this_thread::sleep_for(nextBackoff());
|
||||
}
|
||||
};
|
||||
|
||||
class TORCH_API ExponentialBackoffWithJitter : public Backoff {
|
||||
public:
|
||||
ExponentialBackoffWithJitter();
|
||||
|
||||
std::chrono::milliseconds nextBackoff() override;
|
||||
void reset() override;
|
||||
|
||||
public:
|
||||
std::chrono::milliseconds initialInterval{500};
|
||||
double randomizationFactor{0.5};
|
||||
double multiplier{1.5};
|
||||
std::chrono::milliseconds maxInterval{60000};
|
||||
|
||||
private:
|
||||
std::mt19937 gen_;
|
||||
std::chrono::milliseconds currentInterval_{0};
|
||||
};
|
||||
|
||||
class TORCH_API FixedBackoff : public Backoff {
|
||||
public:
|
||||
FixedBackoff(std::chrono::milliseconds interval);
|
||||
|
||||
std::chrono::milliseconds nextBackoff() override;
|
||||
void reset() override;
|
||||
|
||||
private:
|
||||
std::chrono::milliseconds interval_;
|
||||
};
|
||||
|
||||
} // namespace c10d
|
@ -1,5 +1,6 @@
|
||||
#include <c10/util/irange.h>
|
||||
#include <fmt/format.h>
|
||||
#include <torch/csrc/distributed/c10d/Backoff.hpp>
|
||||
#include <torch/csrc/distributed/c10d/TCPStore.hpp>
|
||||
#include <torch/csrc/distributed/c10d/TCPStoreBackend.hpp>
|
||||
#include <torch/csrc/distributed/c10d/logging.h>
|
||||
@ -152,7 +153,8 @@ class TCPClient {
|
||||
public:
|
||||
static std::unique_ptr<TCPClient> connect(
|
||||
const SocketAddress& addr,
|
||||
const TCPStoreOptions& opts);
|
||||
const TCPStoreOptions& opts,
|
||||
std::shared_ptr<Backoff> backoff);
|
||||
|
||||
void sendRaw(uint8_t* data, size_t lenght) {
|
||||
try {
|
||||
@ -198,10 +200,14 @@ class TCPClient {
|
||||
|
||||
std::unique_ptr<TCPClient> TCPClient::connect(
|
||||
const SocketAddress& addr,
|
||||
const TCPStoreOptions& opts) {
|
||||
const TCPStoreOptions& opts,
|
||||
std::shared_ptr<Backoff> backoff) {
|
||||
auto timeout = std::chrono::duration_cast<std::chrono::seconds>(opts.timeout);
|
||||
Socket socket = Socket::connect(
|
||||
addr.host, addr.port, SocketOptions{}.connect_timeout(timeout));
|
||||
addr.host,
|
||||
addr.port,
|
||||
SocketOptions{}.connect_timeout(timeout).connect_backoff(
|
||||
std::move(backoff)));
|
||||
|
||||
return std::make_unique<TCPClient>(std::move(socket));
|
||||
}
|
||||
@ -350,23 +356,51 @@ TCPStore::TCPStore(std::string host, const TCPStoreOptions& opts)
|
||||
addr_.port = opts.port;
|
||||
}
|
||||
|
||||
if (numWorkers_.has_value()) {
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_int_distribution<> distrib(1, *numWorkers_);
|
||||
// TODO (xilunwu): this wait logic may be removed after fixing read_offset
|
||||
// stagger connecting to the store when there are too many ranks to
|
||||
// avoid causing a DDoS
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(distrib(gen)));
|
||||
}
|
||||
// Try connecting several times -- if the server listen backlog is full it may
|
||||
// fail on the first send in validate.
|
||||
auto deadline = std::chrono::steady_clock::now() + opts.timeout;
|
||||
auto backoff = std::make_shared<ExponentialBackoffWithJitter>();
|
||||
|
||||
client_ = detail::TCPClient::connect(addr_, opts);
|
||||
auto retry = 0;
|
||||
do {
|
||||
try {
|
||||
client_ = detail::TCPClient::connect(addr_, opts, backoff);
|
||||
// TCP connection established
|
||||
C10D_DEBUG("TCP client connected to host {}:{}", addr_.host, addr_.port);
|
||||
|
||||
// client's first query for validation
|
||||
validate();
|
||||
|
||||
// success
|
||||
break;
|
||||
} catch (const c10::DistNetworkError& ex) {
|
||||
if (deadline < std::chrono::steady_clock::now()) {
|
||||
C10D_ERROR(
|
||||
"TCP client failed to connect/validate to host {}:{} - timed out (try={}, timeout={}ms): {}",
|
||||
addr_.host,
|
||||
addr_.port,
|
||||
retry,
|
||||
opts.timeout.count(),
|
||||
ex.what());
|
||||
throw;
|
||||
}
|
||||
|
||||
auto delayDuration = backoff->nextBackoff();
|
||||
|
||||
C10D_WARNING(
|
||||
"TCP client failed to connect/validate to host {}:{} - retrying (try={}, timeout={}ms, delay={}ms): {}",
|
||||
addr_.host,
|
||||
addr_.port,
|
||||
retry,
|
||||
opts.timeout.count(),
|
||||
delayDuration.count(),
|
||||
ex.what());
|
||||
|
||||
std::this_thread::sleep_for(delayDuration);
|
||||
retry += 1;
|
||||
}
|
||||
} while (true);
|
||||
|
||||
if (opts.waitWorkers) {
|
||||
waitForWorkers();
|
||||
}
|
||||
|
@ -68,6 +68,8 @@ struct TCPStoreOptions {
|
||||
|
||||
class TORCH_API TCPStore : public Store {
|
||||
public:
|
||||
static constexpr std::chrono::milliseconds kConnectRetryDelay{1000};
|
||||
|
||||
explicit TCPStore(std::string host, const TCPStoreOptions& opts = {});
|
||||
|
||||
[[deprecated("Use TCPStore(host, opts) instead.")]] explicit TCPStore(
|
||||
|
@ -97,12 +97,14 @@ inline void setSocketError(int val) noexcept {
|
||||
#endif
|
||||
|
||||
// Suspends the current thread for the specified duration.
|
||||
void delay(std::chrono::seconds d) {
|
||||
void delay(std::chrono::milliseconds d) {
|
||||
#ifdef _WIN32
|
||||
std::this_thread::sleep_for(d);
|
||||
#else
|
||||
::timespec req{};
|
||||
req.tv_sec = d.count();
|
||||
auto ms = d.count();
|
||||
req.tv_sec = ms / 1000;
|
||||
req.tv_nsec = (ms % 1000) * 1000000;
|
||||
|
||||
// The C++ Standard does not specify whether `sleep_for()` should be signal-
|
||||
// aware; therefore, we use the `nanosleep()` syscall.
|
||||
@ -720,8 +722,6 @@ class SocketConnectOp {
|
||||
using Duration = std::chrono::steady_clock::duration;
|
||||
using TimePoint = std::chrono::time_point<std::chrono::steady_clock>;
|
||||
|
||||
static const std::chrono::seconds delay_duration_;
|
||||
|
||||
enum class ConnectResult : uint8_t { Success, Error, Retry };
|
||||
|
||||
public:
|
||||
@ -759,8 +759,6 @@ class SocketConnectOp {
|
||||
std::unique_ptr<SocketImpl> socket_{};
|
||||
};
|
||||
|
||||
const std::chrono::seconds SocketConnectOp::delay_duration_{1};
|
||||
|
||||
SocketConnectOp::SocketConnectOp(
|
||||
const std::string& host,
|
||||
std::uint16_t port,
|
||||
@ -816,8 +814,6 @@ bool SocketConnectOp::tryConnect(int family) {
|
||||
|
||||
deadline_ = Clock::now() + opts_->connect_timeout();
|
||||
|
||||
std::size_t retry_attempt = 1;
|
||||
|
||||
bool retry; // NOLINT(cppcoreguidelines-init-variables)
|
||||
do {
|
||||
retry = false;
|
||||
@ -859,21 +855,24 @@ bool SocketConnectOp::tryConnect(int family) {
|
||||
}
|
||||
|
||||
if (retry) {
|
||||
if (Clock::now() < deadline_ - delay_duration_) {
|
||||
auto connectBackoff = opts_->connect_backoff();
|
||||
auto delayDuration = connectBackoff->nextBackoff();
|
||||
|
||||
if (Clock::now() < deadline_ - delayDuration) {
|
||||
// Prevent our log output to be too noisy, warn only every 30 seconds.
|
||||
if (retry_attempt == 30) {
|
||||
static auto lastLog = std::chrono::steady_clock::now();
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
if ((now - lastLog) >= std::chrono::seconds(30)) {
|
||||
C10D_INFO(
|
||||
"No socket on ({}, {}) is listening yet, will retry.",
|
||||
host_,
|
||||
port_);
|
||||
|
||||
retry_attempt = 0;
|
||||
lastLog = now;
|
||||
}
|
||||
|
||||
// Wait one second to avoid choking the server.
|
||||
delay(delay_duration_);
|
||||
|
||||
retry_attempt++;
|
||||
// Wait to avoid choking the server.
|
||||
delay(delayDuration);
|
||||
} else {
|
||||
throwTimeoutError();
|
||||
}
|
||||
|
@ -13,6 +13,7 @@
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/distributed/c10d/Backoff.hpp>
|
||||
#include <torch/csrc/distributed/c10d/exception.h>
|
||||
|
||||
namespace c10d {
|
||||
@ -40,9 +41,22 @@ class SocketOptions {
|
||||
return connect_timeout_;
|
||||
}
|
||||
|
||||
// Sets the backoff policy to use for socket connect ops.
|
||||
SocketOptions& connect_backoff(std::shared_ptr<Backoff> value) noexcept {
|
||||
connect_backoff_ = std::move(value);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
const std::shared_ptr<Backoff>& connect_backoff() const noexcept {
|
||||
return connect_backoff_;
|
||||
}
|
||||
|
||||
private:
|
||||
bool prefer_ipv6_ = true;
|
||||
std::chrono::seconds connect_timeout_{30};
|
||||
std::shared_ptr<Backoff> connect_backoff_{
|
||||
std::make_shared<FixedBackoff>(std::chrono::milliseconds(1000))};
|
||||
};
|
||||
|
||||
class SocketImpl;
|
||||
|
Reference in New Issue
Block a user