Reapply "c10d: add Collectives abstraction (#125978)" (#126695)

This reverts commit d9c3485146913324ab4b3e211d2a4517e138f4af.

Reapplies #125978.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126695
Approved by: https://github.com/c-p-i-o
This commit is contained in:
Tristan Rice
2024-05-21 18:00:09 +00:00
committed by PyTorch MergeBot
parent d8f5627a88
commit ac51920656
11 changed files with 837 additions and 53 deletions

View File

@ -772,7 +772,7 @@ cc_library(
[
"torch/*.h",
"torch/csrc/**/*.h",
"torch/csrc/distributed/c10d/*.hpp",
"torch/csrc/distributed/c10d/**/*.hpp",
"torch/lib/libshm/*.h",
],
exclude = [

View File

@ -487,6 +487,7 @@ libtorch_core_sources = sorted(
# These files are the only ones that are supported on Windows.
libtorch_distributed_base_sources = [
"torch/csrc/distributed/c10d/Backend.cpp",
"torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp",
"torch/csrc/distributed/c10d/FileStore.cpp",
"torch/csrc/distributed/c10d/Functional.cpp",
"torch/csrc/distributed/c10d/GlooDeviceFactory.cpp",

View File

@ -0,0 +1,189 @@
# Owner(s): ["oncall: distributed"]
from datetime import timedelta
from multiprocessing.pool import ThreadPool
import torch
import torch.distributed as dist
from torch.testing._internal.common_utils import run_tests, TestCase
class TestCollectives(TestCase):
def test_barrier(self) -> None:
store = dist.HashStore()
world_size = 2
def f(rank: int) -> None:
collectives = dist._StoreCollectives(store, rank, world_size)
collectives.barrier("foo", timedelta(seconds=10), True)
with ThreadPool(world_size) as pool:
pool.map(f, range(world_size))
def test_broadcast(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(seconds=10)
def f(rank: int) -> None:
collectives = dist._StoreCollectives(store, rank, world_size)
if rank == 2:
collectives.broadcast_send("foo", b"data", timeout)
else:
out = collectives.broadcast_recv("foo", timeout)
self.assertEqual(out, b"data")
with ThreadPool(world_size) as pool:
pool.map(f, range(world_size))
def test_gather(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(seconds=10)
def f(rank: int) -> None:
collectives = dist._StoreCollectives(store, rank, world_size)
if rank == 2:
out = collectives.gather_recv("foo", str(rank), timeout)
self.assertEqual(out, [b"0", b"1", b"2", b"3"])
else:
collectives.gather_send("foo", str(rank), timeout)
with ThreadPool(world_size) as pool:
pool.map(f, range(world_size))
def test_scatter(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(seconds=10)
def f(rank: int) -> None:
collectives = dist._StoreCollectives(store, rank, world_size)
if rank == 2:
out = collectives.scatter_send(
"foo", [str(i) for i in range(world_size)], timeout
)
else:
out = collectives.scatter_recv("foo", timeout)
self.assertEqual(out, str(rank).encode())
with ThreadPool(world_size) as pool:
pool.map(f, range(world_size))
def test_all_sum(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(seconds=10)
def f(rank: int) -> None:
collectives = dist._StoreCollectives(store, rank, world_size)
out = collectives.all_sum("foo", rank, timeout)
self.assertEqual(out, sum(range(world_size)))
with ThreadPool(world_size) as pool:
pool.map(f, range(world_size))
def test_broadcast_timeout(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist._StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(Exception, "Wait timeout"):
collectives.broadcast_recv("foo", timeout)
def test_gather_timeout(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist._StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(
Exception, "gather failed -- missing ranks: 0, 2, 3"
):
collectives.gather_recv("foo", "data", timeout)
def test_scatter_timeout(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist._StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(Exception, "Wait timeout"):
collectives.scatter_recv("foo", timeout)
def test_all_gather_timeout(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist._StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(
Exception, "all_gather failed -- missing ranks: 0, 2, 3"
):
collectives.all_gather("foo", "data", timeout)
def test_barrier_timeout(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist._StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(
Exception, "barrier failed -- missing ranks: 0, 2, 3"
):
collectives.barrier("foo", timeout, True)
def test_all_sum_timeout(self) -> None:
store = dist.HashStore()
world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist._StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(
Exception, "barrier failed -- missing ranks: 0, 2, 3"
):
collectives.all_sum("foo", 1, timeout)
def test_unique(self) -> None:
store = dist.HashStore()
collectives = dist._StoreCollectives(store, 1, 1)
collectives.broadcast_send("foo", "bar")
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.broadcast_send("foo", "bar")
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.broadcast_recv("foo")
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.gather_send("foo", "bar")
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.gather_recv("foo", "asdf")
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.scatter_send("foo", ["asdf"])
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.scatter_recv("foo")
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.all_gather("foo", "bar")
with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.all_sum("foo", 2)
if __name__ == "__main__":
assert (
not torch.cuda._initialized
), "test_distributed must not have initialized CUDA context on main process"
run_tests()

View File

@ -210,6 +210,20 @@ class PrefixStore(Store):
@property
def underlying_store(self) -> Store: ...
class _ControlCollectives:
def barrier(self, key: str, timeout: timedelta, blocking: bool) -> None: ...
def broadcast_send(self, key: str, data: str, timeout: timedelta) -> None: ...
def broadcast_recv(self, key: str, timeout: timedelta) -> str: ...
def gather_send(self, key: str, data: str, timeout: timedelta) -> None: ...
def gather_recv(self, key: str, timeout: timedelta) -> str: ...
def scatter_send(self, key: str, data: str, timeout: timedelta) -> None: ...
def scatter_recv(self, key: str, timeout: timedelta) -> str: ...
def all_gather(self, key: str, data: str, timeout: timedelta) -> str: ...
def all_sum(self, key: str, data: str, timeout: timedelta) -> int: ...
class _StoreCollectives(_ControlCollectives):
def __init__(self, store: Store, rank: int, world_size: int) -> None: ...
class _DistributedBackendOptions:
def __init__(self): ...
@property

View File

@ -22,7 +22,7 @@ class TORCH_API HashStore : public Store {
std::vector<uint8_t> get(const std::string& key) override;
void wait(const std::vector<std::string>& keys) override {
wait(keys, Store::kDefaultTimeout);
wait(keys, timeout_);
}
void wait(

View File

@ -97,4 +97,33 @@ class TORCH_API Store : public torch::CustomClassHolder {
std::chrono::milliseconds timeout_;
};
/*
StoreTimeoutGuard is a RAII guard that will set the store timeout and restore it
when it returns.
*/
class StoreTimeoutGuard {
public:
explicit StoreTimeoutGuard(
Store& store,
const std::chrono::milliseconds& timeout)
: store_(store) {
oldTimeout_ = store.getTimeout();
store.setTimeout(timeout);
}
~StoreTimeoutGuard() {
store_.setTimeout(oldTimeout_);
}
/* Disabling copy and move semantics */
StoreTimeoutGuard(const StoreTimeoutGuard&) = delete;
StoreTimeoutGuard& operator=(const StoreTimeoutGuard&) = delete;
StoreTimeoutGuard(StoreTimeoutGuard&&) = delete;
StoreTimeoutGuard& operator=(StoreTimeoutGuard&&) = delete;
private:
Store& store_;
std::chrono::milliseconds oldTimeout_;
};
} // namespace c10d

View File

@ -0,0 +1,59 @@
#pragma once
#include <ATen/core/ivalue.h>
#include <chrono>
#include <cstdint>
#include <string>
#include <vector>
#include <c10/macros/Macros.h>
#include <torch/custom_class.h>
namespace c10d {
using namespace std::chrono_literals;
class TORCH_API ControlCollectives : public torch::CustomClassHolder {
public:
virtual void barrier(
const std::string& key,
std::chrono::milliseconds timeout = 5min,
bool block = true) = 0;
virtual void broadcastSend(
const std::string& key,
const std::vector<uint8_t>& data,
std::chrono::milliseconds timeout = 5min) = 0;
virtual std::vector<uint8_t> broadcastRecv(
const std::string& key,
std::chrono::milliseconds timeout = 5min) = 0;
virtual void gatherSend(
const std::string& key,
const std::vector<uint8_t>& data,
std::chrono::milliseconds timeout = 5min) = 0;
virtual std::vector<std::vector<uint8_t>> gatherRecv(
const std::string& key,
const std::vector<uint8_t>& data,
std::chrono::milliseconds timeout = 5min) = 0;
virtual std::vector<uint8_t> scatterSend(
const std::string& key,
const std::vector<std::vector<uint8_t>>& data,
std::chrono::milliseconds timeout = 5min) = 0;
virtual std::vector<uint8_t> scatterRecv(
const std::string& key,
std::chrono::milliseconds timeout = 5min) = 0;
virtual std::vector<std::vector<uint8_t>> allGather(
const std::string& key,
const std::vector<uint8_t>& data,
std::chrono::milliseconds timeout = 5min) = 0;
virtual int64_t allSum(
const std::string& key,
int64_t data,
std::chrono::milliseconds timeout = 5min) = 0;
};
} // namespace c10d

View File

@ -0,0 +1,222 @@
#include <c10/util/Exception.h>
#include <fmt/format.h>
#include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp>
#include <chrono>
#include <exception>
#include <vector>
namespace {
std::string getRankKey(const std::string& key, int rank) {
return fmt::format("{}/{}", key, rank);
}
} // namespace
namespace c10d {
StoreCollectives::StoreCollectives(
c10::intrusive_ptr<::c10d::Store> store,
int rank,
int worldSize)
: store_(std::move(store)), rank_(rank), worldSize_(worldSize) {}
void StoreCollectives::barrier(
const std::string& key,
std::chrono::milliseconds timeout,
bool blocking) {
enforceUnique(key);
StoreTimeoutGuard g{*store_, timeout};
auto num_members_key = fmt::format("{}/num_members", key);
auto last_members_key = fmt::format("{}/last_members", key);
auto idx = store_->add(num_members_key, 1);
store_->set(getRankKey(key, rank_), "joined");
if (idx == worldSize_) {
store_->set(last_members_key, "<val_ignored>");
} else if (blocking) {
try {
store_->wait({last_members_key});
} catch (const std::exception& e) {
std::string msg = "barrier failed -- missing ranks: ";
for (int i = 0; i < worldSize_; i++) {
if (i == rank_) {
continue;
}
auto rank_key = getRankKey(key, i);
if (!store_->check({rank_key})) {
msg += fmt::format("{}, ", i);
}
}
throw std::runtime_error(msg + e.what());
}
}
}
void StoreCollectives::broadcastSend(
const std::string& key,
const std::vector<uint8_t>& data,
std::chrono::milliseconds timeout) {
enforceUnique(key);
StoreTimeoutGuard g{*store_, timeout};
store_->set(key, data);
}
std::vector<uint8_t> StoreCollectives::broadcastRecv(
const std::string& key,
std::chrono::milliseconds timeout) {
enforceUnique(key);
StoreTimeoutGuard g{*store_, timeout};
return store_->get(key);
}
void StoreCollectives::gatherSend(
const std::string& key,
const std::vector<uint8_t>& data,
std::chrono::milliseconds timeout) {
enforceUnique(key);
StoreTimeoutGuard g{*store_, timeout};
auto rank_key = getRankKey(key, rank_);
store_->set(rank_key, data);
}
std::vector<std::vector<uint8_t>> StoreCollectives::gatherRecv(
const std::string& key,
const std::vector<uint8_t>& data,
std::chrono::milliseconds timeout) {
enforceUnique(key);
StoreTimeoutGuard g{*store_, timeout};
std::vector<std::string> keys;
keys.reserve(worldSize_);
for (int i = 0; i < worldSize_; i++) {
if (i == rank_) {
continue;
}
auto rank_key = getRankKey(key, i);
keys.emplace_back(rank_key);
}
std::vector<std::vector<uint8_t>> results;
results.reserve(worldSize_);
try {
results = store_->multiGet(keys);
} catch (const std::exception& e) {
std::string msg = "gather failed -- missing ranks: ";
for (int i = 0; i < worldSize_; i++) {
if (i == rank_) {
continue;
}
auto rank_key = getRankKey(key, i);
if (!store_->check({rank_key})) {
msg += fmt::format("{}, ", i);
}
}
throw std::runtime_error(msg + e.what());
}
// insert local data
results.insert(results.begin() + rank_, data);
return results;
}
std::vector<uint8_t> StoreCollectives::scatterSend(
const std::string& key,
const std::vector<std::vector<uint8_t>>& data,
std::chrono::milliseconds timeout) {
enforceUnique(key);
StoreTimeoutGuard g{*store_, timeout};
std::vector<std::string> keys;
keys.reserve(worldSize_);
for (int i = 0; i < worldSize_; i++) {
if (i == rank_) {
continue;
}
auto rank_key = getRankKey(key, i);
keys.emplace_back(rank_key);
}
auto local = data.at(rank_);
std::vector<std::vector<uint8_t>> toSend{data};
toSend.erase(toSend.begin() + rank_);
store_->multiSet(keys, toSend);
return local;
}
std::vector<uint8_t> StoreCollectives::scatterRecv(
const std::string& key,
std::chrono::milliseconds timeout) {
enforceUnique(key);
StoreTimeoutGuard g{*store_, timeout};
auto rank_key = getRankKey(key, rank_);
return store_->get(rank_key);
}
std::vector<std::vector<uint8_t>> StoreCollectives::allGather(
const std::string& key,
const std::vector<uint8_t>& data,
std::chrono::milliseconds timeout) {
enforceUnique(key);
StoreTimeoutGuard g{*store_, timeout};
auto localKey = getRankKey(key, rank_);
store_->set(localKey, data);
std::vector<std::string> keys;
keys.reserve(worldSize_);
for (int i = 0; i < worldSize_; i++) {
auto rank_key = getRankKey(key, i);
keys.emplace_back(rank_key);
}
try {
return store_->multiGet(keys);
} catch (const std::exception& e) {
std::string msg = "all_gather failed -- missing ranks: ";
for (int i = 0; i < worldSize_; i++) {
if (i == rank_) {
continue;
}
auto rank_key = getRankKey(key, i);
if (!store_->check({rank_key})) {
msg += fmt::format("{}, ", i);
}
}
throw std::runtime_error(msg + e.what());
}
}
int64_t StoreCollectives::allSum(
const std::string& key,
int64_t value,
std::chrono::milliseconds timeout) {
enforceUnique(key);
StoreTimeoutGuard g{*store_, timeout};
store_->add(key, value);
barrier(key + "/barrier", timeout);
return store_->add(key, 0);
}
void StoreCollectives::enforceUnique(const std::string& key) {
auto it = seenKeys_.find(key);
TORCH_INTERNAL_ASSERT(
it == seenKeys_.end(), "Key ", key, " has already been used.");
seenKeys_.emplace(key);
}
} // namespace c10d

View File

@ -0,0 +1,68 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/FbcodeMaps.h>
#include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp>
namespace c10d {
class TORCH_API StoreCollectives : public ControlCollectives {
public:
explicit StoreCollectives(
c10::intrusive_ptr<Store> store,
int rank,
int worldSize);
void barrier(
const std::string& key,
std::chrono::milliseconds timeout = 5min,
bool block = true) override;
void broadcastSend(
const std::string& key,
const std::vector<uint8_t>& data,
std::chrono::milliseconds timeout = 5min) override;
std::vector<uint8_t> broadcastRecv(
const std::string& key,
std::chrono::milliseconds timeout = 5min) override;
void gatherSend(
const std::string& key,
const std::vector<uint8_t>& data,
std::chrono::milliseconds timeout = 5min) override;
std::vector<std::vector<uint8_t>> gatherRecv(
const std::string& key,
const std::vector<uint8_t>& data,
std::chrono::milliseconds timeout = 5min) override;
std::vector<uint8_t> scatterSend(
const std::string& key,
const std::vector<std::vector<uint8_t>>& data,
std::chrono::milliseconds timeout = 5min) override;
std::vector<uint8_t> scatterRecv(
const std::string& key,
std::chrono::milliseconds timeout = 5min) override;
std::vector<std::vector<uint8_t>> allGather(
const std::string& key,
const std::vector<uint8_t>& data,
std::chrono::milliseconds timeout = 5min) override;
int64_t allSum(
const std::string& key,
int64_t data,
std::chrono::milliseconds timeout = 5min) override;
private:
void enforceUnique(const std::string& key);
private:
c10::intrusive_ptr<Store> store_;
int rank_;
int worldSize_;
c10::FastSet<std::string> seenKeys_{};
};
} // namespace c10d

View File

@ -6,6 +6,9 @@
#include <torch/csrc/distributed/c10d/GroupRegistry.hpp>
#include <torch/csrc/distributed/c10d/TCPStore.hpp>
#include <torch/csrc/distributed/c10d/Utils.hpp>
#include <torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp>
#include <torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp>
#include <vector>
#ifndef _WIN32
#include <torch/csrc/distributed/c10d/HashStore.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupRoundRobin.hpp>
@ -136,6 +139,34 @@ namespace torch::distributed::c10d {
namespace {
py::bytes toPyBytes(const std::vector<uint8_t>& data) {
return py::bytes(reinterpret_cast<const char*>(data.data()), data.size());
}
std::vector<py::bytes> toPyBytes(
const std::vector<std::vector<uint8_t>>& data) {
std::vector<py::bytes> out;
out.reserve(data.size());
for (const std::vector<uint8_t>& data_ : data) {
out.emplace_back(reinterpret_cast<const char*>(data_.data()), data_.size());
}
return out;
}
std::vector<uint8_t> toVec8(const std::string& data) {
std::vector<uint8_t> out{data.begin(), data.end()};
return out;
}
std::vector<std::vector<uint8_t>> toVec8(const std::vector<std::string>& data) {
std::vector<std::vector<uint8_t>> out;
out.reserve(data.size());
for (auto& data_ : data) {
out.emplace_back(toVec8(data_));
}
return out;
}
template <typename T>
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
@ -166,8 +197,7 @@ class PythonStore : public ::c10d::Store {
pybind11::get_overload(static_cast<const ::c10d::Store*>(this), "set");
TORCH_INTERNAL_ASSERT(fn, "Not implemented.");
// Call function with a py::bytes object for the value.
fn(key,
py::bytes(reinterpret_cast<const char*>(value.data()), value.size()));
fn(key, toPyBytes(value));
}
// Note: this function manually calls the Python-side overload
@ -184,7 +214,7 @@ class PythonStore : public ::c10d::Store {
// std::vector<uint8_t>. There is no API for directly accessing
// the contents of the py::bytes object.
std::string str = pybind11::cast<py::bytes>(fn(key));
return std::vector<uint8_t>(str.begin(), str.end());
return toVec8(str);
}
// Note: this function manually calls the Python-side overload
@ -204,14 +234,8 @@ class PythonStore : public ::c10d::Store {
// std::vector<uint8_t>. There is no API for directly accessing
// the contents of the py::bytes object.
std::string str = pybind11::cast<py::bytes>(
fn(key,
py::bytes(
reinterpret_cast<const char*>(expectedValue.data()),
expectedValue.size()),
py::bytes(
reinterpret_cast<const char*>(desiredValue.data()),
desiredValue.size())));
return std::vector<uint8_t>(str.begin(), str.end());
fn(key, toPyBytes(expectedValue), toPyBytes(desiredValue)));
return toVec8(str);
}
int64_t add(const std::string& key, int64_t value) override {
@ -253,8 +277,7 @@ class PythonStore : public ::c10d::Store {
return Store::append(key, value);
}
// Call function with a py::bytes object for the value.
fn(key,
py::bytes(reinterpret_cast<const char*>(value.data()), value.size()));
fn(key, toPyBytes(value));
}
std::vector<std::vector<uint8_t>> multiGet(
@ -287,14 +310,7 @@ class PythonStore : public ::c10d::Store {
return Store::multiSet(keys, values);
}
std::vector<py::bytes> bytes;
bytes.reserve(values.size());
for (auto& value : values) {
bytes.emplace_back(
reinterpret_cast<const char*>(value.data()), value.size());
}
fn(keys, bytes);
fn(keys, toPyBytes(values));
}
bool hasExtendedApi() const override {
@ -973,10 +989,7 @@ and :class:`~torch.distributed.HashStore`).
"set",
[](::c10d::Store& store,
const std::string& key,
const std::string& value) {
std::vector<uint8_t> value_(value.begin(), value.end());
store.set(key, value_);
},
const std::string& value) { store.set(key, toVec8(value)); },
py::call_guard<py::gil_scoped_release>(),
R"(
Inserts the key-value pair into the store based on the supplied ``key`` and
@ -1001,14 +1014,9 @@ Example::
const std::string& key,
const std::string& expected_value,
const std::string& desired_value) -> py::bytes {
std::vector<uint8_t> expectedValue_(
expected_value.begin(), expected_value.end());
std::vector<uint8_t> desiredValue_(
desired_value.begin(), desired_value.end());
auto value =
store.compareSet(key, expectedValue_, desiredValue_);
return py::bytes(
reinterpret_cast<char*>(value.data()), value.size());
auto value = store.compareSet(
key, toVec8(expected_value), toVec8(desired_value));
return toPyBytes(value);
},
py::call_guard<py::gil_scoped_release>(),
R"(
@ -1040,8 +1048,7 @@ Example::
py::gil_scoped_release guard;
return store.get(key);
}();
return py::bytes(
reinterpret_cast<char*>(value.data()), value.size());
return toPyBytes(value);
},
R"(
Retrieves the value associated with the given ``key`` in the store. If ``key`` is not
@ -1240,8 +1247,7 @@ Example::
[](::c10d::Store& store,
const std::string& key,
const std::string& value) {
std::vector<uint8_t> value_(value.begin(), value.end());
store.append(key, value_);
store.append(key, toVec8(value));
},
py::call_guard<py::gil_scoped_release>(),
R"(
@ -1268,14 +1274,7 @@ Example::
py::gil_scoped_release guard;
return store.multiGet(keys);
}();
std::vector<py::bytes> res;
for (auto& value : values) {
auto bytes = py::bytes(
reinterpret_cast<const char*>(value.data()),
value.size());
res.push_back(bytes);
}
return res;
return toPyBytes(values);
},
R"(
Retrieve all values in ``keys``. If any key in ``keys`` is not
@ -1298,12 +1297,7 @@ Example::
[](::c10d::Store& store,
const std::vector<std::string>& keys,
const std::vector<std::string>& values) {
std::vector<std::vector<uint8_t>> vals;
vals.reserve(values.size());
for (auto& value : values) {
vals.emplace_back(value.begin(), value.end());
}
store.multiSet(keys, vals);
store.multiSet(keys, toVec8(values));
},
py::call_guard<py::gil_scoped_release>(),
R"(
@ -1487,6 +1481,212 @@ Arguments:
&::c10d::PrefixStore::getUnderlyingNonPrefixStore,
R"(Recursively to get the store before layers of wrapping with PrefixStore.)");
using namespace std::chrono_literals;
auto collectives =
py::class_<
::c10d::ControlCollectives,
c10::intrusive_ptr<::c10d::ControlCollectives>>(
module,
"_ControlCollectives",
R"(
Base class for all ControlCollectives implementations.
)")
.def(
"barrier",
&::c10d::ControlCollectives::barrier,
py::arg("key"),
py::arg("timeout") = 5min,
py::arg("block") = true,
py::call_guard<py::gil_scoped_release>(),
R"(
Blocks until all workers have entered this function.
Arguments:
key (str): The unique key used to identify this operation.
timeout (duration): The timeout for this operation.
block (bool): whether to block this working waiting on the results of the barrier.
)")
.def(
"all_sum",
&::c10d::ControlCollectives::allSum,
py::arg("key"),
py::arg("data"),
py::arg("timeout") = 5min,
py::call_guard<py::gil_scoped_release>(),
R"(
Computes a sum across all workers and returns the final value.
Arguments:
key (str): The unique key used to identify this operation.
data (int): The data to sum.
timeout (duration): The timeout for this operation.
)")
.def(
"broadcast_send",
[](::c10d::ControlCollectives& collectives,
const std::string& key,
const std::string& data,
std::chrono::milliseconds timeout = 5min) {
collectives.broadcastSend(key, toVec8(data), timeout);
},
py::arg("key"),
py::arg("data"),
py::arg("timeout") = 5min,
py::call_guard<py::gil_scoped_release>(),
R"(
Sends data to all other workers. Must be only called from one worker.
Arguments:
key (str): The unique key used to identify this operation.
data (str): The data to send.
timeout (duration): The timeout for this operation.
)")
.def(
"broadcast_recv",
[](::c10d::ControlCollectives& collectives,
const std::string& key,
std::chrono::milliseconds timeout = 5min) {
auto out = [&]() {
py::gil_scoped_release guard;
return collectives.broadcastRecv(key, timeout);
}();
return toPyBytes(out);
},
py::arg("key"),
py::arg("timeout") = 5min,
R"(
Receives data broadcasted from 1 worker.
Arguments:
key (str): The unique key used to identify this operation.
timeout (duration): The timeout for this operation.
)")
.def(
"gather_send",
[](::c10d::ControlCollectives& collectives,
const std::string& key,
const std::string& data,
std::chrono::milliseconds timeout = 5min) {
collectives.gatherSend(key, toVec8(data), timeout);
},
py::arg("key"),
py::arg("data"),
py::arg("timeout") = 5min,
py::call_guard<py::gil_scoped_release>(),
R"(
Sends data to one other worker.
Arguments:
key (str): The unique key used to identify this operation.
data (str): The data to send.
timeout (duration): The timeout for this operation.
)")
.def(
"gather_recv",
[](::c10d::ControlCollectives& collectives,
const std::string& key,
const std::string& data,
std::chrono::milliseconds timeout = 5min) {
auto out = [&]() {
py::gil_scoped_release guard;
return collectives.gatherRecv(key, toVec8(data), timeout);
}();
return toPyBytes(out);
},
py::arg("key"),
py::arg("data"),
py::arg("timeout") = 5min,
R"(
Receives data broadcasted from all workers. Must only be called by one worker.
Arguments:
key (str): The unique key used to identify this operation.
timeout (duration): The timeout for this operation.
)")
.def(
"scatter_send",
[](::c10d::ControlCollectives& collectives,
const std::string& key,
const std::vector<std::string>& data,
std::chrono::milliseconds timeout = 5min) {
auto out = [&]() {
py::gil_scoped_release guard;
return collectives.scatterSend(key, toVec8(data), timeout);
}();
return toPyBytes(out);
},
py::arg("key"),
py::arg("data"),
py::arg("timeout") = 5min,
R"(
Sends rank specific data to all other workers.
Arguments:
key (str): The unique key used to identify this operation.
data (str): The data to send.
timeout (duration): The timeout for this operation.
)")
.def(
"scatter_recv",
[](::c10d::ControlCollectives& collectives,
const std::string& key,
std::chrono::milliseconds timeout = 5min) {
auto out = [&]() {
py::gil_scoped_release guard;
return collectives.scatterRecv(key, timeout);
}();
return toPyBytes(out);
},
py::arg("key"),
py::arg("timeout") = 5min,
R"(
Receives rank specific data from one worker.
Arguments:
key (str): The unique key used to identify this operation.
timeout (duration): The timeout for this operation.
)")
.def(
"all_gather",
[](::c10d::ControlCollectives& collectives,
const std::string& key,
const std::string& data,
std::chrono::milliseconds timeout = 5min) {
auto out = [&]() {
py::gil_scoped_release guard;
return collectives.allGather(key, toVec8(data), timeout);
}();
return toPyBytes(out);
},
py::arg("key"),
py::arg("data"),
py::arg("timeout") = 5min,
R"(
Sends data to all workers and receives data from all other workers.
Arguments:
key (str): The unique key used to identify this operation.
data (str): The data to send.
timeout (duration): The timeout for this operation.
)");
intrusive_ptr_class_<::c10d::StoreCollectives>(
module,
"_StoreCollectives",
collectives,
R"(
An implementation of ControlCollectives that uses the provided store as the underlying
communication mechanism.
)")
.def(
py::init<c10::intrusive_ptr<::c10d::Store>, int, int>(),
py::arg("store"),
py::arg("rank"),
py::arg("world_size"));
auto processGroup =
py::class_<
::c10d::ProcessGroup,

View File

@ -54,6 +54,8 @@ if is_available():
set_debug_level,
set_debug_level_from_env,
_make_nccl_premul_sum,
_ControlCollectives,
_StoreCollectives,
)
class _DistributedPdb(pdb.Pdb):