Revert "shrink_group implementation to expose ncclCommShrink API (#164518)"

This reverts commit fa0db212e717b6cb225159cb32ea3d83baa52381.

Reverted https://github.com/pytorch/pytorch/pull/164518 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/164518#issuecomment-3419893217))
This commit is contained in:
PyTorch MergeBot
2025-10-19 19:20:44 +00:00
parent fa0db212e7
commit 633a3b7f67
11 changed files with 2 additions and 1503 deletions

View File

@ -394,10 +394,6 @@ an opaque group handle that can be given as a `group` argument to all collective
.. autofunction:: new_group .. autofunction:: new_group
``` ```
```{eval-rst}
.. autofunction:: torch.distributed.distributed_c10d.shrink_group
```
```{eval-rst} ```{eval-rst}
.. autofunction:: get_group_rank .. autofunction:: get_group_rank
``` ```

View File

@ -1,43 +0,0 @@
import logging
import time
_start_time = time.time()
_logger = logging.getLogger(__name__)
def _ts():
return time.time() - _start_time
def configure(level=logging.INFO, force=False):
try:
logging.basicConfig(
level=level,
format="%(asctime)s %(name)s %(levelname)s: %(message)s",
force=force,
)
except TypeError:
logging.basicConfig(
level=level, format="%(asctime)s %(name)s %(levelname)s: %(message)s"
)
def log_test_info(rank, message):
_logger.info("[%7.3fs][Rank %s] %s", _ts(), rank, message)
def log_test_success(rank, message):
_logger.info("[%7.3fs][Rank %s] ✅ %s", _ts(), rank, message)
def log_test_validation(rank, message):
_logger.info("[%7.3fs][Rank %s] ✓ %s", _ts(), rank, message)
def log_test_warning(rank, message):
_logger.warning("[%7.3fs][Rank %s] ⚠️ %s", _ts(), rank, message)
def log_test_error(rank, message):
_logger.error("[%7.3fs][Rank %s] ✗ %s", _ts(), rank, message)

View File

@ -2,7 +2,6 @@
import copy import copy
import json import json
import logging
import os import os
import pickle import pickle
import random import random
@ -22,7 +21,6 @@ from unittest import mock, SkipTest
import torch import torch
import torch.distributed as c10d import torch.distributed as c10d
import torch.distributed._functional_collectives as _functional_collectives import torch.distributed._functional_collectives as _functional_collectives
from torch.distributed.distributed_c10d import SHRINK_ABORT as NCCL_SHRINK_ABORT
if not c10d.is_available() or not c10d.is_nccl_available(): if not c10d.is_available() or not c10d.is_nccl_available():
@ -49,15 +47,12 @@ from torch._C._distributed_c10d import ErrorType, OpType, WorkResult
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
from torch.testing._internal.common_cuda import _get_torch_rocm_version, TEST_MULTIGPU from torch.testing._internal.common_cuda import _get_torch_rocm_version, TEST_MULTIGPU
from torch.testing._internal.common_distributed import ( from torch.testing._internal.common_distributed import (
get_required_world_size,
get_timeout, get_timeout,
init_multigpu_helper, init_multigpu_helper,
MultiProcessTestCase, MultiProcessTestCase,
requires_multicast_support, requires_multicast_support,
requires_nccl, requires_nccl,
requires_nccl_shrink,
requires_nccl_version, requires_nccl_version,
requires_world_size,
skip_if_lt_x_gpu, skip_if_lt_x_gpu,
skip_if_rocm_multiprocess, skip_if_rocm_multiprocess,
sm_is_or_higher_than, sm_is_or_higher_than,
@ -92,17 +87,6 @@ BFLOAT16_AVAILABLE = torch.cuda.is_available() and (
torch.version.cuda is not None or torch.version.hip is not None torch.version.cuda is not None or torch.version.hip is not None
) )
from logging_utils import (
configure as _log_configure,
log_test_info,
log_test_success,
log_test_validation,
log_test_warning,
)
_log_configure(level=logging.INFO, force=True)
class RendezvousEnvTest(TestCase): class RendezvousEnvTest(TestCase):
@retry_on_connect_failures @retry_on_connect_failures
@ -333,7 +317,7 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
@property @property
def world_size(self): def world_size(self):
return get_required_world_size(self, 2) return 2
@property @property
def rank_to_GPU(self): def rank_to_GPU(self):
@ -1271,628 +1255,6 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
pg_2 = c10d.new_group([0, 1]) pg_2 = c10d.new_group([0, 1])
self.assertEqual(pg_2.group_desc, "undefined") self.assertEqual(pg_2.group_desc, "undefined")
@requires_nccl_shrink()
@requires_world_size(2)
def test_shrink_group_basic(self):
"""Test basic shrink_group functionality."""
self._perform_shrink_test([1], "Basic shrink test")
@requires_nccl_shrink()
@requires_world_size(2)
def test_shrink_group_validation(self):
"""Test input validation in shrink_group."""
device, pg = self._setup_shrink_test("validation")
def _test_invalid_input(ranks, description, expected_exception):
"""Helper to test invalid inputs."""
try:
c10d.shrink_group(ranks)
self.fail(f"Expected {expected_exception.__name__} for {description}")
except expected_exception:
log_test_validation(self.rank, f"{description}")
except Exception:
if expected_exception is Exception: # Accept any exception
log_test_validation(self.rank, f"{description}")
else:
raise
# Test cases
_test_invalid_input([], "Empty exclusion list", ValueError)
if self.world_size > 1:
_test_invalid_input([0, 0, 1], "Duplicate ranks", Exception)
_test_invalid_input([self.world_size + 1], "Out of bounds rank", Exception)
log_test_success(self.rank, "All validation tests passed")
dist.destroy_process_group()
@requires_nccl_shrink()
@requires_world_size(2)
def test_shrink_group_backend_properties(self):
"""Test that backend properties are preserved after shrinking."""
test_name = "Backend Properties Test"
ranks_to_exclude = [0]
# Reuse _setup_shrink_test for complete setup (device, environment, and process group)
device, pg = self._setup_shrink_test("backend_properties")
# Follow _perform_shrink_test pattern from here
log_test_info(self.rank, f"{test_name} (world_size={self.world_size})")
is_excluded = self.rank in ranks_to_exclude
log_test_info(
self.rank,
f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}",
)
# Store original backend property values (not references) before shrinking
original_timeout = None
original_high_priority = None
if not is_excluded:
original_backend = pg._get_backend(device)
original_timeout = original_backend.options._timeout
original_high_priority = original_backend.options.is_high_priority_stream
log_test_info(
self.rank,
f"Storing original backend properties: timeout={original_timeout}, high_priority={original_high_priority}",
)
if is_excluded:
log_test_info(
self.rank,
f"Excluded rank {self.rank} - setup complete, skipping shrink operation",
)
dist.destroy_process_group() # hang without it
return
# Only non-excluded ranks proceed with shrink (same as _perform_shrink_test)
log_test_info(self.rank, "Non-excluded rank calling shrink_group")
shrunk_pg = c10d.shrink_group(ranks_to_exclude)
# Reuse _validate_shrunk_group helper (same as _perform_shrink_test)
expected_size = self.world_size - len(ranks_to_exclude)
_ = self._validate_shrunk_group(shrunk_pg, expected_size, test_name)
# Add custom backend properties validation
new_backend = shrunk_pg._get_backend(device)
log_test_info(self.rank, "Validating backend properties are preserved")
new_timeout = new_backend.options._timeout
new_high_priority = new_backend.options.is_high_priority_stream
log_test_info(
self.rank,
f"Timeout comparison - original: {original_timeout}, new: {new_timeout}",
)
self.assertEqual(
original_timeout, new_timeout, f"{test_name}: timeout not preserved"
)
log_test_info(
self.rank,
f"High priority stream comparison - original: {original_high_priority}, new: {new_high_priority}",
)
self.assertEqual(
original_high_priority,
new_high_priority,
f"{test_name}: high_priority_stream not preserved",
)
log_test_validation(
self.rank, f"{test_name}: Backend properties preserved successfully"
)
log_test_success(
self.rank, f"{test_name} successful (shrink + backend validation)"
)
# Cleanup (same as _perform_shrink_test)
dist.destroy_process_group()
@requires_nccl_shrink()
@requires_world_size(2)
def test_shrink_group_multiple_comms(self):
"""Test shrink_group with multiple communicators and subgroup invalidation."""
device, pg = self._setup_shrink_test("multiple_comms")
# Create subgroup [0, 1] and test shrinking it
subgroup = c10d.new_group([0, 1])
if self.rank <= 1:
# Shrink subgroup: exclude rank 1
if self.rank == 0: # Only rank 0 remains
shrunk_subgroup = c10d.shrink_group([1], group=subgroup)
self.assertEqual(shrunk_subgroup.size(), 1)
# Test communication on shrunk subgroup
tensor = torch.full((1,), self.rank).cuda(device)
c10d.all_reduce(tensor, group=shrunk_subgroup)
self.assertEqual(tensor.item(), 0) # Only rank 0
log_test_success(self.rank, "Subgroup shrinking successful")
dist.barrier() # Sync before default group test
# Shrink default group: exclude last rank
ranks_to_exclude = [self.world_size - 1]
if self.rank not in ranks_to_exclude:
shrunk_default = c10d.shrink_group(ranks_to_exclude)
expected_size = self.world_size - 1
self.assertEqual(shrunk_default.size(), expected_size)
# Test collective on shrunk default group
tensor = torch.full((1,), self.rank).cuda(device)
c10d.all_reduce(tensor, group=shrunk_default)
expected_sum = sum(
range(self.world_size - 1)
) # 0 + 1 + ... + (world_size-2)
self.assertEqual(tensor.item(), expected_sum)
log_test_success(self.rank, "Default group shrinking successful")
# Note: After shrinking default group, the old subgroup is invalid
# due to global rank reassignment
dist.destroy_process_group()
def _test_shrink_group_with_flag(self, shrink_flag, flag_name, rank_to_exclude):
"""Helper method to test shrink_group with a specific flag."""
if self.world_size < 2:
log_test_info(self.rank, f"Skipping (needs ≥2 GPUs, got {self.world_size})")
return
ranks_to_exclude = [rank_to_exclude]
log_test_info(self.rank, f"Using {flag_name} flag (value: {shrink_flag})")
if flag_name == "NCCL_SHRINK_ABORT":
log_test_info(
self.rank,
"ABORT flag will terminate ongoing operations before shrinking",
)
self._perform_shrink_test(
ranks_to_exclude, f"{flag_name} flag test", shrink_flags=shrink_flag
)
@requires_nccl_shrink()
@requires_world_size(2)
def test_shrink_group_flags(self):
"""Test shrink_group with different shrink flags."""
# Test ABORT flags
log_test_info(self.rank, "Testing NCCL_SHRINK_ABORT flag")
self._test_shrink_group_with_flag(NCCL_SHRINK_ABORT, "NCCL_SHRINK_ABORT", 1)
@requires_nccl_shrink()
@requires_world_size(2)
def test_shrink_group_nccl_config(self):
"""Verify that passing NCCL config via pg_options influences the shrunk group's backend options."""
device, pg = self._setup_shrink_test("config")
if self.rank == self.world_size - 1:
# excluded rank should not call shrink_group
dist.destroy_process_group()
return
# Prepare pg_options with NCCL config overrides
# Capture parent's current backend options to ensure we can prove override vs inherit
parent_backend = pg._get_backend(torch.device("cuda"))
parent_hp = parent_backend.options.is_high_priority_stream
parent_blocking = parent_backend.options.config.blocking
# Choose overrides that differ from the parent (flip where possible)
override_hp = not parent_hp
if parent_blocking in (0, 1):
override_blocking = 1 - parent_blocking
else:
# If undefined or unexpected, set to 1 which is a concrete value
override_blocking = 1
opts = c10d.ProcessGroupNCCL.Options()
opts.is_high_priority_stream = override_hp
opts.config.blocking = override_blocking
shrunk_pg = c10d.shrink_group([self.world_size - 1], pg_options=opts)
# Validate backend options propagated
backend = shrunk_pg._get_backend(torch.device("cuda"))
# is_high_priority_stream should exactly match our override and differ from parent
self.assertEqual(backend.options.is_high_priority_stream, override_hp)
self.assertNotEqual(backend.options.is_high_priority_stream, parent_hp)
# config is a struct; check representative field and difference from parent when meaningful
self.assertEqual(backend.options.config.blocking, override_blocking)
if parent_blocking in (0, 1):
self.assertNotEqual(backend.options.config.blocking, parent_blocking)
dist.destroy_process_group()
@requires_nccl_shrink()
@requires_world_size(2)
def test_shrink_group_performance(self):
"""Test shrink_group performance and regression detection."""
import time
ranks_to_exclude = self._get_default_ranks_to_exclude()
is_excluded = self.rank in ranks_to_exclude
if not ranks_to_exclude:
log_test_info(self.rank, "Skipping performance test (world_size=1)")
return
log_test_info(self.rank, f"Performance test with {self.world_size} processes")
device, pg = self._setup_shrink_test("performance")
if not is_excluded:
log_test_info(self.rank, "Measuring shrink_group performance")
start_time = time.time()
shrunk_pg = c10d.shrink_group(ranks_to_exclude)
end_time = time.time()
elapsed_time = end_time - start_time
log_test_info(self.rank, f"shrink_group: {elapsed_time:.3f}s")
# Regression check: should complete within reasonable time
self.assertLess(
elapsed_time,
30.0,
f"shrink_group took {elapsed_time:.3f}s, possible regression",
)
# Test collective performance
expected_size = self.world_size - len(ranks_to_exclude)
self._validate_shrunk_group(shrunk_pg, expected_size, "performance")
collective_start = time.time()
_ = self._test_collective_on_shrunk_group(
shrunk_pg, device, ranks_to_exclude, "performance"
)
collective_time = time.time() - collective_start
log_test_info(self.rank, f"all_reduce: {collective_time:.3f}s")
log_test_success(self.rank, "Performance test passed")
else:
log_test_info(self.rank, "Excluded rank - waiting")
dist.destroy_process_group()
@requires_nccl_shrink()
@requires_world_size(4)
def test_shrink_group_multiple_exclusions(self):
"""Test shrink_group with multiple ranks excluded at once."""
# Scale exclusions with world size
ranks_to_exclude = list(range(2, self.world_size, 2)) # Every other rank from 2
self._perform_shrink_test(ranks_to_exclude, "Multiple exclusions test")
@requires_nccl_shrink()
@requires_world_size(3)
def test_shrink_group_multiple_iterations(self):
"""Test multiple shrink operations in sequence."""
log_test_info(
self.rank,
f"Starting test_shrink_group_multiple_iterations with world_size={self.world_size}",
)
store = c10d.FileStore(self.file_name, self.world_size)
device = torch.device(f"cuda:{self.rank}")
_ = self._create_process_group_nccl(store, self.opts(), device_id=device)
# Track current effective world size throughout shrinking operations
current_world_size = self.world_size
log_test_info(self.rank, f"Initial world_size: {current_world_size}")
# First shrinking: exclude the last rank(s)
first_exclusion = [self.world_size - 1]
if self.world_size >= 6:
first_exclusion.append(
self.world_size - 2
) # Exclude last two ranks for larger sizes
log_test_info(self.rank, f"First shrinking: excluding ranks {first_exclusion}")
if self.rank not in first_exclusion:
# Only non-excluded ranks should call shrink_group
first_pg = c10d.shrink_group(first_exclusion)
self.assertIsNotNone(first_pg)
# IMPORTANT: Update world size after first shrinking
current_world_size = first_pg.size()
expected_first_size = self.world_size - len(first_exclusion)
log_test_info(
self.rank,
f"After first shrinking: world_size {self.world_size} -> {current_world_size}",
)
self.assertEqual(first_pg.size(), expected_first_size)
# Second shrinking: exclude another rank from the remaining group
# Choose a rank that's in the middle range
if current_world_size >= 3:
second_exclusion = [
current_world_size - 1
] # Exclude the new "last" rank
log_test_info(
self.rank,
f"Second shrinking from group of size {current_world_size}: excluding ranks {second_exclusion}",
)
if self.rank not in second_exclusion:
# Only non-excluded ranks should call shrink_group for second iteration
second_pg = c10d.shrink_group(second_exclusion, group=first_pg)
self.assertIsNotNone(second_pg)
# IMPORTANT: Update world size after second shrinking
final_world_size = second_pg.size()
expected_final_size = current_world_size - len(second_exclusion)
log_test_info(
self.rank,
f"After second shrinking: world_size {current_world_size} -> {final_world_size}",
)
self.assertEqual(second_pg.size(), expected_final_size)
# Test collective on final group
tensor = torch.full((1,), self.rank).cuda(device)
log_test_info(
self.rank,
f"Performing all_reduce on final group (size {final_world_size}) with tensor: {tensor.item()}",
)
c10d.all_reduce(tensor, group=second_pg)
log_test_info(
self.rank,
f"Final all_reduce completed, result: {tensor.item()}",
)
# Calculate expected sum of remaining ranks
all_excluded = set(first_exclusion + second_exclusion)
remaining_ranks = [
r for r in range(self.world_size) if r not in all_excluded
]
expected_sum = sum(remaining_ranks)
log_test_info(
self.rank,
f"Remaining ranks: {remaining_ranks}, expected sum: {expected_sum}, actual: {tensor.item()}",
)
self.assertEqual(tensor.item(), expected_sum)
log_test_info(self.rank, "Final verification passed")
else:
log_test_info(
self.rank,
"This rank excluded in second shrinking, not calling shrink_group",
)
else:
log_test_info(
self.rank, "Skipping second shrinking (remaining group too small)"
)
else:
log_test_info(
self.rank,
"This rank excluded in first shrinking, not calling shrink_group",
)
log_test_info(self.rank, "Destroying process group")
dist.destroy_process_group()
log_test_info(self.rank, "test_shrink_group_multiple_iterations completed")
# Helper methods for optimized shrink group tests
def _setup_shrink_test(self, test_suffix, world_size=None, warmup=True):
"""Common setup for shrink group tests."""
os.environ["TORCH_NCCL_USE_COMM_NONBLOCKING"] = "1"
world_size = world_size or self.world_size
store = c10d.FileStore(self.file_name + f"_{test_suffix}", world_size)
device = torch.device(f"cuda:{self.rank}")
c10d.init_process_group(
"nccl",
world_size=world_size,
rank=self.rank,
store=store,
pg_options=self.opts(),
device_id=device,
)
pg = c10d.distributed_c10d._get_default_group()
if warmup:
c10d.all_reduce(torch.ones(1).cuda(device), group=pg)
return device, pg
def _validate_shrunk_group(self, shrunk_pg, expected_size, test_name=""):
"""Validate properties of a shrunk process group."""
self.assertIsNotNone(shrunk_pg, f"{test_name}: shrunk_pg should not be None")
actual_size = shrunk_pg.size()
self.assertEqual(
actual_size, expected_size, f"{test_name}: group size mismatch"
)
new_rank = shrunk_pg.rank()
self.assertTrue(
0 <= new_rank < expected_size, f"{test_name}: invalid new rank {new_rank}"
)
log_test_info(
self.rank,
f"{test_name}: world_size {self.world_size} -> {actual_size}, rank {self.rank} -> {new_rank}",
)
return new_rank
def _test_collective_on_shrunk_group(
self, shrunk_pg, device, ranks_to_exclude, test_name=""
):
"""Test collective communication on shrunk group and verify correctness."""
test_tensor = torch.full((1,), self.rank, device=device, dtype=torch.float32)
c10d.all_reduce(test_tensor, group=shrunk_pg)
result = test_tensor.item()
expected_sum = sum(
r for r in range(self.world_size) if r not in ranks_to_exclude
)
self.assertEqual(
result, expected_sum, f"{test_name}: collective result mismatch"
)
log_test_info(
self.rank, f"{test_name}: collective passed ({result} == {expected_sum})"
)
return result
def _perform_shrink_test(
self, ranks_to_exclude, test_name, shrink_flags=0, with_collective=True
):
"""Complete shrink test flow: setup, shrink, validate, test collective, cleanup.
Consistent API: All ranks perform setup to initialize distributed environment.
ONLY non-excluded ranks call shrink_group() for both default and non-default groups.
Excluded ranks perform setup, then exit without calling shrink_group() or waiting.
"""
log_test_info(self.rank, f"{test_name} (world_size={self.world_size})")
is_excluded = self.rank in ranks_to_exclude
log_test_info(
self.rank,
f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}",
)
# All ranks (including excluded ones) perform setup to initialize distributed environment
device, pg = self._setup_shrink_test(test_name.lower().replace(" ", "_"))
is_default_group = pg == c10d.distributed_c10d._get_default_group()
if is_excluded:
log_test_info(
self.rank,
f"Excluded rank {self.rank} - setup complete, skipping shrink operation",
)
if shrink_flags & NCCL_SHRINK_ABORT:
log_test_info(self.rank, f"Using abort for excluded rank {self.rank}")
pg._get_backend(torch.device(device)).abort()
log_test_info(
self.rank, f"cleanup resources for excluded rank {self.rank}"
)
dist.destroy_process_group()
log_test_info(self.rank, f"Excluded rank {self.rank} - exit")
else:
log_test_info(
self.rank, f"Using regular destroy for excluded rank {self.rank}"
)
dist.destroy_process_group()
return None
# Only non-excluded ranks proceed with shrink
log_test_info(
self.rank,
f"Non-excluded rank calling shrink_group (default_group={is_default_group})",
)
shrunk_pg = c10d.shrink_group(ranks_to_exclude, shrink_flags=shrink_flags)
log_test_info(
self.rank,
f"Non-excluded rank calling shrink_group (default_group={is_default_group}) done",
)
# Non-excluded ranks: validate and test the new group
expected_size = self.world_size - len(ranks_to_exclude)
_ = self._validate_shrunk_group(shrunk_pg, expected_size, test_name)
if with_collective:
_ = self._test_collective_on_shrunk_group(
shrunk_pg, device, ranks_to_exclude, test_name
)
log_test_success(self.rank, f"{test_name} successful (shrink + collective)")
else:
log_test_success(self.rank, f"{test_name} successful (shrink only)")
dist.destroy_process_group()
return shrunk_pg
def _get_default_ranks_to_exclude(self):
"""Get default ranks to exclude based on world size."""
if self.world_size <= 1:
return []
return [self.world_size - 1] # Exclude last rank by default
@requires_nccl_shrink()
@requires_world_size(3)
def test_shrink_group_vs_abort_reinit_performance(self):
"""Compare performance of shrink_group vs traditional abort+reinit (simplified for reliability)."""
log_test_info(self.rank, "=== TEST 1: abort+reinit ===")
device, pg1 = self._setup_shrink_test("_perf_reinit")
torch.cuda.synchronize(device)
# Test 1: Traditional abort + reinit
start_time = time.perf_counter()
dist.destroy_process_group()
device, new_pg = self._setup_shrink_test("perf_shrink_test1")
reinit_time = time.perf_counter() - start_time
# Test collective with original rank values for fair comparison (non-blocking mode)
test_tensor = torch.full((1,), self.rank, device=device, dtype=torch.float32)
work = c10d.all_reduce(test_tensor, group=new_pg, async_op=True)
work.wait()
torch.cuda.synchronize(device)
# Verify correctness
expected_sum = sum(r for r in range(self.world_size))
self.assertEqual(test_tensor.item(), expected_sum, "Reinit collective failed")
log_test_info(self.rank, f"abort+reinit: {reinit_time:.4f}s")
dist.destroy_process_group(new_pg)
# Test 2: shrink_group with NCCL_SHRINK_ABORT
log_test_info(self.rank, "=== TEST 2: shrink_group ===")
ranks_to_exclude = [self.world_size - 1]
is_excluded = self.rank in ranks_to_exclude
log_test_info(
self.rank,
f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}",
)
device, pg1 = self._setup_shrink_test("perf_shrink_test2") # Unique suffix
shrink_time = 0
if not is_excluded:
torch.cuda.synchronize(device) # Ensure accurate timing
start_time = time.perf_counter()
shrunk_pg = c10d.shrink_group(
ranks_to_exclude, shrink_flags=NCCL_SHRINK_ABORT
)
c10d.all_reduce(torch.ones(1).cuda(device), group=shrunk_pg)
shrink_time = time.perf_counter() - start_time
# Test collective communication on shrunk group (non-blocking mode)
test_tensor = torch.full(
(1,), self.rank, device=device, dtype=torch.float32
)
work = c10d.all_reduce(test_tensor, group=shrunk_pg, async_op=True)
work.wait()
# Verify correctness
expected_sum = sum(
r for r in range(self.world_size) if r not in ranks_to_exclude
)
self.assertEqual(
test_tensor.item(),
expected_sum,
"shrink_test: collective result mismatch",
)
torch.cuda.synchronize(device) # Ensure operations complete
log_test_info(self.rank, f"shrink_group: {shrink_time:.4f}s")
dist.destroy_process_group()
else:
log_test_info(self.rank, "Excluded from shrink test - exiting immediately")
dist.destroy_process_group()
return
# Performance analysis (only for participating ranks)
if shrink_time > 0 and reinit_time > 0:
speedup = reinit_time / shrink_time
time_saved = reinit_time - shrink_time
log_test_info(self.rank, "=== PERFORMANCE RESULTS ===")
log_test_info(self.rank, f"shrink_group: {shrink_time:.4f}s")
log_test_info(self.rank, f"abort+reinit: {reinit_time:.4f}s")
log_test_info(self.rank, f"time_saved: {time_saved:+.4f}s")
log_test_info(self.rank, f"speedup: {speedup:.2f}x")
if speedup > 1.1:
log_test_success(self.rank, "shrink_group significantly faster")
elif speedup > 0.9:
log_test_info(self.rank, "≈ comparable performance")
else:
log_test_warning(self.rank, "abort+reinit faster")
log_test_info(self.rank, "Performance test completed")
@requires_nccl() @requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_deterministic_mode_no_break(self): def test_deterministic_mode_no_break(self):

View File

@ -79,23 +79,6 @@ class TORCH_API Backend : public torch::CustomClassHolder {
return false; return false;
} }
virtual bool supportsShrinking() const {
return false;
}
// Shrink the backend by excluding specified ranks. Backends that support
// communicator shrinking should override this and return a new backend
// instance representing the shrunken group. Backends may use opts_override
// to supply backend-specific options for the new group.
virtual c10::intrusive_ptr<Backend> shrink(
const std::vector<int64_t>& /*ranks_to_exclude*/,
int /*shrink_flags*/ = 0,
const c10::intrusive_ptr<Options>& /*opts_override*/ = nullptr) {
TORCH_CHECK(
false,
c10::str("Backend ", getBackendName(), " does not support shrink"));
}
virtual void setTimeout(std::chrono::milliseconds timeout) { virtual void setTimeout(std::chrono::milliseconds timeout) {
TORCH_CHECK( TORCH_CHECK(
false, false,

View File

@ -259,65 +259,6 @@ std::shared_ptr<NCCLComm> NCCLComm::split(
} }
#endif #endif
#ifdef NCCL_HAS_COMM_SHRINK
std::shared_ptr<NCCLComm> NCCLComm::shrink(
NCCLComm* source,
std::vector<int>& ranks_to_exclude,
ncclConfig_t* config,
int shrinkFlags) {
// Preconditions are validated in ProcessGroupNCCL::shrink
LOG(INFO) << "Rank " << source->rank_ << ": shrinking comm " << source->repr()
<< " excluding " << ranks_to_exclude.size() << " ranks";
at::cuda::OptionalCUDAGuard gpuGuard(source->deviceIndex_);
auto comm = std::make_shared<NCCLComm>();
// This call will block until the source communicator is initialized
auto sourceComm = source->getNcclComm();
C10D_NCCL_CHECK_NONBLOCKING(
ncclCommShrink(
sourceComm,
ranks_to_exclude.data(),
ranks_to_exclude.size(),
reinterpret_cast<ncclComm_t*>(&(comm->ncclComm_)),
config,
shrinkFlags),
source->getNcclCommFailureReason());
// Wait for the child communicator to be ready
source->waitReady(true);
comm->initialized_ = true;
// NCCL automatically assigns rank during shrink - query it efficiently
int assigned_rank;
try {
C10D_NCCL_CHECK(
ncclCommUserRank(comm->ncclComm_, &assigned_rank), std::nullopt);
comm->rank_ = assigned_rank;
} catch (const std::exception& e) {
// Fallback: if ncclCommUserRank fails, we can't determine the rank
LOG(ERROR) << "Failed to query NCCL-assigned rank: " << e.what();
throw;
}
// Child comm should be on the same device as parent comm
comm->deviceIndex_ = source->deviceIndex_;
if (config != nullptr) {
comm->nonBlocking_ = config->blocking == 0;
} else {
// Inherit parent behavior if no config provided
comm->nonBlocking_ = source->nonBlocking_;
}
LOG(INFO) << "Rank " << source->rank_ << ": created shrunken comm "
<< comm->repr() << " with NCCL-assigned rank " << assigned_rank;
return comm;
}
#endif
void NCCLComm::finalize() { void NCCLComm::finalize() {
LockType lock(mutex_); LockType lock(mutex_);
if (aborted_) { if (aborted_) {

View File

@ -90,10 +90,6 @@ static_assert(
#define NCCL_HAS_NVLS_CTAS #define NCCL_HAS_NVLS_CTAS
#endif #endif
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 27, 0)
#define NCCL_HAS_COMM_SHRINK
#endif
// Macro to throw on a non-successful NCCL return value. // Macro to throw on a non-successful NCCL return value.
#define C10D_NCCL_CHECK(cmd, failureReason) \ #define C10D_NCCL_CHECK(cmd, failureReason) \
do { \ do { \
@ -298,14 +294,6 @@ class NCCLComm {
ncclConfig_t& config); ncclConfig_t& config);
#endif // NCCL_HAS_COMM_SPLIT #endif // NCCL_HAS_COMM_SPLIT
#ifdef NCCL_HAS_COMM_SHRINK
static std::shared_ptr<NCCLComm> shrink(
NCCLComm* source,
std::vector<int>& ranks_to_exclude,
ncclConfig_t* config,
int shrinkFlags = 0);
#endif // NCCL_HAS_COMM_SHRINK
#if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP) #if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP)
std::unordered_map<std::string, std::string> ncclCommDump(); std::unordered_map<std::string, std::string> ncclCommDump();
#endif #endif

View File

@ -165,7 +165,7 @@ ncclRedOpRAII getNcclReduceOp(
} }
// Get a key string from device // Get a key string from device
inline std::string getKeyFromDevice(const at::Device& device) { inline std::string getKeyFromDevice(at::Device& device) {
return std::to_string(device.index()); return std::to_string(device.index());
} }
@ -5838,139 +5838,6 @@ at::Tensor ProcessGroupNCCL::allocateTensor(
return tensor; return tensor;
} }
#ifdef NCCL_HAS_COMM_SHRINK
c10::intrusive_ptr<Backend> ProcessGroupNCCL::shrink(
const std::vector<int64_t>& ranks_to_exclude,
int shrink_flags,
const c10::intrusive_ptr<Backend::Options>& opts_override) {
// Runtime version check with better error message
auto runtime_version = torch::cuda::nccl::version();
TORCH_CHECK(
runtime_version >= NCCL_VERSION(2, 27, 0),
"ProcessGroupNCCL::shrink requires NCCL version 2.27.0 or later. "
"Found version: ",
runtime_version);
// Early validation with detailed error messages
TORCH_CHECK_VALUE(
!ranks_to_exclude.empty(), "ranks_to_exclude cannot be empty");
TORCH_CHECK_VALUE(
static_cast<int>(ranks_to_exclude.size()) < size_,
"Cannot exclude all ranks (",
ranks_to_exclude.size(),
" >= ",
size_,
")");
// Validate ranks and convert to int efficiently
std::vector<int> int_ranks_to_exclude;
int_ranks_to_exclude.reserve(ranks_to_exclude.size());
for (int64_t rank : ranks_to_exclude) {
TORCH_CHECK_VALUE(
rank >= 0 && rank < size_,
"Invalid rank ",
rank,
" for group size ",
size_);
int_ranks_to_exclude.push_back(static_cast<int>(rank));
}
// Get primary communicator with better error context
auto primary_device_index = guessDeviceId();
auto primary_device = at::Device(at::kCUDA, primary_device_index);
const auto primary_key = getKeyFromDevice(primary_device);
std::shared_ptr<NCCLComm> primary_comm = getNCCLComm(primary_key);
TORCH_CHECK(
primary_comm,
"Primary NCCL communicator for device ",
primary_device,
" (key: ",
primary_key,
") is not initialized");
// Cache device index before shrink operation
at::DeviceIndex parent_device_index = primary_comm->getDeviceIndex();
ncclConfig_t* config = nullptr;
// Default to inheriting from parent options
bool high_priority_stream = options_->is_high_priority_stream;
if (opts_override) {
auto nccl_opts =
c10::static_intrusive_pointer_cast<ProcessGroupNCCL::Options>(
opts_override);
config = &nccl_opts->config;
// If user provided override options, honor is_high_priority_stream as well
high_priority_stream = nccl_opts->is_high_priority_stream;
}
std::shared_ptr<NCCLComm> shrunk_comm = NCCLComm::shrink(
primary_comm.get(),
int_ranks_to_exclude,
(config != nullptr ? config : &options_->config),
shrink_flags);
// Calculate new size and get NCCL-assigned rank
int new_size = size_ - static_cast<int>(ranks_to_exclude.size());
int new_rank = shrunk_comm->rank_;
// Create new ProcessGroupNCCL with optimized options cloning
auto new_store = store_->clone();
auto new_opts = ProcessGroupNCCL::Options::create(high_priority_stream);
new_opts->timeout = options_->timeout;
if (config != nullptr) {
new_opts->config = *config;
} else {
new_opts->config = options_->config;
}
auto new_pg = c10::make_intrusive<ProcessGroupNCCL>(
new_store, new_rank, new_size, new_opts);
// Set up the new process group with optimized device setup
new_pg->initializeDeviceStateForComm(
at::Device(at::kCUDA, parent_device_index), shrunk_comm);
return c10::static_intrusive_pointer_cast<Backend>(new_pg);
}
#else // !NCCL_HAS_COMM_SHRINK
// Backend interface override: raise consistent error when shrink is
// unsupported.
c10::intrusive_ptr<Backend> ProcessGroupNCCL::shrink(
const std::vector<int64_t>& /*ranks_to_exclude*/,
int /*shrink_flags*/,
const c10::intrusive_ptr<Backend::Options>& /*opts_override*/) {
TORCH_CHECK(
false,
"ProcessGroupNCCL::shrink requires NCCL version 2.27.0 or later, "
"but PyTorch was built with an older version or without NCCL shrink support.");
}
#endif // NCCL_HAS_COMM_SHRINK
void ProcessGroupNCCL::initializeDeviceStateForComm(
const at::Device& device,
std::shared_ptr<NCCLComm> comm) {
const auto key = getKeyFromDevice(device);
std::unique_lock<std::mutex> lock(mutex_);
at::cuda::OptionalCUDAGuard gpuGuard(device);
bool force_high = getCvarBool(TORCH_NCCL_HIGH_PRIORITY, false);
auto stream = at::cuda::getStreamFromPool(
options_->is_high_priority_stream || force_high);
devNCCLCommMap_[key] = comm;
ncclStreams_.emplace(key, stream);
ncclEvents_.emplace(key, at::cuda::CUDAEvent(cudaEventDisableTiming));
usedDeviceIdxs_.insert(device.index());
if (shouldAllCommunicatorsRegisterAllTensors()) {
std::lock_guard<std::mutex> map_lock(ncclCommMemPoolMapMutex);
ncclCommMemPoolMap.emplace(std::move(comm), MemPoolSet{});
}
}
} // namespace c10d } // namespace c10d
#endif // USE_C10D_NCCL #endif // USE_C10D_NCCL

View File

@ -997,21 +997,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
ErrorType getError() override; ErrorType getError() override;
bool supportsShrinking() const override {
#ifdef NCCL_HAS_COMM_SHRINK
return true;
#else
return false;
#endif
}
// Backend-style shrink override that returns a Backend instance.
c10::intrusive_ptr<Backend> shrink(
const std::vector<int64_t>& ranks_to_exclude,
int shrink_flags = 0,
const c10::intrusive_ptr<Backend::Options>& opts_override =
nullptr) override;
std::shared_ptr<c10::Allocator> getMemAllocator() override; std::shared_ptr<c10::Allocator> getMemAllocator() override;
// Allocate tensor from communication-optimized memory pool // Allocate tensor from communication-optimized memory pool
@ -1080,12 +1065,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
int p2pRank = 0, int p2pRank = 0,
bool isSendRecvSelf = false); bool isSendRecvSelf = false);
// Initialize device-specific state (comm, stream, event, bookkeeping) for a
// given communicator on this process group instance.
void initializeDeviceStateForComm(
const at::Device& device,
std::shared_ptr<NCCLComm> comm);
// Wrapper method which can be overridden for tests. // Wrapper method which can be overridden for tests.
virtual std::exception_ptr checkForNCCLErrors( virtual std::exception_ptr checkForNCCLErrors(
std::shared_ptr<NCCLComm>& ncclComm); std::shared_ptr<NCCLComm>& ncclComm);

View File

@ -2730,23 +2730,12 @@ Arguments:
"supports_time_estimate", "supports_time_estimate",
&::c10d::Backend::supportsTimeEstimation, &::c10d::Backend::supportsTimeEstimation,
"(test whether the backend supports collective time estimation)") "(test whether the backend supports collective time estimation)")
.def_property_readonly(
"supports_shrinking",
&::c10d::Backend::supportsShrinking,
"(test whether the backend supports communicator shrinking)")
.def( .def(
"set_timeout", "set_timeout",
&::c10d::Backend::setTimeout, &::c10d::Backend::setTimeout,
py::arg("timeout"), py::arg("timeout"),
py::call_guard<py::gil_scoped_release>(), py::call_guard<py::gil_scoped_release>(),
R"(Sets the default timeout for all future operations.)") R"(Sets the default timeout for all future operations.)")
.def(
"shrink",
&::c10d::Backend::shrink,
py::arg("ranks_to_exclude"),
py::arg("shrink_flags") = 0,
py::arg("opts_override") = nullptr,
py::call_guard<py::gil_scoped_release>())
.def( .def(
"broadcast", "broadcast",
&::c10d::Backend::broadcast, &::c10d::Backend::broadcast,

View File

@ -130,7 +130,6 @@ __all__ = [
"reduce_scatter_tensor", "reduce_scatter_tensor",
"get_node_local_rank", "get_node_local_rank",
"split_group", "split_group",
"shrink_group",
] ]
_MPI_AVAILABLE = True _MPI_AVAILABLE = True
@ -5714,517 +5713,3 @@ def _get_process_group_name(pg: ProcessGroup) -> str:
def _get_process_group_store(pg: ProcessGroup) -> Store: def _get_process_group_store(pg: ProcessGroup) -> Store:
return _world.pg_map[pg][1] return _world.pg_map[pg][1]
# Shrink flags for process group backends
SHRINK_DEFAULT = 0x00
SHRINK_ABORT = 0x01
@_time_logger
def shrink_group(
ranks_to_exclude: list[int],
group: Optional[ProcessGroup] = None,
shrink_flags: int = SHRINK_DEFAULT,
pg_options: Optional[Any] = None,
) -> ProcessGroup:
"""
Shrinks a process group by excluding specified ranks.
Creates and returns a new, smaller process group comprising only the ranks
from the original group that were not in the ``ranks_to_exclude`` list.
Args:
ranks_to_exclude (List[int]): A list of ranks from the original
``group`` to exclude from the new group.
group (ProcessGroup, optional): The process group to shrink. If ``None``,
the default process group is used. Defaults to ``None``.
shrink_flags (int, optional): Flags to control the shrinking behavior.
Can be ``SHRINK_DEFAULT`` (default) or ``SHRINK_ABORT``.
``SHRINK_ABORT`` will attempt to terminate ongoing operations
in the parent communicator before shrinking.
Defaults to ``SHRINK_DEFAULT``.
pg_options (ProcessGroupOptions, optional): Backend-specific options to apply
to the shrunken process group. If provided, the backend will use
these options when creating the new group. If omitted, the new group
inherits defaults from the parent.
Returns:
ProcessGroup: a new group comprised of the remaining ranks. If the
default group was shrunk, the returned group becomes the new default group.
Raises:
TypeError: if the groups backend does not support shrinking.
ValueError: if ``ranks_to_exclude`` is invalid (empty, out of bounds,
duplicates, or excludes all ranks).
RuntimeError: if an excluded rank calls this function or the backend
fails the operation.
Notes:
- Only non-excluded ranks should call this function; excluded ranks
must not participate in the shrink operation.
- Shrinking the default group destroys all other process groups since
rank reassignment makes them inconsistent.
"""
# Step 1: Validate input parameters with comprehensive error checking
_validate_shrink_inputs(ranks_to_exclude, shrink_flags)
# Step 2: Get target group and essential properties
target_group_info = _prepare_shrink_target_group(group)
# Step 3: Validate backend requirements and availability
backend_impl = _validate_shrink_backend_requirements(target_group_info)
# Step 4: Validate ranks against group and check for duplicates
excluded_ranks_set = _validate_and_process_excluded_ranks(
ranks_to_exclude, target_group_info
)
# Step 5: Execute the actual shrink operation (backend-specific)
new_backend = backend_impl.shrink(
sorted(excluded_ranks_set),
shrink_flags,
pg_options if pg_options is not None else None,
)
# Step 6: Handle cleanup and creation of new process group
target_group_info["pg_options_override"] = pg_options
return _finalize_shrunk_group(target_group_info, excluded_ranks_set, new_backend)
def _validate_shrink_inputs(ranks_to_exclude: list[int], shrink_flags: int) -> None:
"""Validate input parameters for shrink_group."""
if not isinstance(ranks_to_exclude, list):
raise TypeError(
f"ranks_to_exclude must be a list, but got {type(ranks_to_exclude).__name__}. "
f"Example: [1, 3, 5] to exclude ranks 1, 3, and 5."
)
if not ranks_to_exclude:
raise ValueError(
"ranks_to_exclude cannot be empty. To shrink a group, you must specify at least "
"one rank to exclude. Example: [failed_rank_id]"
)
# Validate shrink_flags with clear explanation of valid values
valid_flags = [SHRINK_DEFAULT, SHRINK_ABORT]
if not isinstance(shrink_flags, int) or shrink_flags not in valid_flags:
raise ValueError(
f"Invalid shrink_flags value: {shrink_flags}. Must be one of: "
f"SHRINK_DEFAULT ({SHRINK_DEFAULT}) or SHRINK_ABORT ({SHRINK_ABORT}). "
f"Use SHRINK_ABORT to abort ongoing operations before shrinking."
)
def _prepare_shrink_target_group(group: Optional[ProcessGroup]) -> dict:
"""Prepare and validate the target group for shrinking."""
target_pg = group if group is not None else _get_default_group()
# Cache frequently accessed properties to avoid repeated calls
group_size = int(target_pg.size())
group_info = {
"process_group": target_pg,
"is_default_group": (target_pg == _get_default_group()),
"group_size": group_size,
"current_rank": target_pg.rank(),
"group_name": _get_process_group_name(target_pg),
}
# Validate that we have a valid process group
if group_size <= 1:
raise ValueError(
f"Cannot shrink a process group with size {group_size}. "
f"Group must have at least 2 ranks to support shrinking."
)
return group_info
def _validate_shrink_backend_requirements(group_info: dict) -> Any:
"""Return the backend implementation for the target group or raise if unsupported."""
target_pg = group_info["process_group"]
group_name = group_info["group_name"]
# Get the group's backend directly via ProcessGroup API. Prefer a bound device if present,
# otherwise try CUDA then fall back to CPU.
try:
preferred_device = getattr(target_pg, "bound_device_id", None)
if preferred_device is not None:
backend_impl = target_pg._get_backend(preferred_device)
else:
# Try CUDA first if available, else CPU
try:
backend_impl = target_pg._get_backend(torch.device("cuda"))
except Exception:
backend_impl = target_pg._get_backend(torch.device("cpu"))
except RuntimeError as e:
raise RuntimeError(
f"Cannot access device backend for process group '{group_name}'. "
f"Ensure the process group was initialized with a compatible device backend and devices are available."
) from e
try:
supports = bool(backend_impl.supports_shrinking)
except Exception:
supports = False
if not supports:
raise TypeError(
f"Process group backend for '{group_name}' does not support shrinking operations."
)
return backend_impl
def _validate_and_process_excluded_ranks(
ranks_to_exclude: list[int], group_info: dict
) -> set:
"""Validate excluded ranks and convert to set for efficient operations."""
group_size = group_info["group_size"]
current_rank = group_info["current_rank"]
# Use set for O(1) duplicate detection and membership testing
excluded_ranks_set = set()
# Validate each rank with detailed error messages
for i, rank in enumerate(ranks_to_exclude):
if not isinstance(rank, int):
raise TypeError(
f"All elements in ranks_to_exclude must be integers. "
f"Element at index {i} is {type(rank).__name__}: {rank}"
)
if not (0 <= rank < group_size):
raise ValueError(
f"Rank {rank} at index {i} is out of bounds for group size {group_size}. "
f"Valid ranks are in range [0, {group_size - 1}]."
)
if rank in excluded_ranks_set:
raise ValueError(
f"Duplicate rank {rank} found in ranks_to_exclude at index {i}. "
f"Each rank can only be excluded once."
)
excluded_ranks_set.add(rank)
# Ensure we don't exclude all ranks
if len(excluded_ranks_set) >= group_size:
raise ValueError(
f"Cannot exclude all {group_size} ranks from process group. "
f"At least one rank must remain. Excluding {len(excluded_ranks_set)} ranks."
)
# Critical check: current rank should not be in excluded list
if current_rank in excluded_ranks_set:
raise RuntimeError(
f"Current rank {current_rank} is in the exclusion list and should not call shrink_group(). "
f"Only non-excluded ranks should participate in the shrinking operation. "
f"Excluded ranks should terminate their processes instead."
)
return excluded_ranks_set
def _finalize_shrunk_group(
group_info: dict, excluded_ranks_set: set, new_backend
) -> ProcessGroup:
"""Clean up old group and create new shrunk process group."""
target_pg = group_info["process_group"]
is_default_group = group_info["is_default_group"]
# Handle default group dependencies - destroy other groups first
if is_default_group:
_destroy_all_other_groups(exclude_group=target_pg)
# Gather original group metadata before cleanup
original_group_metadata = _extract_group_metadata(target_pg)
# Calculate remaining ranks efficiently
original_ranks = get_process_group_ranks(target_pg)
remaining_ranks = [
rank for rank in original_ranks if rank not in excluded_ranks_set
]
# Clean up the original group
_cleanup_original_group(target_pg, is_default_group)
# Create and configure the new process group
new_pg = _create_shrunk_process_group(
new_backend, remaining_ranks, original_group_metadata, is_default_group
)
# Register the new group in global state
if is_default_group:
_update_default_pg(new_pg)
# Update global state with new group information
rank_mapping = {
global_rank: group_rank
for group_rank, global_rank in enumerate(remaining_ranks)
}
_update_process_group_global_state(
pg=new_pg,
backend_name=original_group_metadata["backend_name"],
store=original_group_metadata["store"],
group_name=original_group_metadata["new_group_name"],
backend_config=original_group_metadata["backend_config"],
rank_mapping=rank_mapping,
)
return new_pg
def _extract_group_metadata(target_pg: ProcessGroup) -> dict:
"""Extract metadata from the original group before cleanup."""
original_backend_name, original_store = _world.pg_map[target_pg]
original_backend_config = _world.pg_backend_config.get(target_pg, "")
original_group_name = _get_process_group_name(target_pg)
# Extract device binding information before cleanup to avoid accessing destroyed group
bound_device_id = None
if hasattr(target_pg, "bound_device_id"):
bound_device_id = target_pg.bound_device_id
# Generate new group name for the shrunk group; hash for uniqueness across backends
remaining_ranks = list(get_process_group_ranks(target_pg))
new_group_name = _process_group_name(remaining_ranks, use_hashed_name=True)
return {
"backend_name": original_backend_name,
"store": original_store,
"backend_config": original_backend_config,
"original_group_name": original_group_name,
"new_group_name": new_group_name,
"bound_device_id": bound_device_id, # Safe to access after cleanup
}
def _cleanup_original_group(target_pg: ProcessGroup, is_default_group: bool) -> None:
"""Clean up the original process group safely."""
try:
destroy_process_group(target_pg)
except Exception as e:
group_type = "default" if is_default_group else "non-default"
logger.warning("Failed to destroy %s group during shrinking: %s", group_type, e)
# Ensure global state cleanup even if destroy_process_group fails
_cleanup_process_group_global_state(target_pg)
def _create_shrunk_process_group(
new_backend, remaining_ranks: list[int], metadata: dict, is_default_group: bool
) -> ProcessGroup:
"""Create and configure the new shrunk process group."""
# Create new group properties
new_group_rank = new_backend.rank()
new_group_size = new_backend.size()
group_name = metadata["new_group_name"]
# Generate descriptive group description
if is_default_group:
group_desc = "default:shrunken"
else:
group_desc = f"{metadata['original_group_name']}:shrunk"
# Create process group with new communicator (clone the parent store like split does)
prefix_store = PrefixStore(f"{group_name}/", metadata["store"].clone())
new_pg = ProcessGroup(prefix_store, new_group_rank, new_group_size)
# Configure backend using the device type of the new backend's bound device if available,
# otherwise derive from the original group's bound device or fall back to CPU.
backend_device = metadata.get("bound_device_id")
if backend_device is None:
# Default to CPU if no bound device is present
backend_device = torch.device("cpu")
# Choose backend enum based on device type
if backend_device.type == "cuda":
backend_type = ProcessGroup.BackendType.NCCL
else:
backend_type = ProcessGroup.BackendType.GLOO
new_pg._register_backend(backend_device, backend_type, new_backend)
new_pg._set_default_backend(backend_type)
# Inherit device binding from original group if it was bound
bound_device_id = metadata.get("bound_device_id")
if bound_device_id is not None:
new_pg.bound_device_id = bound_device_id
# Set group metadata
new_pg._set_group_name(group_name)
new_pg._set_group_desc(group_desc)
# Persist backend configuration overrides (if provided via shrink_group)
backend_config_override = metadata.get("backend_config")
if backend_config_override is not None:
# Store for introspection/debugging and potential backend hooks
_world.pg_backend_config[new_pg] = backend_config_override
return new_pg
def _destroy_all_other_groups(exclude_group: Optional[ProcessGroup] = None) -> None:
"""
Destroy all process groups except the excluded group and clean up all global state.
This is necessary when shrinking the default group because global ranks
are reassigned by NCCL, making all existing process groups inconsistent.
Note: Uses abort for non-collective cleanup since excluded ranks may not
participate in collective operations. Backend cleanup is handled independently per group.
Args:
exclude_group (ProcessGroup, optional): Process group to exclude from destruction.
If None, destroys all process groups.
"""
# Get list of groups to destroy (avoid modifying dict while iterating)
groups_to_destroy = []
for pg in list(_world.pg_group_ranks.keys()):
if exclude_group is not None and pg == exclude_group:
continue
groups_to_destroy.append(pg)
# Warn user about automatic destruction
if groups_to_destroy:
group_names = [_get_process_group_name(pg) for pg in groups_to_destroy]
logger.warning(
"Shrinking default group will destroy %d other process groups: %s. "
"This is necessary because shrinking the default group reassigns global ranks, "
"making existing groups inconsistent.",
len(groups_to_destroy),
", ".join(group_names),
)
# Destroy each group and clean up global state
for pg in groups_to_destroy:
try:
# First call abort_process_group which handles the C++ cleanup non-collectively
_abort_process_group(pg)
except Exception as e:
# Log but don't fail - some groups might already be destroyed
logger.warning(
"Failed to abort process group %s: %s",
_get_process_group_name(pg),
e,
)
# Ensure all global state is cleaned up even if _abort_process_group fails
# or doesn't clean up everything
_cleanup_process_group_global_state(pg)
def _cleanup_process_group_global_state(pg: ProcessGroup) -> None:
"""
Clean up all global state associated with a process group.
This function ensures complete cleanup of process group state from all
global dictionaries and registries, even if destroy_process_group fails
or doesn't clean up everything. This is critical when destroying multiple
groups to prevent inconsistent state.
The cleanup removes the process group from:
- _world.pg_map (backend and store mapping)
- _world.pg_names (group name mapping)
- _world.pg_group_ranks (rank mappings)
- _world.pg_backend_config (backend configuration)
- _world.tags_to_pg and _world.pg_to_tag (tag mappings)
- _world.pg_coalesce_state (coalescing state)
- C++ internal registries via _unregister_process_group
Args:
pg (ProcessGroup): The process group to clean up.
"""
try:
# Clean up main process group mappings
_world.pg_map.pop(pg, None)
_world.pg_group_ranks.pop(pg, None)
_world.pg_backend_config.pop(pg, None)
# Clean up process group name mapping
group_name = _world.pg_names.pop(pg, None)
# Clean up tag mappings
pg_tag = _world.pg_to_tag.pop(pg, None)
if pg_tag is not None and pg_tag in _world.tags_to_pg:
try:
_world.tags_to_pg[pg_tag].remove(pg)
# Remove the tag entry if list is empty
if not _world.tags_to_pg[pg_tag]:
_world.tags_to_pg.pop(pg_tag, None)
except (ValueError, KeyError):
# Process group was already removed from the list
pass
# Clean up any registered process group names using C++ unregister function
if group_name is not None:
try:
_unregister_process_group(group_name)
except Exception:
# Process group name might not be registered or already unregistered
pass
# Clean up coalesce state if present
_world.pg_coalesce_state.pop(pg, None)
except Exception as e:
# Log cleanup failures but don't propagate - we want to continue with other cleanups
logger.warning("Failed to fully clean up global state for process group: %s", e)
def _update_process_group_global_state(
pg: ProcessGroup,
backend_name: str,
store: Store,
group_name: str,
backend_config: str,
rank_mapping: Optional[dict[int, int]] = None,
pg_tag: Optional[str] = None,
user_tag: Optional[str] = None,
) -> None:
"""
Update all global state dictionaries for a process group.
This helper function consolidates the common pattern of updating multiple
global state dictionaries when creating or modifying process groups.
Args:
pg (ProcessGroup): The process group to update state for.
backend_name (str): Backend name for pg_map.
store (Store): Store instance for pg_map.
group_name (str): Group name for pg_names and registration.
backend_config (str): Backend configuration string.
rank_mapping (Dict[int, int], optional): Global rank to group rank mapping.
If None, skips updating pg_group_ranks.
pg_tag (str, optional): Process group tag. If None, defaults to f"ptd:{group_name}".
user_tag (str, optional): User-provided tag for special tag handling.
If provided, creates "user:{user_tag}" tag and also adds to default "".
"""
# Update main process group mappings
_world.pg_map[pg] = (backend_name, store)
_world.pg_names[pg] = group_name
_world.pg_backend_config[pg] = backend_config
# Register the process group name
_register_process_group(group_name, pg)
# Update rank mapping if provided
if rank_mapping is not None:
_world.pg_group_ranks[pg] = rank_mapping
# Handle tag management
if pg_tag is None:
pg_tag = f"ptd:{group_name}"
if user_tag is not None:
# Special handling for user-provided tags
# Add to default "" tag first
_world.tags_to_pg.setdefault("", []).append(pg)
# Then create user-specific tag
user_pg_tag = f"user:{user_tag}"
_world.tags_to_pg.setdefault(user_pg_tag, []).append(pg)
_world.pg_to_tag[pg] = user_pg_tag
else:
# Standard process group tag
_world.tags_to_pg.setdefault(pg_tag, []).append(pg)
_world.pg_to_tag[pg] = pg_tag

View File

@ -238,47 +238,6 @@ def skip_if_lt_x_gpu(x):
return decorator return decorator
def requires_world_size(n: int):
"""
Decorator to request a specific world size for a test. The test harness can
read this attribute to set the number of ranks to spawn. If there are fewer
than `n` CUDA devices available, the test should be skipped by the harness.
Usage:
@require_world_size(3)
def test_something(self):
...
"""
def decorator(func):
func._required_world_size = n
available = torch.cuda.device_count()
return unittest.skipUnless(
available >= n, f"requires {n} GPUs, found {available}"
)(func)
return decorator
def get_required_world_size(obj: Any, default: int) -> int:
"""
Returns the requested world size for the currently running unittest method on `obj`
if annotated via `@require_world_size(n)`, else returns `default`.
"""
try:
# Try MultiProcessTestCase helper first, then unittest fallback
test_name = (
obj._current_test_name() # type: ignore[attr-defined]
if hasattr(obj, "_current_test_name") and callable(obj._current_test_name)
else obj._testMethodName
)
fn = getattr(obj, test_name)
value = fn._required_world_size
return int(value)
except Exception:
return default
# This decorator helps avoiding initializing cuda while testing other backends # This decorator helps avoiding initializing cuda while testing other backends
def nccl_skip_if_lt_x_gpu(backend, x): def nccl_skip_if_lt_x_gpu(backend, x):
def decorator(func): def decorator(func):
@ -408,13 +367,6 @@ def requires_nccl_version(version, msg):
) )
def requires_nccl_shrink():
"""
Require NCCL shrink support (NCCL available and version >= 2.27).
"""
return requires_nccl_version((2, 27), "Need NCCL 2.27+ for shrink_group")
def requires_nccl(): def requires_nccl():
return skip_but_pass_in_sandcastle_if( return skip_but_pass_in_sandcastle_if(
not c10d.is_nccl_available(), not c10d.is_nccl_available(),