Files
pytorch/torch/csrc/distributed/c10d/TCPStore.hpp
Xilun Wu 85758fa5ae [c10d][TCPStore] make TCPStore server use libuv by default (#127957)
**Summary**
This PR switches the default TCPStore server backend to a new implementation that utilizes [`libuv`](https://github.com/libuv/libuv) for significantly lower initialization time and better scalability:
<img width="714" alt="image" src="https://github.com/pytorch/pytorch/assets/12968408/18503011-da5d-4104-8ba9-abc456438b02">

We hope this improvement would benefit users from a much shorter startup time in large-scale jobs. Eventually, we hope to fully replace the old TCPStore backend implementation with the libuv one.

**What it changes**
This PR changes the underlying TCPStore server backend to `libuv` if users don't explicitly specify to use the old TCPStore server. This change is not supposed to cause any user notice except significant faster TCPStore startup for large-scale jobs.

One thing to note is, we do not support the initialization approach where user passes in a socket for libuv backend. We plan to support it as a next step but we choose to disable it before fully testing. If you are initializing TCPStore in this approach, you can see the next section to remain using the old TCPStore server.

**Fallback/Remain using the old TCPStore server**
For users who want to stay with the old TCPStore backend, there're 3 ways:

1. If user is directly instantiating TCPStore object, user can pass in argument `use_libuv=False` to use the old TCPStore server backend e.g. `store = torch.distributed.TCPStore(..., use_libuv=False)`.
2. Or, specify the TCPStore backend option in `init_method` when calling default ProcessGroup init, e.g. `torch.distributed.init_process_group(..., init_method="{YOUR_RENDEZVOUS_METHOD}://{YOUR_HOSTNAME}:{YOUR_PORT}?use_libuv=0")`
3. Or, user can set environment variable `USE_LIBUV` to `"0"` when launching.

These 3 approach are in order of precedence. That being said, if user specifies `use_libuv=0` in `init_method` and also sets environment var `USE_LIBUV="1"`, the former will take effect and the TCPStore backend instantiated will be the old one instead of the one using libuv.

**Operating Systems Compatibility**
From the CI signals, we believe the new implementation has the same behavior as the old TCPStore server on all supported platforms. If you notice any behavior discrepancy, please file an issue with `oncall: distributed` label.

**Test Plan**
`pytest test/distributed/test_store.py`
<img width="2548" alt="image" src="https://github.com/pytorch/pytorch/assets/12968408/dc0aebeb-6d5a-4daa-b98c-e56bd39aa588">
note: `TestMultiThreadedWait::test_wait` is a broken test that has been there for some time.

`test/distributed/elastic/utils/distributed_test.py`
<img width="2558" alt="image" src="https://github.com/pytorch/pytorch/assets/12968408/a6a3266d-b798-41c4-94d2-152056a034f6">

**TODO**
1. Update the doc at

- https://pytorch.org/docs/stable/distributed.html#distributed-key-value-store
- https://pytorch.org/docs/stable/distributed.html#tcp-initialization

2. Make torch elastic rendezvous to use libuv TCPStore as well. See `torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py` cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @kurman
3. Test if libuv backend is okay with initialization with socket. Change `LibUvTCPStoreTest::test_take_over_listen_socket`.

**Test Plan**
`pytest test/distributed/test_store.py`
<img width="2548" alt="image" src="https://github.com/pytorch/pytorch/assets/12968408/dc0aebeb-6d5a-4daa-b98c-e56bd39aa588">
note: `TestMultiThreadedWait::test_wait` is a broken test that has been there for some time.

`test/distributed/elastic/utils/distributed_test.py`
<img width="2558" alt="image" src="https://github.com/pytorch/pytorch/assets/12968408/a6a3266d-b798-41c4-94d2-152056a034f6">

Differential Revision: [D58259591](https://our.internmc.facebook.com/intern/diff/D58259591)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127957
Approved by: https://github.com/kurman
ghstack dependencies: #127956
2024-06-07 16:53:01 +00:00

165 lines
4.3 KiB
C++

#pragma once
#include <cstddef>
#include <cstdint>
#include <memory>
#include <torch/csrc/distributed/c10d/Store.hpp>
namespace c10d {
namespace detail {
class TCPServer;
class TCPClient;
struct SocketAddress {
std::string host{};
std::uint16_t port{};
};
class Counter {
public:
void update(double val);
std::unordered_map<std::string, double> observe() const;
double mean() const noexcept {
return mean_;
}
int64_t count() const noexcept {
return count_;
}
double variance() const noexcept {
return m2_ / static_cast<double>(count_);
}
double sample_variance() const noexcept {
return m2_ / static_cast<double>(count_ - 1);
}
private:
int64_t count_ = 0;
double mean_ = 0;
double m2_ = 0;
};
} // namespace detail
struct TCPStoreOptions {
static constexpr std::uint16_t kDefaultPort = 29500;
std::uint16_t port = kDefaultPort;
bool isServer = false;
std::optional<std::size_t> numWorkers = c10::nullopt;
bool waitWorkers = true;
std::chrono::milliseconds timeout = Store::kDefaultTimeout;
// A boolean value indicating whether multiple store instances can be
// initialized with the same host:port pair.
bool multiTenant = false;
// If specified, and if isServer is true, the underlying TCPServer will take
// over the bound socket associated to this fd. This option is useful to avoid
// port assignment races in certain scenarios.
std::optional<int> masterListenFd = c10::nullopt;
// A boolean value indicating whether to use the experimental libUV backend.
bool useLibUV = true;
};
class TORCH_API TCPStore : public Store {
public:
explicit TCPStore(std::string host, const TCPStoreOptions& opts = {});
[[deprecated("Use TCPStore(host, opts) instead.")]] explicit TCPStore(
const std::string& masterAddr,
std::uint16_t masterPort,
std::optional<int> numWorkers = c10::nullopt,
bool isServer = false,
const std::chrono::milliseconds& timeout = kDefaultTimeout,
bool waitWorkers = true);
~TCPStore() override;
void set(const std::string& key, const std::vector<uint8_t>& value) override;
std::vector<uint8_t> compareSet(
const std::string& key,
const std::vector<uint8_t>& expectedValue,
const std::vector<uint8_t>& desiredValue) override;
std::vector<uint8_t> get(const std::string& key) override;
int64_t add(const std::string& key, int64_t value) override;
bool deleteKey(const std::string& key) override;
bool check(const std::vector<std::string>& keys) override;
int64_t getNumKeys() override;
void wait(const std::vector<std::string>& keys) override;
void wait(
const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout) override;
void append(const std::string& key, const std::vector<uint8_t>& value)
override;
std::vector<std::vector<uint8_t>> multiGet(
const std::vector<std::string>& keys) override;
void multiSet(
const std::vector<std::string>& keys,
const std::vector<std::vector<uint8_t>>& values) override;
bool hasExtendedApi() const override;
// Waits for all workers to join.
void waitForWorkers();
// Returns the hostname used by the TCPStore.
const std::string& getHost() const noexcept {
return addr_.host;
}
// Returns the port used by the TCPStore.
std::uint16_t getPort() const noexcept {
return addr_.port;
}
std::unordered_map<std::string, std::unordered_map<std::string, double>>
collectClientCounters() const noexcept;
bool isLibUvBackend() const noexcept {
return usingLibUv_;
}
// note(xilunwu): this function is only for internal testing
void _splitSet(const std::string& key, const std::vector<uint8_t>& data);
private:
int64_t incrementValueBy(const std::string& key, int64_t delta);
void validate();
std::vector<uint8_t> doGet(const std::string& key);
void doWait(
c10::ArrayRef<std::string> keys,
std::chrono::milliseconds timeout);
detail::SocketAddress addr_;
std::shared_ptr<detail::TCPServer> server_;
std::unique_ptr<detail::TCPClient> client_;
std::optional<std::size_t> numWorkers_;
const std::string initKey_ = "init/";
const std::string keyPrefix_ = "/";
std::mutex activeOpLock_;
std::unordered_map<std::string, detail::Counter> clientCounters_;
bool usingLibUv_ = true;
};
} // namespace c10d