PyProcessGroup: support rank, world size, group name/desc overrides (#141529)

This improves `PyProcessGroup` so you can override rank, world size and group name/desc methods from Python. These will be needed to support resizable process groups in torchft.

This also has some small fixes in test_c10d_pypg.py to use threads instead of processes which speeds up the test execution by ~10x.

Test plan:

```
pytest test/distributed/test_c10d_pypg.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141529
Approved by: https://github.com/fegin
This commit is contained in:
Tristan Rice
2024-11-26 20:56:54 +00:00
committed by PyTorch MergeBot
parent 5696df439b
commit 9f4f061f89
4 changed files with 99 additions and 22 deletions

View File

@ -323,7 +323,7 @@ class CommonDistributedDataParallelTest:
# Use this hack to remove files for that test
try:
os.remove(self.file_name)
except OSError:
except (OSError, AttributeError):
pass
@property

View File

@ -1,6 +1,5 @@
# Owner(s): ["oncall: distributed"]
import os
import weakref
import test_c10d_common
@ -11,8 +10,8 @@ import torch.nn as nn
from torch._C._distributed_c10d import _create_work_from_future
from torch.futures import Future
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing._internal.common_distributed import MultiProcessTestCase
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.common_distributed import MultiThreadedTestCase
from torch.testing._internal.common_utils import run_tests, TestCase
def create_work(result):
@ -80,7 +79,7 @@ class LonelyRankProcessGroup(dist.ProcessGroup):
self._work.append(res)
return res
def size(self):
def getSize(self):
return self._world
def getBackendName(self):
@ -90,23 +89,39 @@ class LonelyRankProcessGroup(dist.ProcessGroup):
return f"PLG w:{self._world} r:{self._rank}"
class DummyAttrProcessGroup(dist.ProcessGroup):
def getRank(self):
return 123
def getSize(self):
return 456
def getBackendName(self):
return "dummy-attr"
def setGroupName(self, name) -> None:
self._group_name = "py:" + name
def getGroupName(self) -> str:
return self._group_name
def setGroupDesc(self, group_desc) -> None:
self._group_desc = "py:" + group_desc
def getGroupDesc(self) -> str:
return self._group_desc
# We cannot use parametrize as some tests are defined on the base class and use _get_process_group
class AbstractDDPSingleRank(test_c10d_common.CommonDistributedDataParallelTest):
def setUp(self):
super().setUp()
self._spawn_processes()
self._spawn_threads()
@property
def world_size(self):
return 1
def tearDown(self):
super().tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
def _get_process_group(self):
return LonelyRankProcessGroup(self.rank, self.world_size, self.use_wrapper)
@ -142,17 +157,31 @@ class AbstractDDPSingleRank(test_c10d_common.CommonDistributedDataParallelTest):
)
class TestDDPWithWorkSubclass(AbstractDDPSingleRank, MultiProcessTestCase):
class TestDDPWithWorkSubclass(AbstractDDPSingleRank, MultiThreadedTestCase):
@property
def use_wrapper(self):
return False
class TestDDPWithWorkWrapper(AbstractDDPSingleRank, MultiProcessTestCase):
class TestDDPWithWorkWrapper(AbstractDDPSingleRank, MultiThreadedTestCase):
@property
def use_wrapper(self):
return True
class TestPyProcessGroup(TestCase):
def test_attr_overrides(self):
pg = DummyAttrProcessGroup(0, 1)
self.assertEqual(pg.name(), "dummy-attr")
self.assertEqual(pg.rank(), 123)
self.assertEqual(pg.size(), 456)
pg._set_group_name("name")
self.assertEqual(pg.group_name, "py:name")
pg._set_group_desc("desc")
self.assertEqual(pg.group_desc, "py:desc")
if __name__ == "__main__":
run_tests()

View File

@ -125,11 +125,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
int size);
~ProcessGroup() override;
int getRank() const {
virtual int getRank() const {
return rank_;
}
int getSize() const {
virtual int getSize() const {
return size_;
}
@ -863,10 +863,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
return getDefaultBackend()->hasHooks();
}
const std::string& getGroupName() const;
void setGroupName(const std::string& name);
const std::string& getGroupDesc() const;
void setGroupDesc(const std::string& name);
virtual const std::string& getGroupName() const;
virtual void setGroupName(const std::string& name);
virtual const std::string& getGroupDesc() const;
virtual void setGroupDesc(const std::string& name);
void enableCollectivesTiming();
void release_resources() override;

View File

@ -59,13 +59,61 @@ class PyProcessGroup : public ProcessGroup {
using ProcessGroup::ProcessGroup;
const std::string getBackendName() const override {
PYBIND11_OVERRIDE_PURE(
PYBIND11_OVERRIDE(
std::string, /* Return type */
ProcessGroup, /* Parent class */
getBackendName, /* Name of function in C++ */
);
}
int getRank() const override {
PYBIND11_OVERRIDE(
int, /* Return type */
ProcessGroup, /* Parent class */
getRank, /* Name of function in C++ */
);
}
int getSize() const override {
PYBIND11_OVERRIDE(
int, /* Return type */
ProcessGroup, /* Parent class */
getSize, /* Name of function in C++ */
);
}
const std::string& getGroupName() const override {
PYBIND11_OVERRIDE(
const std::string&, /* Return type */
ProcessGroup, /* Parent class */
getGroupName, /* Name of function in C++ */
);
}
void setGroupName(const std::string& group_name) override {
PYBIND11_OVERRIDE(
void, /* Return type */
ProcessGroup, /* Parent class */
setGroupName, /* Name of function in C++ */
group_name);
}
const std::string& getGroupDesc() const override {
PYBIND11_OVERRIDE(
const std::string&, /* Return type */
ProcessGroup, /* Parent class */
getGroupDesc, /* Name of function in C++ */
);
}
void setGroupDesc(const std::string& group_desc) override {
PYBIND11_OVERRIDE(
void, /* Return type */
ProcessGroup, /* Parent class */
setGroupDesc, /* Name of function in C++ */
group_desc);
}
c10::intrusive_ptr<Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,