Revert "Reland #2 "[C10] PG observability hooks. (#108815, #110907)" (#111072)"

This reverts commit bb1424d46e656dfcdd4c12efe58ada9f1720c4d8.

Reverted https://github.com/pytorch/pytorch/pull/111072 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/111072#issuecomment-1765399829))
This commit is contained in:
PyTorch MergeBot
2023-10-16 23:03:26 +00:00
parent 5a8a89360d
commit 1e70f4d02c
17 changed files with 31 additions and 704 deletions

View File

@ -522,7 +522,6 @@ libtorch_distributed_base_sources = [
"torch/csrc/distributed/c10d/Backend.cpp", "torch/csrc/distributed/c10d/Backend.cpp",
"torch/csrc/distributed/c10d/FileStore.cpp", "torch/csrc/distributed/c10d/FileStore.cpp",
"torch/csrc/distributed/c10d/GlooDeviceFactory.cpp", "torch/csrc/distributed/c10d/GlooDeviceFactory.cpp",
"torch/csrc/distributed/c10d/Hooks.cpp",
"torch/csrc/distributed/c10d/Ops.cpp", "torch/csrc/distributed/c10d/Ops.cpp",
"torch/csrc/distributed/c10d/ParamCommsUtils.cpp", "torch/csrc/distributed/c10d/ParamCommsUtils.cpp",
"torch/csrc/distributed/c10d/PrefixStore.cpp", "torch/csrc/distributed/c10d/PrefixStore.cpp",

View File

@ -1,270 +0,0 @@
# Owner(s): ["oncall: distributed"]
import os
import sys
import tempfile
import threading
from functools import partial, wraps
import torch
import torch.distributed as dist
import torch.distributed._hooks as dhooks
if not dist.is_available():
print("torch.distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import run_tests, TestCase
class PgHooks(MultiProcessTestCase):
@property
def world_size(self) -> int:
return 4
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
def tearDown(self):
super().tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
def test_pg_hook(self):
pgs = []
def pg_hook(pg, pg_name):
pgs.append((pg, pg_name))
dhooks.register_process_group_hook(pg_hook)
dist.init_process_group(
backend="gloo",
rank=self.rank,
world_size=self.world_size,
store=dist.FileStore(self.file_name, self.world_size),
)
self.assertEqual(len(pgs), 1)
self.assertEqual(pgs[0][0], dist.group.WORLD)
# create two partial world PGs
pg0 = dist.new_group(ranks=[0, 1])
pg1 = dist.new_group(ranks=[2, 3])
# Each rank only observe two PGs being created: the default PG and one covering its ranks
# We don't emit events for PG creation if the current rank doesn't belong to it.
# For example, say you're rank 1, you'll get an event for pg0 but not pg1 even though the API contact
# dictates you need to call new_group for both.
self.assertEqual(len(pgs), 2)
self.assertEqual(pgs[1][0], pg0 if self.rank < 2 else pg1)
def with_comms(func=None):
if func is None:
return partial(
with_comms,
)
@wraps(func)
def wrapper(self, *args, **kwargs):
self.init_comms()
func(self, *args, **kwargs)
self.destroy_comms()
return wrapper
class CollectiveHooks:
@property
def world_size(self) -> int:
return 4
def _collective_hooks(self):
# it's ok to access them directly since there's a single bg thread poking at them.
starts = []
ends = []
cv = threading.Condition()
def coll_start(status):
starts.append(status)
print(f"col_start {len(starts)} rank{self.rank}")
def coll_end(status):
ends.append(status)
print(f"col_end {len(ends)} rank{self.rank}")
if len(ends) == 2:
with cv:
cv.notify()
dhooks.register_collective_start_hook(coll_start)
dhooks.register_collective_end_hook(coll_end)
tensor = torch.ones([2, 3]).to(self.device) * self.rank
tensor_list = [torch.empty_like(tensor) for _ in range(self.world_size)]
dist.all_gather(tensor_list, tensor)
tensor2 = torch.ones([2, 3]).to(self.device) * self.rank
dist.all_reduce(tensor2)
with cv:
cv.wait(1)
default_pg_name = dist.group.WORLD.group_name
self.assertEqual(2, len(starts))
self.assertEqual(2, len(ends))
def check_op(idx, coll_name):
self.assertEqual(default_pg_name, starts[idx].pg_name)
self.assertEqual(self.backend_name, starts[idx].backend)
self.assertGreaterEqual(starts[idx].sequence_number, 0)
self.assertGreaterEqual(starts[idx].timestamp, 0)
self.assertEqual(coll_name, starts[idx].operation)
self.assertEqual(default_pg_name, ends[idx].pg_name)
self.assertEqual(self.backend_name, ends[idx].backend)
self.assertEqual(starts[idx].sequence_number, ends[idx].sequence_number)
self.assertLessEqual(starts[idx].timestamp, ends[idx].timestamp)
self.assertEqual(coll_name, ends[idx].operation)
check_op(0, "ALLGATHER")
check_op(1, "ALLREDUCE")
class GlooHooks(MultiProcessTestCase, CollectiveHooks):
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
def tearDown(self):
super().tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
def init_comms(self):
dist.init_process_group(
backend="gloo",
rank=self.rank,
world_size=self.world_size,
store=dist.FileStore(self.file_name, self.world_size),
)
def destroy_comms(self):
dist.destroy_process_group()
@property
def backend_name(self):
return "gloo"
@property
def device(self):
return "cpu"
@with_comms
def test_collective_hooks(self):
self._collective_hooks()
class NcclHooks(MultiProcessTestCase, CollectiveHooks):
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
def tearDown(self):
super().tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
def init_comms(self):
dist.init_process_group(
backend="nccl",
rank=self.rank,
world_size=self.world_size,
store=dist.FileStore(self.file_name, self.world_size),
)
def destroy_comms(self):
dist.destroy_process_group()
@property
def backend_name(self):
return "nccl"
@property
def device(self):
return f"cuda:{self.rank}"
@skip_if_lt_x_gpu(4)
@with_comms
def test_collective_hooks(self):
self._collective_hooks()
class SingleRankTests(TestCase):
def setUp(self) -> None:
super().setUp()
self.rank = 0
self.file_name = tempfile.NamedTemporaryFile(delete=False).name
dist.init_process_group(
backend="gloo",
rank=0,
world_size=1,
store=dist.FileStore(self.file_name, 1),
)
def tearDown(self) -> None:
dist.destroy_process_group()
def test_queue_overflow(self) -> None:
cv_done_colls = threading.Condition()
cv_done_cb = threading.Condition()
colls_done = False
starts = []
status_with_dropped = None
def coll_start(status: dhooks.CollectiveStatus):
starts.append(status)
with cv_done_colls:
while not colls_done:
cv_done_colls.wait()
if status.drop_count > 0:
nonlocal status_with_dropped
status_with_dropped = status
with cv_done_cb:
cv_done_cb.notify()
dhooks.register_collective_start_hook(coll_start)
# native limit is 512
for i in range(600):
dist.all_reduce(torch.ones([2, 3]))
colls_done = True
with cv_done_colls:
cv_done_colls.notify()
with cv_done_cb:
cv_done_cb.wait(10)
self.assertTrue(status_with_dropped is not None)
self.assertTrue(status_with_dropped.drop_count > 0)
if __name__ == "__main__":
assert (
not torch.cuda._initialized
), "test_distributed must not have initialized CUDA context on main process"
run_tests()

View File

@ -11,10 +11,6 @@ _DEFAULT_FIRST_BUCKET_BYTES: int
_DEFAULT_NO_TIMEOUT: timedelta _DEFAULT_NO_TIMEOUT: timedelta
_DEFAULT_PG_TIMEOUT: timedelta _DEFAULT_PG_TIMEOUT: timedelta
class EventKind(Enum):
START = ...
END = ...
class BuiltinCommHookType(Enum): class BuiltinCommHookType(Enum):
ALLREDUCE = ... ALLREDUCE = ...
FP16_COMPRESS = ... FP16_COMPRESS = ...
@ -24,8 +20,6 @@ def _register_builtin_comm_hook(
reducer: Reducer, reducer: Reducer,
comm_hook_type: BuiltinCommHookType, comm_hook_type: BuiltinCommHookType,
): ... ): ...
def _dequeue_c10d_event() -> Dict[str, object]: ...
def _enable_event_collection(pipe_fs: int) -> None: ...
class GradBucket: class GradBucket:
def index(self) -> int: ... def index(self) -> int: ...

View File

@ -1,26 +1,9 @@
#include <c10/util/Logging.h> #include <c10/util/Logging.h>
#include <fmt/format.h> #include <fmt/format.h>
#include <torch/csrc/distributed/c10d/Backend.hpp> #include <torch/csrc/distributed/c10d/Backend.hpp>
#include <torch/csrc/distributed/c10d/Hooks.hpp>
#include <torch/csrc/distributed/c10d/logging.h>
namespace c10d { namespace c10d {
namespace {
void commonEventinit(
details::EventInfo& evt,
const Backend& backend,
const Work& work) {
evt.timestamp =
std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
evt.pg_name = backend.getGroupName();
evt.backend = backend.getBackendName();
evt.sequence_number = work.getSequencenumber();
evt.operation = c10d::opTypeToString(work.retrieveOpType());
evt.drop_count = 0;
}
} // namespace
Backend::Backend(int rank, int size) Backend::Backend(int rank, int size)
: rank_(rank), size_(size), dist_debug_level_(debug_level()) { : rank_(rank), size_(size), dist_debug_level_(debug_level()) {
C10_LOG_API_USAGE_ONCE("c10d.backend"); C10_LOG_API_USAGE_ONCE("c10d.backend");
@ -32,21 +15,4 @@ void Backend::init() {
C10_LOG_API_USAGE_ONCE(fmt::format("c10d.backend_{}", getBackendName())); C10_LOG_API_USAGE_ONCE(fmt::format("c10d.backend_{}", getBackendName()));
} }
void Backend::emitCollectiveStart(const Work& work) {
details::EventInfo evt;
commonEventinit(evt, *this, work);
evt.event_kind = ::c10d::EventKind::CollectiveStart;
details::enqueue_c10d_event(std::move(evt));
}
void Backend::emitCollectiveEnd(const Work& work) {
details::EventInfo evt;
commonEventinit(evt, *this, work);
evt.event_kind = ::c10d::EventKind::CollectiveEnd;
evt.duration_ms = work.getDuration();
details::enqueue_c10d_event(std::move(evt));
}
} // namespace c10d } // namespace c10d

View File

@ -366,8 +366,6 @@ class TORCH_API Backend : public torch::CustomClassHolder {
// Implementations of this interface need to call this to setup // Implementations of this interface need to call this to setup
// appropriate logging etc. // appropriate logging etc.
void init(); void init();
void emitCollectiveStart(const Work& work);
void emitCollectiveEnd(const Work& work);
const int rank_; const int rank_;
const int size_; const int size_;

View File

@ -1,60 +0,0 @@
#include <atomic>
#include <deque>
#include <memory>
#include <mutex>
#ifndef _WIN32
#include <unistd.h>
#else
#include <io.h>
#endif
#include <torch/csrc/distributed/c10d/Hooks.hpp>
namespace c10d {
namespace {
std::atomic<bool> event_queue_enabled = false;
int sync_pipe;
std::mutex event_queue_lock;
std::deque<details::EventInfo> event_queue;
} // namespace
void enable_event_collection(int pipe) {
sync_pipe = pipe;
event_queue_enabled.store(true);
}
namespace details {
// we start dropping events after this
const size_t MAX_QUEUE_SIZE = 512;
bool dequeue_c10d_event(EventInfo& evt) {
std::unique_lock<std::mutex> lock(event_queue_lock);
if (event_queue.size() == 0) {
return false;
}
evt = event_queue.front();
event_queue.pop_front();
return true;
}
void enqueue_c10d_event(EventInfo&& evt) {
if (!event_queue_enabled.load())
return;
std::unique_lock<std::mutex> lock(event_queue_lock);
if (event_queue.size() >= MAX_QUEUE_SIZE) {
event_queue.back().drop_count++;
} else {
event_queue.push_back(std::move(evt));
char m = 'x';
write(sync_pipe, &m, 1);
}
}
} // namespace details
} // namespace c10d

View File

@ -1,30 +0,0 @@
#pragma once
#include <c10/util/Optional.h>
#include <string>
namespace c10d {
enum class EventKind { CollectiveStart, CollectiveEnd };
TORCH_API void enable_event_collection(int sync_pipe);
namespace details {
struct TORCH_API EventInfo {
EventKind event_kind;
std::string pg_name;
std::string backend;
int64_t sequence_number;
std::string operation;
int64_t timestamp;
c10::optional<float> duration_ms;
int64_t drop_count;
};
// TODO do we want to expose something else here?
TORCH_API bool dequeue_c10d_event(EventInfo& evt);
TORCH_API void enqueue_c10d_event(EventInfo&& evt);
} // namespace details
} // namespace c10d

View File

@ -82,14 +82,10 @@ std::string opTypeToString(OpType opType) {
return "RECVANYSOURCE"; return "RECVANYSOURCE";
case OpType::BARRIER: case OpType::BARRIER:
return "BARRIER"; return "BARRIER";
case OpType::_REDUCE_SCATTER_BASE:
return "_REDUCE_SCATTER_BASE";
case OpType::COALESCED:
return "COALESCED";
case OpType::_ALLREDUCE_SPARSE:
return "_ALLREDUCE_SPARSE";
case OpType::UNKNOWN: case OpType::UNKNOWN:
return "UNKNOWN"; return "UNKNOWN";
case OpType::_REDUCE_SCATTER_BASE:
return "_REDUCE_SCATTER_BASE";
default: default:
TORCH_INTERNAL_ASSERT(false, "Unknown op type!"); TORCH_INTERNAL_ASSERT(false, "Unknown op type!");
} }

View File

@ -855,10 +855,6 @@ void ProcessGroupGloo::runLoop(int workerIndex) {
} }
void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) { void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
emitCollectiveStart(*work.get());
work->getFuture()->addCallback(
[=](auto& f) { this->emitCollectiveEnd(*work.get()); });
std::unique_lock<std::mutex> lock(workMutex_); std::unique_lock<std::mutex> lock(workMutex_);
workQueue_.push_back(std::move(work)); workQueue_.push_back(std::move(work));
lock.unlock(); lock.unlock();

View File

@ -931,34 +931,27 @@ void ProcessGroupNCCL::ncclCommWatchdog() {
} }
} }
void ProcessGroupNCCL::logWorkStart(WorkNCCL& work, bool emitDesyncInfo) { void ProcessGroupNCCL::logWorkStart(WorkNCCL& work) {
if (terminateProcessGroup_.load() || work.startTraceUpdated_) if (work.startTraceUpdated_)
return; return;
if (terminateProcessGroup_.load() || storeError_)
return;
work.startTraceUpdated_ = true; work.startTraceUpdated_ = true;
emitCollectiveStart(work);
if (!emitDesyncInfo || storeError_)
return;
storeError_ = !c10d::traceUpdate( storeError_ = !c10d::traceUpdate(
store_, traceKeyStart_, work.seq_, opTypeToString(work.opType_)); store_, traceKeyStart_, work.seq_, opTypeToString(work.opType_));
} }
void ProcessGroupNCCL::logWorkEnd(WorkNCCL& work, bool emitDesyncInfo) { void ProcessGroupNCCL::logWorkEnd(WorkNCCL& work) {
if (terminateProcessGroup_.load()) if (terminateProcessGroup_.load() || storeError_)
return; return;
// In case the start of the work hasn't been logged // In case the start of the work hasn't been logged
if (!work.startTraceUpdated_) { if (!work.startTraceUpdated_) {
logWorkStart(work, emitDesyncInfo); logWorkStart(work);
} }
emitCollectiveEnd(work);
if (!emitDesyncInfo || storeError_)
return;
storeError_ = !c10d::traceUpdate( storeError_ = !c10d::traceUpdate(
store_, traceKeyEnd_, work.seq_, opTypeToString(work.opType_)); store_, traceKeyEnd_, work.seq_, opTypeToString(work.opType_));
} }
@ -1010,11 +1003,13 @@ void ProcessGroupNCCL::workCleanupLoop() {
} }
// Work status logging for desync debug // Work status logging for desync debug
if (work.isStarted()) { if (desyncDebug_) {
logWorkStart(work, desyncDebug_); if (work.isStarted()) {
} logWorkStart(work);
if (work.isCompleted()) { }
logWorkEnd(work, desyncDebug_); if (work.isCompleted()) {
logWorkEnd(work);
}
} }
// Clean up completed work // Clean up completed work
@ -1071,7 +1066,7 @@ void ProcessGroupNCCL::runHookLoop() {
timeStarted, // timeStarted timeStarted, // timeStarted
std::chrono::system_clock::now(), // timeFinished std::chrono::system_clock::now(), // timeFinished
std::chrono::duration<float, std::milli>( std::chrono::duration<float, std::milli>(
work.getDuration().value()) // activeDuration work.getDuration()) // activeDuration
)); ));
lock.lock(); lock.lock();
@ -1584,19 +1579,19 @@ c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupNCCL::WorkNCCL::
return future_; return future_;
} }
c10::optional<float> ProcessGroupNCCL::WorkNCCL::getDuration() const { float ProcessGroupNCCL::WorkNCCL::getDuration() const {
if (!timingEnabled_ || !((*ncclEndEvents_)[0].query())) { TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled")
return c10::optional<float>();
}
TORCH_CHECK( TORCH_CHECK(
ncclStartEvents_->size() == 1, ncclStartEvents_->size() == 1,
"getDuration only works for single device per ProcessGroup."); "getDuration only works for single device per ProcessGroup.");
TORCH_CHECK( TORCH_CHECK(
ncclEndEvents_->size() == 1, ncclEndEvents_->size() == 1,
"getDuration only works for single device per ProcessGroup."); "getDuration only works for single device per ProcessGroup.");
TORCH_CHECK(
(*ncclEndEvents_)[0].query(),
"getDuration can only be called after work is succeeded.")
return (*ncclStartEvents_)[0].elapsed_time((*ncclEndEvents_)[0]); return (*ncclStartEvents_)[0].elapsed_time((*ncclEndEvents_)[0]);
} }
uint64_t ProcessGroupNCCL::WorkNCCL::getSequencenumber() const { uint64_t ProcessGroupNCCL::WorkNCCL::getSequencenumber() const {
return seq_; return seq_;
} }

View File

@ -167,7 +167,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// Get a Future object that will be marked as completed internally. // Get a Future object that will be marked as completed internally.
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override; c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
c10::optional<float> getDuration() const override; float getDuration() const override;
uint64_t getSequencenumber() const override; uint64_t getSequencenumber() const override;
@ -615,10 +615,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
void runHookLoop(); void runHookLoop();
// Desync debug helper // Desync debug helper
void logWorkStart(WorkNCCL& work, bool emitDesyncInfo); void logWorkStart(WorkNCCL& work);
// Desync debug helper // Desync debug helper
void logWorkEnd(WorkNCCL& work, bool emitDesyncInfo); void logWorkEnd(WorkNCCL& work);
protected: protected:
static const int64_t kWatchdogThreadSleepMillis; static const int64_t kWatchdogThreadSleepMillis;

View File

@ -127,8 +127,8 @@ void Work::finishAndThrow(std::exception_ptr exception) {
} }
} }
c10::optional<float> Work::getDuration() const { float Work::getDuration() const {
return c10::optional<float>(); TORCH_CHECK(false, "This Backend doesn't support getDuration.");
} }
uint64_t Work::getSequencenumber() const { uint64_t Work::getSequencenumber() const {

View File

@ -107,7 +107,7 @@ class TORCH_API Work : public torch::CustomClassHolder {
// work. Only NCCL backend is currently supported. // work. Only NCCL backend is currently supported.
virtual c10::intrusive_ptr<c10::ivalue::Future> getFuture(); virtual c10::intrusive_ptr<c10::ivalue::Future> getFuture();
virtual c10::optional<float> getDuration() const; virtual float getDuration() const;
virtual uint64_t getSequencenumber() const; virtual uint64_t getSequencenumber() const;

View File

@ -32,7 +32,6 @@
#include <fmt/format.h> #include <fmt/format.h>
#include <pybind11/chrono.h> #include <pybind11/chrono.h>
#include <torch/csrc/distributed/c10d/Hooks.hpp>
#include <torch/csrc/distributed/c10d/PrefixStore.hpp> #include <torch/csrc/distributed/c10d/PrefixStore.hpp>
#include <torch/csrc/distributed/c10d/comm.hpp> #include <torch/csrc/distributed/c10d/comm.hpp>
@ -292,26 +291,6 @@ void _register_builtin_comm_hook(
reducer.register_builtin_comm_hook(comm_hook_type); reducer.register_builtin_comm_hook(comm_hook_type);
} }
py::object c10d_dequeue_python_event() {
::c10d::details::EventInfo evt;
if (!::c10d::details::dequeue_c10d_event(evt)) {
return py::none();
}
py::dict data;
data["event_kind"] = (int)evt.event_kind;
data["pg_name"] = evt.pg_name;
data["backend"] = evt.backend;
data["sequence_number"] = evt.sequence_number;
data["operation"] = evt.operation;
data["timestamp"] = evt.timestamp;
data["duration"] = evt.duration_ms.value_or(-1);
data["drop_count"] = evt.drop_count;
return std::move(data);
}
// Customize the metaclass of ::c10d::ReduceOp for the backward compatibility. // Customize the metaclass of ::c10d::ReduceOp for the backward compatibility.
// https://github.com/pytorch/pytorch/pull/84243 changed ::c10d::ReduceOp to // https://github.com/pytorch/pytorch/pull/84243 changed ::c10d::ReduceOp to
// struct from enum, sacrificing some of the Python built-in function supports // struct from enum, sacrificing some of the Python built-in function supports
@ -661,11 +640,6 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
&::c10d::Logger::set_static_graph, &::c10d::Logger::set_static_graph,
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
py::enum_<::c10d::EventKind>(module, "EventKind", R"(
An enum for collective hooks event types.)")
.value("START", ::c10d::EventKind::CollectiveStart)
.value("END", ::c10d::EventKind::CollectiveEnd);
py::enum_<::c10d::DebugLevel>(module, "DebugLevel", R"( py::enum_<::c10d::DebugLevel>(module, "DebugLevel", R"(
An enum whose values correspond to different debug levels of the An enum whose values correspond to different debug levels of the
torch.distributed package. Currently supporting OFF, INFO, and DETAIL, torch.distributed package. Currently supporting OFF, INFO, and DETAIL,
@ -689,16 +663,7 @@ An enum for collective hooks event types.)")
"set_debug_level_from_env", "set_debug_level_from_env",
::c10d::setDebugLevelFromEnvironment, ::c10d::setDebugLevelFromEnvironment,
R"(Sets the debug level of the torch.distributed package from the R"(Sets the debug level of the torch.distributed package from the
``TORCH_DISTRIBUTED_DEBUG`` environment variable.)") ``TORCH_DISTRIBUTED_DEBUG`` environment variable.)");
.def(
"_enable_event_collection",
&::c10d::enable_event_collection,
"(Enables events collection).",
py::call_guard<py::gil_scoped_release>())
.def(
"_dequeue_c10d_event",
&c10d_dequeue_python_event,
"(Blocks until a c10d event is available and return it as a python dictionary).");
// TODO(crcrpar): Hardening `ReduceOp`. // TODO(crcrpar): Hardening `ReduceOp`.
// While keeping most op types as enum value, // While keeping most op types as enum value,

View File

@ -1,183 +0,0 @@
import logging
import os
import threading
import time
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional
import torch.distributed as dist
if dist.is_available():
import torch.distributed.distributed_c10d as c10d
from torch._C._distributed_c10d import (
_dequeue_c10d_event,
_enable_event_collection,
EventKind,
)
__all__ = [
"CollectiveStatus",
"COLLECTIVE_HOOK_TYPE",
"PG_HOOK_TYPE",
"register_collective_start_hook",
"register_collective_end_hook",
"register_process_group_hook",
]
@dataclass
class CollectiveStatus:
r"""Status of a collective operation.
Drop count indicates events have been dropped at the producer side, which means they
were not consumed fast enough by hooks.
"""
pg_name: str = "unknown" # This name matches the one informed in the pgreg cb
backend: str = "unknown" # Name of the backend used
sequence_number: int = -1 # This name matches the one informed in the pgreg cb
operation: str = "unknown" # collective name
timestamp: int = 0 # timestamp to the earliest time we noticed this event
duration: Optional[float] = None # value in milliseconds it took executing
drop_count: int = 0 # number of events dropped following this one
COLLECTIVE_HOOK_TYPE = Callable[[CollectiveStatus], None]
PG_HOOK_TYPE = Callable[[dist.ProcessGroup, str], None]
# This controls the number of internal failures we'll tolerate before giving up
_MAX_INTERNAL_FAILURES = 10
logger = logging.getLogger(__name__)
_cb_thread: Optional[threading.Thread] = None
_start_callbacks: List[COLLECTIVE_HOOK_TYPE] = []
_end_callbacks: List[COLLECTIVE_HOOK_TYPE] = []
_pp_r = -1
_pp_w = -1
def _c10d_pg_hooks_loops():
internal_failures = 0
while True:
# we don't care about the result, this is how we implement notification
_ = os.read(_pp_r, 1)
evt: Dict[str, object] = _dequeue_c10d_event()
try:
event_kind = evt.pop("event_kind", None)
if event_kind is None:
logger.warning(
"c10d returned event dictionary %s without 'event_kind' key, cannot dispatch",
evt,
)
internal_failures += 1
if internal_failures >= _MAX_INTERNAL_FAILURES:
logger.warning(
"too many internal c10d failures processing callback loop. stopping"
)
return
time.sleep(1)
continue
if event_kind == int(EventKind.START): # type: ignore[call-overload]
cb_list = _start_callbacks
elif event_kind == int(EventKind.END): # type: ignore[call-overload]
cb_list = _end_callbacks
else:
logger.warning(
"c10d event %s with invalid 'event_kind' with value %d",
evt,
event_kind,
)
internal_failures += 1
if internal_failures >= _MAX_INTERNAL_FAILURES:
logger.warning(
"too many internal c10d failures processing callback loop. stopping"
)
return
time.sleep(1)
continue
status = CollectiveStatus(**evt) # type: ignore[arg-type]
for cb in cb_list:
try:
cb(status)
except Exception as e:
logger.info(
"c10d event callback %s with event %s threw exception %s",
cb,
status,
e,
)
except Exception as e:
# We have to keep processing otherwise the queue will grown infinitely large
logger.warning(
"c10d callback thread when processing event %s raised exception %s.",
evt,
e,
)
internal_failures += 1
if internal_failures >= _MAX_INTERNAL_FAILURES:
logger.warning(
"too many internal c10d failures processing callback loop. stopping"
)
return
# Sleep for a second to avoid hogging the GIL in case of a persistent failure
time.sleep(1)
def _lazy_init():
global _cb_thread
if _cb_thread is not None:
return
global _pp_r
global _pp_w
_pp_r, _pp_w = os.pipe()
_enable_event_collection(_pp_w)
c10d._enable_collectives_timing()
_cb_thread = threading.Thread(target=_c10d_pg_hooks_loops, daemon=True)
_cb_thread.start()
logger.info("c10d::hooks thread enabled")
def _check_distributed_available():
if not dist.is_available():
raise RuntimeError(
"torch.distributed is not available, so hooks are not available."
)
def register_collective_start_hook(hook: COLLECTIVE_HOOK_TYPE) -> None:
r"""Register a hook that is called every time a collective starts.
The hook is invoked on a background thread.
Exceptions raised by the callback are ignored and non-fatal.
"""
_check_distributed_available()
_start_callbacks.append(hook)
_lazy_init()
def register_collective_end_hook(hook: COLLECTIVE_HOOK_TYPE) -> None:
r"""Register a hook that is called every time a collective finishes.
The hook is invoked on a background thread.
Exceptions raised by the callback are ignored and non-fatal.
"""
_check_distributed_available()
_end_callbacks.append(hook)
_lazy_init()
def register_process_group_hook(hook: PG_HOOK_TYPE) -> None:
r"""Register a hook that is called every time a process group is created on this rank.
This hook is only invoked if the current rank is part of the PG being created.
The pg_name is unique to the whole cluster and should be treated as an opaque identified subject to change.
The hook is invoked on a background thread.
Exceptions raised by the callback are ignored and non-fatal.
"""
_check_distributed_available()
c10d._register_creation_hook(hook)

View File

@ -421,19 +421,6 @@ _group_count = 0
_tags_to_pg: Dict[str, List[ProcessGroup]] = {} _tags_to_pg: Dict[str, List[ProcessGroup]] = {}
_pg_to_tag: Dict[ProcessGroup, str] = {} _pg_to_tag: Dict[ProcessGroup, str] = {}
class _HookState:
def __init__(self):
self.creation_hooks = []
def register_creation_hook(self, hook) -> None:
self.creation_hooks.append(hook)
def fire_creation_hook(self, pg, name) -> None:
for hook in self.creation_hooks:
try:
hook(pg, name)
except Exception as e:
logger.info("hook %s failed with %s", hook, e)
class _World: class _World:
""" """
@ -447,8 +434,6 @@ class _World:
self._default_pg = None self._default_pg = None
self._pg_coalesce_state: Dict[ProcessGroup, List[Union[_CollOp, P2POp]]] = {} self._pg_coalesce_state: Dict[ProcessGroup, List[Union[_CollOp, P2POp]]] = {}
self._pg_default_device: Dict[ProcessGroup, torch.device] = {} self._pg_default_device: Dict[ProcessGroup, torch.device] = {}
self._hook_state = _HookState()
self.enable_collectives_timing = False
@property @property
def default_pg(self): def default_pg(self):
@ -558,9 +543,6 @@ class _World:
) )
return config_info return config_info
@property
def pg_hook_state(self) -> _HookState:
return self._hook_state
_world = _World() _world = _World()
"""Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it""" """Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it"""
@ -1382,9 +1364,6 @@ def _new_process_group_helper(
pg._set_group_name(group_name) pg._set_group_name(group_name)
_world.pg_backend_config[pg] = str(backend_config) _world.pg_backend_config[pg] = str(backend_config)
if _world.enable_collectives_timing:
pg._enable_collectives_timing()
# "" is the default tag for user PGs # "" is the default tag for user PGs
if pg_tag in [None, ""]: if pg_tag in [None, ""]:
pg_tag = f"ptd:{group_name}" pg_tag = f"ptd:{group_name}"
@ -1394,8 +1373,6 @@ def _new_process_group_helper(
_world.tags_to_pg.setdefault(pg_tag, []).append(pg) _world.tags_to_pg.setdefault(pg_tag, []).append(pg)
_world.pg_to_tag[pg] = pg_tag _world.pg_to_tag[pg] = pg_tag
_world.pg_hook_state.fire_creation_hook(pg, group_name)
return pg, prefix_store return pg, prefix_store
def destroy_process_group(group: Optional[ProcessGroup] = None): def destroy_process_group(group: Optional[ProcessGroup] = None):
@ -4338,11 +4315,3 @@ dynamo_unsupported_distributed_c10d_ops = [
reduce_scatter_tensor, reduce_scatter_tensor,
send, send,
] ]
def _register_creation_hook(hook):
_world.pg_hook_state.register_creation_hook(hook)
def _enable_collectives_timing():
_world.enable_collectives_timing = True
for pg in _world.pg_map:
pg._enable_collectives_timing()

View File

@ -401,11 +401,6 @@ class WorldData:
class ThreadLocalWorld: class ThreadLocalWorld:
_world = threading.local() _world = threading.local()
def __init__(self):
self.enable_collectives_timing = False
self._hook_state = dist.distributed_c10d._HookState()
def _get_world(self) -> WorldData: def _get_world(self) -> WorldData:
if not hasattr(ThreadLocalWorld._world, "world"): if not hasattr(ThreadLocalWorld._world, "world"):
ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0, {}, {}, {}, {}) ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0, {}, {}, {}, {})
@ -459,9 +454,6 @@ class ThreadLocalWorld:
def pg_default_device(self) -> Dict[dist.ProcessGroup, torch.device]: def pg_default_device(self) -> Dict[dist.ProcessGroup, torch.device]:
return self._get_world().pg_default_device return self._get_world().pg_default_device
@property
def pg_hook_state(self) -> dist.distributed_c10d._HookState:
return self._hook_state
_old_pg_world = None _old_pg_world = None
_ctx_manager = None _ctx_manager = None