dist2: add group context manager (#157988)

This adds new context manager based PG management to dist2. This allows for managing the active process group much in the same way as a stream

```py
with dist2.process_group(pg):
   dist2.current_process_group().allreduce(...).wait()
```

matches

```py
with torch.cuda.stream(stream):
    torch.cuda.current_stream().synchronize()
```

Test plan:

```
pytest test/distributed/test_dist2.py -k context
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157988
Approved by: https://github.com/fduwjj
This commit is contained in:
Tristan Rice
2025-07-10 22:30:15 +00:00
committed by PyTorch MergeBot
parent fca7013f85
commit 83700b4488
6 changed files with 84 additions and 2 deletions

View File

@ -11,7 +11,34 @@ from torch.testing._internal.common_distributed import (
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.common_utils import run_tests, TestCase
class ProcessGroupTest(TestCase):
def test_context_manager(self):
os.environ["RANK"] = str(0)
os.environ["WORLD_SIZE"] = str(1)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
pg1 = dist2.new_group(
backend="gloo", timeout=timedelta(seconds=60), device="cpu", pg_options=None
)
pg2 = dist2.new_group(
backend="gloo", timeout=timedelta(seconds=60), device="cpu", pg_options=None
)
self.assertIsNone(dist2.current_process_group())
with dist2.process_group(pg1):
self.assertIs(dist2.current_process_group(), pg1)
with dist2.process_group(pg2):
self.assertIs(dist2.current_process_group(), pg2)
self.assertIs(dist2.current_process_group(), pg1)
self.assertIsNone(dist2.current_process_group())
class ProcessGroupGlooTest(MultiProcessTestCase):

View File

@ -801,3 +801,6 @@ class ProcessGroupXCCL(Backend):
rank: int,
size: int,
): ...
def _set_process_group(pg: ProcessGroup) -> None: ...
def _current_process_group() -> ProcessGroup: ...

View File

@ -327,4 +327,13 @@ bool allow_inflight_collective_as_graph_input() {
.allow_inflight_collective_as_graph_input();
}
c10::intrusive_ptr<ProcessGroup>& currentProcessGroup() {
thread_local static c10::intrusive_ptr<ProcessGroup> pg = nullptr;
return pg;
}
void setProcessGroup(c10::intrusive_ptr<ProcessGroup> pg) {
currentProcessGroup() = std::move(pg);
}
} // namespace c10d

View File

@ -978,4 +978,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
std::optional<at::Device> bound_device_id_;
};
// Thread local functions for managing the currently active process group.
TORCH_API c10::intrusive_ptr<ProcessGroup>& currentProcessGroup();
TORCH_API void setProcessGroup(c10::intrusive_ptr<ProcessGroup> processGroup);
} // namespace c10d

View File

@ -2568,6 +2568,10 @@ Arguments:
return ivalue.toCustomClass<::c10d::ProcessGroup>();
});
// Thread local process group manipulation
module.def("_set_process_group", &::c10d::setProcessGroup);
module.def("_current_process_group", &::c10d::currentProcessGroup);
py::enum_<::c10d::ProcessGroup::BackendType>(
processGroup,
"BackendType",

View File

@ -7,11 +7,19 @@ This is an experimental new API for PyTorch Distributed. This is actively in dev
This is intended as a proving ground for more flexible and object oriented distributed APIs.
"""
from collections.abc import Generator
from contextlib import contextmanager
from datetime import timedelta
from typing import Protocol, Union
import torch
from torch._C._distributed_c10d import Backend, ProcessGroup, Store
from torch._C._distributed_c10d import (
_current_process_group,
_set_process_group,
Backend,
ProcessGroup,
Store,
)
from torch.distributed.rendezvous import rendezvous
@ -134,3 +142,30 @@ def new_group(
store.set_timeout(timeout)
return _BACKENDS[backend](store, rank, world_size, timeout, device, pg_options)
def current_process_group() -> ProcessGroup:
"""
Get the current process group. Thread local method.
Returns:
The current process group.
"""
return _current_process_group()
@contextmanager
def process_group(pg: ProcessGroup) -> Generator[None, None, None]:
"""
Context manager for process groups. Thread local method.
Args:
pg: The process group to use.
"""
prev_pg = current_process_group()
_set_process_group(pg)
try:
yield
finally:
_set_process_group(prev_pg)