mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
PyProcessGroup: support rank, world size, group name/desc overrides (#141529)
This improves `PyProcessGroup` so you can override rank, world size and group name/desc methods from Python. These will be needed to support resizable process groups in torchft. This also has some small fixes in test_c10d_pypg.py to use threads instead of processes which speeds up the test execution by ~10x. Test plan: ``` pytest test/distributed/test_c10d_pypg.py ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/141529 Approved by: https://github.com/fegin
This commit is contained in:
committed by
PyTorch MergeBot
parent
5696df439b
commit
9f4f061f89
@ -323,7 +323,7 @@ class CommonDistributedDataParallelTest:
|
||||
# Use this hack to remove files for that test
|
||||
try:
|
||||
os.remove(self.file_name)
|
||||
except OSError:
|
||||
except (OSError, AttributeError):
|
||||
pass
|
||||
|
||||
@property
|
||||
|
@ -1,6 +1,5 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import os
|
||||
import weakref
|
||||
|
||||
import test_c10d_common
|
||||
@ -11,8 +10,8 @@ import torch.nn as nn
|
||||
from torch._C._distributed_c10d import _create_work_from_future
|
||||
from torch.futures import Future
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing._internal.common_distributed import MultiProcessTestCase
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.common_distributed import MultiThreadedTestCase
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
def create_work(result):
|
||||
@ -80,7 +79,7 @@ class LonelyRankProcessGroup(dist.ProcessGroup):
|
||||
self._work.append(res)
|
||||
return res
|
||||
|
||||
def size(self):
|
||||
def getSize(self):
|
||||
return self._world
|
||||
|
||||
def getBackendName(self):
|
||||
@ -90,23 +89,39 @@ class LonelyRankProcessGroup(dist.ProcessGroup):
|
||||
return f"PLG w:{self._world} r:{self._rank}"
|
||||
|
||||
|
||||
class DummyAttrProcessGroup(dist.ProcessGroup):
|
||||
def getRank(self):
|
||||
return 123
|
||||
|
||||
def getSize(self):
|
||||
return 456
|
||||
|
||||
def getBackendName(self):
|
||||
return "dummy-attr"
|
||||
|
||||
def setGroupName(self, name) -> None:
|
||||
self._group_name = "py:" + name
|
||||
|
||||
def getGroupName(self) -> str:
|
||||
return self._group_name
|
||||
|
||||
def setGroupDesc(self, group_desc) -> None:
|
||||
self._group_desc = "py:" + group_desc
|
||||
|
||||
def getGroupDesc(self) -> str:
|
||||
return self._group_desc
|
||||
|
||||
|
||||
# We cannot use parametrize as some tests are defined on the base class and use _get_process_group
|
||||
class AbstractDDPSingleRank(test_c10d_common.CommonDistributedDataParallelTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
self._spawn_threads()
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return 1
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
try:
|
||||
os.remove(self.file_name)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _get_process_group(self):
|
||||
return LonelyRankProcessGroup(self.rank, self.world_size, self.use_wrapper)
|
||||
|
||||
@ -142,17 +157,31 @@ class AbstractDDPSingleRank(test_c10d_common.CommonDistributedDataParallelTest):
|
||||
)
|
||||
|
||||
|
||||
class TestDDPWithWorkSubclass(AbstractDDPSingleRank, MultiProcessTestCase):
|
||||
class TestDDPWithWorkSubclass(AbstractDDPSingleRank, MultiThreadedTestCase):
|
||||
@property
|
||||
def use_wrapper(self):
|
||||
return False
|
||||
|
||||
|
||||
class TestDDPWithWorkWrapper(AbstractDDPSingleRank, MultiProcessTestCase):
|
||||
class TestDDPWithWorkWrapper(AbstractDDPSingleRank, MultiThreadedTestCase):
|
||||
@property
|
||||
def use_wrapper(self):
|
||||
return True
|
||||
|
||||
|
||||
class TestPyProcessGroup(TestCase):
|
||||
def test_attr_overrides(self):
|
||||
pg = DummyAttrProcessGroup(0, 1)
|
||||
self.assertEqual(pg.name(), "dummy-attr")
|
||||
self.assertEqual(pg.rank(), 123)
|
||||
self.assertEqual(pg.size(), 456)
|
||||
|
||||
pg._set_group_name("name")
|
||||
self.assertEqual(pg.group_name, "py:name")
|
||||
|
||||
pg._set_group_desc("desc")
|
||||
self.assertEqual(pg.group_desc, "py:desc")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -125,11 +125,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
int size);
|
||||
~ProcessGroup() override;
|
||||
|
||||
int getRank() const {
|
||||
virtual int getRank() const {
|
||||
return rank_;
|
||||
}
|
||||
|
||||
int getSize() const {
|
||||
virtual int getSize() const {
|
||||
return size_;
|
||||
}
|
||||
|
||||
@ -863,10 +863,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
return getDefaultBackend()->hasHooks();
|
||||
}
|
||||
|
||||
const std::string& getGroupName() const;
|
||||
void setGroupName(const std::string& name);
|
||||
const std::string& getGroupDesc() const;
|
||||
void setGroupDesc(const std::string& name);
|
||||
virtual const std::string& getGroupName() const;
|
||||
virtual void setGroupName(const std::string& name);
|
||||
virtual const std::string& getGroupDesc() const;
|
||||
virtual void setGroupDesc(const std::string& name);
|
||||
void enableCollectivesTiming();
|
||||
|
||||
void release_resources() override;
|
||||
|
@ -59,13 +59,61 @@ class PyProcessGroup : public ProcessGroup {
|
||||
using ProcessGroup::ProcessGroup;
|
||||
|
||||
const std::string getBackendName() const override {
|
||||
PYBIND11_OVERRIDE_PURE(
|
||||
PYBIND11_OVERRIDE(
|
||||
std::string, /* Return type */
|
||||
ProcessGroup, /* Parent class */
|
||||
getBackendName, /* Name of function in C++ */
|
||||
);
|
||||
}
|
||||
|
||||
int getRank() const override {
|
||||
PYBIND11_OVERRIDE(
|
||||
int, /* Return type */
|
||||
ProcessGroup, /* Parent class */
|
||||
getRank, /* Name of function in C++ */
|
||||
);
|
||||
}
|
||||
|
||||
int getSize() const override {
|
||||
PYBIND11_OVERRIDE(
|
||||
int, /* Return type */
|
||||
ProcessGroup, /* Parent class */
|
||||
getSize, /* Name of function in C++ */
|
||||
);
|
||||
}
|
||||
|
||||
const std::string& getGroupName() const override {
|
||||
PYBIND11_OVERRIDE(
|
||||
const std::string&, /* Return type */
|
||||
ProcessGroup, /* Parent class */
|
||||
getGroupName, /* Name of function in C++ */
|
||||
);
|
||||
}
|
||||
|
||||
void setGroupName(const std::string& group_name) override {
|
||||
PYBIND11_OVERRIDE(
|
||||
void, /* Return type */
|
||||
ProcessGroup, /* Parent class */
|
||||
setGroupName, /* Name of function in C++ */
|
||||
group_name);
|
||||
}
|
||||
|
||||
const std::string& getGroupDesc() const override {
|
||||
PYBIND11_OVERRIDE(
|
||||
const std::string&, /* Return type */
|
||||
ProcessGroup, /* Parent class */
|
||||
getGroupDesc, /* Name of function in C++ */
|
||||
);
|
||||
}
|
||||
|
||||
void setGroupDesc(const std::string& group_desc) override {
|
||||
PYBIND11_OVERRIDE(
|
||||
void, /* Return type */
|
||||
ProcessGroup, /* Parent class */
|
||||
setGroupDesc, /* Name of function in C++ */
|
||||
group_desc);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Work> allgather(
|
||||
std::vector<std::vector<at::Tensor>>& outputTensors,
|
||||
std::vector<at::Tensor>& inputTensors,
|
||||
|
Reference in New Issue
Block a user