Fake process group Direct construction error (#163665)

Fixes #162129. Added validation in _rank_not_in_group() to check if ```FakeProcessGroup``` is properly initialized before use, raising a clear error message if ```torch.distributed.init_process_group(backend='fake')``` hasn't been called first.
This prevents silent failures and ensures proper dispatch system integration for all distributed operations.

Added test case test_fake_process_group_direct_usage_error() that validates the error is raised for ```all_reduce``` and ```all_to_all_single``` operations.

Please let me know if additional distributed operators should be tested or if any other updates are needed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163665
Approved by: https://github.com/ezyang
This commit is contained in:
ankushwahaRH
2025-10-02 22:19:23 +00:00
committed by PyTorch MergeBot
parent a34797e031
commit ece5e0f01b
6 changed files with 75 additions and 12 deletions

View File

@ -3812,18 +3812,29 @@ such as `dist.all_reduce(tensor, async_op=True)`.
"error_on_collective",
&::c10d::FakeProcessGroup::Options::error_on_collective);
fakeProcessGroup
.def(
py::init([](int rank,
int size,
c10::intrusive_ptr<::c10d::FakeProcessGroup::Options>
options) {
return c10::make_intrusive<::c10d::FakeProcessGroup>(
.def_static(
"_create_internal",
[](int rank,
int size,
c10::intrusive_ptr<::c10d::FakeProcessGroup::Options> options) {
return ::c10d::FakeProcessGroup::_create_internal(
rank, size, std::move(options));
}),
},
py::arg("rank"),
py::arg("world_size"),
py::arg("options") =
c10::make_intrusive<::c10d::FakeProcessGroup::Options>())
.def(
"__init__",
[](const py::object&,
const py::args& args,
const py::kwargs& kwargs) {
TORCH_CHECK(
false,
"FakeProcessGroup cannot be constructed directly. "
"Use torch.distributed.init_process_group(backend='fake') instead to ensure "
"proper dispatch system integration.");
})
.def_property_readonly(
"options", &::c10d::FakeProcessGroup::getBackendOptions);
auto fakeWork =