[C10] PG observability hooks. (#108815)

Expose a set of observability hooks into C10D such that our users can
detect collectives failure both faster and more easily.

The design is similar to NCCL desync debug that it minimized the
overhead by doing most of the work out of the main thread.

This PR introduces a new module torch.distributed.hooks that exposes the following set of methods:

    register_collective_start_hook
    register_collective_end_hook
    register_process_group_hook

The process group hook exposes PG creation on the member ranks and call them inline from the
the PG creation code. This is fine since this happens during initialization and a limited number of times.

The collective start/end hooks are fired from a single background thread. It reads
events from a C++ queue and dispatches over.

Queue notification is oddly done using a pipe, this is needed so python can abort the thread on shutdown
and have it as background thread. This is not possible with more reasonable choices like a condvar.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108815
Approved by: https://github.com/wconstab, https://github.com/fduwjj
This commit is contained in:
Rodrigo Kumpera
2023-10-05 14:38:59 -07:00
committed by PyTorch MergeBot
parent 17348b0f51
commit 0c7a877745
17 changed files with 690 additions and 31 deletions

View File

@ -521,6 +521,7 @@ libtorch_distributed_base_sources = [
"torch/csrc/distributed/c10d/Backend.cpp",
"torch/csrc/distributed/c10d/FileStore.cpp",
"torch/csrc/distributed/c10d/GlooDeviceFactory.cpp",
"torch/csrc/distributed/c10d/Hooks.cpp",
"torch/csrc/distributed/c10d/Ops.cpp",
"torch/csrc/distributed/c10d/ParamCommsUtils.cpp",
"torch/csrc/distributed/c10d/PrefixStore.cpp",

View File

@ -0,0 +1,270 @@
# Owner(s): ["oncall: distributed"]
import os
import sys
import tempfile
import threading
from functools import partial, wraps
import torch
import torch.distributed as dist
import torch.distributed.hooks as dhooks
if not dist.is_available():
print("torch.distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import run_tests, TestCase
class PgHooks(MultiProcessTestCase):
@property
def world_size(self) -> int:
return 4
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
def tearDown(self):
super().tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
def test_pg_hook(self):
pgs = []
def pg_hook(pg, pg_name):
pgs.append((pg, pg_name))
dhooks.register_process_group_hook(pg_hook)
dist.init_process_group(
backend="gloo",
rank=self.rank,
world_size=self.world_size,
store=dist.FileStore(self.file_name, self.world_size),
)
self.assertEqual(len(pgs), 1)
self.assertEqual(pgs[0][0], dist.group.WORLD)
# create two partial world PGs
pg0 = dist.new_group(ranks=[0, 1])
pg1 = dist.new_group(ranks=[2, 3])
# Each rank only observe two PGs being created: the default PG and one covering its ranks
# We don't emit events for PG creation if the current rank doesn't belong to it.
# For example, say you're rank 1, you'll get an event for pg0 but not pg1 even though the API contact
# dictates you need to call new_group for both.
self.assertEqual(len(pgs), 2)
self.assertEqual(pgs[1][0], pg0 if self.rank < 2 else pg1)
def with_comms(func=None):
if func is None:
return partial(
with_comms,
)
@wraps(func)
def wrapper(self, *args, **kwargs):
self.init_comms()
func(self, *args, **kwargs)
self.destroy_comms()
return wrapper
class CollectiveHooks:
@property
def world_size(self) -> int:
return 4
def _collective_hooks(self):
# it's ok to access them directly since there's a single bg thread poking at them.
starts = []
ends = []
cv = threading.Condition()
def coll_start(status):
starts.append(status)
print(f"col_start {len(starts)} rank{self.rank}")
def coll_end(status):
ends.append(status)
print(f"col_end {len(ends)} rank{self.rank}")
if len(ends) == 2:
with cv:
cv.notify()
dhooks.register_collective_start_hook(coll_start)
dhooks.register_collective_end_hook(coll_end)
tensor = torch.ones([2, 3]).to(self.device) * self.rank
tensor_list = [torch.empty_like(tensor) for _ in range(self.world_size)]
dist.all_gather(tensor_list, tensor)
tensor2 = torch.ones([2, 3]).to(self.device) * self.rank
dist.all_reduce(tensor2)
with cv:
cv.wait(1)
default_pg_name = dist.group.WORLD.group_name
self.assertEqual(2, len(starts))
self.assertEqual(2, len(ends))
def check_op(idx, coll_name):
self.assertEqual(default_pg_name, starts[idx].pg_name)
self.assertEqual(self.backend_name, starts[idx].backend)
self.assertGreaterEqual(starts[idx].sequence_number, 0)
self.assertGreaterEqual(starts[idx].timestamp, 0)
self.assertEqual(coll_name, starts[idx].operation)
self.assertEqual(default_pg_name, ends[idx].pg_name)
self.assertEqual(self.backend_name, ends[idx].backend)
self.assertEqual(starts[idx].sequence_number, ends[idx].sequence_number)
self.assertLessEqual(starts[idx].timestamp, ends[idx].timestamp)
self.assertEqual(coll_name, ends[idx].operation)
check_op(0, "ALLGATHER")
check_op(1, "ALLREDUCE")
class GlooHooks(MultiProcessTestCase, CollectiveHooks):
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
def tearDown(self):
super().tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
def init_comms(self):
dist.init_process_group(
backend="gloo",
rank=self.rank,
world_size=self.world_size,
store=dist.FileStore(self.file_name, self.world_size),
)
def destroy_comms(self):
dist.destroy_process_group()
@property
def backend_name(self):
return "gloo"
@property
def device(self):
return "cpu"
@with_comms
def test_collective_hooks(self):
self._collective_hooks()
class NcclHooks(MultiProcessTestCase, CollectiveHooks):
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
def tearDown(self):
super().tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
def init_comms(self):
dist.init_process_group(
backend="nccl",
rank=self.rank,
world_size=self.world_size,
store=dist.FileStore(self.file_name, self.world_size),
)
def destroy_comms(self):
dist.destroy_process_group()
@property
def backend_name(self):
return "nccl"
@property
def device(self):
return f"cuda:{self.rank}"
@skip_if_lt_x_gpu(4)
@with_comms
def test_collective_hooks(self):
self._collective_hooks()
class SingleRankTests(TestCase):
def setUp(self) -> None:
super().setUp()
self.rank = 0
self.file_name = tempfile.NamedTemporaryFile(delete=False).name
dist.init_process_group(
backend="gloo",
rank=0,
world_size=1,
store=dist.FileStore(self.file_name, 1),
)
def tearDown(self) -> None:
dist.destroy_process_group()
def test_queue_overflow(self) -> None:
cv_done_colls = threading.Condition()
cv_done_cb = threading.Condition()
colls_done = False
starts = []
status_with_dropped = None
def coll_start(status: dhooks.CollectiveStatus):
starts.append(status)
with cv_done_colls:
while not colls_done:
cv_done_colls.wait()
if status.drop_count > 0:
nonlocal status_with_dropped
status_with_dropped = status
with cv_done_cb:
cv_done_cb.notify()
dhooks.register_collective_start_hook(coll_start)
# native limit is 512
for i in range(600):
dist.all_reduce(torch.ones([2, 3]))
colls_done = True
with cv_done_colls:
cv_done_colls.notify()
with cv_done_cb:
cv_done_cb.wait(10)
self.assertTrue(status_with_dropped is not None)
self.assertTrue(status_with_dropped.drop_count > 0)
if __name__ == "__main__":
assert (
not torch.cuda._initialized
), "test_distributed must not have initialized CUDA context on main process"
run_tests()

View File

@ -11,6 +11,10 @@ _DEFAULT_FIRST_BUCKET_BYTES: int
_DEFAULT_NO_TIMEOUT: timedelta
_DEFAULT_PG_TIMEOUT: timedelta
class EventKind(Enum):
START = ...
END = ...
class BuiltinCommHookType(Enum):
ALLREDUCE = ...
FP16_COMPRESS = ...
@ -20,6 +24,8 @@ def _register_builtin_comm_hook(
reducer: Reducer,
comm_hook_type: BuiltinCommHookType,
): ...
def _dequeue_c10d_event() -> Dict[str, object]: ...
def _enable_event_collection(pipe_fs: int) -> None: ...
class GradBucket:
def index(self) -> int: ...

View File

@ -1,9 +1,26 @@
#include <c10/util/Logging.h>
#include <fmt/format.h>
#include <torch/csrc/distributed/c10d/Backend.hpp>
#include <torch/csrc/distributed/c10d/Hooks.hpp>
#include <torch/csrc/distributed/c10d/logging.h>
namespace c10d {
namespace {
void commonEventinit(
details::EventInfo& evt,
const Backend& backend,
const Work& work) {
evt.timestamp =
std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
evt.pg_name = backend.getGroupName();
evt.backend = backend.getBackendName();
evt.sequence_number = work.getSequencenumber();
evt.operation = c10d::opTypeToString(work.retrieveOpType());
evt.drop_count = 0;
}
} // namespace
Backend::Backend(int rank, int size)
: rank_(rank), size_(size), dist_debug_level_(debug_level()) {
C10_LOG_API_USAGE_ONCE("c10d.backend");
@ -15,4 +32,21 @@ void Backend::init() {
C10_LOG_API_USAGE_ONCE(fmt::format("c10d.backend_{}", getBackendName()));
}
void Backend::emitCollectiveStart(const Work& work) {
details::EventInfo evt;
commonEventinit(evt, *this, work);
evt.event_kind = ::c10d::EventKind::CollectiveStart;
details::enqueue_c10d_event(std::move(evt));
}
void Backend::emitCollectiveEnd(const Work& work) {
details::EventInfo evt;
commonEventinit(evt, *this, work);
evt.event_kind = ::c10d::EventKind::CollectiveEnd;
evt.duration_ms = work.getDuration();
details::enqueue_c10d_event(std::move(evt));
}
} // namespace c10d

View File

@ -366,6 +366,8 @@ class TORCH_API Backend : public torch::CustomClassHolder {
// Implementations of this interface need to call this to setup
// appropriate logging etc.
void init();
void emitCollectiveStart(const Work& work);
void emitCollectiveEnd(const Work& work);
const int rank_;
const int size_;

View File

@ -0,0 +1,60 @@
#include <atomic>
#include <deque>
#include <memory>
#include <mutex>
#ifndef _WIN32
#include <unistd.h>
#else
#include <io.h>
#endif
#include <torch/csrc/distributed/c10d/Hooks.hpp>
namespace c10d {
namespace {
std::atomic<bool> event_queue_enabled = false;
int sync_pipe;
std::mutex event_queue_lock;
std::deque<details::EventInfo> event_queue;
} // namespace
void enable_event_collection(int pipe) {
sync_pipe = pipe;
event_queue_enabled.store(true);
}
namespace details {
// we start dropping events after this
const size_t MAX_QUEUE_SIZE = 512;
bool dequeue_c10d_event(EventInfo& evt) {
std::unique_lock<std::mutex> lock(event_queue_lock);
if (event_queue.size() == 0) {
return false;
}
evt = event_queue.front();
event_queue.pop_front();
return true;
}
void enqueue_c10d_event(EventInfo&& evt) {
if (!event_queue_enabled.load())
return;
std::unique_lock<std::mutex> lock(event_queue_lock);
if (event_queue.size() >= MAX_QUEUE_SIZE) {
event_queue.back().drop_count++;
} else {
event_queue.push_back(std::move(evt));
char m = 'x';
write(sync_pipe, &m, 1);
}
}
} // namespace details
} // namespace c10d

View File

@ -0,0 +1,30 @@
#pragma once
#include <c10/util/Optional.h>
#include <string>
namespace c10d {
enum class EventKind { CollectiveStart, CollectiveEnd };
TORCH_API void enable_event_collection(int sync_pipe);
namespace details {
struct TORCH_API EventInfo {
EventKind event_kind;
std::string pg_name;
std::string backend;
int64_t sequence_number;
std::string operation;
int64_t timestamp;
c10::optional<float> duration_ms;
int64_t drop_count;
};
// TODO do we want to expose something else here?
TORCH_API bool dequeue_c10d_event(EventInfo& evt);
TORCH_API void enqueue_c10d_event(EventInfo&& evt);
} // namespace details
} // namespace c10d

View File

@ -82,10 +82,14 @@ std::string opTypeToString(OpType opType) {
return "RECVANYSOURCE";
case OpType::BARRIER:
return "BARRIER";
case OpType::UNKNOWN:
return "UNKNOWN";
case OpType::_REDUCE_SCATTER_BASE:
return "_REDUCE_SCATTER_BASE";
case OpType::COALESCED:
return "COALESCED";
case OpType::_ALLREDUCE_SPARSE:
return "_ALLREDUCE_SPARSE";
case OpType::UNKNOWN:
return "UNKNOWN";
default:
TORCH_INTERNAL_ASSERT(false, "Unknown op type!");
}

View File

@ -855,6 +855,10 @@ void ProcessGroupGloo::runLoop(int workerIndex) {
}
void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
emitCollectiveStart(*work.get());
work->getFuture()->addCallback(
[=](auto& f) { this->emitCollectiveEnd(*work.get()); });
std::unique_lock<std::mutex> lock(workMutex_);
workQueue_.push_back(std::move(work));
lock.unlock();

View File

@ -942,27 +942,34 @@ void ProcessGroupNCCL::ncclCommWatchdog() {
}
}
void ProcessGroupNCCL::logWorkStart(WorkNCCL& work) {
if (work.startTraceUpdated_)
void ProcessGroupNCCL::logWorkStart(WorkNCCL& work, bool emitDesyncInfo) {
if (terminateProcessGroup_.load() || work.startTraceUpdated_)
return;
if (terminateProcessGroup_.load() || storeError_)
return;
work.startTraceUpdated_ = true;
emitCollectiveStart(work);
if (!emitDesyncInfo || storeError_)
return;
storeError_ = !c10d::traceUpdate(
store_, traceKeyStart_, work.seq_, opTypeToString(work.opType_));
}
void ProcessGroupNCCL::logWorkEnd(WorkNCCL& work) {
if (terminateProcessGroup_.load() || storeError_)
void ProcessGroupNCCL::logWorkEnd(WorkNCCL& work, bool emitDesyncInfo) {
if (terminateProcessGroup_.load())
return;
// In case the start of the work hasn't been logged
if (!work.startTraceUpdated_) {
logWorkStart(work);
logWorkStart(work, emitDesyncInfo);
}
emitCollectiveEnd(work);
if (!emitDesyncInfo || storeError_)
return;
storeError_ = !c10d::traceUpdate(
store_, traceKeyEnd_, work.seq_, opTypeToString(work.opType_));
}
@ -1014,13 +1021,11 @@ void ProcessGroupNCCL::workCleanupLoop() {
}
// Work status logging for desync debug
if (desyncDebug_) {
if (work.isStarted()) {
logWorkStart(work);
}
if (work.isCompleted()) {
logWorkEnd(work);
}
if (work.isStarted()) {
logWorkStart(work, desyncDebug_);
}
if (work.isCompleted()) {
logWorkEnd(work, desyncDebug_);
}
// Clean up completed work
@ -1077,7 +1082,7 @@ void ProcessGroupNCCL::runHookLoop() {
timeStarted, // timeStarted
std::chrono::system_clock::now(), // timeFinished
std::chrono::duration<float, std::milli>(
work.getDuration()) // activeDuration
work.getDuration().value()) // activeDuration
));
lock.lock();
@ -1580,19 +1585,19 @@ c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupNCCL::WorkNCCL::
return future_;
}
float ProcessGroupNCCL::WorkNCCL::getDuration() const {
TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled")
c10::optional<float> ProcessGroupNCCL::WorkNCCL::getDuration() const {
if (!timingEnabled_ || !((*ncclEndEvents_)[0].query())) {
return c10::optional<float>();
}
TORCH_CHECK(
ncclStartEvents_->size() == 1,
"getDuration only works for single device per ProcessGroup.");
TORCH_CHECK(
ncclEndEvents_->size() == 1,
"getDuration only works for single device per ProcessGroup.");
TORCH_CHECK(
(*ncclEndEvents_)[0].query(),
"getDuration can only be called after work is succeeded.")
return (*ncclStartEvents_)[0].elapsed_time((*ncclEndEvents_)[0]);
}
uint64_t ProcessGroupNCCL::WorkNCCL::getSequencenumber() const {
return seq_;
}

View File

@ -164,7 +164,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// Get a Future object that will be marked as completed internally.
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
float getDuration() const override;
c10::optional<float> getDuration() const override;
uint64_t getSequencenumber() const override;
@ -612,10 +612,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
void runHookLoop();
// Desync debug helper
void logWorkStart(WorkNCCL& work);
void logWorkStart(WorkNCCL& work, bool emitDesyncInfo);
// Desync debug helper
void logWorkEnd(WorkNCCL& work);
void logWorkEnd(WorkNCCL& work, bool emitDesyncInfo);
protected:
static const int64_t kWatchdogThreadSleepMillis;

View File

@ -127,8 +127,8 @@ void Work::finishAndThrow(std::exception_ptr exception) {
}
}
float Work::getDuration() const {
TORCH_CHECK(false, "This Backend doesn't support getDuration.");
c10::optional<float> Work::getDuration() const {
return c10::optional<float>();
}
uint64_t Work::getSequencenumber() const {

View File

@ -107,7 +107,7 @@ class TORCH_API Work : public torch::CustomClassHolder {
// work. Only NCCL backend is currently supported.
virtual c10::intrusive_ptr<c10::ivalue::Future> getFuture();
virtual float getDuration() const;
virtual c10::optional<float> getDuration() const;
virtual uint64_t getSequencenumber() const;

View File

@ -32,6 +32,7 @@
#include <fmt/format.h>
#include <pybind11/chrono.h>
#include <torch/csrc/distributed/c10d/Hooks.hpp>
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
#include <torch/csrc/distributed/c10d/comm.hpp>
@ -291,6 +292,26 @@ void _register_builtin_comm_hook(
reducer.register_builtin_comm_hook(comm_hook_type);
}
py::object c10d_dequeue_python_event() {
::c10d::details::EventInfo evt;
if (!::c10d::details::dequeue_c10d_event(evt)) {
return py::none();
}
py::dict data;
data["event_kind"] = (int)evt.event_kind;
data["pg_name"] = evt.pg_name;
data["backend"] = evt.backend;
data["sequence_number"] = evt.sequence_number;
data["operation"] = evt.operation;
data["timestamp"] = evt.timestamp;
data["duration"] = evt.duration_ms.value_or(-1);
data["drop_count"] = evt.drop_count;
return std::move(data);
}
// Customize the metaclass of ::c10d::ReduceOp for the backward compatibility.
// https://github.com/pytorch/pytorch/pull/84243 changed ::c10d::ReduceOp to
// struct from enum, sacrificing some of the Python built-in function supports
@ -640,6 +661,11 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
&::c10d::Logger::set_static_graph,
py::call_guard<py::gil_scoped_release>());
py::enum_<::c10d::EventKind>(module, "EventKind", R"(
An enum for collective hooks event types.)")
.value("START", ::c10d::EventKind::CollectiveStart)
.value("END", ::c10d::EventKind::CollectiveEnd);
py::enum_<::c10d::DebugLevel>(module, "DebugLevel", R"(
An enum whose values correspond to different debug levels of the
torch.distributed package. Currently supporting OFF, INFO, and DETAIL,
@ -663,7 +689,16 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
"set_debug_level_from_env",
::c10d::setDebugLevelFromEnvironment,
R"(Sets the debug level of the torch.distributed package from the
``TORCH_DISTRIBUTED_DEBUG`` environment variable.)");
``TORCH_DISTRIBUTED_DEBUG`` environment variable.)")
.def(
"_enable_event_collection",
&::c10d::enable_event_collection,
"(Enables events collection).",
py::call_guard<py::gil_scoped_release>())
.def(
"_dequeue_c10d_event",
&c10d_dequeue_python_event,
"(Blocks until a c10d event is available and return it as a python dictionary).");
// TODO(crcrpar): Hardening `ReduceOp`.
// While keeping most op types as enum value,

View File

@ -421,6 +421,19 @@ _group_count = 0
_tags_to_pg: Dict[str, List[ProcessGroup]] = {}
_pg_to_tag: Dict[ProcessGroup, str] = {}
class _HookState:
def __init__(self):
self.creation_hooks = []
def register_creation_hook(self, hook) -> None:
self.creation_hooks.append(hook)
def fire_creation_hook(self, pg, name) -> None:
for hook in self.creation_hooks:
try:
hook(pg, name)
except Exception as e:
logger.info("hook %s failed with %s", hook, e)
class _World:
"""
@ -434,6 +447,8 @@ class _World:
self._default_pg = None
self._pg_coalesce_state: Dict[ProcessGroup, List[Union[_CollOp, P2POp]]] = {}
self._pg_default_device: Dict[ProcessGroup, torch.device] = {}
self._hook_state = _HookState()
self.enable_collectives_timing = False
@property
def default_pg(self):
@ -543,6 +558,9 @@ class _World:
)
return config_info
@property
def pg_hook_state(self) -> _HookState:
return self._hook_state
_world = _World()
"""Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it"""
@ -1362,6 +1380,9 @@ def _new_process_group_helper(
pg._set_group_name(group_name)
_world.pg_backend_config[pg] = str(backend_config)
if _world.enable_collectives_timing:
pg._enable_collectives_timing()
# "" is the default tag for user PGs
if pg_tag in [None, ""]:
pg_tag = f"ptd:{group_name}"
@ -1371,6 +1392,8 @@ def _new_process_group_helper(
_world.tags_to_pg.setdefault(pg_tag, []).append(pg)
_world.pg_to_tag[pg] = pg_tag
_world.pg_hook_state.fire_creation_hook(pg, group_name)
return pg, prefix_store
def destroy_process_group(group: Optional[ProcessGroup] = None):
@ -4313,3 +4336,11 @@ dynamo_unsupported_distributed_c10d_ops = [
reduce_scatter_tensor,
send,
]
def _register_creation_hook(hook):
_world.pg_hook_state.register_creation_hook(hook)
def _enable_collectives_timing():
_world.enable_collectives_timing = True
for pg in _world.pg_map:
pg._enable_collectives_timing()

169
torch/distributed/hooks.py Normal file
View File

@ -0,0 +1,169 @@
import logging
import os
import threading
import time
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional
import torch.distributed as dist
import torch.distributed.distributed_c10d as c10d
from torch._C._distributed_c10d import (
_dequeue_c10d_event,
_enable_event_collection,
EventKind,
)
__all__ = [
"CollectiveStatus",
"COLLECTIVE_HOOK_TYPE",
"PG_HOOK_TYPE",
"register_collective_start_hook",
"register_collective_end_hook",
"register_process_group_hook",
]
@dataclass
class CollectiveStatus:
pg_name: str = "unknown" # This name matches the one informed in the pgreg cb
backend: str = "unknown" # Name of the backend used
sequence_number: int = -1 # This name matches the one informed in the pgreg cb
operation: str = "unknown" # collective name
timestamp: int = 0 # timestamp to the earliest time we noticed this event
duration: Optional[float] = None # value in milliseconds it took executing
drop_count: int = 0 # number of events dropped following this one
COLLECTIVE_HOOK_TYPE = Callable[[CollectiveStatus], None]
PG_HOOK_TYPE = Callable[[dist.ProcessGroup, str], None]
# This controls the number of internal failures we'll tolerate before giving up
_MAX_INTERNAL_FAILURES = 10
logger = logging.getLogger(__name__)
_cb_thread: Optional[threading.Thread] = None
_start_callbacks: List[COLLECTIVE_HOOK_TYPE] = []
_end_callbacks: List[COLLECTIVE_HOOK_TYPE] = []
_pp_r = -1
_pp_w = -1
def _c10d_pg_hooks_loops():
internal_failures = 0
while True:
# we don't care about the result, this is how we implement notification
_ = os.read(_pp_r, 1)
evt: Dict[str, object] = _dequeue_c10d_event()
try:
event_kind = evt.pop("event_kind", None)
if event_kind is None:
logger.warning(
"c10d returned event dictionary %s without 'event_kind' key, cannot dispatch",
evt,
)
internal_failures += 1
if internal_failures >= _MAX_INTERNAL_FAILURES:
logger.warning(
"too many internal c10d failures processing callback loop. stopping"
)
return
time.sleep(1)
continue
if event_kind == int(EventKind.START): # type: ignore[call-overload]
cb_list = _start_callbacks
elif event_kind == int(EventKind.END): # type: ignore[call-overload]
cb_list = _end_callbacks
else:
logger.warning(
"c10d event %s with invalid 'event_kind' with value %d",
evt,
event_kind,
)
internal_failures += 1
if internal_failures >= _MAX_INTERNAL_FAILURES:
logger.warning(
"too many internal c10d failures processing callback loop. stopping"
)
return
time.sleep(1)
continue
status = CollectiveStatus(**evt) # type: ignore[arg-type]
for cb in cb_list:
try:
cb(status)
except Exception as e:
logger.info(
"c10d event callback %s with event %s threw exception %s",
cb,
status,
e,
)
except Exception as e:
# We have to keep processing otherwise the queue will grown infinitely large
logger.warning(
"c10d callback thread when processing event %s raised exception %s.",
evt,
e,
)
internal_failures += 1
if internal_failures >= _MAX_INTERNAL_FAILURES:
logger.warning(
"too many internal c10d failures processing callback loop. stopping"
)
return
# Sleep for a second to avoid hogging the GIL in case of a persistent failure
time.sleep(1)
def _lazy_init():
global _cb_thread
if _cb_thread is not None:
return
global _pp_r
global _pp_w
_pp_r, _pp_w = os.pipe()
_enable_event_collection(_pp_w)
c10d._enable_collectives_timing()
_cb_thread = threading.Thread(target=_c10d_pg_hooks_loops, daemon=True)
_cb_thread.start()
logger.info("c10d::hooks thread enabled")
def register_collective_start_hook(hook: COLLECTIVE_HOOK_TYPE) -> None:
"""
Register a hook that is called every time a collective starts.
The hook is invoked on a background thread.
Exceptions raised by the callback are ignored and non-fatal.
"""
_start_callbacks.append(hook)
_lazy_init()
def register_collective_end_hook(hook: COLLECTIVE_HOOK_TYPE) -> None:
"""
Register a hook that is called every time a collective finishes.
The hook is invoked on a background thread.
Exceptions raised by the callback are ignored and non-fatal.
"""
_end_callbacks.append(hook)
_lazy_init()
def register_process_group_hook(hook: PG_HOOK_TYPE) -> None:
"""
Register a hook that is called every time a process group is created on this rank.
This hook is only invoked if the current rank is part of the PG being created.
The pg_name is unique to the whole cluster and should be treated as an opaque identified subject to change.
The hook is invoked on a background thread.
Exceptions raised by the callback are ignored and non-fatal.
"""
c10d._register_creation_hook(hook)

View File

@ -401,6 +401,11 @@ class WorldData:
class ThreadLocalWorld:
_world = threading.local()
def __init__(self):
self.enable_collectives_timing = False
self._hook_state = dist.distributed_c10d._HookState()
def _get_world(self) -> WorldData:
if not hasattr(ThreadLocalWorld._world, "world"):
ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0, {}, {}, {}, {})
@ -454,6 +459,9 @@ class ThreadLocalWorld:
def pg_default_device(self) -> Dict[dist.ProcessGroup, torch.device]:
return self._get_world().pg_default_device
@property
def pg_hook_state(self) -> dist.distributed_c10d._HookState:
return self._hook_state
_old_pg_world = None
_ctx_manager = None