mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit bb1424d46e656dfcdd4c12efe58ada9f1720c4d8. Reverted https://github.com/pytorch/pytorch/pull/111072 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/111072#issuecomment-1765399829))
This commit is contained in:
@ -522,7 +522,6 @@ libtorch_distributed_base_sources = [
|
|||||||
"torch/csrc/distributed/c10d/Backend.cpp",
|
"torch/csrc/distributed/c10d/Backend.cpp",
|
||||||
"torch/csrc/distributed/c10d/FileStore.cpp",
|
"torch/csrc/distributed/c10d/FileStore.cpp",
|
||||||
"torch/csrc/distributed/c10d/GlooDeviceFactory.cpp",
|
"torch/csrc/distributed/c10d/GlooDeviceFactory.cpp",
|
||||||
"torch/csrc/distributed/c10d/Hooks.cpp",
|
|
||||||
"torch/csrc/distributed/c10d/Ops.cpp",
|
"torch/csrc/distributed/c10d/Ops.cpp",
|
||||||
"torch/csrc/distributed/c10d/ParamCommsUtils.cpp",
|
"torch/csrc/distributed/c10d/ParamCommsUtils.cpp",
|
||||||
"torch/csrc/distributed/c10d/PrefixStore.cpp",
|
"torch/csrc/distributed/c10d/PrefixStore.cpp",
|
||||||
|
@ -1,270 +0,0 @@
|
|||||||
# Owner(s): ["oncall: distributed"]
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import tempfile
|
|
||||||
import threading
|
|
||||||
from functools import partial, wraps
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.distributed._hooks as dhooks
|
|
||||||
|
|
||||||
if not dist.is_available():
|
|
||||||
print("torch.distributed not available, skipping tests", file=sys.stderr)
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
|
|
||||||
from torch.testing._internal.common_distributed import (
|
|
||||||
MultiProcessTestCase,
|
|
||||||
skip_if_lt_x_gpu,
|
|
||||||
)
|
|
||||||
|
|
||||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
||||||
|
|
||||||
|
|
||||||
class PgHooks(MultiProcessTestCase):
|
|
||||||
@property
|
|
||||||
def world_size(self) -> int:
|
|
||||||
return 4
|
|
||||||
|
|
||||||
def setUp(self) -> None:
|
|
||||||
super().setUp()
|
|
||||||
self._spawn_processes()
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
super().tearDown()
|
|
||||||
try:
|
|
||||||
os.remove(self.file_name)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_pg_hook(self):
|
|
||||||
pgs = []
|
|
||||||
|
|
||||||
def pg_hook(pg, pg_name):
|
|
||||||
pgs.append((pg, pg_name))
|
|
||||||
|
|
||||||
dhooks.register_process_group_hook(pg_hook)
|
|
||||||
dist.init_process_group(
|
|
||||||
backend="gloo",
|
|
||||||
rank=self.rank,
|
|
||||||
world_size=self.world_size,
|
|
||||||
store=dist.FileStore(self.file_name, self.world_size),
|
|
||||||
)
|
|
||||||
self.assertEqual(len(pgs), 1)
|
|
||||||
self.assertEqual(pgs[0][0], dist.group.WORLD)
|
|
||||||
|
|
||||||
# create two partial world PGs
|
|
||||||
pg0 = dist.new_group(ranks=[0, 1])
|
|
||||||
pg1 = dist.new_group(ranks=[2, 3])
|
|
||||||
|
|
||||||
# Each rank only observe two PGs being created: the default PG and one covering its ranks
|
|
||||||
# We don't emit events for PG creation if the current rank doesn't belong to it.
|
|
||||||
# For example, say you're rank 1, you'll get an event for pg0 but not pg1 even though the API contact
|
|
||||||
# dictates you need to call new_group for both.
|
|
||||||
self.assertEqual(len(pgs), 2)
|
|
||||||
self.assertEqual(pgs[1][0], pg0 if self.rank < 2 else pg1)
|
|
||||||
|
|
||||||
|
|
||||||
def with_comms(func=None):
|
|
||||||
if func is None:
|
|
||||||
return partial(
|
|
||||||
with_comms,
|
|
||||||
)
|
|
||||||
|
|
||||||
@wraps(func)
|
|
||||||
def wrapper(self, *args, **kwargs):
|
|
||||||
self.init_comms()
|
|
||||||
func(self, *args, **kwargs)
|
|
||||||
self.destroy_comms()
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
class CollectiveHooks:
|
|
||||||
@property
|
|
||||||
def world_size(self) -> int:
|
|
||||||
return 4
|
|
||||||
|
|
||||||
def _collective_hooks(self):
|
|
||||||
# it's ok to access them directly since there's a single bg thread poking at them.
|
|
||||||
starts = []
|
|
||||||
ends = []
|
|
||||||
cv = threading.Condition()
|
|
||||||
|
|
||||||
def coll_start(status):
|
|
||||||
starts.append(status)
|
|
||||||
print(f"col_start {len(starts)} rank{self.rank}")
|
|
||||||
|
|
||||||
def coll_end(status):
|
|
||||||
ends.append(status)
|
|
||||||
print(f"col_end {len(ends)} rank{self.rank}")
|
|
||||||
if len(ends) == 2:
|
|
||||||
with cv:
|
|
||||||
cv.notify()
|
|
||||||
|
|
||||||
dhooks.register_collective_start_hook(coll_start)
|
|
||||||
dhooks.register_collective_end_hook(coll_end)
|
|
||||||
|
|
||||||
tensor = torch.ones([2, 3]).to(self.device) * self.rank
|
|
||||||
tensor_list = [torch.empty_like(tensor) for _ in range(self.world_size)]
|
|
||||||
|
|
||||||
dist.all_gather(tensor_list, tensor)
|
|
||||||
|
|
||||||
tensor2 = torch.ones([2, 3]).to(self.device) * self.rank
|
|
||||||
dist.all_reduce(tensor2)
|
|
||||||
|
|
||||||
with cv:
|
|
||||||
cv.wait(1)
|
|
||||||
|
|
||||||
default_pg_name = dist.group.WORLD.group_name
|
|
||||||
self.assertEqual(2, len(starts))
|
|
||||||
self.assertEqual(2, len(ends))
|
|
||||||
|
|
||||||
def check_op(idx, coll_name):
|
|
||||||
self.assertEqual(default_pg_name, starts[idx].pg_name)
|
|
||||||
self.assertEqual(self.backend_name, starts[idx].backend)
|
|
||||||
self.assertGreaterEqual(starts[idx].sequence_number, 0)
|
|
||||||
self.assertGreaterEqual(starts[idx].timestamp, 0)
|
|
||||||
self.assertEqual(coll_name, starts[idx].operation)
|
|
||||||
|
|
||||||
self.assertEqual(default_pg_name, ends[idx].pg_name)
|
|
||||||
self.assertEqual(self.backend_name, ends[idx].backend)
|
|
||||||
|
|
||||||
self.assertEqual(starts[idx].sequence_number, ends[idx].sequence_number)
|
|
||||||
self.assertLessEqual(starts[idx].timestamp, ends[idx].timestamp)
|
|
||||||
self.assertEqual(coll_name, ends[idx].operation)
|
|
||||||
|
|
||||||
check_op(0, "ALLGATHER")
|
|
||||||
check_op(1, "ALLREDUCE")
|
|
||||||
|
|
||||||
|
|
||||||
class GlooHooks(MultiProcessTestCase, CollectiveHooks):
|
|
||||||
def setUp(self) -> None:
|
|
||||||
super().setUp()
|
|
||||||
self._spawn_processes()
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
super().tearDown()
|
|
||||||
try:
|
|
||||||
os.remove(self.file_name)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def init_comms(self):
|
|
||||||
dist.init_process_group(
|
|
||||||
backend="gloo",
|
|
||||||
rank=self.rank,
|
|
||||||
world_size=self.world_size,
|
|
||||||
store=dist.FileStore(self.file_name, self.world_size),
|
|
||||||
)
|
|
||||||
|
|
||||||
def destroy_comms(self):
|
|
||||||
dist.destroy_process_group()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def backend_name(self):
|
|
||||||
return "gloo"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self):
|
|
||||||
return "cpu"
|
|
||||||
|
|
||||||
@with_comms
|
|
||||||
def test_collective_hooks(self):
|
|
||||||
self._collective_hooks()
|
|
||||||
|
|
||||||
|
|
||||||
class NcclHooks(MultiProcessTestCase, CollectiveHooks):
|
|
||||||
def setUp(self) -> None:
|
|
||||||
super().setUp()
|
|
||||||
self._spawn_processes()
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
super().tearDown()
|
|
||||||
try:
|
|
||||||
os.remove(self.file_name)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def init_comms(self):
|
|
||||||
dist.init_process_group(
|
|
||||||
backend="nccl",
|
|
||||||
rank=self.rank,
|
|
||||||
world_size=self.world_size,
|
|
||||||
store=dist.FileStore(self.file_name, self.world_size),
|
|
||||||
)
|
|
||||||
|
|
||||||
def destroy_comms(self):
|
|
||||||
dist.destroy_process_group()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def backend_name(self):
|
|
||||||
return "nccl"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self):
|
|
||||||
return f"cuda:{self.rank}"
|
|
||||||
|
|
||||||
@skip_if_lt_x_gpu(4)
|
|
||||||
@with_comms
|
|
||||||
def test_collective_hooks(self):
|
|
||||||
self._collective_hooks()
|
|
||||||
|
|
||||||
|
|
||||||
class SingleRankTests(TestCase):
|
|
||||||
def setUp(self) -> None:
|
|
||||||
super().setUp()
|
|
||||||
self.rank = 0
|
|
||||||
self.file_name = tempfile.NamedTemporaryFile(delete=False).name
|
|
||||||
dist.init_process_group(
|
|
||||||
backend="gloo",
|
|
||||||
rank=0,
|
|
||||||
world_size=1,
|
|
||||||
store=dist.FileStore(self.file_name, 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
def tearDown(self) -> None:
|
|
||||||
dist.destroy_process_group()
|
|
||||||
|
|
||||||
def test_queue_overflow(self) -> None:
|
|
||||||
cv_done_colls = threading.Condition()
|
|
||||||
cv_done_cb = threading.Condition()
|
|
||||||
colls_done = False
|
|
||||||
starts = []
|
|
||||||
status_with_dropped = None
|
|
||||||
|
|
||||||
def coll_start(status: dhooks.CollectiveStatus):
|
|
||||||
starts.append(status)
|
|
||||||
with cv_done_colls:
|
|
||||||
while not colls_done:
|
|
||||||
cv_done_colls.wait()
|
|
||||||
if status.drop_count > 0:
|
|
||||||
nonlocal status_with_dropped
|
|
||||||
status_with_dropped = status
|
|
||||||
with cv_done_cb:
|
|
||||||
cv_done_cb.notify()
|
|
||||||
|
|
||||||
dhooks.register_collective_start_hook(coll_start)
|
|
||||||
|
|
||||||
# native limit is 512
|
|
||||||
for i in range(600):
|
|
||||||
dist.all_reduce(torch.ones([2, 3]))
|
|
||||||
colls_done = True
|
|
||||||
with cv_done_colls:
|
|
||||||
cv_done_colls.notify()
|
|
||||||
|
|
||||||
with cv_done_cb:
|
|
||||||
cv_done_cb.wait(10)
|
|
||||||
|
|
||||||
self.assertTrue(status_with_dropped is not None)
|
|
||||||
self.assertTrue(status_with_dropped.drop_count > 0)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
assert (
|
|
||||||
not torch.cuda._initialized
|
|
||||||
), "test_distributed must not have initialized CUDA context on main process"
|
|
||||||
|
|
||||||
run_tests()
|
|
@ -11,10 +11,6 @@ _DEFAULT_FIRST_BUCKET_BYTES: int
|
|||||||
_DEFAULT_NO_TIMEOUT: timedelta
|
_DEFAULT_NO_TIMEOUT: timedelta
|
||||||
_DEFAULT_PG_TIMEOUT: timedelta
|
_DEFAULT_PG_TIMEOUT: timedelta
|
||||||
|
|
||||||
class EventKind(Enum):
|
|
||||||
START = ...
|
|
||||||
END = ...
|
|
||||||
|
|
||||||
class BuiltinCommHookType(Enum):
|
class BuiltinCommHookType(Enum):
|
||||||
ALLREDUCE = ...
|
ALLREDUCE = ...
|
||||||
FP16_COMPRESS = ...
|
FP16_COMPRESS = ...
|
||||||
@ -24,8 +20,6 @@ def _register_builtin_comm_hook(
|
|||||||
reducer: Reducer,
|
reducer: Reducer,
|
||||||
comm_hook_type: BuiltinCommHookType,
|
comm_hook_type: BuiltinCommHookType,
|
||||||
): ...
|
): ...
|
||||||
def _dequeue_c10d_event() -> Dict[str, object]: ...
|
|
||||||
def _enable_event_collection(pipe_fs: int) -> None: ...
|
|
||||||
|
|
||||||
class GradBucket:
|
class GradBucket:
|
||||||
def index(self) -> int: ...
|
def index(self) -> int: ...
|
||||||
|
@ -1,26 +1,9 @@
|
|||||||
#include <c10/util/Logging.h>
|
#include <c10/util/Logging.h>
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
#include <torch/csrc/distributed/c10d/Backend.hpp>
|
#include <torch/csrc/distributed/c10d/Backend.hpp>
|
||||||
#include <torch/csrc/distributed/c10d/Hooks.hpp>
|
|
||||||
#include <torch/csrc/distributed/c10d/logging.h>
|
|
||||||
|
|
||||||
namespace c10d {
|
namespace c10d {
|
||||||
|
|
||||||
namespace {
|
|
||||||
void commonEventinit(
|
|
||||||
details::EventInfo& evt,
|
|
||||||
const Backend& backend,
|
|
||||||
const Work& work) {
|
|
||||||
evt.timestamp =
|
|
||||||
std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
|
|
||||||
evt.pg_name = backend.getGroupName();
|
|
||||||
evt.backend = backend.getBackendName();
|
|
||||||
evt.sequence_number = work.getSequencenumber();
|
|
||||||
evt.operation = c10d::opTypeToString(work.retrieveOpType());
|
|
||||||
evt.drop_count = 0;
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
Backend::Backend(int rank, int size)
|
Backend::Backend(int rank, int size)
|
||||||
: rank_(rank), size_(size), dist_debug_level_(debug_level()) {
|
: rank_(rank), size_(size), dist_debug_level_(debug_level()) {
|
||||||
C10_LOG_API_USAGE_ONCE("c10d.backend");
|
C10_LOG_API_USAGE_ONCE("c10d.backend");
|
||||||
@ -32,21 +15,4 @@ void Backend::init() {
|
|||||||
C10_LOG_API_USAGE_ONCE(fmt::format("c10d.backend_{}", getBackendName()));
|
C10_LOG_API_USAGE_ONCE(fmt::format("c10d.backend_{}", getBackendName()));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Backend::emitCollectiveStart(const Work& work) {
|
|
||||||
details::EventInfo evt;
|
|
||||||
commonEventinit(evt, *this, work);
|
|
||||||
|
|
||||||
evt.event_kind = ::c10d::EventKind::CollectiveStart;
|
|
||||||
details::enqueue_c10d_event(std::move(evt));
|
|
||||||
}
|
|
||||||
|
|
||||||
void Backend::emitCollectiveEnd(const Work& work) {
|
|
||||||
details::EventInfo evt;
|
|
||||||
commonEventinit(evt, *this, work);
|
|
||||||
|
|
||||||
evt.event_kind = ::c10d::EventKind::CollectiveEnd;
|
|
||||||
evt.duration_ms = work.getDuration();
|
|
||||||
details::enqueue_c10d_event(std::move(evt));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace c10d
|
} // namespace c10d
|
||||||
|
@ -366,8 +366,6 @@ class TORCH_API Backend : public torch::CustomClassHolder {
|
|||||||
// Implementations of this interface need to call this to setup
|
// Implementations of this interface need to call this to setup
|
||||||
// appropriate logging etc.
|
// appropriate logging etc.
|
||||||
void init();
|
void init();
|
||||||
void emitCollectiveStart(const Work& work);
|
|
||||||
void emitCollectiveEnd(const Work& work);
|
|
||||||
|
|
||||||
const int rank_;
|
const int rank_;
|
||||||
const int size_;
|
const int size_;
|
||||||
|
@ -1,60 +0,0 @@
|
|||||||
#include <atomic>
|
|
||||||
|
|
||||||
#include <deque>
|
|
||||||
#include <memory>
|
|
||||||
#include <mutex>
|
|
||||||
|
|
||||||
#ifndef _WIN32
|
|
||||||
#include <unistd.h>
|
|
||||||
#else
|
|
||||||
#include <io.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include <torch/csrc/distributed/c10d/Hooks.hpp>
|
|
||||||
namespace c10d {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
std::atomic<bool> event_queue_enabled = false;
|
|
||||||
int sync_pipe;
|
|
||||||
std::mutex event_queue_lock;
|
|
||||||
std::deque<details::EventInfo> event_queue;
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void enable_event_collection(int pipe) {
|
|
||||||
sync_pipe = pipe;
|
|
||||||
event_queue_enabled.store(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace details {
|
|
||||||
|
|
||||||
// we start dropping events after this
|
|
||||||
const size_t MAX_QUEUE_SIZE = 512;
|
|
||||||
|
|
||||||
bool dequeue_c10d_event(EventInfo& evt) {
|
|
||||||
std::unique_lock<std::mutex> lock(event_queue_lock);
|
|
||||||
if (event_queue.size() == 0) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
evt = event_queue.front();
|
|
||||||
event_queue.pop_front();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void enqueue_c10d_event(EventInfo&& evt) {
|
|
||||||
if (!event_queue_enabled.load())
|
|
||||||
return;
|
|
||||||
|
|
||||||
std::unique_lock<std::mutex> lock(event_queue_lock);
|
|
||||||
if (event_queue.size() >= MAX_QUEUE_SIZE) {
|
|
||||||
event_queue.back().drop_count++;
|
|
||||||
} else {
|
|
||||||
event_queue.push_back(std::move(evt));
|
|
||||||
char m = 'x';
|
|
||||||
write(sync_pipe, &m, 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace details
|
|
||||||
} // namespace c10d
|
|
@ -1,30 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <c10/util/Optional.h>
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
namespace c10d {
|
|
||||||
|
|
||||||
enum class EventKind { CollectiveStart, CollectiveEnd };
|
|
||||||
|
|
||||||
TORCH_API void enable_event_collection(int sync_pipe);
|
|
||||||
|
|
||||||
namespace details {
|
|
||||||
|
|
||||||
struct TORCH_API EventInfo {
|
|
||||||
EventKind event_kind;
|
|
||||||
std::string pg_name;
|
|
||||||
std::string backend;
|
|
||||||
int64_t sequence_number;
|
|
||||||
std::string operation;
|
|
||||||
int64_t timestamp;
|
|
||||||
c10::optional<float> duration_ms;
|
|
||||||
int64_t drop_count;
|
|
||||||
};
|
|
||||||
|
|
||||||
// TODO do we want to expose something else here?
|
|
||||||
TORCH_API bool dequeue_c10d_event(EventInfo& evt);
|
|
||||||
TORCH_API void enqueue_c10d_event(EventInfo&& evt);
|
|
||||||
|
|
||||||
} // namespace details
|
|
||||||
} // namespace c10d
|
|
@ -82,14 +82,10 @@ std::string opTypeToString(OpType opType) {
|
|||||||
return "RECVANYSOURCE";
|
return "RECVANYSOURCE";
|
||||||
case OpType::BARRIER:
|
case OpType::BARRIER:
|
||||||
return "BARRIER";
|
return "BARRIER";
|
||||||
case OpType::_REDUCE_SCATTER_BASE:
|
|
||||||
return "_REDUCE_SCATTER_BASE";
|
|
||||||
case OpType::COALESCED:
|
|
||||||
return "COALESCED";
|
|
||||||
case OpType::_ALLREDUCE_SPARSE:
|
|
||||||
return "_ALLREDUCE_SPARSE";
|
|
||||||
case OpType::UNKNOWN:
|
case OpType::UNKNOWN:
|
||||||
return "UNKNOWN";
|
return "UNKNOWN";
|
||||||
|
case OpType::_REDUCE_SCATTER_BASE:
|
||||||
|
return "_REDUCE_SCATTER_BASE";
|
||||||
default:
|
default:
|
||||||
TORCH_INTERNAL_ASSERT(false, "Unknown op type!");
|
TORCH_INTERNAL_ASSERT(false, "Unknown op type!");
|
||||||
}
|
}
|
||||||
|
@ -855,10 +855,6 @@ void ProcessGroupGloo::runLoop(int workerIndex) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
|
void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
|
||||||
emitCollectiveStart(*work.get());
|
|
||||||
work->getFuture()->addCallback(
|
|
||||||
[=](auto& f) { this->emitCollectiveEnd(*work.get()); });
|
|
||||||
|
|
||||||
std::unique_lock<std::mutex> lock(workMutex_);
|
std::unique_lock<std::mutex> lock(workMutex_);
|
||||||
workQueue_.push_back(std::move(work));
|
workQueue_.push_back(std::move(work));
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
|
@ -931,34 +931,27 @@ void ProcessGroupNCCL::ncclCommWatchdog() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProcessGroupNCCL::logWorkStart(WorkNCCL& work, bool emitDesyncInfo) {
|
void ProcessGroupNCCL::logWorkStart(WorkNCCL& work) {
|
||||||
if (terminateProcessGroup_.load() || work.startTraceUpdated_)
|
if (work.startTraceUpdated_)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
|
if (terminateProcessGroup_.load() || storeError_)
|
||||||
|
return;
|
||||||
|
|
||||||
work.startTraceUpdated_ = true;
|
work.startTraceUpdated_ = true;
|
||||||
|
|
||||||
emitCollectiveStart(work);
|
|
||||||
|
|
||||||
if (!emitDesyncInfo || storeError_)
|
|
||||||
return;
|
|
||||||
|
|
||||||
storeError_ = !c10d::traceUpdate(
|
storeError_ = !c10d::traceUpdate(
|
||||||
store_, traceKeyStart_, work.seq_, opTypeToString(work.opType_));
|
store_, traceKeyStart_, work.seq_, opTypeToString(work.opType_));
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProcessGroupNCCL::logWorkEnd(WorkNCCL& work, bool emitDesyncInfo) {
|
void ProcessGroupNCCL::logWorkEnd(WorkNCCL& work) {
|
||||||
if (terminateProcessGroup_.load())
|
if (terminateProcessGroup_.load() || storeError_)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
// In case the start of the work hasn't been logged
|
// In case the start of the work hasn't been logged
|
||||||
if (!work.startTraceUpdated_) {
|
if (!work.startTraceUpdated_) {
|
||||||
logWorkStart(work, emitDesyncInfo);
|
logWorkStart(work);
|
||||||
}
|
}
|
||||||
|
|
||||||
emitCollectiveEnd(work);
|
|
||||||
|
|
||||||
if (!emitDesyncInfo || storeError_)
|
|
||||||
return;
|
|
||||||
|
|
||||||
storeError_ = !c10d::traceUpdate(
|
storeError_ = !c10d::traceUpdate(
|
||||||
store_, traceKeyEnd_, work.seq_, opTypeToString(work.opType_));
|
store_, traceKeyEnd_, work.seq_, opTypeToString(work.opType_));
|
||||||
}
|
}
|
||||||
@ -1010,11 +1003,13 @@ void ProcessGroupNCCL::workCleanupLoop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Work status logging for desync debug
|
// Work status logging for desync debug
|
||||||
if (work.isStarted()) {
|
if (desyncDebug_) {
|
||||||
logWorkStart(work, desyncDebug_);
|
if (work.isStarted()) {
|
||||||
}
|
logWorkStart(work);
|
||||||
if (work.isCompleted()) {
|
}
|
||||||
logWorkEnd(work, desyncDebug_);
|
if (work.isCompleted()) {
|
||||||
|
logWorkEnd(work);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean up completed work
|
// Clean up completed work
|
||||||
@ -1071,7 +1066,7 @@ void ProcessGroupNCCL::runHookLoop() {
|
|||||||
timeStarted, // timeStarted
|
timeStarted, // timeStarted
|
||||||
std::chrono::system_clock::now(), // timeFinished
|
std::chrono::system_clock::now(), // timeFinished
|
||||||
std::chrono::duration<float, std::milli>(
|
std::chrono::duration<float, std::milli>(
|
||||||
work.getDuration().value()) // activeDuration
|
work.getDuration()) // activeDuration
|
||||||
));
|
));
|
||||||
|
|
||||||
lock.lock();
|
lock.lock();
|
||||||
@ -1584,19 +1579,19 @@ c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupNCCL::WorkNCCL::
|
|||||||
return future_;
|
return future_;
|
||||||
}
|
}
|
||||||
|
|
||||||
c10::optional<float> ProcessGroupNCCL::WorkNCCL::getDuration() const {
|
float ProcessGroupNCCL::WorkNCCL::getDuration() const {
|
||||||
if (!timingEnabled_ || !((*ncclEndEvents_)[0].query())) {
|
TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled")
|
||||||
return c10::optional<float>();
|
|
||||||
}
|
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
ncclStartEvents_->size() == 1,
|
ncclStartEvents_->size() == 1,
|
||||||
"getDuration only works for single device per ProcessGroup.");
|
"getDuration only works for single device per ProcessGroup.");
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
ncclEndEvents_->size() == 1,
|
ncclEndEvents_->size() == 1,
|
||||||
"getDuration only works for single device per ProcessGroup.");
|
"getDuration only works for single device per ProcessGroup.");
|
||||||
|
TORCH_CHECK(
|
||||||
|
(*ncclEndEvents_)[0].query(),
|
||||||
|
"getDuration can only be called after work is succeeded.")
|
||||||
return (*ncclStartEvents_)[0].elapsed_time((*ncclEndEvents_)[0]);
|
return (*ncclStartEvents_)[0].elapsed_time((*ncclEndEvents_)[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t ProcessGroupNCCL::WorkNCCL::getSequencenumber() const {
|
uint64_t ProcessGroupNCCL::WorkNCCL::getSequencenumber() const {
|
||||||
return seq_;
|
return seq_;
|
||||||
}
|
}
|
||||||
|
@ -167,7 +167,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
|||||||
// Get a Future object that will be marked as completed internally.
|
// Get a Future object that will be marked as completed internally.
|
||||||
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
|
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
|
||||||
|
|
||||||
c10::optional<float> getDuration() const override;
|
float getDuration() const override;
|
||||||
|
|
||||||
uint64_t getSequencenumber() const override;
|
uint64_t getSequencenumber() const override;
|
||||||
|
|
||||||
@ -615,10 +615,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
|||||||
void runHookLoop();
|
void runHookLoop();
|
||||||
|
|
||||||
// Desync debug helper
|
// Desync debug helper
|
||||||
void logWorkStart(WorkNCCL& work, bool emitDesyncInfo);
|
void logWorkStart(WorkNCCL& work);
|
||||||
|
|
||||||
// Desync debug helper
|
// Desync debug helper
|
||||||
void logWorkEnd(WorkNCCL& work, bool emitDesyncInfo);
|
void logWorkEnd(WorkNCCL& work);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
static const int64_t kWatchdogThreadSleepMillis;
|
static const int64_t kWatchdogThreadSleepMillis;
|
||||||
|
@ -127,8 +127,8 @@ void Work::finishAndThrow(std::exception_ptr exception) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c10::optional<float> Work::getDuration() const {
|
float Work::getDuration() const {
|
||||||
return c10::optional<float>();
|
TORCH_CHECK(false, "This Backend doesn't support getDuration.");
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t Work::getSequencenumber() const {
|
uint64_t Work::getSequencenumber() const {
|
||||||
|
@ -107,7 +107,7 @@ class TORCH_API Work : public torch::CustomClassHolder {
|
|||||||
// work. Only NCCL backend is currently supported.
|
// work. Only NCCL backend is currently supported.
|
||||||
virtual c10::intrusive_ptr<c10::ivalue::Future> getFuture();
|
virtual c10::intrusive_ptr<c10::ivalue::Future> getFuture();
|
||||||
|
|
||||||
virtual c10::optional<float> getDuration() const;
|
virtual float getDuration() const;
|
||||||
|
|
||||||
virtual uint64_t getSequencenumber() const;
|
virtual uint64_t getSequencenumber() const;
|
||||||
|
|
||||||
|
@ -32,7 +32,6 @@
|
|||||||
|
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
#include <pybind11/chrono.h>
|
#include <pybind11/chrono.h>
|
||||||
#include <torch/csrc/distributed/c10d/Hooks.hpp>
|
|
||||||
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
|
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
|
||||||
|
|
||||||
#include <torch/csrc/distributed/c10d/comm.hpp>
|
#include <torch/csrc/distributed/c10d/comm.hpp>
|
||||||
@ -292,26 +291,6 @@ void _register_builtin_comm_hook(
|
|||||||
reducer.register_builtin_comm_hook(comm_hook_type);
|
reducer.register_builtin_comm_hook(comm_hook_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
py::object c10d_dequeue_python_event() {
|
|
||||||
::c10d::details::EventInfo evt;
|
|
||||||
if (!::c10d::details::dequeue_c10d_event(evt)) {
|
|
||||||
return py::none();
|
|
||||||
}
|
|
||||||
|
|
||||||
py::dict data;
|
|
||||||
data["event_kind"] = (int)evt.event_kind;
|
|
||||||
|
|
||||||
data["pg_name"] = evt.pg_name;
|
|
||||||
data["backend"] = evt.backend;
|
|
||||||
data["sequence_number"] = evt.sequence_number;
|
|
||||||
data["operation"] = evt.operation;
|
|
||||||
data["timestamp"] = evt.timestamp;
|
|
||||||
data["duration"] = evt.duration_ms.value_or(-1);
|
|
||||||
data["drop_count"] = evt.drop_count;
|
|
||||||
|
|
||||||
return std::move(data);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Customize the metaclass of ::c10d::ReduceOp for the backward compatibility.
|
// Customize the metaclass of ::c10d::ReduceOp for the backward compatibility.
|
||||||
// https://github.com/pytorch/pytorch/pull/84243 changed ::c10d::ReduceOp to
|
// https://github.com/pytorch/pytorch/pull/84243 changed ::c10d::ReduceOp to
|
||||||
// struct from enum, sacrificing some of the Python built-in function supports
|
// struct from enum, sacrificing some of the Python built-in function supports
|
||||||
@ -661,11 +640,6 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
|
|||||||
&::c10d::Logger::set_static_graph,
|
&::c10d::Logger::set_static_graph,
|
||||||
py::call_guard<py::gil_scoped_release>());
|
py::call_guard<py::gil_scoped_release>());
|
||||||
|
|
||||||
py::enum_<::c10d::EventKind>(module, "EventKind", R"(
|
|
||||||
An enum for collective hooks event types.)")
|
|
||||||
.value("START", ::c10d::EventKind::CollectiveStart)
|
|
||||||
.value("END", ::c10d::EventKind::CollectiveEnd);
|
|
||||||
|
|
||||||
py::enum_<::c10d::DebugLevel>(module, "DebugLevel", R"(
|
py::enum_<::c10d::DebugLevel>(module, "DebugLevel", R"(
|
||||||
An enum whose values correspond to different debug levels of the
|
An enum whose values correspond to different debug levels of the
|
||||||
torch.distributed package. Currently supporting OFF, INFO, and DETAIL,
|
torch.distributed package. Currently supporting OFF, INFO, and DETAIL,
|
||||||
@ -689,16 +663,7 @@ An enum for collective hooks event types.)")
|
|||||||
"set_debug_level_from_env",
|
"set_debug_level_from_env",
|
||||||
::c10d::setDebugLevelFromEnvironment,
|
::c10d::setDebugLevelFromEnvironment,
|
||||||
R"(Sets the debug level of the torch.distributed package from the
|
R"(Sets the debug level of the torch.distributed package from the
|
||||||
``TORCH_DISTRIBUTED_DEBUG`` environment variable.)")
|
``TORCH_DISTRIBUTED_DEBUG`` environment variable.)");
|
||||||
.def(
|
|
||||||
"_enable_event_collection",
|
|
||||||
&::c10d::enable_event_collection,
|
|
||||||
"(Enables events collection).",
|
|
||||||
py::call_guard<py::gil_scoped_release>())
|
|
||||||
.def(
|
|
||||||
"_dequeue_c10d_event",
|
|
||||||
&c10d_dequeue_python_event,
|
|
||||||
"(Blocks until a c10d event is available and return it as a python dictionary).");
|
|
||||||
|
|
||||||
// TODO(crcrpar): Hardening `ReduceOp`.
|
// TODO(crcrpar): Hardening `ReduceOp`.
|
||||||
// While keeping most op types as enum value,
|
// While keeping most op types as enum value,
|
||||||
|
@ -1,183 +0,0 @@
|
|||||||
import logging
|
|
||||||
import os
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Callable, Dict, List, Optional
|
|
||||||
|
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
if dist.is_available():
|
|
||||||
import torch.distributed.distributed_c10d as c10d
|
|
||||||
|
|
||||||
from torch._C._distributed_c10d import (
|
|
||||||
_dequeue_c10d_event,
|
|
||||||
_enable_event_collection,
|
|
||||||
EventKind,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"CollectiveStatus",
|
|
||||||
"COLLECTIVE_HOOK_TYPE",
|
|
||||||
"PG_HOOK_TYPE",
|
|
||||||
"register_collective_start_hook",
|
|
||||||
"register_collective_end_hook",
|
|
||||||
"register_process_group_hook",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CollectiveStatus:
|
|
||||||
r"""Status of a collective operation.
|
|
||||||
|
|
||||||
Drop count indicates events have been dropped at the producer side, which means they
|
|
||||||
were not consumed fast enough by hooks.
|
|
||||||
"""
|
|
||||||
pg_name: str = "unknown" # This name matches the one informed in the pgreg cb
|
|
||||||
backend: str = "unknown" # Name of the backend used
|
|
||||||
sequence_number: int = -1 # This name matches the one informed in the pgreg cb
|
|
||||||
operation: str = "unknown" # collective name
|
|
||||||
timestamp: int = 0 # timestamp to the earliest time we noticed this event
|
|
||||||
duration: Optional[float] = None # value in milliseconds it took executing
|
|
||||||
drop_count: int = 0 # number of events dropped following this one
|
|
||||||
|
|
||||||
|
|
||||||
COLLECTIVE_HOOK_TYPE = Callable[[CollectiveStatus], None]
|
|
||||||
PG_HOOK_TYPE = Callable[[dist.ProcessGroup, str], None]
|
|
||||||
|
|
||||||
# This controls the number of internal failures we'll tolerate before giving up
|
|
||||||
_MAX_INTERNAL_FAILURES = 10
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
_cb_thread: Optional[threading.Thread] = None
|
|
||||||
_start_callbacks: List[COLLECTIVE_HOOK_TYPE] = []
|
|
||||||
_end_callbacks: List[COLLECTIVE_HOOK_TYPE] = []
|
|
||||||
_pp_r = -1
|
|
||||||
_pp_w = -1
|
|
||||||
|
|
||||||
|
|
||||||
def _c10d_pg_hooks_loops():
|
|
||||||
internal_failures = 0
|
|
||||||
while True:
|
|
||||||
# we don't care about the result, this is how we implement notification
|
|
||||||
_ = os.read(_pp_r, 1)
|
|
||||||
evt: Dict[str, object] = _dequeue_c10d_event()
|
|
||||||
try:
|
|
||||||
event_kind = evt.pop("event_kind", None)
|
|
||||||
if event_kind is None:
|
|
||||||
logger.warning(
|
|
||||||
"c10d returned event dictionary %s without 'event_kind' key, cannot dispatch",
|
|
||||||
evt,
|
|
||||||
)
|
|
||||||
internal_failures += 1
|
|
||||||
if internal_failures >= _MAX_INTERNAL_FAILURES:
|
|
||||||
logger.warning(
|
|
||||||
"too many internal c10d failures processing callback loop. stopping"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
time.sleep(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if event_kind == int(EventKind.START): # type: ignore[call-overload]
|
|
||||||
cb_list = _start_callbacks
|
|
||||||
elif event_kind == int(EventKind.END): # type: ignore[call-overload]
|
|
||||||
cb_list = _end_callbacks
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"c10d event %s with invalid 'event_kind' with value %d",
|
|
||||||
evt,
|
|
||||||
event_kind,
|
|
||||||
)
|
|
||||||
internal_failures += 1
|
|
||||||
if internal_failures >= _MAX_INTERNAL_FAILURES:
|
|
||||||
logger.warning(
|
|
||||||
"too many internal c10d failures processing callback loop. stopping"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
time.sleep(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
status = CollectiveStatus(**evt) # type: ignore[arg-type]
|
|
||||||
for cb in cb_list:
|
|
||||||
try:
|
|
||||||
cb(status)
|
|
||||||
except Exception as e:
|
|
||||||
logger.info(
|
|
||||||
"c10d event callback %s with event %s threw exception %s",
|
|
||||||
cb,
|
|
||||||
status,
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
# We have to keep processing otherwise the queue will grown infinitely large
|
|
||||||
logger.warning(
|
|
||||||
"c10d callback thread when processing event %s raised exception %s.",
|
|
||||||
evt,
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
internal_failures += 1
|
|
||||||
if internal_failures >= _MAX_INTERNAL_FAILURES:
|
|
||||||
logger.warning(
|
|
||||||
"too many internal c10d failures processing callback loop. stopping"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Sleep for a second to avoid hogging the GIL in case of a persistent failure
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
|
|
||||||
def _lazy_init():
|
|
||||||
global _cb_thread
|
|
||||||
if _cb_thread is not None:
|
|
||||||
return
|
|
||||||
global _pp_r
|
|
||||||
global _pp_w
|
|
||||||
_pp_r, _pp_w = os.pipe()
|
|
||||||
_enable_event_collection(_pp_w)
|
|
||||||
c10d._enable_collectives_timing()
|
|
||||||
_cb_thread = threading.Thread(target=_c10d_pg_hooks_loops, daemon=True)
|
|
||||||
_cb_thread.start()
|
|
||||||
logger.info("c10d::hooks thread enabled")
|
|
||||||
|
|
||||||
|
|
||||||
def _check_distributed_available():
|
|
||||||
if not dist.is_available():
|
|
||||||
raise RuntimeError(
|
|
||||||
"torch.distributed is not available, so hooks are not available."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def register_collective_start_hook(hook: COLLECTIVE_HOOK_TYPE) -> None:
|
|
||||||
r"""Register a hook that is called every time a collective starts.
|
|
||||||
|
|
||||||
The hook is invoked on a background thread.
|
|
||||||
Exceptions raised by the callback are ignored and non-fatal.
|
|
||||||
"""
|
|
||||||
_check_distributed_available()
|
|
||||||
_start_callbacks.append(hook)
|
|
||||||
_lazy_init()
|
|
||||||
|
|
||||||
|
|
||||||
def register_collective_end_hook(hook: COLLECTIVE_HOOK_TYPE) -> None:
|
|
||||||
r"""Register a hook that is called every time a collective finishes.
|
|
||||||
|
|
||||||
|
|
||||||
The hook is invoked on a background thread.
|
|
||||||
Exceptions raised by the callback are ignored and non-fatal.
|
|
||||||
"""
|
|
||||||
_check_distributed_available()
|
|
||||||
_end_callbacks.append(hook)
|
|
||||||
_lazy_init()
|
|
||||||
|
|
||||||
|
|
||||||
def register_process_group_hook(hook: PG_HOOK_TYPE) -> None:
|
|
||||||
r"""Register a hook that is called every time a process group is created on this rank.
|
|
||||||
|
|
||||||
This hook is only invoked if the current rank is part of the PG being created.
|
|
||||||
The pg_name is unique to the whole cluster and should be treated as an opaque identified subject to change.
|
|
||||||
|
|
||||||
The hook is invoked on a background thread.
|
|
||||||
Exceptions raised by the callback are ignored and non-fatal.
|
|
||||||
"""
|
|
||||||
_check_distributed_available()
|
|
||||||
c10d._register_creation_hook(hook)
|
|
@ -421,19 +421,6 @@ _group_count = 0
|
|||||||
_tags_to_pg: Dict[str, List[ProcessGroup]] = {}
|
_tags_to_pg: Dict[str, List[ProcessGroup]] = {}
|
||||||
_pg_to_tag: Dict[ProcessGroup, str] = {}
|
_pg_to_tag: Dict[ProcessGroup, str] = {}
|
||||||
|
|
||||||
class _HookState:
|
|
||||||
def __init__(self):
|
|
||||||
self.creation_hooks = []
|
|
||||||
|
|
||||||
def register_creation_hook(self, hook) -> None:
|
|
||||||
self.creation_hooks.append(hook)
|
|
||||||
|
|
||||||
def fire_creation_hook(self, pg, name) -> None:
|
|
||||||
for hook in self.creation_hooks:
|
|
||||||
try:
|
|
||||||
hook(pg, name)
|
|
||||||
except Exception as e:
|
|
||||||
logger.info("hook %s failed with %s", hook, e)
|
|
||||||
|
|
||||||
class _World:
|
class _World:
|
||||||
"""
|
"""
|
||||||
@ -447,8 +434,6 @@ class _World:
|
|||||||
self._default_pg = None
|
self._default_pg = None
|
||||||
self._pg_coalesce_state: Dict[ProcessGroup, List[Union[_CollOp, P2POp]]] = {}
|
self._pg_coalesce_state: Dict[ProcessGroup, List[Union[_CollOp, P2POp]]] = {}
|
||||||
self._pg_default_device: Dict[ProcessGroup, torch.device] = {}
|
self._pg_default_device: Dict[ProcessGroup, torch.device] = {}
|
||||||
self._hook_state = _HookState()
|
|
||||||
self.enable_collectives_timing = False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def default_pg(self):
|
def default_pg(self):
|
||||||
@ -558,9 +543,6 @@ class _World:
|
|||||||
)
|
)
|
||||||
return config_info
|
return config_info
|
||||||
|
|
||||||
@property
|
|
||||||
def pg_hook_state(self) -> _HookState:
|
|
||||||
return self._hook_state
|
|
||||||
|
|
||||||
_world = _World()
|
_world = _World()
|
||||||
"""Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it"""
|
"""Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it"""
|
||||||
@ -1382,9 +1364,6 @@ def _new_process_group_helper(
|
|||||||
pg._set_group_name(group_name)
|
pg._set_group_name(group_name)
|
||||||
|
|
||||||
_world.pg_backend_config[pg] = str(backend_config)
|
_world.pg_backend_config[pg] = str(backend_config)
|
||||||
if _world.enable_collectives_timing:
|
|
||||||
pg._enable_collectives_timing()
|
|
||||||
|
|
||||||
# "" is the default tag for user PGs
|
# "" is the default tag for user PGs
|
||||||
if pg_tag in [None, ""]:
|
if pg_tag in [None, ""]:
|
||||||
pg_tag = f"ptd:{group_name}"
|
pg_tag = f"ptd:{group_name}"
|
||||||
@ -1394,8 +1373,6 @@ def _new_process_group_helper(
|
|||||||
|
|
||||||
_world.tags_to_pg.setdefault(pg_tag, []).append(pg)
|
_world.tags_to_pg.setdefault(pg_tag, []).append(pg)
|
||||||
_world.pg_to_tag[pg] = pg_tag
|
_world.pg_to_tag[pg] = pg_tag
|
||||||
_world.pg_hook_state.fire_creation_hook(pg, group_name)
|
|
||||||
|
|
||||||
return pg, prefix_store
|
return pg, prefix_store
|
||||||
|
|
||||||
def destroy_process_group(group: Optional[ProcessGroup] = None):
|
def destroy_process_group(group: Optional[ProcessGroup] = None):
|
||||||
@ -4338,11 +4315,3 @@ dynamo_unsupported_distributed_c10d_ops = [
|
|||||||
reduce_scatter_tensor,
|
reduce_scatter_tensor,
|
||||||
send,
|
send,
|
||||||
]
|
]
|
||||||
|
|
||||||
def _register_creation_hook(hook):
|
|
||||||
_world.pg_hook_state.register_creation_hook(hook)
|
|
||||||
|
|
||||||
def _enable_collectives_timing():
|
|
||||||
_world.enable_collectives_timing = True
|
|
||||||
for pg in _world.pg_map:
|
|
||||||
pg._enable_collectives_timing()
|
|
||||||
|
@ -401,11 +401,6 @@ class WorldData:
|
|||||||
class ThreadLocalWorld:
|
class ThreadLocalWorld:
|
||||||
_world = threading.local()
|
_world = threading.local()
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.enable_collectives_timing = False
|
|
||||||
self._hook_state = dist.distributed_c10d._HookState()
|
|
||||||
|
|
||||||
|
|
||||||
def _get_world(self) -> WorldData:
|
def _get_world(self) -> WorldData:
|
||||||
if not hasattr(ThreadLocalWorld._world, "world"):
|
if not hasattr(ThreadLocalWorld._world, "world"):
|
||||||
ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0, {}, {}, {}, {})
|
ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0, {}, {}, {}, {})
|
||||||
@ -459,9 +454,6 @@ class ThreadLocalWorld:
|
|||||||
def pg_default_device(self) -> Dict[dist.ProcessGroup, torch.device]:
|
def pg_default_device(self) -> Dict[dist.ProcessGroup, torch.device]:
|
||||||
return self._get_world().pg_default_device
|
return self._get_world().pg_default_device
|
||||||
|
|
||||||
@property
|
|
||||||
def pg_hook_state(self) -> dist.distributed_c10d._HookState:
|
|
||||||
return self._hook_state
|
|
||||||
|
|
||||||
_old_pg_world = None
|
_old_pg_world = None
|
||||||
_ctx_manager = None
|
_ctx_manager = None
|
||||||
|
Reference in New Issue
Block a user