dist2: add group context manager (#157988)

This adds new context manager based PG management to dist2. This allows for managing the active process group much in the same way as a stream

```py
with dist2.process_group(pg):
   dist2.current_process_group().allreduce(...).wait()
```

matches

```py
with torch.cuda.stream(stream):
    torch.cuda.current_stream().synchronize()
```

Test plan:

```
pytest test/distributed/test_dist2.py -k context
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157988
Approved by: https://github.com/fduwjj
This commit is contained in:
Tristan Rice
2025-07-10 22:30:15 +00:00
committed by PyTorch MergeBot
parent fca7013f85
commit 83700b4488
6 changed files with 84 additions and 2 deletions

View File

@ -978,4 +978,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
std::optional<at::Device> bound_device_id_;
};
// Thread local functions for managing the currently active process group.
TORCH_API c10::intrusive_ptr<ProcessGroup>& currentProcessGroup();
TORCH_API void setProcessGroup(c10::intrusive_ptr<ProcessGroup> processGroup);
} // namespace c10d