[c10d] Fix new_subgroups(group=) bug (#153798)

Summary: The bug, introduced in https://github.com/pytorch/pytorch/pull/152765, was caused by passing the `group` parameter to the `get_rank()` function, which caused the function to return the rank of the entire group instead of the rank of the current process. The fix involves removing the `group` parameter from the `get_rank()` function call.

Test Plan: contbuild & OSS CI

Differential Revision: D74964213

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153798
Approved by: https://github.com/Skylion007
This commit is contained in:
Tsung-Hsien Lee
2025-05-19 17:01:10 +00:00
committed by PyTorch MergeBot
parent b0e5402377
commit 6487ea30b3
2 changed files with 35 additions and 1 deletions

View File

@ -5455,7 +5455,7 @@ def new_subgroups(
)
subgroups.append(subgroup)
if rank := get_rank(group=group) in ranks_in_subgroup:
if rank := get_rank() in ranks_in_subgroup:
cur_subgroup = subgroup
logger.info("Rank %s is assigned to subgroup %s", rank, ranks_in_subgroup)

View File

@ -915,6 +915,40 @@ class DistributedTest:
for subgroup in subgroups:
dist.destroy_process_group(subgroup)
@skip_but_pass_in_sandcastle_if(
BACKEND not in DistTestCases.backend_feature["subgroup"],
f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
)
@require_world_size(4)
@skip_if_lt_x_gpu(4)
def test_new_subgroups_with_group_param(self):
# Initialize global test environment
self._init_global_test()
# Set up GPU devices for each rank
init_multigpu_helper(dist.get_world_size(), BACKEND)
# Create two subgroups: one with ranks [0,2] and another with ranks [1,3]
cur_subgroup, subgroups = dist.new_subgroups_by_enumeration(
ranks_per_subgroup_list=[[0, 2], [1, 3]]
)
# Further divide the current subgroup into sub-subgroups of size 1
cur_sub_subgroup, sub_subgroups = dist.new_subgroups(
group_size=1, group=cur_subgroup
)
# Verify we have 2 sub-subgroups (one for each rank in the original subgroup)
self.assertEqual(len(sub_subgroups), 2)
# Verify the current process's sub-subgroup has size 1
self.assertEqual(cur_sub_subgroup.size(), 1)
# Verify the current process is in its assigned sub-subgroup
self.assertFalse(dist._rank_not_in_group(group=cur_sub_subgroup))
# Clean up by destroying all created process groups
for sub_subgroup in sub_subgroups:
dist.destroy_process_group(sub_subgroup)
for subgroup in subgroups:
dist.destroy_process_group(subgroup)
@skip_but_pass_in_sandcastle_if(
BACKEND not in DistTestCases.backend_feature["subgroup"],
f"The {BACKEND} backend does not support creating subgroups on CUDA devices",