mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
dist2: add group context manager (#157988)
This adds new context manager based PG management to dist2. This allows for managing the active process group much in the same way as a stream ```py with dist2.process_group(pg): dist2.current_process_group().allreduce(...).wait() ``` matches ```py with torch.cuda.stream(stream): torch.cuda.current_stream().synchronize() ``` Test plan: ``` pytest test/distributed/test_dist2.py -k context ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/157988 Approved by: https://github.com/fduwjj
This commit is contained in:
committed by
PyTorch MergeBot
parent
fca7013f85
commit
83700b4488
@ -11,7 +11,34 @@ from torch.testing._internal.common_distributed import (
|
||||
requires_nccl,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class ProcessGroupTest(TestCase):
|
||||
def test_context_manager(self):
|
||||
os.environ["RANK"] = str(0)
|
||||
os.environ["WORLD_SIZE"] = str(1)
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
os.environ["MASTER_PORT"] = "29500"
|
||||
|
||||
pg1 = dist2.new_group(
|
||||
backend="gloo", timeout=timedelta(seconds=60), device="cpu", pg_options=None
|
||||
)
|
||||
pg2 = dist2.new_group(
|
||||
backend="gloo", timeout=timedelta(seconds=60), device="cpu", pg_options=None
|
||||
)
|
||||
|
||||
self.assertIsNone(dist2.current_process_group())
|
||||
|
||||
with dist2.process_group(pg1):
|
||||
self.assertIs(dist2.current_process_group(), pg1)
|
||||
|
||||
with dist2.process_group(pg2):
|
||||
self.assertIs(dist2.current_process_group(), pg2)
|
||||
|
||||
self.assertIs(dist2.current_process_group(), pg1)
|
||||
|
||||
self.assertIsNone(dist2.current_process_group())
|
||||
|
||||
|
||||
class ProcessGroupGlooTest(MultiProcessTestCase):
|
||||
|
@ -801,3 +801,6 @@ class ProcessGroupXCCL(Backend):
|
||||
rank: int,
|
||||
size: int,
|
||||
): ...
|
||||
|
||||
def _set_process_group(pg: ProcessGroup) -> None: ...
|
||||
def _current_process_group() -> ProcessGroup: ...
|
||||
|
@ -327,4 +327,13 @@ bool allow_inflight_collective_as_graph_input() {
|
||||
.allow_inflight_collective_as_graph_input();
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup>& currentProcessGroup() {
|
||||
thread_local static c10::intrusive_ptr<ProcessGroup> pg = nullptr;
|
||||
return pg;
|
||||
}
|
||||
|
||||
void setProcessGroup(c10::intrusive_ptr<ProcessGroup> pg) {
|
||||
currentProcessGroup() = std::move(pg);
|
||||
}
|
||||
|
||||
} // namespace c10d
|
||||
|
@ -978,4 +978,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
std::optional<at::Device> bound_device_id_;
|
||||
};
|
||||
|
||||
// Thread local functions for managing the currently active process group.
|
||||
TORCH_API c10::intrusive_ptr<ProcessGroup>& currentProcessGroup();
|
||||
TORCH_API void setProcessGroup(c10::intrusive_ptr<ProcessGroup> processGroup);
|
||||
|
||||
} // namespace c10d
|
||||
|
@ -2568,6 +2568,10 @@ Arguments:
|
||||
return ivalue.toCustomClass<::c10d::ProcessGroup>();
|
||||
});
|
||||
|
||||
// Thread local process group manipulation
|
||||
module.def("_set_process_group", &::c10d::setProcessGroup);
|
||||
module.def("_current_process_group", &::c10d::currentProcessGroup);
|
||||
|
||||
py::enum_<::c10d::ProcessGroup::BackendType>(
|
||||
processGroup,
|
||||
"BackendType",
|
||||
|
@ -7,11 +7,19 @@ This is an experimental new API for PyTorch Distributed. This is actively in dev
|
||||
This is intended as a proving ground for more flexible and object oriented distributed APIs.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from datetime import timedelta
|
||||
from typing import Protocol, Union
|
||||
|
||||
import torch
|
||||
from torch._C._distributed_c10d import Backend, ProcessGroup, Store
|
||||
from torch._C._distributed_c10d import (
|
||||
_current_process_group,
|
||||
_set_process_group,
|
||||
Backend,
|
||||
ProcessGroup,
|
||||
Store,
|
||||
)
|
||||
from torch.distributed.rendezvous import rendezvous
|
||||
|
||||
|
||||
@ -134,3 +142,30 @@ def new_group(
|
||||
store.set_timeout(timeout)
|
||||
|
||||
return _BACKENDS[backend](store, rank, world_size, timeout, device, pg_options)
|
||||
|
||||
|
||||
def current_process_group() -> ProcessGroup:
|
||||
"""
|
||||
Get the current process group. Thread local method.
|
||||
|
||||
Returns:
|
||||
The current process group.
|
||||
"""
|
||||
return _current_process_group()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def process_group(pg: ProcessGroup) -> Generator[None, None, None]:
|
||||
"""
|
||||
Context manager for process groups. Thread local method.
|
||||
|
||||
Args:
|
||||
pg: The process group to use.
|
||||
"""
|
||||
prev_pg = current_process_group()
|
||||
|
||||
_set_process_group(pg)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_set_process_group(prev_pg)
|
||||
|
Reference in New Issue
Block a user