mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[C10] PG observability hooks. (#108815)
Expose a set of observability hooks into C10D such that our users can detect collectives failure both faster and more easily. The design is similar to NCCL desync debug that it minimized the overhead by doing most of the work out of the main thread. This PR introduces a new module torch.distributed.hooks that exposes the following set of methods: register_collective_start_hook register_collective_end_hook register_process_group_hook The process group hook exposes PG creation on the member ranks and call them inline from the the PG creation code. This is fine since this happens during initialization and a limited number of times. The collective start/end hooks are fired from a single background thread. It reads events from a C++ queue and dispatches over. Queue notification is oddly done using a pipe, this is needed so python can abort the thread on shutdown and have it as background thread. This is not possible with more reasonable choices like a condvar. Pull Request resolved: https://github.com/pytorch/pytorch/pull/108815 Approved by: https://github.com/wconstab, https://github.com/fduwjj
This commit is contained in:
committed by
PyTorch MergeBot
parent
17348b0f51
commit
0c7a877745
@ -521,6 +521,7 @@ libtorch_distributed_base_sources = [
|
||||
"torch/csrc/distributed/c10d/Backend.cpp",
|
||||
"torch/csrc/distributed/c10d/FileStore.cpp",
|
||||
"torch/csrc/distributed/c10d/GlooDeviceFactory.cpp",
|
||||
"torch/csrc/distributed/c10d/Hooks.cpp",
|
||||
"torch/csrc/distributed/c10d/Ops.cpp",
|
||||
"torch/csrc/distributed/c10d/ParamCommsUtils.cpp",
|
||||
"torch/csrc/distributed/c10d/PrefixStore.cpp",
|
||||
|
270
test/distributed/test_hooks.py
Normal file
270
test/distributed/test_hooks.py
Normal file
@ -0,0 +1,270 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
from functools import partial, wraps
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.hooks as dhooks
|
||||
|
||||
if not dist.is_available():
|
||||
print("torch.distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class PgHooks(MultiProcessTestCase):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 4
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
try:
|
||||
os.remove(self.file_name)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def test_pg_hook(self):
|
||||
pgs = []
|
||||
|
||||
def pg_hook(pg, pg_name):
|
||||
pgs.append((pg, pg_name))
|
||||
|
||||
dhooks.register_process_group_hook(pg_hook)
|
||||
dist.init_process_group(
|
||||
backend="gloo",
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
store=dist.FileStore(self.file_name, self.world_size),
|
||||
)
|
||||
self.assertEqual(len(pgs), 1)
|
||||
self.assertEqual(pgs[0][0], dist.group.WORLD)
|
||||
|
||||
# create two partial world PGs
|
||||
pg0 = dist.new_group(ranks=[0, 1])
|
||||
pg1 = dist.new_group(ranks=[2, 3])
|
||||
|
||||
# Each rank only observe two PGs being created: the default PG and one covering its ranks
|
||||
# We don't emit events for PG creation if the current rank doesn't belong to it.
|
||||
# For example, say you're rank 1, you'll get an event for pg0 but not pg1 even though the API contact
|
||||
# dictates you need to call new_group for both.
|
||||
self.assertEqual(len(pgs), 2)
|
||||
self.assertEqual(pgs[1][0], pg0 if self.rank < 2 else pg1)
|
||||
|
||||
|
||||
def with_comms(func=None):
|
||||
if func is None:
|
||||
return partial(
|
||||
with_comms,
|
||||
)
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
self.init_comms()
|
||||
func(self, *args, **kwargs)
|
||||
self.destroy_comms()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class CollectiveHooks:
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 4
|
||||
|
||||
def _collective_hooks(self):
|
||||
# it's ok to access them directly since there's a single bg thread poking at them.
|
||||
starts = []
|
||||
ends = []
|
||||
cv = threading.Condition()
|
||||
|
||||
def coll_start(status):
|
||||
starts.append(status)
|
||||
print(f"col_start {len(starts)} rank{self.rank}")
|
||||
|
||||
def coll_end(status):
|
||||
ends.append(status)
|
||||
print(f"col_end {len(ends)} rank{self.rank}")
|
||||
if len(ends) == 2:
|
||||
with cv:
|
||||
cv.notify()
|
||||
|
||||
dhooks.register_collective_start_hook(coll_start)
|
||||
dhooks.register_collective_end_hook(coll_end)
|
||||
|
||||
tensor = torch.ones([2, 3]).to(self.device) * self.rank
|
||||
tensor_list = [torch.empty_like(tensor) for _ in range(self.world_size)]
|
||||
|
||||
dist.all_gather(tensor_list, tensor)
|
||||
|
||||
tensor2 = torch.ones([2, 3]).to(self.device) * self.rank
|
||||
dist.all_reduce(tensor2)
|
||||
|
||||
with cv:
|
||||
cv.wait(1)
|
||||
|
||||
default_pg_name = dist.group.WORLD.group_name
|
||||
self.assertEqual(2, len(starts))
|
||||
self.assertEqual(2, len(ends))
|
||||
|
||||
def check_op(idx, coll_name):
|
||||
self.assertEqual(default_pg_name, starts[idx].pg_name)
|
||||
self.assertEqual(self.backend_name, starts[idx].backend)
|
||||
self.assertGreaterEqual(starts[idx].sequence_number, 0)
|
||||
self.assertGreaterEqual(starts[idx].timestamp, 0)
|
||||
self.assertEqual(coll_name, starts[idx].operation)
|
||||
|
||||
self.assertEqual(default_pg_name, ends[idx].pg_name)
|
||||
self.assertEqual(self.backend_name, ends[idx].backend)
|
||||
|
||||
self.assertEqual(starts[idx].sequence_number, ends[idx].sequence_number)
|
||||
self.assertLessEqual(starts[idx].timestamp, ends[idx].timestamp)
|
||||
self.assertEqual(coll_name, ends[idx].operation)
|
||||
|
||||
check_op(0, "ALLGATHER")
|
||||
check_op(1, "ALLREDUCE")
|
||||
|
||||
|
||||
class GlooHooks(MultiProcessTestCase, CollectiveHooks):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
try:
|
||||
os.remove(self.file_name)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def init_comms(self):
|
||||
dist.init_process_group(
|
||||
backend="gloo",
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
store=dist.FileStore(self.file_name, self.world_size),
|
||||
)
|
||||
|
||||
def destroy_comms(self):
|
||||
dist.destroy_process_group()
|
||||
|
||||
@property
|
||||
def backend_name(self):
|
||||
return "gloo"
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return "cpu"
|
||||
|
||||
@with_comms
|
||||
def test_collective_hooks(self):
|
||||
self._collective_hooks()
|
||||
|
||||
|
||||
class NcclHooks(MultiProcessTestCase, CollectiveHooks):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
try:
|
||||
os.remove(self.file_name)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def init_comms(self):
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
store=dist.FileStore(self.file_name, self.world_size),
|
||||
)
|
||||
|
||||
def destroy_comms(self):
|
||||
dist.destroy_process_group()
|
||||
|
||||
@property
|
||||
def backend_name(self):
|
||||
return "nccl"
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return f"cuda:{self.rank}"
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
def test_collective_hooks(self):
|
||||
self._collective_hooks()
|
||||
|
||||
|
||||
class SingleRankTests(TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.rank = 0
|
||||
self.file_name = tempfile.NamedTemporaryFile(delete=False).name
|
||||
dist.init_process_group(
|
||||
backend="gloo",
|
||||
rank=0,
|
||||
world_size=1,
|
||||
store=dist.FileStore(self.file_name, 1),
|
||||
)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
dist.destroy_process_group()
|
||||
|
||||
def test_queue_overflow(self) -> None:
|
||||
cv_done_colls = threading.Condition()
|
||||
cv_done_cb = threading.Condition()
|
||||
colls_done = False
|
||||
starts = []
|
||||
status_with_dropped = None
|
||||
|
||||
def coll_start(status: dhooks.CollectiveStatus):
|
||||
starts.append(status)
|
||||
with cv_done_colls:
|
||||
while not colls_done:
|
||||
cv_done_colls.wait()
|
||||
if status.drop_count > 0:
|
||||
nonlocal status_with_dropped
|
||||
status_with_dropped = status
|
||||
with cv_done_cb:
|
||||
cv_done_cb.notify()
|
||||
|
||||
dhooks.register_collective_start_hook(coll_start)
|
||||
|
||||
# native limit is 512
|
||||
for i in range(600):
|
||||
dist.all_reduce(torch.ones([2, 3]))
|
||||
colls_done = True
|
||||
with cv_done_colls:
|
||||
cv_done_colls.notify()
|
||||
|
||||
with cv_done_cb:
|
||||
cv_done_cb.wait(10)
|
||||
|
||||
self.assertTrue(status_with_dropped is not None)
|
||||
self.assertTrue(status_with_dropped.drop_count > 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert (
|
||||
not torch.cuda._initialized
|
||||
), "test_distributed must not have initialized CUDA context on main process"
|
||||
|
||||
run_tests()
|
@ -11,6 +11,10 @@ _DEFAULT_FIRST_BUCKET_BYTES: int
|
||||
_DEFAULT_NO_TIMEOUT: timedelta
|
||||
_DEFAULT_PG_TIMEOUT: timedelta
|
||||
|
||||
class EventKind(Enum):
|
||||
START = ...
|
||||
END = ...
|
||||
|
||||
class BuiltinCommHookType(Enum):
|
||||
ALLREDUCE = ...
|
||||
FP16_COMPRESS = ...
|
||||
@ -20,6 +24,8 @@ def _register_builtin_comm_hook(
|
||||
reducer: Reducer,
|
||||
comm_hook_type: BuiltinCommHookType,
|
||||
): ...
|
||||
def _dequeue_c10d_event() -> Dict[str, object]: ...
|
||||
def _enable_event_collection(pipe_fs: int) -> None: ...
|
||||
|
||||
class GradBucket:
|
||||
def index(self) -> int: ...
|
||||
|
@ -1,9 +1,26 @@
|
||||
#include <c10/util/Logging.h>
|
||||
#include <fmt/format.h>
|
||||
#include <torch/csrc/distributed/c10d/Backend.hpp>
|
||||
#include <torch/csrc/distributed/c10d/Hooks.hpp>
|
||||
#include <torch/csrc/distributed/c10d/logging.h>
|
||||
|
||||
namespace c10d {
|
||||
|
||||
namespace {
|
||||
void commonEventinit(
|
||||
details::EventInfo& evt,
|
||||
const Backend& backend,
|
||||
const Work& work) {
|
||||
evt.timestamp =
|
||||
std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
|
||||
evt.pg_name = backend.getGroupName();
|
||||
evt.backend = backend.getBackendName();
|
||||
evt.sequence_number = work.getSequencenumber();
|
||||
evt.operation = c10d::opTypeToString(work.retrieveOpType());
|
||||
evt.drop_count = 0;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Backend::Backend(int rank, int size)
|
||||
: rank_(rank), size_(size), dist_debug_level_(debug_level()) {
|
||||
C10_LOG_API_USAGE_ONCE("c10d.backend");
|
||||
@ -15,4 +32,21 @@ void Backend::init() {
|
||||
C10_LOG_API_USAGE_ONCE(fmt::format("c10d.backend_{}", getBackendName()));
|
||||
}
|
||||
|
||||
void Backend::emitCollectiveStart(const Work& work) {
|
||||
details::EventInfo evt;
|
||||
commonEventinit(evt, *this, work);
|
||||
|
||||
evt.event_kind = ::c10d::EventKind::CollectiveStart;
|
||||
details::enqueue_c10d_event(std::move(evt));
|
||||
}
|
||||
|
||||
void Backend::emitCollectiveEnd(const Work& work) {
|
||||
details::EventInfo evt;
|
||||
commonEventinit(evt, *this, work);
|
||||
|
||||
evt.event_kind = ::c10d::EventKind::CollectiveEnd;
|
||||
evt.duration_ms = work.getDuration();
|
||||
details::enqueue_c10d_event(std::move(evt));
|
||||
}
|
||||
|
||||
} // namespace c10d
|
||||
|
@ -366,6 +366,8 @@ class TORCH_API Backend : public torch::CustomClassHolder {
|
||||
// Implementations of this interface need to call this to setup
|
||||
// appropriate logging etc.
|
||||
void init();
|
||||
void emitCollectiveStart(const Work& work);
|
||||
void emitCollectiveEnd(const Work& work);
|
||||
|
||||
const int rank_;
|
||||
const int size_;
|
||||
|
60
torch/csrc/distributed/c10d/Hooks.cpp
Normal file
60
torch/csrc/distributed/c10d/Hooks.cpp
Normal file
@ -0,0 +1,60 @@
|
||||
#include <atomic>
|
||||
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <unistd.h>
|
||||
#else
|
||||
#include <io.h>
|
||||
#endif
|
||||
|
||||
#include <torch/csrc/distributed/c10d/Hooks.hpp>
|
||||
namespace c10d {
|
||||
|
||||
namespace {
|
||||
|
||||
std::atomic<bool> event_queue_enabled = false;
|
||||
int sync_pipe;
|
||||
std::mutex event_queue_lock;
|
||||
std::deque<details::EventInfo> event_queue;
|
||||
|
||||
} // namespace
|
||||
|
||||
void enable_event_collection(int pipe) {
|
||||
sync_pipe = pipe;
|
||||
event_queue_enabled.store(true);
|
||||
}
|
||||
|
||||
namespace details {
|
||||
|
||||
// we start dropping events after this
|
||||
const size_t MAX_QUEUE_SIZE = 512;
|
||||
|
||||
bool dequeue_c10d_event(EventInfo& evt) {
|
||||
std::unique_lock<std::mutex> lock(event_queue_lock);
|
||||
if (event_queue.size() == 0) {
|
||||
return false;
|
||||
}
|
||||
evt = event_queue.front();
|
||||
event_queue.pop_front();
|
||||
return true;
|
||||
}
|
||||
|
||||
void enqueue_c10d_event(EventInfo&& evt) {
|
||||
if (!event_queue_enabled.load())
|
||||
return;
|
||||
|
||||
std::unique_lock<std::mutex> lock(event_queue_lock);
|
||||
if (event_queue.size() >= MAX_QUEUE_SIZE) {
|
||||
event_queue.back().drop_count++;
|
||||
} else {
|
||||
event_queue.push_back(std::move(evt));
|
||||
char m = 'x';
|
||||
write(sync_pipe, &m, 1);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace c10d
|
30
torch/csrc/distributed/c10d/Hooks.hpp
Normal file
30
torch/csrc/distributed/c10d/Hooks.hpp
Normal file
@ -0,0 +1,30 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/Optional.h>
|
||||
#include <string>
|
||||
|
||||
namespace c10d {
|
||||
|
||||
enum class EventKind { CollectiveStart, CollectiveEnd };
|
||||
|
||||
TORCH_API void enable_event_collection(int sync_pipe);
|
||||
|
||||
namespace details {
|
||||
|
||||
struct TORCH_API EventInfo {
|
||||
EventKind event_kind;
|
||||
std::string pg_name;
|
||||
std::string backend;
|
||||
int64_t sequence_number;
|
||||
std::string operation;
|
||||
int64_t timestamp;
|
||||
c10::optional<float> duration_ms;
|
||||
int64_t drop_count;
|
||||
};
|
||||
|
||||
// TODO do we want to expose something else here?
|
||||
TORCH_API bool dequeue_c10d_event(EventInfo& evt);
|
||||
TORCH_API void enqueue_c10d_event(EventInfo&& evt);
|
||||
|
||||
} // namespace details
|
||||
} // namespace c10d
|
@ -82,10 +82,14 @@ std::string opTypeToString(OpType opType) {
|
||||
return "RECVANYSOURCE";
|
||||
case OpType::BARRIER:
|
||||
return "BARRIER";
|
||||
case OpType::UNKNOWN:
|
||||
return "UNKNOWN";
|
||||
case OpType::_REDUCE_SCATTER_BASE:
|
||||
return "_REDUCE_SCATTER_BASE";
|
||||
case OpType::COALESCED:
|
||||
return "COALESCED";
|
||||
case OpType::_ALLREDUCE_SPARSE:
|
||||
return "_ALLREDUCE_SPARSE";
|
||||
case OpType::UNKNOWN:
|
||||
return "UNKNOWN";
|
||||
default:
|
||||
TORCH_INTERNAL_ASSERT(false, "Unknown op type!");
|
||||
}
|
||||
|
@ -855,6 +855,10 @@ void ProcessGroupGloo::runLoop(int workerIndex) {
|
||||
}
|
||||
|
||||
void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
|
||||
emitCollectiveStart(*work.get());
|
||||
work->getFuture()->addCallback(
|
||||
[=](auto& f) { this->emitCollectiveEnd(*work.get()); });
|
||||
|
||||
std::unique_lock<std::mutex> lock(workMutex_);
|
||||
workQueue_.push_back(std::move(work));
|
||||
lock.unlock();
|
||||
|
@ -942,27 +942,34 @@ void ProcessGroupNCCL::ncclCommWatchdog() {
|
||||
}
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::logWorkStart(WorkNCCL& work) {
|
||||
if (work.startTraceUpdated_)
|
||||
void ProcessGroupNCCL::logWorkStart(WorkNCCL& work, bool emitDesyncInfo) {
|
||||
if (terminateProcessGroup_.load() || work.startTraceUpdated_)
|
||||
return;
|
||||
|
||||
if (terminateProcessGroup_.load() || storeError_)
|
||||
return;
|
||||
|
||||
work.startTraceUpdated_ = true;
|
||||
|
||||
emitCollectiveStart(work);
|
||||
|
||||
if (!emitDesyncInfo || storeError_)
|
||||
return;
|
||||
|
||||
storeError_ = !c10d::traceUpdate(
|
||||
store_, traceKeyStart_, work.seq_, opTypeToString(work.opType_));
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::logWorkEnd(WorkNCCL& work) {
|
||||
if (terminateProcessGroup_.load() || storeError_)
|
||||
void ProcessGroupNCCL::logWorkEnd(WorkNCCL& work, bool emitDesyncInfo) {
|
||||
if (terminateProcessGroup_.load())
|
||||
return;
|
||||
|
||||
// In case the start of the work hasn't been logged
|
||||
if (!work.startTraceUpdated_) {
|
||||
logWorkStart(work);
|
||||
logWorkStart(work, emitDesyncInfo);
|
||||
}
|
||||
|
||||
emitCollectiveEnd(work);
|
||||
|
||||
if (!emitDesyncInfo || storeError_)
|
||||
return;
|
||||
|
||||
storeError_ = !c10d::traceUpdate(
|
||||
store_, traceKeyEnd_, work.seq_, opTypeToString(work.opType_));
|
||||
}
|
||||
@ -1014,13 +1021,11 @@ void ProcessGroupNCCL::workCleanupLoop() {
|
||||
}
|
||||
|
||||
// Work status logging for desync debug
|
||||
if (desyncDebug_) {
|
||||
if (work.isStarted()) {
|
||||
logWorkStart(work);
|
||||
}
|
||||
if (work.isCompleted()) {
|
||||
logWorkEnd(work);
|
||||
}
|
||||
if (work.isStarted()) {
|
||||
logWorkStart(work, desyncDebug_);
|
||||
}
|
||||
if (work.isCompleted()) {
|
||||
logWorkEnd(work, desyncDebug_);
|
||||
}
|
||||
|
||||
// Clean up completed work
|
||||
@ -1077,7 +1082,7 @@ void ProcessGroupNCCL::runHookLoop() {
|
||||
timeStarted, // timeStarted
|
||||
std::chrono::system_clock::now(), // timeFinished
|
||||
std::chrono::duration<float, std::milli>(
|
||||
work.getDuration()) // activeDuration
|
||||
work.getDuration().value()) // activeDuration
|
||||
));
|
||||
|
||||
lock.lock();
|
||||
@ -1580,19 +1585,19 @@ c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupNCCL::WorkNCCL::
|
||||
return future_;
|
||||
}
|
||||
|
||||
float ProcessGroupNCCL::WorkNCCL::getDuration() const {
|
||||
TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled")
|
||||
c10::optional<float> ProcessGroupNCCL::WorkNCCL::getDuration() const {
|
||||
if (!timingEnabled_ || !((*ncclEndEvents_)[0].query())) {
|
||||
return c10::optional<float>();
|
||||
}
|
||||
TORCH_CHECK(
|
||||
ncclStartEvents_->size() == 1,
|
||||
"getDuration only works for single device per ProcessGroup.");
|
||||
TORCH_CHECK(
|
||||
ncclEndEvents_->size() == 1,
|
||||
"getDuration only works for single device per ProcessGroup.");
|
||||
TORCH_CHECK(
|
||||
(*ncclEndEvents_)[0].query(),
|
||||
"getDuration can only be called after work is succeeded.")
|
||||
return (*ncclStartEvents_)[0].elapsed_time((*ncclEndEvents_)[0]);
|
||||
}
|
||||
|
||||
uint64_t ProcessGroupNCCL::WorkNCCL::getSequencenumber() const {
|
||||
return seq_;
|
||||
}
|
||||
|
@ -164,7 +164,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// Get a Future object that will be marked as completed internally.
|
||||
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
|
||||
|
||||
float getDuration() const override;
|
||||
c10::optional<float> getDuration() const override;
|
||||
|
||||
uint64_t getSequencenumber() const override;
|
||||
|
||||
@ -612,10 +612,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
void runHookLoop();
|
||||
|
||||
// Desync debug helper
|
||||
void logWorkStart(WorkNCCL& work);
|
||||
void logWorkStart(WorkNCCL& work, bool emitDesyncInfo);
|
||||
|
||||
// Desync debug helper
|
||||
void logWorkEnd(WorkNCCL& work);
|
||||
void logWorkEnd(WorkNCCL& work, bool emitDesyncInfo);
|
||||
|
||||
protected:
|
||||
static const int64_t kWatchdogThreadSleepMillis;
|
||||
|
@ -127,8 +127,8 @@ void Work::finishAndThrow(std::exception_ptr exception) {
|
||||
}
|
||||
}
|
||||
|
||||
float Work::getDuration() const {
|
||||
TORCH_CHECK(false, "This Backend doesn't support getDuration.");
|
||||
c10::optional<float> Work::getDuration() const {
|
||||
return c10::optional<float>();
|
||||
}
|
||||
|
||||
uint64_t Work::getSequencenumber() const {
|
||||
|
@ -107,7 +107,7 @@ class TORCH_API Work : public torch::CustomClassHolder {
|
||||
// work. Only NCCL backend is currently supported.
|
||||
virtual c10::intrusive_ptr<c10::ivalue::Future> getFuture();
|
||||
|
||||
virtual float getDuration() const;
|
||||
virtual c10::optional<float> getDuration() const;
|
||||
|
||||
virtual uint64_t getSequencenumber() const;
|
||||
|
||||
|
@ -32,6 +32,7 @@
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <pybind11/chrono.h>
|
||||
#include <torch/csrc/distributed/c10d/Hooks.hpp>
|
||||
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/comm.hpp>
|
||||
@ -291,6 +292,26 @@ void _register_builtin_comm_hook(
|
||||
reducer.register_builtin_comm_hook(comm_hook_type);
|
||||
}
|
||||
|
||||
py::object c10d_dequeue_python_event() {
|
||||
::c10d::details::EventInfo evt;
|
||||
if (!::c10d::details::dequeue_c10d_event(evt)) {
|
||||
return py::none();
|
||||
}
|
||||
|
||||
py::dict data;
|
||||
data["event_kind"] = (int)evt.event_kind;
|
||||
|
||||
data["pg_name"] = evt.pg_name;
|
||||
data["backend"] = evt.backend;
|
||||
data["sequence_number"] = evt.sequence_number;
|
||||
data["operation"] = evt.operation;
|
||||
data["timestamp"] = evt.timestamp;
|
||||
data["duration"] = evt.duration_ms.value_or(-1);
|
||||
data["drop_count"] = evt.drop_count;
|
||||
|
||||
return std::move(data);
|
||||
}
|
||||
|
||||
// Customize the metaclass of ::c10d::ReduceOp for the backward compatibility.
|
||||
// https://github.com/pytorch/pytorch/pull/84243 changed ::c10d::ReduceOp to
|
||||
// struct from enum, sacrificing some of the Python built-in function supports
|
||||
@ -640,6 +661,11 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
|
||||
&::c10d::Logger::set_static_graph,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
py::enum_<::c10d::EventKind>(module, "EventKind", R"(
|
||||
An enum for collective hooks event types.)")
|
||||
.value("START", ::c10d::EventKind::CollectiveStart)
|
||||
.value("END", ::c10d::EventKind::CollectiveEnd);
|
||||
|
||||
py::enum_<::c10d::DebugLevel>(module, "DebugLevel", R"(
|
||||
An enum whose values correspond to different debug levels of the
|
||||
torch.distributed package. Currently supporting OFF, INFO, and DETAIL,
|
||||
@ -663,7 +689,16 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
|
||||
"set_debug_level_from_env",
|
||||
::c10d::setDebugLevelFromEnvironment,
|
||||
R"(Sets the debug level of the torch.distributed package from the
|
||||
``TORCH_DISTRIBUTED_DEBUG`` environment variable.)");
|
||||
``TORCH_DISTRIBUTED_DEBUG`` environment variable.)")
|
||||
.def(
|
||||
"_enable_event_collection",
|
||||
&::c10d::enable_event_collection,
|
||||
"(Enables events collection).",
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"_dequeue_c10d_event",
|
||||
&c10d_dequeue_python_event,
|
||||
"(Blocks until a c10d event is available and return it as a python dictionary).");
|
||||
|
||||
// TODO(crcrpar): Hardening `ReduceOp`.
|
||||
// While keeping most op types as enum value,
|
||||
|
@ -421,6 +421,19 @@ _group_count = 0
|
||||
_tags_to_pg: Dict[str, List[ProcessGroup]] = {}
|
||||
_pg_to_tag: Dict[ProcessGroup, str] = {}
|
||||
|
||||
class _HookState:
|
||||
def __init__(self):
|
||||
self.creation_hooks = []
|
||||
|
||||
def register_creation_hook(self, hook) -> None:
|
||||
self.creation_hooks.append(hook)
|
||||
|
||||
def fire_creation_hook(self, pg, name) -> None:
|
||||
for hook in self.creation_hooks:
|
||||
try:
|
||||
hook(pg, name)
|
||||
except Exception as e:
|
||||
logger.info("hook %s failed with %s", hook, e)
|
||||
|
||||
class _World:
|
||||
"""
|
||||
@ -434,6 +447,8 @@ class _World:
|
||||
self._default_pg = None
|
||||
self._pg_coalesce_state: Dict[ProcessGroup, List[Union[_CollOp, P2POp]]] = {}
|
||||
self._pg_default_device: Dict[ProcessGroup, torch.device] = {}
|
||||
self._hook_state = _HookState()
|
||||
self.enable_collectives_timing = False
|
||||
|
||||
@property
|
||||
def default_pg(self):
|
||||
@ -543,6 +558,9 @@ class _World:
|
||||
)
|
||||
return config_info
|
||||
|
||||
@property
|
||||
def pg_hook_state(self) -> _HookState:
|
||||
return self._hook_state
|
||||
|
||||
_world = _World()
|
||||
"""Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it"""
|
||||
@ -1362,6 +1380,9 @@ def _new_process_group_helper(
|
||||
pg._set_group_name(group_name)
|
||||
|
||||
_world.pg_backend_config[pg] = str(backend_config)
|
||||
if _world.enable_collectives_timing:
|
||||
pg._enable_collectives_timing()
|
||||
|
||||
# "" is the default tag for user PGs
|
||||
if pg_tag in [None, ""]:
|
||||
pg_tag = f"ptd:{group_name}"
|
||||
@ -1371,6 +1392,8 @@ def _new_process_group_helper(
|
||||
|
||||
_world.tags_to_pg.setdefault(pg_tag, []).append(pg)
|
||||
_world.pg_to_tag[pg] = pg_tag
|
||||
_world.pg_hook_state.fire_creation_hook(pg, group_name)
|
||||
|
||||
return pg, prefix_store
|
||||
|
||||
def destroy_process_group(group: Optional[ProcessGroup] = None):
|
||||
@ -4313,3 +4336,11 @@ dynamo_unsupported_distributed_c10d_ops = [
|
||||
reduce_scatter_tensor,
|
||||
send,
|
||||
]
|
||||
|
||||
def _register_creation_hook(hook):
|
||||
_world.pg_hook_state.register_creation_hook(hook)
|
||||
|
||||
def _enable_collectives_timing():
|
||||
_world.enable_collectives_timing = True
|
||||
for pg in _world.pg_map:
|
||||
pg._enable_collectives_timing()
|
||||
|
169
torch/distributed/hooks.py
Normal file
169
torch/distributed/hooks.py
Normal file
@ -0,0 +1,169 @@
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.distributed_c10d as c10d
|
||||
|
||||
from torch._C._distributed_c10d import (
|
||||
_dequeue_c10d_event,
|
||||
_enable_event_collection,
|
||||
EventKind,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CollectiveStatus",
|
||||
"COLLECTIVE_HOOK_TYPE",
|
||||
"PG_HOOK_TYPE",
|
||||
"register_collective_start_hook",
|
||||
"register_collective_end_hook",
|
||||
"register_process_group_hook",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CollectiveStatus:
|
||||
pg_name: str = "unknown" # This name matches the one informed in the pgreg cb
|
||||
backend: str = "unknown" # Name of the backend used
|
||||
sequence_number: int = -1 # This name matches the one informed in the pgreg cb
|
||||
operation: str = "unknown" # collective name
|
||||
timestamp: int = 0 # timestamp to the earliest time we noticed this event
|
||||
duration: Optional[float] = None # value in milliseconds it took executing
|
||||
drop_count: int = 0 # number of events dropped following this one
|
||||
|
||||
|
||||
COLLECTIVE_HOOK_TYPE = Callable[[CollectiveStatus], None]
|
||||
PG_HOOK_TYPE = Callable[[dist.ProcessGroup, str], None]
|
||||
|
||||
# This controls the number of internal failures we'll tolerate before giving up
|
||||
_MAX_INTERNAL_FAILURES = 10
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_cb_thread: Optional[threading.Thread] = None
|
||||
_start_callbacks: List[COLLECTIVE_HOOK_TYPE] = []
|
||||
_end_callbacks: List[COLLECTIVE_HOOK_TYPE] = []
|
||||
_pp_r = -1
|
||||
_pp_w = -1
|
||||
|
||||
|
||||
def _c10d_pg_hooks_loops():
|
||||
internal_failures = 0
|
||||
while True:
|
||||
# we don't care about the result, this is how we implement notification
|
||||
_ = os.read(_pp_r, 1)
|
||||
evt: Dict[str, object] = _dequeue_c10d_event()
|
||||
try:
|
||||
event_kind = evt.pop("event_kind", None)
|
||||
if event_kind is None:
|
||||
logger.warning(
|
||||
"c10d returned event dictionary %s without 'event_kind' key, cannot dispatch",
|
||||
evt,
|
||||
)
|
||||
internal_failures += 1
|
||||
if internal_failures >= _MAX_INTERNAL_FAILURES:
|
||||
logger.warning(
|
||||
"too many internal c10d failures processing callback loop. stopping"
|
||||
)
|
||||
return
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
if event_kind == int(EventKind.START): # type: ignore[call-overload]
|
||||
cb_list = _start_callbacks
|
||||
elif event_kind == int(EventKind.END): # type: ignore[call-overload]
|
||||
cb_list = _end_callbacks
|
||||
else:
|
||||
logger.warning(
|
||||
"c10d event %s with invalid 'event_kind' with value %d",
|
||||
evt,
|
||||
event_kind,
|
||||
)
|
||||
internal_failures += 1
|
||||
if internal_failures >= _MAX_INTERNAL_FAILURES:
|
||||
logger.warning(
|
||||
"too many internal c10d failures processing callback loop. stopping"
|
||||
)
|
||||
return
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
status = CollectiveStatus(**evt) # type: ignore[arg-type]
|
||||
for cb in cb_list:
|
||||
try:
|
||||
cb(status)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"c10d event callback %s with event %s threw exception %s",
|
||||
cb,
|
||||
status,
|
||||
e,
|
||||
)
|
||||
except Exception as e:
|
||||
# We have to keep processing otherwise the queue will grown infinitely large
|
||||
logger.warning(
|
||||
"c10d callback thread when processing event %s raised exception %s.",
|
||||
evt,
|
||||
e,
|
||||
)
|
||||
internal_failures += 1
|
||||
if internal_failures >= _MAX_INTERNAL_FAILURES:
|
||||
logger.warning(
|
||||
"too many internal c10d failures processing callback loop. stopping"
|
||||
)
|
||||
return
|
||||
|
||||
# Sleep for a second to avoid hogging the GIL in case of a persistent failure
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def _lazy_init():
|
||||
global _cb_thread
|
||||
if _cb_thread is not None:
|
||||
return
|
||||
global _pp_r
|
||||
global _pp_w
|
||||
_pp_r, _pp_w = os.pipe()
|
||||
_enable_event_collection(_pp_w)
|
||||
c10d._enable_collectives_timing()
|
||||
_cb_thread = threading.Thread(target=_c10d_pg_hooks_loops, daemon=True)
|
||||
_cb_thread.start()
|
||||
logger.info("c10d::hooks thread enabled")
|
||||
|
||||
|
||||
def register_collective_start_hook(hook: COLLECTIVE_HOOK_TYPE) -> None:
|
||||
"""
|
||||
Register a hook that is called every time a collective starts.
|
||||
|
||||
The hook is invoked on a background thread.
|
||||
Exceptions raised by the callback are ignored and non-fatal.
|
||||
"""
|
||||
_start_callbacks.append(hook)
|
||||
_lazy_init()
|
||||
|
||||
|
||||
def register_collective_end_hook(hook: COLLECTIVE_HOOK_TYPE) -> None:
|
||||
"""
|
||||
Register a hook that is called every time a collective finishes.
|
||||
|
||||
|
||||
The hook is invoked on a background thread.
|
||||
Exceptions raised by the callback are ignored and non-fatal.
|
||||
"""
|
||||
_end_callbacks.append(hook)
|
||||
_lazy_init()
|
||||
|
||||
|
||||
def register_process_group_hook(hook: PG_HOOK_TYPE) -> None:
|
||||
"""
|
||||
Register a hook that is called every time a process group is created on this rank.
|
||||
|
||||
This hook is only invoked if the current rank is part of the PG being created.
|
||||
The pg_name is unique to the whole cluster and should be treated as an opaque identified subject to change.
|
||||
|
||||
The hook is invoked on a background thread.
|
||||
Exceptions raised by the callback are ignored and non-fatal.
|
||||
"""
|
||||
c10d._register_creation_hook(hook)
|
@ -401,6 +401,11 @@ class WorldData:
|
||||
class ThreadLocalWorld:
|
||||
_world = threading.local()
|
||||
|
||||
def __init__(self):
|
||||
self.enable_collectives_timing = False
|
||||
self._hook_state = dist.distributed_c10d._HookState()
|
||||
|
||||
|
||||
def _get_world(self) -> WorldData:
|
||||
if not hasattr(ThreadLocalWorld._world, "world"):
|
||||
ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0, {}, {}, {}, {})
|
||||
@ -454,6 +459,9 @@ class ThreadLocalWorld:
|
||||
def pg_default_device(self) -> Dict[dist.ProcessGroup, torch.device]:
|
||||
return self._get_world().pg_default_device
|
||||
|
||||
@property
|
||||
def pg_hook_state(self) -> dist.distributed_c10d._HookState:
|
||||
return self._hook_state
|
||||
|
||||
_old_pg_world = None
|
||||
_ctx_manager = None
|
||||
|
Reference in New Issue
Block a user