mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit 314a502eb04c6382e2cc9af0573533efba54109d. Changes since original PR: Reland 1 * rename torch.distributed.hooks to torch.distributed._hooks Reland 2 * make _hooks importable even if !distributed.is_available() * handle cuda driver exit intermittent failure caused by new cuda api usage in callback caller (see prev PR in stack) (original PR https://github.com/pytorch/pytorch/pull/108815 desc copied below) Expose a set of observability hooks into C10D such that our users can detect collectives failure both faster and more easily. The design is similar to NCCL desync debug that it minimized the overhead by doing most of the work out of the main thread. This PR introduces a new module torch.distributed.hooks that exposes the following set of methods: register_collective_start_hook register_collective_end_hook register_process_group_hook The process group hook exposes PG creation on the member ranks and call them inline from the the PG creation code. This is fine since this happens during initialization and a limited number of times. The collective start/end hooks are fired from a single background thread. It reads events from a C++ queue and dispatches over. Queue notification is oddly done using a pipe, this is needed so python can abort the thread on shutdown and have it as background thread. This is not possible with more reasonable choices like a condvar. Pull Request resolved: https://github.com/pytorch/pytorch/pull/111072 Approved by: https://github.com/malfet ghstack dependencies: #111061
271 lines
7.3 KiB
Python
271 lines
7.3 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
import threading
|
|
from functools import partial, wraps
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.distributed._hooks as dhooks
|
|
|
|
if not dist.is_available():
|
|
print("torch.distributed not available, skipping tests", file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
|
|
from torch.testing._internal.common_distributed import (
|
|
MultiProcessTestCase,
|
|
skip_if_lt_x_gpu,
|
|
)
|
|
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
|
|
|
|
class PgHooks(MultiProcessTestCase):
|
|
@property
|
|
def world_size(self) -> int:
|
|
return 4
|
|
|
|
def setUp(self) -> None:
|
|
super().setUp()
|
|
self._spawn_processes()
|
|
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
try:
|
|
os.remove(self.file_name)
|
|
except OSError:
|
|
pass
|
|
|
|
def test_pg_hook(self):
|
|
pgs = []
|
|
|
|
def pg_hook(pg, pg_name):
|
|
pgs.append((pg, pg_name))
|
|
|
|
dhooks.register_process_group_hook(pg_hook)
|
|
dist.init_process_group(
|
|
backend="gloo",
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
store=dist.FileStore(self.file_name, self.world_size),
|
|
)
|
|
self.assertEqual(len(pgs), 1)
|
|
self.assertEqual(pgs[0][0], dist.group.WORLD)
|
|
|
|
# create two partial world PGs
|
|
pg0 = dist.new_group(ranks=[0, 1])
|
|
pg1 = dist.new_group(ranks=[2, 3])
|
|
|
|
# Each rank only observe two PGs being created: the default PG and one covering its ranks
|
|
# We don't emit events for PG creation if the current rank doesn't belong to it.
|
|
# For example, say you're rank 1, you'll get an event for pg0 but not pg1 even though the API contact
|
|
# dictates you need to call new_group for both.
|
|
self.assertEqual(len(pgs), 2)
|
|
self.assertEqual(pgs[1][0], pg0 if self.rank < 2 else pg1)
|
|
|
|
|
|
def with_comms(func=None):
|
|
if func is None:
|
|
return partial(
|
|
with_comms,
|
|
)
|
|
|
|
@wraps(func)
|
|
def wrapper(self, *args, **kwargs):
|
|
self.init_comms()
|
|
func(self, *args, **kwargs)
|
|
self.destroy_comms()
|
|
|
|
return wrapper
|
|
|
|
|
|
class CollectiveHooks:
|
|
@property
|
|
def world_size(self) -> int:
|
|
return 4
|
|
|
|
def _collective_hooks(self):
|
|
# it's ok to access them directly since there's a single bg thread poking at them.
|
|
starts = []
|
|
ends = []
|
|
cv = threading.Condition()
|
|
|
|
def coll_start(status):
|
|
starts.append(status)
|
|
print(f"col_start {len(starts)} rank{self.rank}")
|
|
|
|
def coll_end(status):
|
|
ends.append(status)
|
|
print(f"col_end {len(ends)} rank{self.rank}")
|
|
if len(ends) == 2:
|
|
with cv:
|
|
cv.notify()
|
|
|
|
dhooks.register_collective_start_hook(coll_start)
|
|
dhooks.register_collective_end_hook(coll_end)
|
|
|
|
tensor = torch.ones([2, 3]).to(self.device) * self.rank
|
|
tensor_list = [torch.empty_like(tensor) for _ in range(self.world_size)]
|
|
|
|
dist.all_gather(tensor_list, tensor)
|
|
|
|
tensor2 = torch.ones([2, 3]).to(self.device) * self.rank
|
|
dist.all_reduce(tensor2)
|
|
|
|
with cv:
|
|
cv.wait(1)
|
|
|
|
default_pg_name = dist.group.WORLD.group_name
|
|
self.assertEqual(2, len(starts))
|
|
self.assertEqual(2, len(ends))
|
|
|
|
def check_op(idx, coll_name):
|
|
self.assertEqual(default_pg_name, starts[idx].pg_name)
|
|
self.assertEqual(self.backend_name, starts[idx].backend)
|
|
self.assertGreaterEqual(starts[idx].sequence_number, 0)
|
|
self.assertGreaterEqual(starts[idx].timestamp, 0)
|
|
self.assertEqual(coll_name, starts[idx].operation)
|
|
|
|
self.assertEqual(default_pg_name, ends[idx].pg_name)
|
|
self.assertEqual(self.backend_name, ends[idx].backend)
|
|
|
|
self.assertEqual(starts[idx].sequence_number, ends[idx].sequence_number)
|
|
self.assertLessEqual(starts[idx].timestamp, ends[idx].timestamp)
|
|
self.assertEqual(coll_name, ends[idx].operation)
|
|
|
|
check_op(0, "ALLGATHER")
|
|
check_op(1, "ALLREDUCE")
|
|
|
|
|
|
class GlooHooks(MultiProcessTestCase, CollectiveHooks):
|
|
def setUp(self) -> None:
|
|
super().setUp()
|
|
self._spawn_processes()
|
|
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
try:
|
|
os.remove(self.file_name)
|
|
except OSError:
|
|
pass
|
|
|
|
def init_comms(self):
|
|
dist.init_process_group(
|
|
backend="gloo",
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
store=dist.FileStore(self.file_name, self.world_size),
|
|
)
|
|
|
|
def destroy_comms(self):
|
|
dist.destroy_process_group()
|
|
|
|
@property
|
|
def backend_name(self):
|
|
return "gloo"
|
|
|
|
@property
|
|
def device(self):
|
|
return "cpu"
|
|
|
|
@with_comms
|
|
def test_collective_hooks(self):
|
|
self._collective_hooks()
|
|
|
|
|
|
class NcclHooks(MultiProcessTestCase, CollectiveHooks):
|
|
def setUp(self) -> None:
|
|
super().setUp()
|
|
self._spawn_processes()
|
|
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
try:
|
|
os.remove(self.file_name)
|
|
except OSError:
|
|
pass
|
|
|
|
def init_comms(self):
|
|
dist.init_process_group(
|
|
backend="nccl",
|
|
rank=self.rank,
|
|
world_size=self.world_size,
|
|
store=dist.FileStore(self.file_name, self.world_size),
|
|
)
|
|
|
|
def destroy_comms(self):
|
|
dist.destroy_process_group()
|
|
|
|
@property
|
|
def backend_name(self):
|
|
return "nccl"
|
|
|
|
@property
|
|
def device(self):
|
|
return f"cuda:{self.rank}"
|
|
|
|
@skip_if_lt_x_gpu(4)
|
|
@with_comms
|
|
def test_collective_hooks(self):
|
|
self._collective_hooks()
|
|
|
|
|
|
class SingleRankTests(TestCase):
|
|
def setUp(self) -> None:
|
|
super().setUp()
|
|
self.rank = 0
|
|
self.file_name = tempfile.NamedTemporaryFile(delete=False).name
|
|
dist.init_process_group(
|
|
backend="gloo",
|
|
rank=0,
|
|
world_size=1,
|
|
store=dist.FileStore(self.file_name, 1),
|
|
)
|
|
|
|
def tearDown(self) -> None:
|
|
dist.destroy_process_group()
|
|
|
|
def test_queue_overflow(self) -> None:
|
|
cv_done_colls = threading.Condition()
|
|
cv_done_cb = threading.Condition()
|
|
colls_done = False
|
|
starts = []
|
|
status_with_dropped = None
|
|
|
|
def coll_start(status: dhooks.CollectiveStatus):
|
|
starts.append(status)
|
|
with cv_done_colls:
|
|
while not colls_done:
|
|
cv_done_colls.wait()
|
|
if status.drop_count > 0:
|
|
nonlocal status_with_dropped
|
|
status_with_dropped = status
|
|
with cv_done_cb:
|
|
cv_done_cb.notify()
|
|
|
|
dhooks.register_collective_start_hook(coll_start)
|
|
|
|
# native limit is 512
|
|
for i in range(600):
|
|
dist.all_reduce(torch.ones([2, 3]))
|
|
colls_done = True
|
|
with cv_done_colls:
|
|
cv_done_colls.notify()
|
|
|
|
with cv_done_cb:
|
|
cv_done_cb.wait(10)
|
|
|
|
self.assertTrue(status_with_dropped is not None)
|
|
self.assertTrue(status_with_dropped.drop_count > 0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
assert (
|
|
not torch.cuda._initialized
|
|
), "test_distributed must not have initialized CUDA context on main process"
|
|
|
|
run_tests()
|