Revert "Reland #2 "[C10] PG observability hooks. (#108815, #110907)" (#111072)"

This reverts commit bb1424d46e656dfcdd4c12efe58ada9f1720c4d8.

Reverted https://github.com/pytorch/pytorch/pull/111072 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/111072#issuecomment-1765399829))
This commit is contained in:
PyTorch MergeBot
2023-10-16 23:03:26 +00:00
parent 5a8a89360d
commit 1e70f4d02c
17 changed files with 31 additions and 704 deletions

View File

@ -522,7 +522,6 @@ libtorch_distributed_base_sources = [
"torch/csrc/distributed/c10d/Backend.cpp",
"torch/csrc/distributed/c10d/FileStore.cpp",
"torch/csrc/distributed/c10d/GlooDeviceFactory.cpp",
"torch/csrc/distributed/c10d/Hooks.cpp",
"torch/csrc/distributed/c10d/Ops.cpp",
"torch/csrc/distributed/c10d/ParamCommsUtils.cpp",
"torch/csrc/distributed/c10d/PrefixStore.cpp",

View File

@ -1,270 +0,0 @@
# Owner(s): ["oncall: distributed"]
import os
import sys
import tempfile
import threading
from functools import partial, wraps
import torch
import torch.distributed as dist
import torch.distributed._hooks as dhooks
if not dist.is_available():
print("torch.distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import run_tests, TestCase
class PgHooks(MultiProcessTestCase):
@property
def world_size(self) -> int:
return 4
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
def tearDown(self):
super().tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
def test_pg_hook(self):
pgs = []
def pg_hook(pg, pg_name):
pgs.append((pg, pg_name))
dhooks.register_process_group_hook(pg_hook)
dist.init_process_group(
backend="gloo",
rank=self.rank,
world_size=self.world_size,
store=dist.FileStore(self.file_name, self.world_size),
)
self.assertEqual(len(pgs), 1)
self.assertEqual(pgs[0][0], dist.group.WORLD)
# create two partial world PGs
pg0 = dist.new_group(ranks=[0, 1])
pg1 = dist.new_group(ranks=[2, 3])
# Each rank only observe two PGs being created: the default PG and one covering its ranks
# We don't emit events for PG creation if the current rank doesn't belong to it.
# For example, say you're rank 1, you'll get an event for pg0 but not pg1 even though the API contact
# dictates you need to call new_group for both.
self.assertEqual(len(pgs), 2)
self.assertEqual(pgs[1][0], pg0 if self.rank < 2 else pg1)
def with_comms(func=None):
if func is None:
return partial(
with_comms,
)
@wraps(func)
def wrapper(self, *args, **kwargs):
self.init_comms()
func(self, *args, **kwargs)
self.destroy_comms()
return wrapper
class CollectiveHooks:
@property
def world_size(self) -> int:
return 4
def _collective_hooks(self):
# it's ok to access them directly since there's a single bg thread poking at them.
starts = []
ends = []
cv = threading.Condition()
def coll_start(status):
starts.append(status)
print(f"col_start {len(starts)} rank{self.rank}")
def coll_end(status):
ends.append(status)
print(f"col_end {len(ends)} rank{self.rank}")
if len(ends) == 2:
with cv:
cv.notify()
dhooks.register_collective_start_hook(coll_start)
dhooks.register_collective_end_hook(coll_end)
tensor = torch.ones([2, 3]).to(self.device) * self.rank
tensor_list = [torch.empty_like(tensor) for _ in range(self.world_size)]
dist.all_gather(tensor_list, tensor)
tensor2 = torch.ones([2, 3]).to(self.device) * self.rank
dist.all_reduce(tensor2)
with cv:
cv.wait(1)
default_pg_name = dist.group.WORLD.group_name
self.assertEqual(2, len(starts))
self.assertEqual(2, len(ends))
def check_op(idx, coll_name):
self.assertEqual(default_pg_name, starts[idx].pg_name)
self.assertEqual(self.backend_name, starts[idx].backend)
self.assertGreaterEqual(starts[idx].sequence_number, 0)
self.assertGreaterEqual(starts[idx].timestamp, 0)
self.assertEqual(coll_name, starts[idx].operation)
self.assertEqual(default_pg_name, ends[idx].pg_name)
self.assertEqual(self.backend_name, ends[idx].backend)
self.assertEqual(starts[idx].sequence_number, ends[idx].sequence_number)
self.assertLessEqual(starts[idx].timestamp, ends[idx].timestamp)
self.assertEqual(coll_name, ends[idx].operation)
check_op(0, "ALLGATHER")
check_op(1, "ALLREDUCE")
class GlooHooks(MultiProcessTestCase, CollectiveHooks):
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
def tearDown(self):
super().tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
def init_comms(self):
dist.init_process_group(
backend="gloo",
rank=self.rank,
world_size=self.world_size,
store=dist.FileStore(self.file_name, self.world_size),
)
def destroy_comms(self):
dist.destroy_process_group()
@property
def backend_name(self):
return "gloo"
@property
def device(self):
return "cpu"
@with_comms
def test_collective_hooks(self):
self._collective_hooks()
class NcclHooks(MultiProcessTestCase, CollectiveHooks):
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
def tearDown(self):
super().tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
def init_comms(self):
dist.init_process_group(
backend="nccl",
rank=self.rank,
world_size=self.world_size,
store=dist.FileStore(self.file_name, self.world_size),
)
def destroy_comms(self):
dist.destroy_process_group()
@property
def backend_name(self):
return "nccl"
@property
def device(self):
return f"cuda:{self.rank}"
@skip_if_lt_x_gpu(4)
@with_comms
def test_collective_hooks(self):
self._collective_hooks()
class SingleRankTests(TestCase):
def setUp(self) -> None:
super().setUp()
self.rank = 0
self.file_name = tempfile.NamedTemporaryFile(delete=False).name
dist.init_process_group(
backend="gloo",
rank=0,
world_size=1,
store=dist.FileStore(self.file_name, 1),
)
def tearDown(self) -> None:
dist.destroy_process_group()
def test_queue_overflow(self) -> None:
cv_done_colls = threading.Condition()
cv_done_cb = threading.Condition()
colls_done = False
starts = []
status_with_dropped = None
def coll_start(status: dhooks.CollectiveStatus):
starts.append(status)
with cv_done_colls:
while not colls_done:
cv_done_colls.wait()
if status.drop_count > 0:
nonlocal status_with_dropped
status_with_dropped = status
with cv_done_cb:
cv_done_cb.notify()
dhooks.register_collective_start_hook(coll_start)
# native limit is 512
for i in range(600):
dist.all_reduce(torch.ones([2, 3]))
colls_done = True
with cv_done_colls:
cv_done_colls.notify()
with cv_done_cb:
cv_done_cb.wait(10)
self.assertTrue(status_with_dropped is not None)
self.assertTrue(status_with_dropped.drop_count > 0)
if __name__ == "__main__":
assert (
not torch.cuda._initialized
), "test_distributed must not have initialized CUDA context on main process"
run_tests()

View File

@ -11,10 +11,6 @@ _DEFAULT_FIRST_BUCKET_BYTES: int
_DEFAULT_NO_TIMEOUT: timedelta
_DEFAULT_PG_TIMEOUT: timedelta
class EventKind(Enum):
START = ...
END = ...
class BuiltinCommHookType(Enum):
ALLREDUCE = ...
FP16_COMPRESS = ...
@ -24,8 +20,6 @@ def _register_builtin_comm_hook(
reducer: Reducer,
comm_hook_type: BuiltinCommHookType,
): ...
def _dequeue_c10d_event() -> Dict[str, object]: ...
def _enable_event_collection(pipe_fs: int) -> None: ...
class GradBucket:
def index(self) -> int: ...

View File

@ -1,26 +1,9 @@
#include <c10/util/Logging.h>
#include <fmt/format.h>
#include <torch/csrc/distributed/c10d/Backend.hpp>
#include <torch/csrc/distributed/c10d/Hooks.hpp>
#include <torch/csrc/distributed/c10d/logging.h>
namespace c10d {
namespace {
void commonEventinit(
details::EventInfo& evt,
const Backend& backend,
const Work& work) {
evt.timestamp =
std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
evt.pg_name = backend.getGroupName();
evt.backend = backend.getBackendName();
evt.sequence_number = work.getSequencenumber();
evt.operation = c10d::opTypeToString(work.retrieveOpType());
evt.drop_count = 0;
}
} // namespace
Backend::Backend(int rank, int size)
: rank_(rank), size_(size), dist_debug_level_(debug_level()) {
C10_LOG_API_USAGE_ONCE("c10d.backend");
@ -32,21 +15,4 @@ void Backend::init() {
C10_LOG_API_USAGE_ONCE(fmt::format("c10d.backend_{}", getBackendName()));
}
void Backend::emitCollectiveStart(const Work& work) {
details::EventInfo evt;
commonEventinit(evt, *this, work);
evt.event_kind = ::c10d::EventKind::CollectiveStart;
details::enqueue_c10d_event(std::move(evt));
}
void Backend::emitCollectiveEnd(const Work& work) {
details::EventInfo evt;
commonEventinit(evt, *this, work);
evt.event_kind = ::c10d::EventKind::CollectiveEnd;
evt.duration_ms = work.getDuration();
details::enqueue_c10d_event(std::move(evt));
}
} // namespace c10d

View File

@ -366,8 +366,6 @@ class TORCH_API Backend : public torch::CustomClassHolder {
// Implementations of this interface need to call this to setup
// appropriate logging etc.
void init();
void emitCollectiveStart(const Work& work);
void emitCollectiveEnd(const Work& work);
const int rank_;
const int size_;

View File

@ -1,60 +0,0 @@
#include <atomic>
#include <deque>
#include <memory>
#include <mutex>
#ifndef _WIN32
#include <unistd.h>
#else
#include <io.h>
#endif
#include <torch/csrc/distributed/c10d/Hooks.hpp>
namespace c10d {
namespace {
std::atomic<bool> event_queue_enabled = false;
int sync_pipe;
std::mutex event_queue_lock;
std::deque<details::EventInfo> event_queue;
} // namespace
void enable_event_collection(int pipe) {
sync_pipe = pipe;
event_queue_enabled.store(true);
}
namespace details {
// we start dropping events after this
const size_t MAX_QUEUE_SIZE = 512;
bool dequeue_c10d_event(EventInfo& evt) {
std::unique_lock<std::mutex> lock(event_queue_lock);
if (event_queue.size() == 0) {
return false;
}
evt = event_queue.front();
event_queue.pop_front();
return true;
}
void enqueue_c10d_event(EventInfo&& evt) {
if (!event_queue_enabled.load())
return;
std::unique_lock<std::mutex> lock(event_queue_lock);
if (event_queue.size() >= MAX_QUEUE_SIZE) {
event_queue.back().drop_count++;
} else {
event_queue.push_back(std::move(evt));
char m = 'x';
write(sync_pipe, &m, 1);
}
}
} // namespace details
} // namespace c10d

View File

@ -1,30 +0,0 @@
#pragma once
#include <c10/util/Optional.h>
#include <string>
namespace c10d {
enum class EventKind { CollectiveStart, CollectiveEnd };
TORCH_API void enable_event_collection(int sync_pipe);
namespace details {
struct TORCH_API EventInfo {
EventKind event_kind;
std::string pg_name;
std::string backend;
int64_t sequence_number;
std::string operation;
int64_t timestamp;
c10::optional<float> duration_ms;
int64_t drop_count;
};
// TODO do we want to expose something else here?
TORCH_API bool dequeue_c10d_event(EventInfo& evt);
TORCH_API void enqueue_c10d_event(EventInfo&& evt);
} // namespace details
} // namespace c10d

View File

@ -82,14 +82,10 @@ std::string opTypeToString(OpType opType) {
return "RECVANYSOURCE";
case OpType::BARRIER:
return "BARRIER";
case OpType::_REDUCE_SCATTER_BASE:
return "_REDUCE_SCATTER_BASE";
case OpType::COALESCED:
return "COALESCED";
case OpType::_ALLREDUCE_SPARSE:
return "_ALLREDUCE_SPARSE";
case OpType::UNKNOWN:
return "UNKNOWN";
case OpType::_REDUCE_SCATTER_BASE:
return "_REDUCE_SCATTER_BASE";
default:
TORCH_INTERNAL_ASSERT(false, "Unknown op type!");
}

View File

@ -855,10 +855,6 @@ void ProcessGroupGloo::runLoop(int workerIndex) {
}
void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
emitCollectiveStart(*work.get());
work->getFuture()->addCallback(
[=](auto& f) { this->emitCollectiveEnd(*work.get()); });
std::unique_lock<std::mutex> lock(workMutex_);
workQueue_.push_back(std::move(work));
lock.unlock();

View File

@ -931,34 +931,27 @@ void ProcessGroupNCCL::ncclCommWatchdog() {
}
}
void ProcessGroupNCCL::logWorkStart(WorkNCCL& work, bool emitDesyncInfo) {
if (terminateProcessGroup_.load() || work.startTraceUpdated_)
void ProcessGroupNCCL::logWorkStart(WorkNCCL& work) {
if (work.startTraceUpdated_)
return;
if (terminateProcessGroup_.load() || storeError_)
return;
work.startTraceUpdated_ = true;
emitCollectiveStart(work);
if (!emitDesyncInfo || storeError_)
return;
storeError_ = !c10d::traceUpdate(
store_, traceKeyStart_, work.seq_, opTypeToString(work.opType_));
}
void ProcessGroupNCCL::logWorkEnd(WorkNCCL& work, bool emitDesyncInfo) {
if (terminateProcessGroup_.load())
void ProcessGroupNCCL::logWorkEnd(WorkNCCL& work) {
if (terminateProcessGroup_.load() || storeError_)
return;
// In case the start of the work hasn't been logged
if (!work.startTraceUpdated_) {
logWorkStart(work, emitDesyncInfo);
logWorkStart(work);
}
emitCollectiveEnd(work);
if (!emitDesyncInfo || storeError_)
return;
storeError_ = !c10d::traceUpdate(
store_, traceKeyEnd_, work.seq_, opTypeToString(work.opType_));
}
@ -1010,11 +1003,13 @@ void ProcessGroupNCCL::workCleanupLoop() {
}
// Work status logging for desync debug
if (work.isStarted()) {
logWorkStart(work, desyncDebug_);
}
if (work.isCompleted()) {
logWorkEnd(work, desyncDebug_);
if (desyncDebug_) {
if (work.isStarted()) {
logWorkStart(work);
}
if (work.isCompleted()) {
logWorkEnd(work);
}
}
// Clean up completed work
@ -1071,7 +1066,7 @@ void ProcessGroupNCCL::runHookLoop() {
timeStarted, // timeStarted
std::chrono::system_clock::now(), // timeFinished
std::chrono::duration<float, std::milli>(
work.getDuration().value()) // activeDuration
work.getDuration()) // activeDuration
));
lock.lock();
@ -1584,19 +1579,19 @@ c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupNCCL::WorkNCCL::
return future_;
}
c10::optional<float> ProcessGroupNCCL::WorkNCCL::getDuration() const {
if (!timingEnabled_ || !((*ncclEndEvents_)[0].query())) {
return c10::optional<float>();
}
float ProcessGroupNCCL::WorkNCCL::getDuration() const {
TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled")
TORCH_CHECK(
ncclStartEvents_->size() == 1,
"getDuration only works for single device per ProcessGroup.");
TORCH_CHECK(
ncclEndEvents_->size() == 1,
"getDuration only works for single device per ProcessGroup.");
TORCH_CHECK(
(*ncclEndEvents_)[0].query(),
"getDuration can only be called after work is succeeded.")
return (*ncclStartEvents_)[0].elapsed_time((*ncclEndEvents_)[0]);
}
uint64_t ProcessGroupNCCL::WorkNCCL::getSequencenumber() const {
return seq_;
}

View File

@ -167,7 +167,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// Get a Future object that will be marked as completed internally.
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
c10::optional<float> getDuration() const override;
float getDuration() const override;
uint64_t getSequencenumber() const override;
@ -615,10 +615,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
void runHookLoop();
// Desync debug helper
void logWorkStart(WorkNCCL& work, bool emitDesyncInfo);
void logWorkStart(WorkNCCL& work);
// Desync debug helper
void logWorkEnd(WorkNCCL& work, bool emitDesyncInfo);
void logWorkEnd(WorkNCCL& work);
protected:
static const int64_t kWatchdogThreadSleepMillis;

View File

@ -127,8 +127,8 @@ void Work::finishAndThrow(std::exception_ptr exception) {
}
}
c10::optional<float> Work::getDuration() const {
return c10::optional<float>();
float Work::getDuration() const {
TORCH_CHECK(false, "This Backend doesn't support getDuration.");
}
uint64_t Work::getSequencenumber() const {

View File

@ -107,7 +107,7 @@ class TORCH_API Work : public torch::CustomClassHolder {
// work. Only NCCL backend is currently supported.
virtual c10::intrusive_ptr<c10::ivalue::Future> getFuture();
virtual c10::optional<float> getDuration() const;
virtual float getDuration() const;
virtual uint64_t getSequencenumber() const;

View File

@ -32,7 +32,6 @@
#include <fmt/format.h>
#include <pybind11/chrono.h>
#include <torch/csrc/distributed/c10d/Hooks.hpp>
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
#include <torch/csrc/distributed/c10d/comm.hpp>
@ -292,26 +291,6 @@ void _register_builtin_comm_hook(
reducer.register_builtin_comm_hook(comm_hook_type);
}
py::object c10d_dequeue_python_event() {
::c10d::details::EventInfo evt;
if (!::c10d::details::dequeue_c10d_event(evt)) {
return py::none();
}
py::dict data;
data["event_kind"] = (int)evt.event_kind;
data["pg_name"] = evt.pg_name;
data["backend"] = evt.backend;
data["sequence_number"] = evt.sequence_number;
data["operation"] = evt.operation;
data["timestamp"] = evt.timestamp;
data["duration"] = evt.duration_ms.value_or(-1);
data["drop_count"] = evt.drop_count;
return std::move(data);
}
// Customize the metaclass of ::c10d::ReduceOp for the backward compatibility.
// https://github.com/pytorch/pytorch/pull/84243 changed ::c10d::ReduceOp to
// struct from enum, sacrificing some of the Python built-in function supports
@ -661,11 +640,6 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
&::c10d::Logger::set_static_graph,
py::call_guard<py::gil_scoped_release>());
py::enum_<::c10d::EventKind>(module, "EventKind", R"(
An enum for collective hooks event types.)")
.value("START", ::c10d::EventKind::CollectiveStart)
.value("END", ::c10d::EventKind::CollectiveEnd);
py::enum_<::c10d::DebugLevel>(module, "DebugLevel", R"(
An enum whose values correspond to different debug levels of the
torch.distributed package. Currently supporting OFF, INFO, and DETAIL,
@ -689,16 +663,7 @@ An enum for collective hooks event types.)")
"set_debug_level_from_env",
::c10d::setDebugLevelFromEnvironment,
R"(Sets the debug level of the torch.distributed package from the
``TORCH_DISTRIBUTED_DEBUG`` environment variable.)")
.def(
"_enable_event_collection",
&::c10d::enable_event_collection,
"(Enables events collection).",
py::call_guard<py::gil_scoped_release>())
.def(
"_dequeue_c10d_event",
&c10d_dequeue_python_event,
"(Blocks until a c10d event is available and return it as a python dictionary).");
``TORCH_DISTRIBUTED_DEBUG`` environment variable.)");
// TODO(crcrpar): Hardening `ReduceOp`.
// While keeping most op types as enum value,

View File

@ -1,183 +0,0 @@
import logging
import os
import threading
import time
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional
import torch.distributed as dist
if dist.is_available():
import torch.distributed.distributed_c10d as c10d
from torch._C._distributed_c10d import (
_dequeue_c10d_event,
_enable_event_collection,
EventKind,
)
__all__ = [
"CollectiveStatus",
"COLLECTIVE_HOOK_TYPE",
"PG_HOOK_TYPE",
"register_collective_start_hook",
"register_collective_end_hook",
"register_process_group_hook",
]
@dataclass
class CollectiveStatus:
r"""Status of a collective operation.
Drop count indicates events have been dropped at the producer side, which means they
were not consumed fast enough by hooks.
"""
pg_name: str = "unknown" # This name matches the one informed in the pgreg cb
backend: str = "unknown" # Name of the backend used
sequence_number: int = -1 # This name matches the one informed in the pgreg cb
operation: str = "unknown" # collective name
timestamp: int = 0 # timestamp to the earliest time we noticed this event
duration: Optional[float] = None # value in milliseconds it took executing
drop_count: int = 0 # number of events dropped following this one
COLLECTIVE_HOOK_TYPE = Callable[[CollectiveStatus], None]
PG_HOOK_TYPE = Callable[[dist.ProcessGroup, str], None]
# This controls the number of internal failures we'll tolerate before giving up
_MAX_INTERNAL_FAILURES = 10
logger = logging.getLogger(__name__)
_cb_thread: Optional[threading.Thread] = None
_start_callbacks: List[COLLECTIVE_HOOK_TYPE] = []
_end_callbacks: List[COLLECTIVE_HOOK_TYPE] = []
_pp_r = -1
_pp_w = -1
def _c10d_pg_hooks_loops():
internal_failures = 0
while True:
# we don't care about the result, this is how we implement notification
_ = os.read(_pp_r, 1)
evt: Dict[str, object] = _dequeue_c10d_event()
try:
event_kind = evt.pop("event_kind", None)
if event_kind is None:
logger.warning(
"c10d returned event dictionary %s without 'event_kind' key, cannot dispatch",
evt,
)
internal_failures += 1
if internal_failures >= _MAX_INTERNAL_FAILURES:
logger.warning(
"too many internal c10d failures processing callback loop. stopping"
)
return
time.sleep(1)
continue
if event_kind == int(EventKind.START): # type: ignore[call-overload]
cb_list = _start_callbacks
elif event_kind == int(EventKind.END): # type: ignore[call-overload]
cb_list = _end_callbacks
else:
logger.warning(
"c10d event %s with invalid 'event_kind' with value %d",
evt,
event_kind,
)
internal_failures += 1
if internal_failures >= _MAX_INTERNAL_FAILURES:
logger.warning(
"too many internal c10d failures processing callback loop. stopping"
)
return
time.sleep(1)
continue
status = CollectiveStatus(**evt) # type: ignore[arg-type]
for cb in cb_list:
try:
cb(status)
except Exception as e:
logger.info(
"c10d event callback %s with event %s threw exception %s",
cb,
status,
e,
)
except Exception as e:
# We have to keep processing otherwise the queue will grown infinitely large
logger.warning(
"c10d callback thread when processing event %s raised exception %s.",
evt,
e,
)
internal_failures += 1
if internal_failures >= _MAX_INTERNAL_FAILURES:
logger.warning(
"too many internal c10d failures processing callback loop. stopping"
)
return
# Sleep for a second to avoid hogging the GIL in case of a persistent failure
time.sleep(1)
def _lazy_init():
global _cb_thread
if _cb_thread is not None:
return
global _pp_r
global _pp_w
_pp_r, _pp_w = os.pipe()
_enable_event_collection(_pp_w)
c10d._enable_collectives_timing()
_cb_thread = threading.Thread(target=_c10d_pg_hooks_loops, daemon=True)
_cb_thread.start()
logger.info("c10d::hooks thread enabled")
def _check_distributed_available():
if not dist.is_available():
raise RuntimeError(
"torch.distributed is not available, so hooks are not available."
)
def register_collective_start_hook(hook: COLLECTIVE_HOOK_TYPE) -> None:
r"""Register a hook that is called every time a collective starts.
The hook is invoked on a background thread.
Exceptions raised by the callback are ignored and non-fatal.
"""
_check_distributed_available()
_start_callbacks.append(hook)
_lazy_init()
def register_collective_end_hook(hook: COLLECTIVE_HOOK_TYPE) -> None:
r"""Register a hook that is called every time a collective finishes.
The hook is invoked on a background thread.
Exceptions raised by the callback are ignored and non-fatal.
"""
_check_distributed_available()
_end_callbacks.append(hook)
_lazy_init()
def register_process_group_hook(hook: PG_HOOK_TYPE) -> None:
r"""Register a hook that is called every time a process group is created on this rank.
This hook is only invoked if the current rank is part of the PG being created.
The pg_name is unique to the whole cluster and should be treated as an opaque identified subject to change.
The hook is invoked on a background thread.
Exceptions raised by the callback are ignored and non-fatal.
"""
_check_distributed_available()
c10d._register_creation_hook(hook)

View File

@ -421,19 +421,6 @@ _group_count = 0
_tags_to_pg: Dict[str, List[ProcessGroup]] = {}
_pg_to_tag: Dict[ProcessGroup, str] = {}
class _HookState:
def __init__(self):
self.creation_hooks = []
def register_creation_hook(self, hook) -> None:
self.creation_hooks.append(hook)
def fire_creation_hook(self, pg, name) -> None:
for hook in self.creation_hooks:
try:
hook(pg, name)
except Exception as e:
logger.info("hook %s failed with %s", hook, e)
class _World:
"""
@ -447,8 +434,6 @@ class _World:
self._default_pg = None
self._pg_coalesce_state: Dict[ProcessGroup, List[Union[_CollOp, P2POp]]] = {}
self._pg_default_device: Dict[ProcessGroup, torch.device] = {}
self._hook_state = _HookState()
self.enable_collectives_timing = False
@property
def default_pg(self):
@ -558,9 +543,6 @@ class _World:
)
return config_info
@property
def pg_hook_state(self) -> _HookState:
return self._hook_state
_world = _World()
"""Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it"""
@ -1382,9 +1364,6 @@ def _new_process_group_helper(
pg._set_group_name(group_name)
_world.pg_backend_config[pg] = str(backend_config)
if _world.enable_collectives_timing:
pg._enable_collectives_timing()
# "" is the default tag for user PGs
if pg_tag in [None, ""]:
pg_tag = f"ptd:{group_name}"
@ -1394,8 +1373,6 @@ def _new_process_group_helper(
_world.tags_to_pg.setdefault(pg_tag, []).append(pg)
_world.pg_to_tag[pg] = pg_tag
_world.pg_hook_state.fire_creation_hook(pg, group_name)
return pg, prefix_store
def destroy_process_group(group: Optional[ProcessGroup] = None):
@ -4338,11 +4315,3 @@ dynamo_unsupported_distributed_c10d_ops = [
reduce_scatter_tensor,
send,
]
def _register_creation_hook(hook):
_world.pg_hook_state.register_creation_hook(hook)
def _enable_collectives_timing():
_world.enable_collectives_timing = True
for pg in _world.pg_map:
pg._enable_collectives_timing()

View File

@ -401,11 +401,6 @@ class WorldData:
class ThreadLocalWorld:
_world = threading.local()
def __init__(self):
self.enable_collectives_timing = False
self._hook_state = dist.distributed_c10d._HookState()
def _get_world(self) -> WorldData:
if not hasattr(ThreadLocalWorld._world, "world"):
ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0, {}, {}, {}, {})
@ -459,9 +454,6 @@ class ThreadLocalWorld:
def pg_default_device(self) -> Dict[dist.ProcessGroup, torch.device]:
return self._get_world().pg_default_device
@property
def pg_hook_state(self) -> dist.distributed_c10d._HookState:
return self._hook_state
_old_pg_world = None
_ctx_manager = None