mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
We've been facing issues where TCPStore can successfully connect but then fail in the validate() function due to resets from listen backlog queue overflow when combined with reset enabled as well as long init times. This PR does a few things: * Retry that connect and validate up to the specified timeout. * Use exponential backoff for the retry logic with jitter instead of a fixed 1s sleep. * Eliminate the `sleep(std::chrono::milliseconds(numWorkers))` on init which can add significant delays to startup. This is no longer necessary per @XilunWu https://github.com/pytorch/pytorch/pull/116141 Test plan: ``` python test/distributed/test_store.py -v ./build/bin/BackoffTest ``` Will do internal testing with some large scale jobs to ensure TCPStore works correctly. At 4k scale: 4x improvement ``` tristanr@devvm4382 ~/pt_tests [SIGABRT]> time TORCH_SHOW_CPP_STACKTRACES=1 python tcpstore_large_test.py (pytorch-3.10) started 0 init 0 set 0 joined all ________________________________________________________ Executed in 1.98 secs fish external usr time 0.93 secs 91.00 micros 0.93 secs sys time 1.98 secs 954.00 micros 1.97 secs tristanr@devvm4382 ~/pt_tests> conda activate torchdrive-3.10 (pytorch-3.10) tristanr@devvm4382 ~/pt_tests> time TORCH_SHOW_CPP_STACKTRACES=1 python tcpstore_large_test.py (torchdrive-3.10) started 0 init 0 set 0 joined all ________________________________________________________ Executed in 8.20 secs fish external usr time 2.15 secs 0.00 micros 2.15 secs sys time 2.76 secs 843.00 micros 2.76 secs ``` ```py import time import os import threading from multiprocessing import Pool WORLD_SIZE = 10000 import torch.distributed as dist def run(rank): should_log = rank % (WORLD_SIZE // 10) == 0 if should_log: print(f"started {rank}") store = dist.TCPStore( host_name="devvm4382.nao0.facebook.com", port=29500, world_size=WORLD_SIZE, is_master=rank == 0, use_libuv=True, ) if should_log: print(f"init {rank}") store.set(f"key{rank}", "1234") if should_log: print(f"set {rank}") del store def noop(rank): pass print("starting pool") with Pool(WORLD_SIZE) as pool: pool.map(noop, range(WORLD_SIZE), 1) print("pool hot") start = time.time() pool.map(run, range(WORLD_SIZE), 1) print("run finished", time.time()-start) ``` ``` tristanr@devvm4382 ~/pt_tests> python tcpstore_large_test.py (pytorch-3.10) starting pool pool hot started 0 [W624 16:58:09.086081750 TCPStore.cpp:343] [c10d] Starting store with 10000 workers but somaxconn is 4096.This might cause instability during bootstrap, consider increasing it. started 1000 init 1000 set 1000 started 2000 init 2000 set 2000 started 3000 init 3000 set 3000 started 4000 init 4000 set 4000 started 5000 init 5000 set 5000 started 6000 init 6000 set 6000 started 7000 init 7000 set 7000 started 8000 init 8000 set 8000 started 9000 init 9000 set 9000 init 0 set 0 run finished 0.705092191696167 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129261 Approved by: https://github.com/rsdcastro, https://github.com/wconstab, https://github.com/kurman, https://github.com/XilunWu, https://github.com/c-p-i-o
167 lines
4.4 KiB
C++
167 lines
4.4 KiB
C++
#pragma once
|
|
|
|
#include <cstddef>
|
|
#include <cstdint>
|
|
#include <memory>
|
|
|
|
#include <torch/csrc/distributed/c10d/Store.hpp>
|
|
|
|
namespace c10d {
|
|
namespace detail {
|
|
|
|
class TCPServer;
|
|
|
|
class TCPClient;
|
|
|
|
struct SocketAddress {
|
|
std::string host{};
|
|
std::uint16_t port{};
|
|
};
|
|
|
|
class Counter {
|
|
public:
|
|
void update(double val);
|
|
std::unordered_map<std::string, double> observe() const;
|
|
|
|
double mean() const noexcept {
|
|
return mean_;
|
|
}
|
|
int64_t count() const noexcept {
|
|
return count_;
|
|
}
|
|
double variance() const noexcept {
|
|
return m2_ / static_cast<double>(count_);
|
|
}
|
|
double sample_variance() const noexcept {
|
|
return m2_ / static_cast<double>(count_ - 1);
|
|
}
|
|
|
|
private:
|
|
int64_t count_ = 0;
|
|
double mean_ = 0;
|
|
double m2_ = 0;
|
|
};
|
|
|
|
} // namespace detail
|
|
|
|
struct TCPStoreOptions {
|
|
static constexpr std::uint16_t kDefaultPort = 29500;
|
|
|
|
std::uint16_t port = kDefaultPort;
|
|
bool isServer = false;
|
|
std::optional<std::size_t> numWorkers = c10::nullopt;
|
|
bool waitWorkers = true;
|
|
std::chrono::milliseconds timeout = Store::kDefaultTimeout;
|
|
|
|
// A boolean value indicating whether multiple store instances can be
|
|
// initialized with the same host:port pair.
|
|
bool multiTenant = false;
|
|
|
|
// If specified, and if isServer is true, the underlying TCPServer will take
|
|
// over the bound socket associated to this fd. This option is useful to avoid
|
|
// port assignment races in certain scenarios.
|
|
std::optional<int> masterListenFd = c10::nullopt;
|
|
|
|
// A boolean value indicating whether to use the experimental libUV backend.
|
|
bool useLibUV = true;
|
|
};
|
|
|
|
class TORCH_API TCPStore : public Store {
|
|
public:
|
|
static constexpr std::chrono::milliseconds kConnectRetryDelay{1000};
|
|
|
|
explicit TCPStore(std::string host, const TCPStoreOptions& opts = {});
|
|
|
|
[[deprecated("Use TCPStore(host, opts) instead.")]] explicit TCPStore(
|
|
const std::string& masterAddr,
|
|
std::uint16_t masterPort,
|
|
std::optional<int> numWorkers = c10::nullopt,
|
|
bool isServer = false,
|
|
const std::chrono::milliseconds& timeout = kDefaultTimeout,
|
|
bool waitWorkers = true);
|
|
|
|
~TCPStore() override;
|
|
|
|
void set(const std::string& key, const std::vector<uint8_t>& value) override;
|
|
|
|
std::vector<uint8_t> compareSet(
|
|
const std::string& key,
|
|
const std::vector<uint8_t>& expectedValue,
|
|
const std::vector<uint8_t>& desiredValue) override;
|
|
|
|
std::vector<uint8_t> get(const std::string& key) override;
|
|
|
|
int64_t add(const std::string& key, int64_t value) override;
|
|
|
|
bool deleteKey(const std::string& key) override;
|
|
|
|
bool check(const std::vector<std::string>& keys) override;
|
|
|
|
int64_t getNumKeys() override;
|
|
|
|
void wait(const std::vector<std::string>& keys) override;
|
|
|
|
void wait(
|
|
const std::vector<std::string>& keys,
|
|
const std::chrono::milliseconds& timeout) override;
|
|
|
|
void append(const std::string& key, const std::vector<uint8_t>& value)
|
|
override;
|
|
|
|
std::vector<std::vector<uint8_t>> multiGet(
|
|
const std::vector<std::string>& keys) override;
|
|
|
|
void multiSet(
|
|
const std::vector<std::string>& keys,
|
|
const std::vector<std::vector<uint8_t>>& values) override;
|
|
|
|
bool hasExtendedApi() const override;
|
|
|
|
// Waits for all workers to join.
|
|
void waitForWorkers();
|
|
|
|
// Returns the hostname used by the TCPStore.
|
|
const std::string& getHost() const noexcept {
|
|
return addr_.host;
|
|
}
|
|
|
|
// Returns the port used by the TCPStore.
|
|
std::uint16_t getPort() const noexcept {
|
|
return addr_.port;
|
|
}
|
|
|
|
std::unordered_map<std::string, std::unordered_map<std::string, double>>
|
|
collectClientCounters() const noexcept;
|
|
|
|
bool isLibUvBackend() const noexcept {
|
|
return usingLibUv_;
|
|
}
|
|
|
|
// note(xilunwu): this function is only for internal testing
|
|
void _splitSet(const std::string& key, const std::vector<uint8_t>& data);
|
|
|
|
private:
|
|
int64_t incrementValueBy(const std::string& key, int64_t delta);
|
|
|
|
void validate();
|
|
|
|
std::vector<uint8_t> doGet(const std::string& key);
|
|
|
|
void doWait(
|
|
c10::ArrayRef<std::string> keys,
|
|
std::chrono::milliseconds timeout);
|
|
|
|
detail::SocketAddress addr_;
|
|
std::shared_ptr<detail::TCPServer> server_;
|
|
std::unique_ptr<detail::TCPClient> client_;
|
|
std::optional<std::size_t> numWorkers_;
|
|
|
|
const std::string initKey_ = "init/";
|
|
const std::string keyPrefix_ = "/";
|
|
std::mutex activeOpLock_;
|
|
std::unordered_map<std::string, detail::Counter> clientCounters_;
|
|
bool usingLibUv_ = true;
|
|
};
|
|
|
|
} // namespace c10d
|