mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit d9c3485146913324ab4b3e211d2a4517e138f4af. Reapplies #125978. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126695 Approved by: https://github.com/c-p-i-o
This commit is contained in:
committed by
PyTorch MergeBot
parent
d8f5627a88
commit
ac51920656
@ -772,7 +772,7 @@ cc_library(
|
||||
[
|
||||
"torch/*.h",
|
||||
"torch/csrc/**/*.h",
|
||||
"torch/csrc/distributed/c10d/*.hpp",
|
||||
"torch/csrc/distributed/c10d/**/*.hpp",
|
||||
"torch/lib/libshm/*.h",
|
||||
],
|
||||
exclude = [
|
||||
|
@ -487,6 +487,7 @@ libtorch_core_sources = sorted(
|
||||
# These files are the only ones that are supported on Windows.
|
||||
libtorch_distributed_base_sources = [
|
||||
"torch/csrc/distributed/c10d/Backend.cpp",
|
||||
"torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp",
|
||||
"torch/csrc/distributed/c10d/FileStore.cpp",
|
||||
"torch/csrc/distributed/c10d/Functional.cpp",
|
||||
"torch/csrc/distributed/c10d/GlooDeviceFactory.cpp",
|
||||
|
189
test/distributed/test_control_collectives.py
Normal file
189
test/distributed/test_control_collectives.py
Normal file
@ -0,0 +1,189 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
from datetime import timedelta
|
||||
from multiprocessing.pool import ThreadPool
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class TestCollectives(TestCase):
|
||||
def test_barrier(self) -> None:
|
||||
store = dist.HashStore()
|
||||
|
||||
world_size = 2
|
||||
|
||||
def f(rank: int) -> None:
|
||||
collectives = dist._StoreCollectives(store, rank, world_size)
|
||||
collectives.barrier("foo", timedelta(seconds=10), True)
|
||||
|
||||
with ThreadPool(world_size) as pool:
|
||||
pool.map(f, range(world_size))
|
||||
|
||||
def test_broadcast(self) -> None:
|
||||
store = dist.HashStore()
|
||||
|
||||
world_size = 4
|
||||
timeout = timedelta(seconds=10)
|
||||
|
||||
def f(rank: int) -> None:
|
||||
collectives = dist._StoreCollectives(store, rank, world_size)
|
||||
if rank == 2:
|
||||
collectives.broadcast_send("foo", b"data", timeout)
|
||||
else:
|
||||
out = collectives.broadcast_recv("foo", timeout)
|
||||
self.assertEqual(out, b"data")
|
||||
|
||||
with ThreadPool(world_size) as pool:
|
||||
pool.map(f, range(world_size))
|
||||
|
||||
def test_gather(self) -> None:
|
||||
store = dist.HashStore()
|
||||
|
||||
world_size = 4
|
||||
timeout = timedelta(seconds=10)
|
||||
|
||||
def f(rank: int) -> None:
|
||||
collectives = dist._StoreCollectives(store, rank, world_size)
|
||||
if rank == 2:
|
||||
out = collectives.gather_recv("foo", str(rank), timeout)
|
||||
self.assertEqual(out, [b"0", b"1", b"2", b"3"])
|
||||
else:
|
||||
collectives.gather_send("foo", str(rank), timeout)
|
||||
|
||||
with ThreadPool(world_size) as pool:
|
||||
pool.map(f, range(world_size))
|
||||
|
||||
def test_scatter(self) -> None:
|
||||
store = dist.HashStore()
|
||||
|
||||
world_size = 4
|
||||
timeout = timedelta(seconds=10)
|
||||
|
||||
def f(rank: int) -> None:
|
||||
collectives = dist._StoreCollectives(store, rank, world_size)
|
||||
if rank == 2:
|
||||
out = collectives.scatter_send(
|
||||
"foo", [str(i) for i in range(world_size)], timeout
|
||||
)
|
||||
else:
|
||||
out = collectives.scatter_recv("foo", timeout)
|
||||
self.assertEqual(out, str(rank).encode())
|
||||
|
||||
with ThreadPool(world_size) as pool:
|
||||
pool.map(f, range(world_size))
|
||||
|
||||
def test_all_sum(self) -> None:
|
||||
store = dist.HashStore()
|
||||
|
||||
world_size = 4
|
||||
timeout = timedelta(seconds=10)
|
||||
|
||||
def f(rank: int) -> None:
|
||||
collectives = dist._StoreCollectives(store, rank, world_size)
|
||||
out = collectives.all_sum("foo", rank, timeout)
|
||||
self.assertEqual(out, sum(range(world_size)))
|
||||
|
||||
with ThreadPool(world_size) as pool:
|
||||
pool.map(f, range(world_size))
|
||||
|
||||
def test_broadcast_timeout(self) -> None:
|
||||
store = dist.HashStore()
|
||||
|
||||
world_size = 4
|
||||
timeout = timedelta(milliseconds=1)
|
||||
collectives = dist._StoreCollectives(store, 1, world_size)
|
||||
with self.assertRaisesRegex(Exception, "Wait timeout"):
|
||||
collectives.broadcast_recv("foo", timeout)
|
||||
|
||||
def test_gather_timeout(self) -> None:
|
||||
store = dist.HashStore()
|
||||
|
||||
world_size = 4
|
||||
timeout = timedelta(milliseconds=1)
|
||||
collectives = dist._StoreCollectives(store, 1, world_size)
|
||||
with self.assertRaisesRegex(
|
||||
Exception, "gather failed -- missing ranks: 0, 2, 3"
|
||||
):
|
||||
collectives.gather_recv("foo", "data", timeout)
|
||||
|
||||
def test_scatter_timeout(self) -> None:
|
||||
store = dist.HashStore()
|
||||
|
||||
world_size = 4
|
||||
timeout = timedelta(milliseconds=1)
|
||||
collectives = dist._StoreCollectives(store, 1, world_size)
|
||||
with self.assertRaisesRegex(Exception, "Wait timeout"):
|
||||
collectives.scatter_recv("foo", timeout)
|
||||
|
||||
def test_all_gather_timeout(self) -> None:
|
||||
store = dist.HashStore()
|
||||
|
||||
world_size = 4
|
||||
timeout = timedelta(milliseconds=1)
|
||||
collectives = dist._StoreCollectives(store, 1, world_size)
|
||||
with self.assertRaisesRegex(
|
||||
Exception, "all_gather failed -- missing ranks: 0, 2, 3"
|
||||
):
|
||||
collectives.all_gather("foo", "data", timeout)
|
||||
|
||||
def test_barrier_timeout(self) -> None:
|
||||
store = dist.HashStore()
|
||||
|
||||
world_size = 4
|
||||
timeout = timedelta(milliseconds=1)
|
||||
collectives = dist._StoreCollectives(store, 1, world_size)
|
||||
with self.assertRaisesRegex(
|
||||
Exception, "barrier failed -- missing ranks: 0, 2, 3"
|
||||
):
|
||||
collectives.barrier("foo", timeout, True)
|
||||
|
||||
def test_all_sum_timeout(self) -> None:
|
||||
store = dist.HashStore()
|
||||
|
||||
world_size = 4
|
||||
timeout = timedelta(milliseconds=1)
|
||||
collectives = dist._StoreCollectives(store, 1, world_size)
|
||||
with self.assertRaisesRegex(
|
||||
Exception, "barrier failed -- missing ranks: 0, 2, 3"
|
||||
):
|
||||
collectives.all_sum("foo", 1, timeout)
|
||||
|
||||
def test_unique(self) -> None:
|
||||
store = dist.HashStore()
|
||||
|
||||
collectives = dist._StoreCollectives(store, 1, 1)
|
||||
collectives.broadcast_send("foo", "bar")
|
||||
|
||||
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
|
||||
collectives.broadcast_send("foo", "bar")
|
||||
|
||||
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
|
||||
collectives.broadcast_recv("foo")
|
||||
|
||||
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
|
||||
collectives.gather_send("foo", "bar")
|
||||
|
||||
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
|
||||
collectives.gather_recv("foo", "asdf")
|
||||
|
||||
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
|
||||
collectives.scatter_send("foo", ["asdf"])
|
||||
|
||||
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
|
||||
collectives.scatter_recv("foo")
|
||||
|
||||
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
|
||||
collectives.all_gather("foo", "bar")
|
||||
|
||||
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
|
||||
collectives.all_sum("foo", 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert (
|
||||
not torch.cuda._initialized
|
||||
), "test_distributed must not have initialized CUDA context on main process"
|
||||
|
||||
run_tests()
|
@ -210,6 +210,20 @@ class PrefixStore(Store):
|
||||
@property
|
||||
def underlying_store(self) -> Store: ...
|
||||
|
||||
class _ControlCollectives:
|
||||
def barrier(self, key: str, timeout: timedelta, blocking: bool) -> None: ...
|
||||
def broadcast_send(self, key: str, data: str, timeout: timedelta) -> None: ...
|
||||
def broadcast_recv(self, key: str, timeout: timedelta) -> str: ...
|
||||
def gather_send(self, key: str, data: str, timeout: timedelta) -> None: ...
|
||||
def gather_recv(self, key: str, timeout: timedelta) -> str: ...
|
||||
def scatter_send(self, key: str, data: str, timeout: timedelta) -> None: ...
|
||||
def scatter_recv(self, key: str, timeout: timedelta) -> str: ...
|
||||
def all_gather(self, key: str, data: str, timeout: timedelta) -> str: ...
|
||||
def all_sum(self, key: str, data: str, timeout: timedelta) -> int: ...
|
||||
|
||||
class _StoreCollectives(_ControlCollectives):
|
||||
def __init__(self, store: Store, rank: int, world_size: int) -> None: ...
|
||||
|
||||
class _DistributedBackendOptions:
|
||||
def __init__(self): ...
|
||||
@property
|
||||
|
@ -22,7 +22,7 @@ class TORCH_API HashStore : public Store {
|
||||
std::vector<uint8_t> get(const std::string& key) override;
|
||||
|
||||
void wait(const std::vector<std::string>& keys) override {
|
||||
wait(keys, Store::kDefaultTimeout);
|
||||
wait(keys, timeout_);
|
||||
}
|
||||
|
||||
void wait(
|
||||
|
@ -97,4 +97,33 @@ class TORCH_API Store : public torch::CustomClassHolder {
|
||||
std::chrono::milliseconds timeout_;
|
||||
};
|
||||
|
||||
/*
|
||||
StoreTimeoutGuard is a RAII guard that will set the store timeout and restore it
|
||||
when it returns.
|
||||
*/
|
||||
class StoreTimeoutGuard {
|
||||
public:
|
||||
explicit StoreTimeoutGuard(
|
||||
Store& store,
|
||||
const std::chrono::milliseconds& timeout)
|
||||
: store_(store) {
|
||||
oldTimeout_ = store.getTimeout();
|
||||
store.setTimeout(timeout);
|
||||
}
|
||||
|
||||
~StoreTimeoutGuard() {
|
||||
store_.setTimeout(oldTimeout_);
|
||||
}
|
||||
|
||||
/* Disabling copy and move semantics */
|
||||
StoreTimeoutGuard(const StoreTimeoutGuard&) = delete;
|
||||
StoreTimeoutGuard& operator=(const StoreTimeoutGuard&) = delete;
|
||||
StoreTimeoutGuard(StoreTimeoutGuard&&) = delete;
|
||||
StoreTimeoutGuard& operator=(StoreTimeoutGuard&&) = delete;
|
||||
|
||||
private:
|
||||
Store& store_;
|
||||
std::chrono::milliseconds oldTimeout_;
|
||||
};
|
||||
|
||||
} // namespace c10d
|
||||
|
@ -0,0 +1,59 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <chrono>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <torch/custom_class.h>
|
||||
|
||||
namespace c10d {
|
||||
|
||||
using namespace std::chrono_literals;
|
||||
|
||||
class TORCH_API ControlCollectives : public torch::CustomClassHolder {
|
||||
public:
|
||||
virtual void barrier(
|
||||
const std::string& key,
|
||||
std::chrono::milliseconds timeout = 5min,
|
||||
bool block = true) = 0;
|
||||
|
||||
virtual void broadcastSend(
|
||||
const std::string& key,
|
||||
const std::vector<uint8_t>& data,
|
||||
std::chrono::milliseconds timeout = 5min) = 0;
|
||||
virtual std::vector<uint8_t> broadcastRecv(
|
||||
const std::string& key,
|
||||
std::chrono::milliseconds timeout = 5min) = 0;
|
||||
|
||||
virtual void gatherSend(
|
||||
const std::string& key,
|
||||
const std::vector<uint8_t>& data,
|
||||
std::chrono::milliseconds timeout = 5min) = 0;
|
||||
virtual std::vector<std::vector<uint8_t>> gatherRecv(
|
||||
const std::string& key,
|
||||
const std::vector<uint8_t>& data,
|
||||
std::chrono::milliseconds timeout = 5min) = 0;
|
||||
|
||||
virtual std::vector<uint8_t> scatterSend(
|
||||
const std::string& key,
|
||||
const std::vector<std::vector<uint8_t>>& data,
|
||||
std::chrono::milliseconds timeout = 5min) = 0;
|
||||
virtual std::vector<uint8_t> scatterRecv(
|
||||
const std::string& key,
|
||||
std::chrono::milliseconds timeout = 5min) = 0;
|
||||
|
||||
virtual std::vector<std::vector<uint8_t>> allGather(
|
||||
const std::string& key,
|
||||
const std::vector<uint8_t>& data,
|
||||
std::chrono::milliseconds timeout = 5min) = 0;
|
||||
|
||||
virtual int64_t allSum(
|
||||
const std::string& key,
|
||||
int64_t data,
|
||||
std::chrono::milliseconds timeout = 5min) = 0;
|
||||
};
|
||||
|
||||
} // namespace c10d
|
@ -0,0 +1,222 @@
|
||||
#include <c10/util/Exception.h>
|
||||
#include <fmt/format.h>
|
||||
#include <torch/csrc/distributed/c10d/Store.hpp>
|
||||
#include <torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp>
|
||||
#include <chrono>
|
||||
#include <exception>
|
||||
#include <vector>
|
||||
|
||||
namespace {
|
||||
std::string getRankKey(const std::string& key, int rank) {
|
||||
return fmt::format("{}/{}", key, rank);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace c10d {
|
||||
|
||||
StoreCollectives::StoreCollectives(
|
||||
c10::intrusive_ptr<::c10d::Store> store,
|
||||
int rank,
|
||||
int worldSize)
|
||||
: store_(std::move(store)), rank_(rank), worldSize_(worldSize) {}
|
||||
|
||||
void StoreCollectives::barrier(
|
||||
const std::string& key,
|
||||
std::chrono::milliseconds timeout,
|
||||
bool blocking) {
|
||||
enforceUnique(key);
|
||||
StoreTimeoutGuard g{*store_, timeout};
|
||||
|
||||
auto num_members_key = fmt::format("{}/num_members", key);
|
||||
auto last_members_key = fmt::format("{}/last_members", key);
|
||||
|
||||
auto idx = store_->add(num_members_key, 1);
|
||||
store_->set(getRankKey(key, rank_), "joined");
|
||||
|
||||
if (idx == worldSize_) {
|
||||
store_->set(last_members_key, "<val_ignored>");
|
||||
} else if (blocking) {
|
||||
try {
|
||||
store_->wait({last_members_key});
|
||||
} catch (const std::exception& e) {
|
||||
std::string msg = "barrier failed -- missing ranks: ";
|
||||
for (int i = 0; i < worldSize_; i++) {
|
||||
if (i == rank_) {
|
||||
continue;
|
||||
}
|
||||
auto rank_key = getRankKey(key, i);
|
||||
if (!store_->check({rank_key})) {
|
||||
msg += fmt::format("{}, ", i);
|
||||
}
|
||||
}
|
||||
throw std::runtime_error(msg + e.what());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void StoreCollectives::broadcastSend(
|
||||
const std::string& key,
|
||||
const std::vector<uint8_t>& data,
|
||||
std::chrono::milliseconds timeout) {
|
||||
enforceUnique(key);
|
||||
StoreTimeoutGuard g{*store_, timeout};
|
||||
|
||||
store_->set(key, data);
|
||||
}
|
||||
|
||||
std::vector<uint8_t> StoreCollectives::broadcastRecv(
|
||||
const std::string& key,
|
||||
std::chrono::milliseconds timeout) {
|
||||
enforceUnique(key);
|
||||
StoreTimeoutGuard g{*store_, timeout};
|
||||
|
||||
return store_->get(key);
|
||||
}
|
||||
|
||||
void StoreCollectives::gatherSend(
|
||||
const std::string& key,
|
||||
const std::vector<uint8_t>& data,
|
||||
std::chrono::milliseconds timeout) {
|
||||
enforceUnique(key);
|
||||
StoreTimeoutGuard g{*store_, timeout};
|
||||
|
||||
auto rank_key = getRankKey(key, rank_);
|
||||
store_->set(rank_key, data);
|
||||
}
|
||||
|
||||
std::vector<std::vector<uint8_t>> StoreCollectives::gatherRecv(
|
||||
const std::string& key,
|
||||
const std::vector<uint8_t>& data,
|
||||
std::chrono::milliseconds timeout) {
|
||||
enforceUnique(key);
|
||||
StoreTimeoutGuard g{*store_, timeout};
|
||||
|
||||
std::vector<std::string> keys;
|
||||
keys.reserve(worldSize_);
|
||||
|
||||
for (int i = 0; i < worldSize_; i++) {
|
||||
if (i == rank_) {
|
||||
continue;
|
||||
}
|
||||
auto rank_key = getRankKey(key, i);
|
||||
keys.emplace_back(rank_key);
|
||||
}
|
||||
|
||||
std::vector<std::vector<uint8_t>> results;
|
||||
results.reserve(worldSize_);
|
||||
|
||||
try {
|
||||
results = store_->multiGet(keys);
|
||||
} catch (const std::exception& e) {
|
||||
std::string msg = "gather failed -- missing ranks: ";
|
||||
for (int i = 0; i < worldSize_; i++) {
|
||||
if (i == rank_) {
|
||||
continue;
|
||||
}
|
||||
auto rank_key = getRankKey(key, i);
|
||||
if (!store_->check({rank_key})) {
|
||||
msg += fmt::format("{}, ", i);
|
||||
}
|
||||
}
|
||||
throw std::runtime_error(msg + e.what());
|
||||
}
|
||||
|
||||
// insert local data
|
||||
results.insert(results.begin() + rank_, data);
|
||||
return results;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> StoreCollectives::scatterSend(
|
||||
const std::string& key,
|
||||
const std::vector<std::vector<uint8_t>>& data,
|
||||
std::chrono::milliseconds timeout) {
|
||||
enforceUnique(key);
|
||||
StoreTimeoutGuard g{*store_, timeout};
|
||||
|
||||
std::vector<std::string> keys;
|
||||
keys.reserve(worldSize_);
|
||||
for (int i = 0; i < worldSize_; i++) {
|
||||
if (i == rank_) {
|
||||
continue;
|
||||
}
|
||||
auto rank_key = getRankKey(key, i);
|
||||
keys.emplace_back(rank_key);
|
||||
}
|
||||
auto local = data.at(rank_);
|
||||
|
||||
std::vector<std::vector<uint8_t>> toSend{data};
|
||||
|
||||
toSend.erase(toSend.begin() + rank_);
|
||||
|
||||
store_->multiSet(keys, toSend);
|
||||
|
||||
return local;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> StoreCollectives::scatterRecv(
|
||||
const std::string& key,
|
||||
std::chrono::milliseconds timeout) {
|
||||
enforceUnique(key);
|
||||
StoreTimeoutGuard g{*store_, timeout};
|
||||
|
||||
auto rank_key = getRankKey(key, rank_);
|
||||
return store_->get(rank_key);
|
||||
}
|
||||
|
||||
std::vector<std::vector<uint8_t>> StoreCollectives::allGather(
|
||||
const std::string& key,
|
||||
const std::vector<uint8_t>& data,
|
||||
std::chrono::milliseconds timeout) {
|
||||
enforceUnique(key);
|
||||
StoreTimeoutGuard g{*store_, timeout};
|
||||
|
||||
auto localKey = getRankKey(key, rank_);
|
||||
store_->set(localKey, data);
|
||||
|
||||
std::vector<std::string> keys;
|
||||
keys.reserve(worldSize_);
|
||||
|
||||
for (int i = 0; i < worldSize_; i++) {
|
||||
auto rank_key = getRankKey(key, i);
|
||||
keys.emplace_back(rank_key);
|
||||
}
|
||||
|
||||
try {
|
||||
return store_->multiGet(keys);
|
||||
} catch (const std::exception& e) {
|
||||
std::string msg = "all_gather failed -- missing ranks: ";
|
||||
for (int i = 0; i < worldSize_; i++) {
|
||||
if (i == rank_) {
|
||||
continue;
|
||||
}
|
||||
auto rank_key = getRankKey(key, i);
|
||||
if (!store_->check({rank_key})) {
|
||||
msg += fmt::format("{}, ", i);
|
||||
}
|
||||
}
|
||||
throw std::runtime_error(msg + e.what());
|
||||
}
|
||||
}
|
||||
|
||||
int64_t StoreCollectives::allSum(
|
||||
const std::string& key,
|
||||
int64_t value,
|
||||
std::chrono::milliseconds timeout) {
|
||||
enforceUnique(key);
|
||||
StoreTimeoutGuard g{*store_, timeout};
|
||||
|
||||
store_->add(key, value);
|
||||
|
||||
barrier(key + "/barrier", timeout);
|
||||
|
||||
return store_->add(key, 0);
|
||||
}
|
||||
|
||||
void StoreCollectives::enforceUnique(const std::string& key) {
|
||||
auto it = seenKeys_.find(key);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
it == seenKeys_.end(), "Key ", key, " has already been used.");
|
||||
seenKeys_.emplace(key);
|
||||
}
|
||||
|
||||
} // namespace c10d
|
@ -0,0 +1,68 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/FbcodeMaps.h>
|
||||
#include <torch/csrc/distributed/c10d/Store.hpp>
|
||||
#include <torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp>
|
||||
|
||||
namespace c10d {
|
||||
|
||||
class TORCH_API StoreCollectives : public ControlCollectives {
|
||||
public:
|
||||
explicit StoreCollectives(
|
||||
c10::intrusive_ptr<Store> store,
|
||||
int rank,
|
||||
int worldSize);
|
||||
|
||||
void barrier(
|
||||
const std::string& key,
|
||||
std::chrono::milliseconds timeout = 5min,
|
||||
bool block = true) override;
|
||||
|
||||
void broadcastSend(
|
||||
const std::string& key,
|
||||
const std::vector<uint8_t>& data,
|
||||
std::chrono::milliseconds timeout = 5min) override;
|
||||
std::vector<uint8_t> broadcastRecv(
|
||||
const std::string& key,
|
||||
std::chrono::milliseconds timeout = 5min) override;
|
||||
|
||||
void gatherSend(
|
||||
const std::string& key,
|
||||
const std::vector<uint8_t>& data,
|
||||
std::chrono::milliseconds timeout = 5min) override;
|
||||
std::vector<std::vector<uint8_t>> gatherRecv(
|
||||
const std::string& key,
|
||||
const std::vector<uint8_t>& data,
|
||||
std::chrono::milliseconds timeout = 5min) override;
|
||||
|
||||
std::vector<uint8_t> scatterSend(
|
||||
const std::string& key,
|
||||
const std::vector<std::vector<uint8_t>>& data,
|
||||
std::chrono::milliseconds timeout = 5min) override;
|
||||
std::vector<uint8_t> scatterRecv(
|
||||
const std::string& key,
|
||||
std::chrono::milliseconds timeout = 5min) override;
|
||||
|
||||
std::vector<std::vector<uint8_t>> allGather(
|
||||
const std::string& key,
|
||||
const std::vector<uint8_t>& data,
|
||||
std::chrono::milliseconds timeout = 5min) override;
|
||||
|
||||
int64_t allSum(
|
||||
const std::string& key,
|
||||
int64_t data,
|
||||
std::chrono::milliseconds timeout = 5min) override;
|
||||
|
||||
private:
|
||||
void enforceUnique(const std::string& key);
|
||||
|
||||
private:
|
||||
c10::intrusive_ptr<Store> store_;
|
||||
int rank_;
|
||||
int worldSize_;
|
||||
|
||||
c10::FastSet<std::string> seenKeys_{};
|
||||
};
|
||||
|
||||
} // namespace c10d
|
@ -6,6 +6,9 @@
|
||||
#include <torch/csrc/distributed/c10d/GroupRegistry.hpp>
|
||||
#include <torch/csrc/distributed/c10d/TCPStore.hpp>
|
||||
#include <torch/csrc/distributed/c10d/Utils.hpp>
|
||||
#include <torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp>
|
||||
#include <torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp>
|
||||
#include <vector>
|
||||
#ifndef _WIN32
|
||||
#include <torch/csrc/distributed/c10d/HashStore.hpp>
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroupRoundRobin.hpp>
|
||||
@ -136,6 +139,34 @@ namespace torch::distributed::c10d {
|
||||
|
||||
namespace {
|
||||
|
||||
py::bytes toPyBytes(const std::vector<uint8_t>& data) {
|
||||
return py::bytes(reinterpret_cast<const char*>(data.data()), data.size());
|
||||
}
|
||||
|
||||
std::vector<py::bytes> toPyBytes(
|
||||
const std::vector<std::vector<uint8_t>>& data) {
|
||||
std::vector<py::bytes> out;
|
||||
out.reserve(data.size());
|
||||
for (const std::vector<uint8_t>& data_ : data) {
|
||||
out.emplace_back(reinterpret_cast<const char*>(data_.data()), data_.size());
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> toVec8(const std::string& data) {
|
||||
std::vector<uint8_t> out{data.begin(), data.end()};
|
||||
return out;
|
||||
}
|
||||
|
||||
std::vector<std::vector<uint8_t>> toVec8(const std::vector<std::string>& data) {
|
||||
std::vector<std::vector<uint8_t>> out;
|
||||
out.reserve(data.size());
|
||||
for (auto& data_ : data) {
|
||||
out.emplace_back(toVec8(data_));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
|
||||
|
||||
@ -166,8 +197,7 @@ class PythonStore : public ::c10d::Store {
|
||||
pybind11::get_overload(static_cast<const ::c10d::Store*>(this), "set");
|
||||
TORCH_INTERNAL_ASSERT(fn, "Not implemented.");
|
||||
// Call function with a py::bytes object for the value.
|
||||
fn(key,
|
||||
py::bytes(reinterpret_cast<const char*>(value.data()), value.size()));
|
||||
fn(key, toPyBytes(value));
|
||||
}
|
||||
|
||||
// Note: this function manually calls the Python-side overload
|
||||
@ -184,7 +214,7 @@ class PythonStore : public ::c10d::Store {
|
||||
// std::vector<uint8_t>. There is no API for directly accessing
|
||||
// the contents of the py::bytes object.
|
||||
std::string str = pybind11::cast<py::bytes>(fn(key));
|
||||
return std::vector<uint8_t>(str.begin(), str.end());
|
||||
return toVec8(str);
|
||||
}
|
||||
|
||||
// Note: this function manually calls the Python-side overload
|
||||
@ -204,14 +234,8 @@ class PythonStore : public ::c10d::Store {
|
||||
// std::vector<uint8_t>. There is no API for directly accessing
|
||||
// the contents of the py::bytes object.
|
||||
std::string str = pybind11::cast<py::bytes>(
|
||||
fn(key,
|
||||
py::bytes(
|
||||
reinterpret_cast<const char*>(expectedValue.data()),
|
||||
expectedValue.size()),
|
||||
py::bytes(
|
||||
reinterpret_cast<const char*>(desiredValue.data()),
|
||||
desiredValue.size())));
|
||||
return std::vector<uint8_t>(str.begin(), str.end());
|
||||
fn(key, toPyBytes(expectedValue), toPyBytes(desiredValue)));
|
||||
return toVec8(str);
|
||||
}
|
||||
|
||||
int64_t add(const std::string& key, int64_t value) override {
|
||||
@ -253,8 +277,7 @@ class PythonStore : public ::c10d::Store {
|
||||
return Store::append(key, value);
|
||||
}
|
||||
// Call function with a py::bytes object for the value.
|
||||
fn(key,
|
||||
py::bytes(reinterpret_cast<const char*>(value.data()), value.size()));
|
||||
fn(key, toPyBytes(value));
|
||||
}
|
||||
|
||||
std::vector<std::vector<uint8_t>> multiGet(
|
||||
@ -287,14 +310,7 @@ class PythonStore : public ::c10d::Store {
|
||||
return Store::multiSet(keys, values);
|
||||
}
|
||||
|
||||
std::vector<py::bytes> bytes;
|
||||
bytes.reserve(values.size());
|
||||
for (auto& value : values) {
|
||||
bytes.emplace_back(
|
||||
reinterpret_cast<const char*>(value.data()), value.size());
|
||||
}
|
||||
|
||||
fn(keys, bytes);
|
||||
fn(keys, toPyBytes(values));
|
||||
}
|
||||
|
||||
bool hasExtendedApi() const override {
|
||||
@ -973,10 +989,7 @@ and :class:`~torch.distributed.HashStore`).
|
||||
"set",
|
||||
[](::c10d::Store& store,
|
||||
const std::string& key,
|
||||
const std::string& value) {
|
||||
std::vector<uint8_t> value_(value.begin(), value.end());
|
||||
store.set(key, value_);
|
||||
},
|
||||
const std::string& value) { store.set(key, toVec8(value)); },
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
R"(
|
||||
Inserts the key-value pair into the store based on the supplied ``key`` and
|
||||
@ -1001,14 +1014,9 @@ Example::
|
||||
const std::string& key,
|
||||
const std::string& expected_value,
|
||||
const std::string& desired_value) -> py::bytes {
|
||||
std::vector<uint8_t> expectedValue_(
|
||||
expected_value.begin(), expected_value.end());
|
||||
std::vector<uint8_t> desiredValue_(
|
||||
desired_value.begin(), desired_value.end());
|
||||
auto value =
|
||||
store.compareSet(key, expectedValue_, desiredValue_);
|
||||
return py::bytes(
|
||||
reinterpret_cast<char*>(value.data()), value.size());
|
||||
auto value = store.compareSet(
|
||||
key, toVec8(expected_value), toVec8(desired_value));
|
||||
return toPyBytes(value);
|
||||
},
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
R"(
|
||||
@ -1040,8 +1048,7 @@ Example::
|
||||
py::gil_scoped_release guard;
|
||||
return store.get(key);
|
||||
}();
|
||||
return py::bytes(
|
||||
reinterpret_cast<char*>(value.data()), value.size());
|
||||
return toPyBytes(value);
|
||||
},
|
||||
R"(
|
||||
Retrieves the value associated with the given ``key`` in the store. If ``key`` is not
|
||||
@ -1240,8 +1247,7 @@ Example::
|
||||
[](::c10d::Store& store,
|
||||
const std::string& key,
|
||||
const std::string& value) {
|
||||
std::vector<uint8_t> value_(value.begin(), value.end());
|
||||
store.append(key, value_);
|
||||
store.append(key, toVec8(value));
|
||||
},
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
R"(
|
||||
@ -1268,14 +1274,7 @@ Example::
|
||||
py::gil_scoped_release guard;
|
||||
return store.multiGet(keys);
|
||||
}();
|
||||
std::vector<py::bytes> res;
|
||||
for (auto& value : values) {
|
||||
auto bytes = py::bytes(
|
||||
reinterpret_cast<const char*>(value.data()),
|
||||
value.size());
|
||||
res.push_back(bytes);
|
||||
}
|
||||
return res;
|
||||
return toPyBytes(values);
|
||||
},
|
||||
R"(
|
||||
Retrieve all values in ``keys``. If any key in ``keys`` is not
|
||||
@ -1298,12 +1297,7 @@ Example::
|
||||
[](::c10d::Store& store,
|
||||
const std::vector<std::string>& keys,
|
||||
const std::vector<std::string>& values) {
|
||||
std::vector<std::vector<uint8_t>> vals;
|
||||
vals.reserve(values.size());
|
||||
for (auto& value : values) {
|
||||
vals.emplace_back(value.begin(), value.end());
|
||||
}
|
||||
store.multiSet(keys, vals);
|
||||
store.multiSet(keys, toVec8(values));
|
||||
},
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
R"(
|
||||
@ -1487,6 +1481,212 @@ Arguments:
|
||||
&::c10d::PrefixStore::getUnderlyingNonPrefixStore,
|
||||
R"(Recursively to get the store before layers of wrapping with PrefixStore.)");
|
||||
|
||||
using namespace std::chrono_literals;
|
||||
|
||||
auto collectives =
|
||||
py::class_<
|
||||
::c10d::ControlCollectives,
|
||||
c10::intrusive_ptr<::c10d::ControlCollectives>>(
|
||||
module,
|
||||
"_ControlCollectives",
|
||||
R"(
|
||||
Base class for all ControlCollectives implementations.
|
||||
)")
|
||||
.def(
|
||||
"barrier",
|
||||
&::c10d::ControlCollectives::barrier,
|
||||
py::arg("key"),
|
||||
py::arg("timeout") = 5min,
|
||||
py::arg("block") = true,
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
R"(
|
||||
Blocks until all workers have entered this function.
|
||||
|
||||
Arguments:
|
||||
key (str): The unique key used to identify this operation.
|
||||
timeout (duration): The timeout for this operation.
|
||||
block (bool): whether to block this working waiting on the results of the barrier.
|
||||
)")
|
||||
.def(
|
||||
"all_sum",
|
||||
&::c10d::ControlCollectives::allSum,
|
||||
py::arg("key"),
|
||||
py::arg("data"),
|
||||
py::arg("timeout") = 5min,
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
R"(
|
||||
Computes a sum across all workers and returns the final value.
|
||||
|
||||
Arguments:
|
||||
key (str): The unique key used to identify this operation.
|
||||
data (int): The data to sum.
|
||||
timeout (duration): The timeout for this operation.
|
||||
)")
|
||||
.def(
|
||||
"broadcast_send",
|
||||
[](::c10d::ControlCollectives& collectives,
|
||||
const std::string& key,
|
||||
const std::string& data,
|
||||
std::chrono::milliseconds timeout = 5min) {
|
||||
collectives.broadcastSend(key, toVec8(data), timeout);
|
||||
},
|
||||
py::arg("key"),
|
||||
py::arg("data"),
|
||||
py::arg("timeout") = 5min,
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
R"(
|
||||
Sends data to all other workers. Must be only called from one worker.
|
||||
|
||||
Arguments:
|
||||
key (str): The unique key used to identify this operation.
|
||||
data (str): The data to send.
|
||||
timeout (duration): The timeout for this operation.
|
||||
)")
|
||||
.def(
|
||||
"broadcast_recv",
|
||||
[](::c10d::ControlCollectives& collectives,
|
||||
const std::string& key,
|
||||
std::chrono::milliseconds timeout = 5min) {
|
||||
auto out = [&]() {
|
||||
py::gil_scoped_release guard;
|
||||
return collectives.broadcastRecv(key, timeout);
|
||||
}();
|
||||
return toPyBytes(out);
|
||||
},
|
||||
py::arg("key"),
|
||||
py::arg("timeout") = 5min,
|
||||
R"(
|
||||
Receives data broadcasted from 1 worker.
|
||||
|
||||
Arguments:
|
||||
key (str): The unique key used to identify this operation.
|
||||
timeout (duration): The timeout for this operation.
|
||||
)")
|
||||
.def(
|
||||
"gather_send",
|
||||
[](::c10d::ControlCollectives& collectives,
|
||||
const std::string& key,
|
||||
const std::string& data,
|
||||
std::chrono::milliseconds timeout = 5min) {
|
||||
collectives.gatherSend(key, toVec8(data), timeout);
|
||||
},
|
||||
py::arg("key"),
|
||||
py::arg("data"),
|
||||
py::arg("timeout") = 5min,
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
R"(
|
||||
Sends data to one other worker.
|
||||
|
||||
Arguments:
|
||||
key (str): The unique key used to identify this operation.
|
||||
data (str): The data to send.
|
||||
timeout (duration): The timeout for this operation.
|
||||
)")
|
||||
.def(
|
||||
"gather_recv",
|
||||
[](::c10d::ControlCollectives& collectives,
|
||||
const std::string& key,
|
||||
const std::string& data,
|
||||
std::chrono::milliseconds timeout = 5min) {
|
||||
auto out = [&]() {
|
||||
py::gil_scoped_release guard;
|
||||
return collectives.gatherRecv(key, toVec8(data), timeout);
|
||||
}();
|
||||
return toPyBytes(out);
|
||||
},
|
||||
py::arg("key"),
|
||||
py::arg("data"),
|
||||
py::arg("timeout") = 5min,
|
||||
R"(
|
||||
Receives data broadcasted from all workers. Must only be called by one worker.
|
||||
|
||||
Arguments:
|
||||
key (str): The unique key used to identify this operation.
|
||||
timeout (duration): The timeout for this operation.
|
||||
)")
|
||||
|
||||
.def(
|
||||
"scatter_send",
|
||||
[](::c10d::ControlCollectives& collectives,
|
||||
const std::string& key,
|
||||
const std::vector<std::string>& data,
|
||||
std::chrono::milliseconds timeout = 5min) {
|
||||
auto out = [&]() {
|
||||
py::gil_scoped_release guard;
|
||||
return collectives.scatterSend(key, toVec8(data), timeout);
|
||||
}();
|
||||
return toPyBytes(out);
|
||||
},
|
||||
py::arg("key"),
|
||||
py::arg("data"),
|
||||
py::arg("timeout") = 5min,
|
||||
R"(
|
||||
Sends rank specific data to all other workers.
|
||||
|
||||
Arguments:
|
||||
key (str): The unique key used to identify this operation.
|
||||
data (str): The data to send.
|
||||
timeout (duration): The timeout for this operation.
|
||||
)")
|
||||
.def(
|
||||
"scatter_recv",
|
||||
[](::c10d::ControlCollectives& collectives,
|
||||
const std::string& key,
|
||||
std::chrono::milliseconds timeout = 5min) {
|
||||
auto out = [&]() {
|
||||
py::gil_scoped_release guard;
|
||||
return collectives.scatterRecv(key, timeout);
|
||||
}();
|
||||
return toPyBytes(out);
|
||||
},
|
||||
py::arg("key"),
|
||||
py::arg("timeout") = 5min,
|
||||
R"(
|
||||
Receives rank specific data from one worker.
|
||||
|
||||
Arguments:
|
||||
key (str): The unique key used to identify this operation.
|
||||
timeout (duration): The timeout for this operation.
|
||||
)")
|
||||
|
||||
.def(
|
||||
"all_gather",
|
||||
[](::c10d::ControlCollectives& collectives,
|
||||
const std::string& key,
|
||||
const std::string& data,
|
||||
std::chrono::milliseconds timeout = 5min) {
|
||||
auto out = [&]() {
|
||||
py::gil_scoped_release guard;
|
||||
return collectives.allGather(key, toVec8(data), timeout);
|
||||
}();
|
||||
return toPyBytes(out);
|
||||
},
|
||||
py::arg("key"),
|
||||
py::arg("data"),
|
||||
py::arg("timeout") = 5min,
|
||||
R"(
|
||||
Sends data to all workers and receives data from all other workers.
|
||||
|
||||
Arguments:
|
||||
key (str): The unique key used to identify this operation.
|
||||
data (str): The data to send.
|
||||
timeout (duration): The timeout for this operation.
|
||||
)");
|
||||
|
||||
intrusive_ptr_class_<::c10d::StoreCollectives>(
|
||||
module,
|
||||
"_StoreCollectives",
|
||||
collectives,
|
||||
R"(
|
||||
An implementation of ControlCollectives that uses the provided store as the underlying
|
||||
communication mechanism.
|
||||
)")
|
||||
.def(
|
||||
py::init<c10::intrusive_ptr<::c10d::Store>, int, int>(),
|
||||
py::arg("store"),
|
||||
py::arg("rank"),
|
||||
py::arg("world_size"));
|
||||
|
||||
auto processGroup =
|
||||
py::class_<
|
||||
::c10d::ProcessGroup,
|
||||
|
@ -54,6 +54,8 @@ if is_available():
|
||||
set_debug_level,
|
||||
set_debug_level_from_env,
|
||||
_make_nccl_premul_sum,
|
||||
_ControlCollectives,
|
||||
_StoreCollectives,
|
||||
)
|
||||
|
||||
class _DistributedPdb(pdb.Pdb):
|
||||
|
Reference in New Issue
Block a user