mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Expose a set of observability hooks into C10D such that our users can detect collectives failure both faster and more easily. The design is similar to NCCL desync debug that it minimized the overhead by doing most of the work out of the main thread. This PR introduces a new module torch.distributed.hooks that exposes the following set of methods: register_collective_start_hook register_collective_end_hook register_process_group_hook The process group hook exposes PG creation on the member ranks and call them inline from the the PG creation code. This is fine since this happens during initialization and a limited number of times. The collective start/end hooks are fired from a single background thread. It reads events from a C++ queue and dispatches over. Queue notification is oddly done using a pipe, this is needed so python can abort the thread on shutdown and have it as background thread. This is not possible with more reasonable choices like a condvar. Pull Request resolved: https://github.com/pytorch/pytorch/pull/108815 Approved by: https://github.com/wconstab, https://github.com/fduwjj
170 lines
5.6 KiB
Python
170 lines
5.6 KiB
Python
import logging
|
|
import os
|
|
import threading
|
|
import time
|
|
from dataclasses import dataclass
|
|
from typing import Callable, Dict, List, Optional
|
|
|
|
import torch.distributed as dist
|
|
import torch.distributed.distributed_c10d as c10d
|
|
|
|
from torch._C._distributed_c10d import (
|
|
_dequeue_c10d_event,
|
|
_enable_event_collection,
|
|
EventKind,
|
|
)
|
|
|
|
__all__ = [
|
|
"CollectiveStatus",
|
|
"COLLECTIVE_HOOK_TYPE",
|
|
"PG_HOOK_TYPE",
|
|
"register_collective_start_hook",
|
|
"register_collective_end_hook",
|
|
"register_process_group_hook",
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class CollectiveStatus:
|
|
pg_name: str = "unknown" # This name matches the one informed in the pgreg cb
|
|
backend: str = "unknown" # Name of the backend used
|
|
sequence_number: int = -1 # This name matches the one informed in the pgreg cb
|
|
operation: str = "unknown" # collective name
|
|
timestamp: int = 0 # timestamp to the earliest time we noticed this event
|
|
duration: Optional[float] = None # value in milliseconds it took executing
|
|
drop_count: int = 0 # number of events dropped following this one
|
|
|
|
|
|
COLLECTIVE_HOOK_TYPE = Callable[[CollectiveStatus], None]
|
|
PG_HOOK_TYPE = Callable[[dist.ProcessGroup, str], None]
|
|
|
|
# This controls the number of internal failures we'll tolerate before giving up
|
|
_MAX_INTERNAL_FAILURES = 10
|
|
|
|
logger = logging.getLogger(__name__)
|
|
_cb_thread: Optional[threading.Thread] = None
|
|
_start_callbacks: List[COLLECTIVE_HOOK_TYPE] = []
|
|
_end_callbacks: List[COLLECTIVE_HOOK_TYPE] = []
|
|
_pp_r = -1
|
|
_pp_w = -1
|
|
|
|
|
|
def _c10d_pg_hooks_loops():
|
|
internal_failures = 0
|
|
while True:
|
|
# we don't care about the result, this is how we implement notification
|
|
_ = os.read(_pp_r, 1)
|
|
evt: Dict[str, object] = _dequeue_c10d_event()
|
|
try:
|
|
event_kind = evt.pop("event_kind", None)
|
|
if event_kind is None:
|
|
logger.warning(
|
|
"c10d returned event dictionary %s without 'event_kind' key, cannot dispatch",
|
|
evt,
|
|
)
|
|
internal_failures += 1
|
|
if internal_failures >= _MAX_INTERNAL_FAILURES:
|
|
logger.warning(
|
|
"too many internal c10d failures processing callback loop. stopping"
|
|
)
|
|
return
|
|
time.sleep(1)
|
|
continue
|
|
|
|
if event_kind == int(EventKind.START): # type: ignore[call-overload]
|
|
cb_list = _start_callbacks
|
|
elif event_kind == int(EventKind.END): # type: ignore[call-overload]
|
|
cb_list = _end_callbacks
|
|
else:
|
|
logger.warning(
|
|
"c10d event %s with invalid 'event_kind' with value %d",
|
|
evt,
|
|
event_kind,
|
|
)
|
|
internal_failures += 1
|
|
if internal_failures >= _MAX_INTERNAL_FAILURES:
|
|
logger.warning(
|
|
"too many internal c10d failures processing callback loop. stopping"
|
|
)
|
|
return
|
|
time.sleep(1)
|
|
continue
|
|
|
|
status = CollectiveStatus(**evt) # type: ignore[arg-type]
|
|
for cb in cb_list:
|
|
try:
|
|
cb(status)
|
|
except Exception as e:
|
|
logger.info(
|
|
"c10d event callback %s with event %s threw exception %s",
|
|
cb,
|
|
status,
|
|
e,
|
|
)
|
|
except Exception as e:
|
|
# We have to keep processing otherwise the queue will grown infinitely large
|
|
logger.warning(
|
|
"c10d callback thread when processing event %s raised exception %s.",
|
|
evt,
|
|
e,
|
|
)
|
|
internal_failures += 1
|
|
if internal_failures >= _MAX_INTERNAL_FAILURES:
|
|
logger.warning(
|
|
"too many internal c10d failures processing callback loop. stopping"
|
|
)
|
|
return
|
|
|
|
# Sleep for a second to avoid hogging the GIL in case of a persistent failure
|
|
time.sleep(1)
|
|
|
|
|
|
def _lazy_init():
|
|
global _cb_thread
|
|
if _cb_thread is not None:
|
|
return
|
|
global _pp_r
|
|
global _pp_w
|
|
_pp_r, _pp_w = os.pipe()
|
|
_enable_event_collection(_pp_w)
|
|
c10d._enable_collectives_timing()
|
|
_cb_thread = threading.Thread(target=_c10d_pg_hooks_loops, daemon=True)
|
|
_cb_thread.start()
|
|
logger.info("c10d::hooks thread enabled")
|
|
|
|
|
|
def register_collective_start_hook(hook: COLLECTIVE_HOOK_TYPE) -> None:
|
|
"""
|
|
Register a hook that is called every time a collective starts.
|
|
|
|
The hook is invoked on a background thread.
|
|
Exceptions raised by the callback are ignored and non-fatal.
|
|
"""
|
|
_start_callbacks.append(hook)
|
|
_lazy_init()
|
|
|
|
|
|
def register_collective_end_hook(hook: COLLECTIVE_HOOK_TYPE) -> None:
|
|
"""
|
|
Register a hook that is called every time a collective finishes.
|
|
|
|
|
|
The hook is invoked on a background thread.
|
|
Exceptions raised by the callback are ignored and non-fatal.
|
|
"""
|
|
_end_callbacks.append(hook)
|
|
_lazy_init()
|
|
|
|
|
|
def register_process_group_hook(hook: PG_HOOK_TYPE) -> None:
|
|
"""
|
|
Register a hook that is called every time a process group is created on this rank.
|
|
|
|
This hook is only invoked if the current rank is part of the PG being created.
|
|
The pg_name is unique to the whole cluster and should be treated as an opaque identified subject to change.
|
|
|
|
The hook is invoked on a background thread.
|
|
Exceptions raised by the callback are ignored and non-fatal.
|
|
"""
|
|
c10d._register_creation_hook(hook)
|