mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit bb1424d46e656dfcdd4c12efe58ada9f1720c4d8. Reverted https://github.com/pytorch/pytorch/pull/111072 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/111072#issuecomment-1765399829))
This commit is contained in:
@ -522,7 +522,6 @@ libtorch_distributed_base_sources = [
|
||||
"torch/csrc/distributed/c10d/Backend.cpp",
|
||||
"torch/csrc/distributed/c10d/FileStore.cpp",
|
||||
"torch/csrc/distributed/c10d/GlooDeviceFactory.cpp",
|
||||
"torch/csrc/distributed/c10d/Hooks.cpp",
|
||||
"torch/csrc/distributed/c10d/Ops.cpp",
|
||||
"torch/csrc/distributed/c10d/ParamCommsUtils.cpp",
|
||||
"torch/csrc/distributed/c10d/PrefixStore.cpp",
|
||||
|
||||
@ -1,270 +0,0 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
from functools import partial, wraps
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._hooks as dhooks
|
||||
|
||||
if not dist.is_available():
|
||||
print("torch.distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class PgHooks(MultiProcessTestCase):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 4
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
try:
|
||||
os.remove(self.file_name)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def test_pg_hook(self):
|
||||
pgs = []
|
||||
|
||||
def pg_hook(pg, pg_name):
|
||||
pgs.append((pg, pg_name))
|
||||
|
||||
dhooks.register_process_group_hook(pg_hook)
|
||||
dist.init_process_group(
|
||||
backend="gloo",
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
store=dist.FileStore(self.file_name, self.world_size),
|
||||
)
|
||||
self.assertEqual(len(pgs), 1)
|
||||
self.assertEqual(pgs[0][0], dist.group.WORLD)
|
||||
|
||||
# create two partial world PGs
|
||||
pg0 = dist.new_group(ranks=[0, 1])
|
||||
pg1 = dist.new_group(ranks=[2, 3])
|
||||
|
||||
# Each rank only observe two PGs being created: the default PG and one covering its ranks
|
||||
# We don't emit events for PG creation if the current rank doesn't belong to it.
|
||||
# For example, say you're rank 1, you'll get an event for pg0 but not pg1 even though the API contact
|
||||
# dictates you need to call new_group for both.
|
||||
self.assertEqual(len(pgs), 2)
|
||||
self.assertEqual(pgs[1][0], pg0 if self.rank < 2 else pg1)
|
||||
|
||||
|
||||
def with_comms(func=None):
|
||||
if func is None:
|
||||
return partial(
|
||||
with_comms,
|
||||
)
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
self.init_comms()
|
||||
func(self, *args, **kwargs)
|
||||
self.destroy_comms()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class CollectiveHooks:
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 4
|
||||
|
||||
def _collective_hooks(self):
|
||||
# it's ok to access them directly since there's a single bg thread poking at them.
|
||||
starts = []
|
||||
ends = []
|
||||
cv = threading.Condition()
|
||||
|
||||
def coll_start(status):
|
||||
starts.append(status)
|
||||
print(f"col_start {len(starts)} rank{self.rank}")
|
||||
|
||||
def coll_end(status):
|
||||
ends.append(status)
|
||||
print(f"col_end {len(ends)} rank{self.rank}")
|
||||
if len(ends) == 2:
|
||||
with cv:
|
||||
cv.notify()
|
||||
|
||||
dhooks.register_collective_start_hook(coll_start)
|
||||
dhooks.register_collective_end_hook(coll_end)
|
||||
|
||||
tensor = torch.ones([2, 3]).to(self.device) * self.rank
|
||||
tensor_list = [torch.empty_like(tensor) for _ in range(self.world_size)]
|
||||
|
||||
dist.all_gather(tensor_list, tensor)
|
||||
|
||||
tensor2 = torch.ones([2, 3]).to(self.device) * self.rank
|
||||
dist.all_reduce(tensor2)
|
||||
|
||||
with cv:
|
||||
cv.wait(1)
|
||||
|
||||
default_pg_name = dist.group.WORLD.group_name
|
||||
self.assertEqual(2, len(starts))
|
||||
self.assertEqual(2, len(ends))
|
||||
|
||||
def check_op(idx, coll_name):
|
||||
self.assertEqual(default_pg_name, starts[idx].pg_name)
|
||||
self.assertEqual(self.backend_name, starts[idx].backend)
|
||||
self.assertGreaterEqual(starts[idx].sequence_number, 0)
|
||||
self.assertGreaterEqual(starts[idx].timestamp, 0)
|
||||
self.assertEqual(coll_name, starts[idx].operation)
|
||||
|
||||
self.assertEqual(default_pg_name, ends[idx].pg_name)
|
||||
self.assertEqual(self.backend_name, ends[idx].backend)
|
||||
|
||||
self.assertEqual(starts[idx].sequence_number, ends[idx].sequence_number)
|
||||
self.assertLessEqual(starts[idx].timestamp, ends[idx].timestamp)
|
||||
self.assertEqual(coll_name, ends[idx].operation)
|
||||
|
||||
check_op(0, "ALLGATHER")
|
||||
check_op(1, "ALLREDUCE")
|
||||
|
||||
|
||||
class GlooHooks(MultiProcessTestCase, CollectiveHooks):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
try:
|
||||
os.remove(self.file_name)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def init_comms(self):
|
||||
dist.init_process_group(
|
||||
backend="gloo",
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
store=dist.FileStore(self.file_name, self.world_size),
|
||||
)
|
||||
|
||||
def destroy_comms(self):
|
||||
dist.destroy_process_group()
|
||||
|
||||
@property
|
||||
def backend_name(self):
|
||||
return "gloo"
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return "cpu"
|
||||
|
||||
@with_comms
|
||||
def test_collective_hooks(self):
|
||||
self._collective_hooks()
|
||||
|
||||
|
||||
class NcclHooks(MultiProcessTestCase, CollectiveHooks):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
try:
|
||||
os.remove(self.file_name)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def init_comms(self):
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
store=dist.FileStore(self.file_name, self.world_size),
|
||||
)
|
||||
|
||||
def destroy_comms(self):
|
||||
dist.destroy_process_group()
|
||||
|
||||
@property
|
||||
def backend_name(self):
|
||||
return "nccl"
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return f"cuda:{self.rank}"
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
def test_collective_hooks(self):
|
||||
self._collective_hooks()
|
||||
|
||||
|
||||
class SingleRankTests(TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.rank = 0
|
||||
self.file_name = tempfile.NamedTemporaryFile(delete=False).name
|
||||
dist.init_process_group(
|
||||
backend="gloo",
|
||||
rank=0,
|
||||
world_size=1,
|
||||
store=dist.FileStore(self.file_name, 1),
|
||||
)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
dist.destroy_process_group()
|
||||
|
||||
def test_queue_overflow(self) -> None:
|
||||
cv_done_colls = threading.Condition()
|
||||
cv_done_cb = threading.Condition()
|
||||
colls_done = False
|
||||
starts = []
|
||||
status_with_dropped = None
|
||||
|
||||
def coll_start(status: dhooks.CollectiveStatus):
|
||||
starts.append(status)
|
||||
with cv_done_colls:
|
||||
while not colls_done:
|
||||
cv_done_colls.wait()
|
||||
if status.drop_count > 0:
|
||||
nonlocal status_with_dropped
|
||||
status_with_dropped = status
|
||||
with cv_done_cb:
|
||||
cv_done_cb.notify()
|
||||
|
||||
dhooks.register_collective_start_hook(coll_start)
|
||||
|
||||
# native limit is 512
|
||||
for i in range(600):
|
||||
dist.all_reduce(torch.ones([2, 3]))
|
||||
colls_done = True
|
||||
with cv_done_colls:
|
||||
cv_done_colls.notify()
|
||||
|
||||
with cv_done_cb:
|
||||
cv_done_cb.wait(10)
|
||||
|
||||
self.assertTrue(status_with_dropped is not None)
|
||||
self.assertTrue(status_with_dropped.drop_count > 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert (
|
||||
not torch.cuda._initialized
|
||||
), "test_distributed must not have initialized CUDA context on main process"
|
||||
|
||||
run_tests()
|
||||
@ -11,10 +11,6 @@ _DEFAULT_FIRST_BUCKET_BYTES: int
|
||||
_DEFAULT_NO_TIMEOUT: timedelta
|
||||
_DEFAULT_PG_TIMEOUT: timedelta
|
||||
|
||||
class EventKind(Enum):
|
||||
START = ...
|
||||
END = ...
|
||||
|
||||
class BuiltinCommHookType(Enum):
|
||||
ALLREDUCE = ...
|
||||
FP16_COMPRESS = ...
|
||||
@ -24,8 +20,6 @@ def _register_builtin_comm_hook(
|
||||
reducer: Reducer,
|
||||
comm_hook_type: BuiltinCommHookType,
|
||||
): ...
|
||||
def _dequeue_c10d_event() -> Dict[str, object]: ...
|
||||
def _enable_event_collection(pipe_fs: int) -> None: ...
|
||||
|
||||
class GradBucket:
|
||||
def index(self) -> int: ...
|
||||
|
||||
@ -1,26 +1,9 @@
|
||||
#include <c10/util/Logging.h>
|
||||
#include <fmt/format.h>
|
||||
#include <torch/csrc/distributed/c10d/Backend.hpp>
|
||||
#include <torch/csrc/distributed/c10d/Hooks.hpp>
|
||||
#include <torch/csrc/distributed/c10d/logging.h>
|
||||
|
||||
namespace c10d {
|
||||
|
||||
namespace {
|
||||
void commonEventinit(
|
||||
details::EventInfo& evt,
|
||||
const Backend& backend,
|
||||
const Work& work) {
|
||||
evt.timestamp =
|
||||
std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
|
||||
evt.pg_name = backend.getGroupName();
|
||||
evt.backend = backend.getBackendName();
|
||||
evt.sequence_number = work.getSequencenumber();
|
||||
evt.operation = c10d::opTypeToString(work.retrieveOpType());
|
||||
evt.drop_count = 0;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Backend::Backend(int rank, int size)
|
||||
: rank_(rank), size_(size), dist_debug_level_(debug_level()) {
|
||||
C10_LOG_API_USAGE_ONCE("c10d.backend");
|
||||
@ -32,21 +15,4 @@ void Backend::init() {
|
||||
C10_LOG_API_USAGE_ONCE(fmt::format("c10d.backend_{}", getBackendName()));
|
||||
}
|
||||
|
||||
void Backend::emitCollectiveStart(const Work& work) {
|
||||
details::EventInfo evt;
|
||||
commonEventinit(evt, *this, work);
|
||||
|
||||
evt.event_kind = ::c10d::EventKind::CollectiveStart;
|
||||
details::enqueue_c10d_event(std::move(evt));
|
||||
}
|
||||
|
||||
void Backend::emitCollectiveEnd(const Work& work) {
|
||||
details::EventInfo evt;
|
||||
commonEventinit(evt, *this, work);
|
||||
|
||||
evt.event_kind = ::c10d::EventKind::CollectiveEnd;
|
||||
evt.duration_ms = work.getDuration();
|
||||
details::enqueue_c10d_event(std::move(evt));
|
||||
}
|
||||
|
||||
} // namespace c10d
|
||||
|
||||
@ -366,8 +366,6 @@ class TORCH_API Backend : public torch::CustomClassHolder {
|
||||
// Implementations of this interface need to call this to setup
|
||||
// appropriate logging etc.
|
||||
void init();
|
||||
void emitCollectiveStart(const Work& work);
|
||||
void emitCollectiveEnd(const Work& work);
|
||||
|
||||
const int rank_;
|
||||
const int size_;
|
||||
|
||||
@ -1,60 +0,0 @@
|
||||
#include <atomic>
|
||||
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <unistd.h>
|
||||
#else
|
||||
#include <io.h>
|
||||
#endif
|
||||
|
||||
#include <torch/csrc/distributed/c10d/Hooks.hpp>
|
||||
namespace c10d {
|
||||
|
||||
namespace {
|
||||
|
||||
std::atomic<bool> event_queue_enabled = false;
|
||||
int sync_pipe;
|
||||
std::mutex event_queue_lock;
|
||||
std::deque<details::EventInfo> event_queue;
|
||||
|
||||
} // namespace
|
||||
|
||||
void enable_event_collection(int pipe) {
|
||||
sync_pipe = pipe;
|
||||
event_queue_enabled.store(true);
|
||||
}
|
||||
|
||||
namespace details {
|
||||
|
||||
// we start dropping events after this
|
||||
const size_t MAX_QUEUE_SIZE = 512;
|
||||
|
||||
bool dequeue_c10d_event(EventInfo& evt) {
|
||||
std::unique_lock<std::mutex> lock(event_queue_lock);
|
||||
if (event_queue.size() == 0) {
|
||||
return false;
|
||||
}
|
||||
evt = event_queue.front();
|
||||
event_queue.pop_front();
|
||||
return true;
|
||||
}
|
||||
|
||||
void enqueue_c10d_event(EventInfo&& evt) {
|
||||
if (!event_queue_enabled.load())
|
||||
return;
|
||||
|
||||
std::unique_lock<std::mutex> lock(event_queue_lock);
|
||||
if (event_queue.size() >= MAX_QUEUE_SIZE) {
|
||||
event_queue.back().drop_count++;
|
||||
} else {
|
||||
event_queue.push_back(std::move(evt));
|
||||
char m = 'x';
|
||||
write(sync_pipe, &m, 1);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace c10d
|
||||
@ -1,30 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/Optional.h>
|
||||
#include <string>
|
||||
|
||||
namespace c10d {
|
||||
|
||||
enum class EventKind { CollectiveStart, CollectiveEnd };
|
||||
|
||||
TORCH_API void enable_event_collection(int sync_pipe);
|
||||
|
||||
namespace details {
|
||||
|
||||
struct TORCH_API EventInfo {
|
||||
EventKind event_kind;
|
||||
std::string pg_name;
|
||||
std::string backend;
|
||||
int64_t sequence_number;
|
||||
std::string operation;
|
||||
int64_t timestamp;
|
||||
c10::optional<float> duration_ms;
|
||||
int64_t drop_count;
|
||||
};
|
||||
|
||||
// TODO do we want to expose something else here?
|
||||
TORCH_API bool dequeue_c10d_event(EventInfo& evt);
|
||||
TORCH_API void enqueue_c10d_event(EventInfo&& evt);
|
||||
|
||||
} // namespace details
|
||||
} // namespace c10d
|
||||
@ -82,14 +82,10 @@ std::string opTypeToString(OpType opType) {
|
||||
return "RECVANYSOURCE";
|
||||
case OpType::BARRIER:
|
||||
return "BARRIER";
|
||||
case OpType::_REDUCE_SCATTER_BASE:
|
||||
return "_REDUCE_SCATTER_BASE";
|
||||
case OpType::COALESCED:
|
||||
return "COALESCED";
|
||||
case OpType::_ALLREDUCE_SPARSE:
|
||||
return "_ALLREDUCE_SPARSE";
|
||||
case OpType::UNKNOWN:
|
||||
return "UNKNOWN";
|
||||
case OpType::_REDUCE_SCATTER_BASE:
|
||||
return "_REDUCE_SCATTER_BASE";
|
||||
default:
|
||||
TORCH_INTERNAL_ASSERT(false, "Unknown op type!");
|
||||
}
|
||||
|
||||
@ -855,10 +855,6 @@ void ProcessGroupGloo::runLoop(int workerIndex) {
|
||||
}
|
||||
|
||||
void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
|
||||
emitCollectiveStart(*work.get());
|
||||
work->getFuture()->addCallback(
|
||||
[=](auto& f) { this->emitCollectiveEnd(*work.get()); });
|
||||
|
||||
std::unique_lock<std::mutex> lock(workMutex_);
|
||||
workQueue_.push_back(std::move(work));
|
||||
lock.unlock();
|
||||
|
||||
@ -931,34 +931,27 @@ void ProcessGroupNCCL::ncclCommWatchdog() {
|
||||
}
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::logWorkStart(WorkNCCL& work, bool emitDesyncInfo) {
|
||||
if (terminateProcessGroup_.load() || work.startTraceUpdated_)
|
||||
void ProcessGroupNCCL::logWorkStart(WorkNCCL& work) {
|
||||
if (work.startTraceUpdated_)
|
||||
return;
|
||||
|
||||
if (terminateProcessGroup_.load() || storeError_)
|
||||
return;
|
||||
|
||||
work.startTraceUpdated_ = true;
|
||||
|
||||
emitCollectiveStart(work);
|
||||
|
||||
if (!emitDesyncInfo || storeError_)
|
||||
return;
|
||||
|
||||
storeError_ = !c10d::traceUpdate(
|
||||
store_, traceKeyStart_, work.seq_, opTypeToString(work.opType_));
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::logWorkEnd(WorkNCCL& work, bool emitDesyncInfo) {
|
||||
if (terminateProcessGroup_.load())
|
||||
void ProcessGroupNCCL::logWorkEnd(WorkNCCL& work) {
|
||||
if (terminateProcessGroup_.load() || storeError_)
|
||||
return;
|
||||
|
||||
// In case the start of the work hasn't been logged
|
||||
if (!work.startTraceUpdated_) {
|
||||
logWorkStart(work, emitDesyncInfo);
|
||||
logWorkStart(work);
|
||||
}
|
||||
|
||||
emitCollectiveEnd(work);
|
||||
|
||||
if (!emitDesyncInfo || storeError_)
|
||||
return;
|
||||
|
||||
storeError_ = !c10d::traceUpdate(
|
||||
store_, traceKeyEnd_, work.seq_, opTypeToString(work.opType_));
|
||||
}
|
||||
@ -1010,11 +1003,13 @@ void ProcessGroupNCCL::workCleanupLoop() {
|
||||
}
|
||||
|
||||
// Work status logging for desync debug
|
||||
if (work.isStarted()) {
|
||||
logWorkStart(work, desyncDebug_);
|
||||
}
|
||||
if (work.isCompleted()) {
|
||||
logWorkEnd(work, desyncDebug_);
|
||||
if (desyncDebug_) {
|
||||
if (work.isStarted()) {
|
||||
logWorkStart(work);
|
||||
}
|
||||
if (work.isCompleted()) {
|
||||
logWorkEnd(work);
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up completed work
|
||||
@ -1071,7 +1066,7 @@ void ProcessGroupNCCL::runHookLoop() {
|
||||
timeStarted, // timeStarted
|
||||
std::chrono::system_clock::now(), // timeFinished
|
||||
std::chrono::duration<float, std::milli>(
|
||||
work.getDuration().value()) // activeDuration
|
||||
work.getDuration()) // activeDuration
|
||||
));
|
||||
|
||||
lock.lock();
|
||||
@ -1584,19 +1579,19 @@ c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupNCCL::WorkNCCL::
|
||||
return future_;
|
||||
}
|
||||
|
||||
c10::optional<float> ProcessGroupNCCL::WorkNCCL::getDuration() const {
|
||||
if (!timingEnabled_ || !((*ncclEndEvents_)[0].query())) {
|
||||
return c10::optional<float>();
|
||||
}
|
||||
float ProcessGroupNCCL::WorkNCCL::getDuration() const {
|
||||
TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled")
|
||||
TORCH_CHECK(
|
||||
ncclStartEvents_->size() == 1,
|
||||
"getDuration only works for single device per ProcessGroup.");
|
||||
TORCH_CHECK(
|
||||
ncclEndEvents_->size() == 1,
|
||||
"getDuration only works for single device per ProcessGroup.");
|
||||
TORCH_CHECK(
|
||||
(*ncclEndEvents_)[0].query(),
|
||||
"getDuration can only be called after work is succeeded.")
|
||||
return (*ncclStartEvents_)[0].elapsed_time((*ncclEndEvents_)[0]);
|
||||
}
|
||||
|
||||
uint64_t ProcessGroupNCCL::WorkNCCL::getSequencenumber() const {
|
||||
return seq_;
|
||||
}
|
||||
|
||||
@ -167,7 +167,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// Get a Future object that will be marked as completed internally.
|
||||
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
|
||||
|
||||
c10::optional<float> getDuration() const override;
|
||||
float getDuration() const override;
|
||||
|
||||
uint64_t getSequencenumber() const override;
|
||||
|
||||
@ -615,10 +615,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
void runHookLoop();
|
||||
|
||||
// Desync debug helper
|
||||
void logWorkStart(WorkNCCL& work, bool emitDesyncInfo);
|
||||
void logWorkStart(WorkNCCL& work);
|
||||
|
||||
// Desync debug helper
|
||||
void logWorkEnd(WorkNCCL& work, bool emitDesyncInfo);
|
||||
void logWorkEnd(WorkNCCL& work);
|
||||
|
||||
protected:
|
||||
static const int64_t kWatchdogThreadSleepMillis;
|
||||
|
||||
@ -127,8 +127,8 @@ void Work::finishAndThrow(std::exception_ptr exception) {
|
||||
}
|
||||
}
|
||||
|
||||
c10::optional<float> Work::getDuration() const {
|
||||
return c10::optional<float>();
|
||||
float Work::getDuration() const {
|
||||
TORCH_CHECK(false, "This Backend doesn't support getDuration.");
|
||||
}
|
||||
|
||||
uint64_t Work::getSequencenumber() const {
|
||||
|
||||
@ -107,7 +107,7 @@ class TORCH_API Work : public torch::CustomClassHolder {
|
||||
// work. Only NCCL backend is currently supported.
|
||||
virtual c10::intrusive_ptr<c10::ivalue::Future> getFuture();
|
||||
|
||||
virtual c10::optional<float> getDuration() const;
|
||||
virtual float getDuration() const;
|
||||
|
||||
virtual uint64_t getSequencenumber() const;
|
||||
|
||||
|
||||
@ -32,7 +32,6 @@
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <pybind11/chrono.h>
|
||||
#include <torch/csrc/distributed/c10d/Hooks.hpp>
|
||||
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/comm.hpp>
|
||||
@ -292,26 +291,6 @@ void _register_builtin_comm_hook(
|
||||
reducer.register_builtin_comm_hook(comm_hook_type);
|
||||
}
|
||||
|
||||
py::object c10d_dequeue_python_event() {
|
||||
::c10d::details::EventInfo evt;
|
||||
if (!::c10d::details::dequeue_c10d_event(evt)) {
|
||||
return py::none();
|
||||
}
|
||||
|
||||
py::dict data;
|
||||
data["event_kind"] = (int)evt.event_kind;
|
||||
|
||||
data["pg_name"] = evt.pg_name;
|
||||
data["backend"] = evt.backend;
|
||||
data["sequence_number"] = evt.sequence_number;
|
||||
data["operation"] = evt.operation;
|
||||
data["timestamp"] = evt.timestamp;
|
||||
data["duration"] = evt.duration_ms.value_or(-1);
|
||||
data["drop_count"] = evt.drop_count;
|
||||
|
||||
return std::move(data);
|
||||
}
|
||||
|
||||
// Customize the metaclass of ::c10d::ReduceOp for the backward compatibility.
|
||||
// https://github.com/pytorch/pytorch/pull/84243 changed ::c10d::ReduceOp to
|
||||
// struct from enum, sacrificing some of the Python built-in function supports
|
||||
@ -661,11 +640,6 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
|
||||
&::c10d::Logger::set_static_graph,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
py::enum_<::c10d::EventKind>(module, "EventKind", R"(
|
||||
An enum for collective hooks event types.)")
|
||||
.value("START", ::c10d::EventKind::CollectiveStart)
|
||||
.value("END", ::c10d::EventKind::CollectiveEnd);
|
||||
|
||||
py::enum_<::c10d::DebugLevel>(module, "DebugLevel", R"(
|
||||
An enum whose values correspond to different debug levels of the
|
||||
torch.distributed package. Currently supporting OFF, INFO, and DETAIL,
|
||||
@ -689,16 +663,7 @@ An enum for collective hooks event types.)")
|
||||
"set_debug_level_from_env",
|
||||
::c10d::setDebugLevelFromEnvironment,
|
||||
R"(Sets the debug level of the torch.distributed package from the
|
||||
``TORCH_DISTRIBUTED_DEBUG`` environment variable.)")
|
||||
.def(
|
||||
"_enable_event_collection",
|
||||
&::c10d::enable_event_collection,
|
||||
"(Enables events collection).",
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"_dequeue_c10d_event",
|
||||
&c10d_dequeue_python_event,
|
||||
"(Blocks until a c10d event is available and return it as a python dictionary).");
|
||||
``TORCH_DISTRIBUTED_DEBUG`` environment variable.)");
|
||||
|
||||
// TODO(crcrpar): Hardening `ReduceOp`.
|
||||
// While keeping most op types as enum value,
|
||||
|
||||
@ -1,183 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
if dist.is_available():
|
||||
import torch.distributed.distributed_c10d as c10d
|
||||
|
||||
from torch._C._distributed_c10d import (
|
||||
_dequeue_c10d_event,
|
||||
_enable_event_collection,
|
||||
EventKind,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CollectiveStatus",
|
||||
"COLLECTIVE_HOOK_TYPE",
|
||||
"PG_HOOK_TYPE",
|
||||
"register_collective_start_hook",
|
||||
"register_collective_end_hook",
|
||||
"register_process_group_hook",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CollectiveStatus:
|
||||
r"""Status of a collective operation.
|
||||
|
||||
Drop count indicates events have been dropped at the producer side, which means they
|
||||
were not consumed fast enough by hooks.
|
||||
"""
|
||||
pg_name: str = "unknown" # This name matches the one informed in the pgreg cb
|
||||
backend: str = "unknown" # Name of the backend used
|
||||
sequence_number: int = -1 # This name matches the one informed in the pgreg cb
|
||||
operation: str = "unknown" # collective name
|
||||
timestamp: int = 0 # timestamp to the earliest time we noticed this event
|
||||
duration: Optional[float] = None # value in milliseconds it took executing
|
||||
drop_count: int = 0 # number of events dropped following this one
|
||||
|
||||
|
||||
COLLECTIVE_HOOK_TYPE = Callable[[CollectiveStatus], None]
|
||||
PG_HOOK_TYPE = Callable[[dist.ProcessGroup, str], None]
|
||||
|
||||
# This controls the number of internal failures we'll tolerate before giving up
|
||||
_MAX_INTERNAL_FAILURES = 10
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_cb_thread: Optional[threading.Thread] = None
|
||||
_start_callbacks: List[COLLECTIVE_HOOK_TYPE] = []
|
||||
_end_callbacks: List[COLLECTIVE_HOOK_TYPE] = []
|
||||
_pp_r = -1
|
||||
_pp_w = -1
|
||||
|
||||
|
||||
def _c10d_pg_hooks_loops():
|
||||
internal_failures = 0
|
||||
while True:
|
||||
# we don't care about the result, this is how we implement notification
|
||||
_ = os.read(_pp_r, 1)
|
||||
evt: Dict[str, object] = _dequeue_c10d_event()
|
||||
try:
|
||||
event_kind = evt.pop("event_kind", None)
|
||||
if event_kind is None:
|
||||
logger.warning(
|
||||
"c10d returned event dictionary %s without 'event_kind' key, cannot dispatch",
|
||||
evt,
|
||||
)
|
||||
internal_failures += 1
|
||||
if internal_failures >= _MAX_INTERNAL_FAILURES:
|
||||
logger.warning(
|
||||
"too many internal c10d failures processing callback loop. stopping"
|
||||
)
|
||||
return
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
if event_kind == int(EventKind.START): # type: ignore[call-overload]
|
||||
cb_list = _start_callbacks
|
||||
elif event_kind == int(EventKind.END): # type: ignore[call-overload]
|
||||
cb_list = _end_callbacks
|
||||
else:
|
||||
logger.warning(
|
||||
"c10d event %s with invalid 'event_kind' with value %d",
|
||||
evt,
|
||||
event_kind,
|
||||
)
|
||||
internal_failures += 1
|
||||
if internal_failures >= _MAX_INTERNAL_FAILURES:
|
||||
logger.warning(
|
||||
"too many internal c10d failures processing callback loop. stopping"
|
||||
)
|
||||
return
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
status = CollectiveStatus(**evt) # type: ignore[arg-type]
|
||||
for cb in cb_list:
|
||||
try:
|
||||
cb(status)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"c10d event callback %s with event %s threw exception %s",
|
||||
cb,
|
||||
status,
|
||||
e,
|
||||
)
|
||||
except Exception as e:
|
||||
# We have to keep processing otherwise the queue will grown infinitely large
|
||||
logger.warning(
|
||||
"c10d callback thread when processing event %s raised exception %s.",
|
||||
evt,
|
||||
e,
|
||||
)
|
||||
internal_failures += 1
|
||||
if internal_failures >= _MAX_INTERNAL_FAILURES:
|
||||
logger.warning(
|
||||
"too many internal c10d failures processing callback loop. stopping"
|
||||
)
|
||||
return
|
||||
|
||||
# Sleep for a second to avoid hogging the GIL in case of a persistent failure
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def _lazy_init():
|
||||
global _cb_thread
|
||||
if _cb_thread is not None:
|
||||
return
|
||||
global _pp_r
|
||||
global _pp_w
|
||||
_pp_r, _pp_w = os.pipe()
|
||||
_enable_event_collection(_pp_w)
|
||||
c10d._enable_collectives_timing()
|
||||
_cb_thread = threading.Thread(target=_c10d_pg_hooks_loops, daemon=True)
|
||||
_cb_thread.start()
|
||||
logger.info("c10d::hooks thread enabled")
|
||||
|
||||
|
||||
def _check_distributed_available():
|
||||
if not dist.is_available():
|
||||
raise RuntimeError(
|
||||
"torch.distributed is not available, so hooks are not available."
|
||||
)
|
||||
|
||||
|
||||
def register_collective_start_hook(hook: COLLECTIVE_HOOK_TYPE) -> None:
|
||||
r"""Register a hook that is called every time a collective starts.
|
||||
|
||||
The hook is invoked on a background thread.
|
||||
Exceptions raised by the callback are ignored and non-fatal.
|
||||
"""
|
||||
_check_distributed_available()
|
||||
_start_callbacks.append(hook)
|
||||
_lazy_init()
|
||||
|
||||
|
||||
def register_collective_end_hook(hook: COLLECTIVE_HOOK_TYPE) -> None:
|
||||
r"""Register a hook that is called every time a collective finishes.
|
||||
|
||||
|
||||
The hook is invoked on a background thread.
|
||||
Exceptions raised by the callback are ignored and non-fatal.
|
||||
"""
|
||||
_check_distributed_available()
|
||||
_end_callbacks.append(hook)
|
||||
_lazy_init()
|
||||
|
||||
|
||||
def register_process_group_hook(hook: PG_HOOK_TYPE) -> None:
|
||||
r"""Register a hook that is called every time a process group is created on this rank.
|
||||
|
||||
This hook is only invoked if the current rank is part of the PG being created.
|
||||
The pg_name is unique to the whole cluster and should be treated as an opaque identified subject to change.
|
||||
|
||||
The hook is invoked on a background thread.
|
||||
Exceptions raised by the callback are ignored and non-fatal.
|
||||
"""
|
||||
_check_distributed_available()
|
||||
c10d._register_creation_hook(hook)
|
||||
@ -421,19 +421,6 @@ _group_count = 0
|
||||
_tags_to_pg: Dict[str, List[ProcessGroup]] = {}
|
||||
_pg_to_tag: Dict[ProcessGroup, str] = {}
|
||||
|
||||
class _HookState:
|
||||
def __init__(self):
|
||||
self.creation_hooks = []
|
||||
|
||||
def register_creation_hook(self, hook) -> None:
|
||||
self.creation_hooks.append(hook)
|
||||
|
||||
def fire_creation_hook(self, pg, name) -> None:
|
||||
for hook in self.creation_hooks:
|
||||
try:
|
||||
hook(pg, name)
|
||||
except Exception as e:
|
||||
logger.info("hook %s failed with %s", hook, e)
|
||||
|
||||
class _World:
|
||||
"""
|
||||
@ -447,8 +434,6 @@ class _World:
|
||||
self._default_pg = None
|
||||
self._pg_coalesce_state: Dict[ProcessGroup, List[Union[_CollOp, P2POp]]] = {}
|
||||
self._pg_default_device: Dict[ProcessGroup, torch.device] = {}
|
||||
self._hook_state = _HookState()
|
||||
self.enable_collectives_timing = False
|
||||
|
||||
@property
|
||||
def default_pg(self):
|
||||
@ -558,9 +543,6 @@ class _World:
|
||||
)
|
||||
return config_info
|
||||
|
||||
@property
|
||||
def pg_hook_state(self) -> _HookState:
|
||||
return self._hook_state
|
||||
|
||||
_world = _World()
|
||||
"""Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it"""
|
||||
@ -1382,9 +1364,6 @@ def _new_process_group_helper(
|
||||
pg._set_group_name(group_name)
|
||||
|
||||
_world.pg_backend_config[pg] = str(backend_config)
|
||||
if _world.enable_collectives_timing:
|
||||
pg._enable_collectives_timing()
|
||||
|
||||
# "" is the default tag for user PGs
|
||||
if pg_tag in [None, ""]:
|
||||
pg_tag = f"ptd:{group_name}"
|
||||
@ -1394,8 +1373,6 @@ def _new_process_group_helper(
|
||||
|
||||
_world.tags_to_pg.setdefault(pg_tag, []).append(pg)
|
||||
_world.pg_to_tag[pg] = pg_tag
|
||||
_world.pg_hook_state.fire_creation_hook(pg, group_name)
|
||||
|
||||
return pg, prefix_store
|
||||
|
||||
def destroy_process_group(group: Optional[ProcessGroup] = None):
|
||||
@ -4338,11 +4315,3 @@ dynamo_unsupported_distributed_c10d_ops = [
|
||||
reduce_scatter_tensor,
|
||||
send,
|
||||
]
|
||||
|
||||
def _register_creation_hook(hook):
|
||||
_world.pg_hook_state.register_creation_hook(hook)
|
||||
|
||||
def _enable_collectives_timing():
|
||||
_world.enable_collectives_timing = True
|
||||
for pg in _world.pg_map:
|
||||
pg._enable_collectives_timing()
|
||||
|
||||
@ -401,11 +401,6 @@ class WorldData:
|
||||
class ThreadLocalWorld:
|
||||
_world = threading.local()
|
||||
|
||||
def __init__(self):
|
||||
self.enable_collectives_timing = False
|
||||
self._hook_state = dist.distributed_c10d._HookState()
|
||||
|
||||
|
||||
def _get_world(self) -> WorldData:
|
||||
if not hasattr(ThreadLocalWorld._world, "world"):
|
||||
ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0, {}, {}, {}, {})
|
||||
@ -459,9 +454,6 @@ class ThreadLocalWorld:
|
||||
def pg_default_device(self) -> Dict[dist.ProcessGroup, torch.device]:
|
||||
return self._get_world().pg_default_device
|
||||
|
||||
@property
|
||||
def pg_hook_state(self) -> dist.distributed_c10d._HookState:
|
||||
return self._hook_state
|
||||
|
||||
_old_pg_world = None
|
||||
_ctx_manager = None
|
||||
|
||||
Reference in New Issue
Block a user