diff --git a/docs/source/distributed.md b/docs/source/distributed.md index 69df7be1fa80..5da02bb8a194 100644 --- a/docs/source/distributed.md +++ b/docs/source/distributed.md @@ -394,10 +394,6 @@ an opaque group handle that can be given as a `group` argument to all collective .. autofunction:: new_group ``` -```{eval-rst} -.. autofunction:: torch.distributed.distributed_c10d.shrink_group -``` - ```{eval-rst} .. autofunction:: get_group_rank ``` diff --git a/test/distributed/logging_utils.py b/test/distributed/logging_utils.py deleted file mode 100644 index 09a0adccfd80..000000000000 --- a/test/distributed/logging_utils.py +++ /dev/null @@ -1,43 +0,0 @@ -import logging -import time - - -_start_time = time.time() -_logger = logging.getLogger(__name__) - - -def _ts(): - return time.time() - _start_time - - -def configure(level=logging.INFO, force=False): - try: - logging.basicConfig( - level=level, - format="%(asctime)s %(name)s %(levelname)s: %(message)s", - force=force, - ) - except TypeError: - logging.basicConfig( - level=level, format="%(asctime)s %(name)s %(levelname)s: %(message)s" - ) - - -def log_test_info(rank, message): - _logger.info("[%7.3fs][Rank %s] %s", _ts(), rank, message) - - -def log_test_success(rank, message): - _logger.info("[%7.3fs][Rank %s] ✅ %s", _ts(), rank, message) - - -def log_test_validation(rank, message): - _logger.info("[%7.3fs][Rank %s] ✓ %s", _ts(), rank, message) - - -def log_test_warning(rank, message): - _logger.warning("[%7.3fs][Rank %s] ⚠️ %s", _ts(), rank, message) - - -def log_test_error(rank, message): - _logger.error("[%7.3fs][Rank %s] ✗ %s", _ts(), rank, message) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 149622e9445c..7410255d27a8 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -2,7 +2,6 @@ import copy import json -import logging import os import pickle import random @@ -22,7 +21,6 @@ from unittest import mock, SkipTest import torch import torch.distributed as c10d import torch.distributed._functional_collectives as _functional_collectives -from torch.distributed.distributed_c10d import SHRINK_ABORT as NCCL_SHRINK_ABORT if not c10d.is_available() or not c10d.is_nccl_available(): @@ -49,15 +47,12 @@ from torch._C._distributed_c10d import ErrorType, OpType, WorkResult from torch.nn.parallel import DistributedDataParallel from torch.testing._internal.common_cuda import _get_torch_rocm_version, TEST_MULTIGPU from torch.testing._internal.common_distributed import ( - get_required_world_size, get_timeout, init_multigpu_helper, MultiProcessTestCase, requires_multicast_support, requires_nccl, - requires_nccl_shrink, requires_nccl_version, - requires_world_size, skip_if_lt_x_gpu, skip_if_rocm_multiprocess, sm_is_or_higher_than, @@ -92,17 +87,6 @@ BFLOAT16_AVAILABLE = torch.cuda.is_available() and ( torch.version.cuda is not None or torch.version.hip is not None ) -from logging_utils import ( - configure as _log_configure, - log_test_info, - log_test_success, - log_test_validation, - log_test_warning, -) - - -_log_configure(level=logging.INFO, force=True) - class RendezvousEnvTest(TestCase): @retry_on_connect_failures @@ -333,7 +317,7 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase): @property def world_size(self): - return get_required_world_size(self, 2) + return 2 @property def rank_to_GPU(self): @@ -1271,628 +1255,6 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase): pg_2 = c10d.new_group([0, 1]) self.assertEqual(pg_2.group_desc, "undefined") - @requires_nccl_shrink() - @requires_world_size(2) - def test_shrink_group_basic(self): - """Test basic shrink_group functionality.""" - self._perform_shrink_test([1], "Basic shrink test") - - @requires_nccl_shrink() - @requires_world_size(2) - def test_shrink_group_validation(self): - """Test input validation in shrink_group.""" - device, pg = self._setup_shrink_test("validation") - - def _test_invalid_input(ranks, description, expected_exception): - """Helper to test invalid inputs.""" - try: - c10d.shrink_group(ranks) - self.fail(f"Expected {expected_exception.__name__} for {description}") - except expected_exception: - log_test_validation(self.rank, f"✓ {description}") - except Exception: - if expected_exception is Exception: # Accept any exception - log_test_validation(self.rank, f"✓ {description}") - else: - raise - - # Test cases - _test_invalid_input([], "Empty exclusion list", ValueError) - if self.world_size > 1: - _test_invalid_input([0, 0, 1], "Duplicate ranks", Exception) - _test_invalid_input([self.world_size + 1], "Out of bounds rank", Exception) - - log_test_success(self.rank, "All validation tests passed") - dist.destroy_process_group() - - @requires_nccl_shrink() - @requires_world_size(2) - def test_shrink_group_backend_properties(self): - """Test that backend properties are preserved after shrinking.""" - - test_name = "Backend Properties Test" - ranks_to_exclude = [0] - - # Reuse _setup_shrink_test for complete setup (device, environment, and process group) - device, pg = self._setup_shrink_test("backend_properties") - - # Follow _perform_shrink_test pattern from here - log_test_info(self.rank, f"{test_name} (world_size={self.world_size})") - - is_excluded = self.rank in ranks_to_exclude - log_test_info( - self.rank, - f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}", - ) - - # Store original backend property values (not references) before shrinking - original_timeout = None - original_high_priority = None - if not is_excluded: - original_backend = pg._get_backend(device) - original_timeout = original_backend.options._timeout - original_high_priority = original_backend.options.is_high_priority_stream - log_test_info( - self.rank, - f"Storing original backend properties: timeout={original_timeout}, high_priority={original_high_priority}", - ) - - if is_excluded: - log_test_info( - self.rank, - f"Excluded rank {self.rank} - setup complete, skipping shrink operation", - ) - dist.destroy_process_group() # hang without it - return - - # Only non-excluded ranks proceed with shrink (same as _perform_shrink_test) - log_test_info(self.rank, "Non-excluded rank calling shrink_group") - shrunk_pg = c10d.shrink_group(ranks_to_exclude) - - # Reuse _validate_shrunk_group helper (same as _perform_shrink_test) - expected_size = self.world_size - len(ranks_to_exclude) - _ = self._validate_shrunk_group(shrunk_pg, expected_size, test_name) - - # Add custom backend properties validation - new_backend = shrunk_pg._get_backend(device) - log_test_info(self.rank, "Validating backend properties are preserved") - - new_timeout = new_backend.options._timeout - new_high_priority = new_backend.options.is_high_priority_stream - - log_test_info( - self.rank, - f"Timeout comparison - original: {original_timeout}, new: {new_timeout}", - ) - self.assertEqual( - original_timeout, new_timeout, f"{test_name}: timeout not preserved" - ) - - log_test_info( - self.rank, - f"High priority stream comparison - original: {original_high_priority}, new: {new_high_priority}", - ) - self.assertEqual( - original_high_priority, - new_high_priority, - f"{test_name}: high_priority_stream not preserved", - ) - - log_test_validation( - self.rank, f"{test_name}: Backend properties preserved successfully" - ) - log_test_success( - self.rank, f"{test_name} successful (shrink + backend validation)" - ) - - # Cleanup (same as _perform_shrink_test) - dist.destroy_process_group() - - @requires_nccl_shrink() - @requires_world_size(2) - def test_shrink_group_multiple_comms(self): - """Test shrink_group with multiple communicators and subgroup invalidation.""" - - device, pg = self._setup_shrink_test("multiple_comms") - - # Create subgroup [0, 1] and test shrinking it - subgroup = c10d.new_group([0, 1]) - if self.rank <= 1: - # Shrink subgroup: exclude rank 1 - if self.rank == 0: # Only rank 0 remains - shrunk_subgroup = c10d.shrink_group([1], group=subgroup) - self.assertEqual(shrunk_subgroup.size(), 1) - # Test communication on shrunk subgroup - tensor = torch.full((1,), self.rank).cuda(device) - c10d.all_reduce(tensor, group=shrunk_subgroup) - self.assertEqual(tensor.item(), 0) # Only rank 0 - log_test_success(self.rank, "Subgroup shrinking successful") - - dist.barrier() # Sync before default group test - - # Shrink default group: exclude last rank - ranks_to_exclude = [self.world_size - 1] - if self.rank not in ranks_to_exclude: - shrunk_default = c10d.shrink_group(ranks_to_exclude) - expected_size = self.world_size - 1 - self.assertEqual(shrunk_default.size(), expected_size) - - # Test collective on shrunk default group - tensor = torch.full((1,), self.rank).cuda(device) - c10d.all_reduce(tensor, group=shrunk_default) - expected_sum = sum( - range(self.world_size - 1) - ) # 0 + 1 + ... + (world_size-2) - self.assertEqual(tensor.item(), expected_sum) - log_test_success(self.rank, "Default group shrinking successful") - - # Note: After shrinking default group, the old subgroup is invalid - # due to global rank reassignment - - dist.destroy_process_group() - - def _test_shrink_group_with_flag(self, shrink_flag, flag_name, rank_to_exclude): - """Helper method to test shrink_group with a specific flag.""" - if self.world_size < 2: - log_test_info(self.rank, f"Skipping (needs ≥2 GPUs, got {self.world_size})") - return - ranks_to_exclude = [rank_to_exclude] - log_test_info(self.rank, f"Using {flag_name} flag (value: {shrink_flag})") - if flag_name == "NCCL_SHRINK_ABORT": - log_test_info( - self.rank, - "ABORT flag will terminate ongoing operations before shrinking", - ) - - self._perform_shrink_test( - ranks_to_exclude, f"{flag_name} flag test", shrink_flags=shrink_flag - ) - - @requires_nccl_shrink() - @requires_world_size(2) - def test_shrink_group_flags(self): - """Test shrink_group with different shrink flags.""" - # Test ABORT flags - log_test_info(self.rank, "Testing NCCL_SHRINK_ABORT flag") - self._test_shrink_group_with_flag(NCCL_SHRINK_ABORT, "NCCL_SHRINK_ABORT", 1) - - @requires_nccl_shrink() - @requires_world_size(2) - def test_shrink_group_nccl_config(self): - """Verify that passing NCCL config via pg_options influences the shrunk group's backend options.""" - device, pg = self._setup_shrink_test("config") - if self.rank == self.world_size - 1: - # excluded rank should not call shrink_group - dist.destroy_process_group() - return - - # Prepare pg_options with NCCL config overrides - # Capture parent's current backend options to ensure we can prove override vs inherit - parent_backend = pg._get_backend(torch.device("cuda")) - parent_hp = parent_backend.options.is_high_priority_stream - parent_blocking = parent_backend.options.config.blocking - - # Choose overrides that differ from the parent (flip where possible) - override_hp = not parent_hp - if parent_blocking in (0, 1): - override_blocking = 1 - parent_blocking - else: - # If undefined or unexpected, set to 1 which is a concrete value - override_blocking = 1 - - opts = c10d.ProcessGroupNCCL.Options() - opts.is_high_priority_stream = override_hp - opts.config.blocking = override_blocking - - shrunk_pg = c10d.shrink_group([self.world_size - 1], pg_options=opts) - - # Validate backend options propagated - backend = shrunk_pg._get_backend(torch.device("cuda")) - # is_high_priority_stream should exactly match our override and differ from parent - self.assertEqual(backend.options.is_high_priority_stream, override_hp) - self.assertNotEqual(backend.options.is_high_priority_stream, parent_hp) - # config is a struct; check representative field and difference from parent when meaningful - self.assertEqual(backend.options.config.blocking, override_blocking) - if parent_blocking in (0, 1): - self.assertNotEqual(backend.options.config.blocking, parent_blocking) - - dist.destroy_process_group() - - @requires_nccl_shrink() - @requires_world_size(2) - def test_shrink_group_performance(self): - """Test shrink_group performance and regression detection.""" - import time - - ranks_to_exclude = self._get_default_ranks_to_exclude() - is_excluded = self.rank in ranks_to_exclude - - if not ranks_to_exclude: - log_test_info(self.rank, "Skipping performance test (world_size=1)") - return - - log_test_info(self.rank, f"Performance test with {self.world_size} processes") - device, pg = self._setup_shrink_test("performance") - - if not is_excluded: - log_test_info(self.rank, "Measuring shrink_group performance") - start_time = time.time() - shrunk_pg = c10d.shrink_group(ranks_to_exclude) - end_time = time.time() - - elapsed_time = end_time - start_time - log_test_info(self.rank, f"shrink_group: {elapsed_time:.3f}s") - - # Regression check: should complete within reasonable time - self.assertLess( - elapsed_time, - 30.0, - f"shrink_group took {elapsed_time:.3f}s, possible regression", - ) - - # Test collective performance - expected_size = self.world_size - len(ranks_to_exclude) - self._validate_shrunk_group(shrunk_pg, expected_size, "performance") - - collective_start = time.time() - _ = self._test_collective_on_shrunk_group( - shrunk_pg, device, ranks_to_exclude, "performance" - ) - collective_time = time.time() - collective_start - - log_test_info(self.rank, f"all_reduce: {collective_time:.3f}s") - log_test_success(self.rank, "Performance test passed") - else: - log_test_info(self.rank, "Excluded rank - waiting") - - dist.destroy_process_group() - - @requires_nccl_shrink() - @requires_world_size(4) - def test_shrink_group_multiple_exclusions(self): - """Test shrink_group with multiple ranks excluded at once.""" - # Scale exclusions with world size - ranks_to_exclude = list(range(2, self.world_size, 2)) # Every other rank from 2 - - self._perform_shrink_test(ranks_to_exclude, "Multiple exclusions test") - - @requires_nccl_shrink() - @requires_world_size(3) - def test_shrink_group_multiple_iterations(self): - """Test multiple shrink operations in sequence.""" - log_test_info( - self.rank, - f"Starting test_shrink_group_multiple_iterations with world_size={self.world_size}", - ) - - store = c10d.FileStore(self.file_name, self.world_size) - device = torch.device(f"cuda:{self.rank}") - _ = self._create_process_group_nccl(store, self.opts(), device_id=device) - - # Track current effective world size throughout shrinking operations - current_world_size = self.world_size - log_test_info(self.rank, f"Initial world_size: {current_world_size}") - - # First shrinking: exclude the last rank(s) - first_exclusion = [self.world_size - 1] - if self.world_size >= 6: - first_exclusion.append( - self.world_size - 2 - ) # Exclude last two ranks for larger sizes - - log_test_info(self.rank, f"First shrinking: excluding ranks {first_exclusion}") - - if self.rank not in first_exclusion: - # Only non-excluded ranks should call shrink_group - first_pg = c10d.shrink_group(first_exclusion) - self.assertIsNotNone(first_pg) - # IMPORTANT: Update world size after first shrinking - current_world_size = first_pg.size() - expected_first_size = self.world_size - len(first_exclusion) - log_test_info( - self.rank, - f"After first shrinking: world_size {self.world_size} -> {current_world_size}", - ) - self.assertEqual(first_pg.size(), expected_first_size) - - # Second shrinking: exclude another rank from the remaining group - # Choose a rank that's in the middle range - if current_world_size >= 3: - second_exclusion = [ - current_world_size - 1 - ] # Exclude the new "last" rank - log_test_info( - self.rank, - f"Second shrinking from group of size {current_world_size}: excluding ranks {second_exclusion}", - ) - - if self.rank not in second_exclusion: - # Only non-excluded ranks should call shrink_group for second iteration - second_pg = c10d.shrink_group(second_exclusion, group=first_pg) - self.assertIsNotNone(second_pg) - # IMPORTANT: Update world size after second shrinking - final_world_size = second_pg.size() - expected_final_size = current_world_size - len(second_exclusion) - log_test_info( - self.rank, - f"After second shrinking: world_size {current_world_size} -> {final_world_size}", - ) - self.assertEqual(second_pg.size(), expected_final_size) - - # Test collective on final group - tensor = torch.full((1,), self.rank).cuda(device) - log_test_info( - self.rank, - f"Performing all_reduce on final group (size {final_world_size}) with tensor: {tensor.item()}", - ) - c10d.all_reduce(tensor, group=second_pg) - log_test_info( - self.rank, - f"Final all_reduce completed, result: {tensor.item()}", - ) - - # Calculate expected sum of remaining ranks - all_excluded = set(first_exclusion + second_exclusion) - remaining_ranks = [ - r for r in range(self.world_size) if r not in all_excluded - ] - expected_sum = sum(remaining_ranks) - log_test_info( - self.rank, - f"Remaining ranks: {remaining_ranks}, expected sum: {expected_sum}, actual: {tensor.item()}", - ) - self.assertEqual(tensor.item(), expected_sum) - log_test_info(self.rank, "Final verification passed") - else: - log_test_info( - self.rank, - "This rank excluded in second shrinking, not calling shrink_group", - ) - else: - log_test_info( - self.rank, "Skipping second shrinking (remaining group too small)" - ) - else: - log_test_info( - self.rank, - "This rank excluded in first shrinking, not calling shrink_group", - ) - - log_test_info(self.rank, "Destroying process group") - dist.destroy_process_group() - log_test_info(self.rank, "test_shrink_group_multiple_iterations completed") - - # Helper methods for optimized shrink group tests - def _setup_shrink_test(self, test_suffix, world_size=None, warmup=True): - """Common setup for shrink group tests.""" - os.environ["TORCH_NCCL_USE_COMM_NONBLOCKING"] = "1" - world_size = world_size or self.world_size - store = c10d.FileStore(self.file_name + f"_{test_suffix}", world_size) - device = torch.device(f"cuda:{self.rank}") - c10d.init_process_group( - "nccl", - world_size=world_size, - rank=self.rank, - store=store, - pg_options=self.opts(), - device_id=device, - ) - pg = c10d.distributed_c10d._get_default_group() - - if warmup: - c10d.all_reduce(torch.ones(1).cuda(device), group=pg) - - return device, pg - - def _validate_shrunk_group(self, shrunk_pg, expected_size, test_name=""): - """Validate properties of a shrunk process group.""" - self.assertIsNotNone(shrunk_pg, f"{test_name}: shrunk_pg should not be None") - actual_size = shrunk_pg.size() - self.assertEqual( - actual_size, expected_size, f"{test_name}: group size mismatch" - ) - - new_rank = shrunk_pg.rank() - self.assertTrue( - 0 <= new_rank < expected_size, f"{test_name}: invalid new rank {new_rank}" - ) - - log_test_info( - self.rank, - f"{test_name}: world_size {self.world_size} -> {actual_size}, rank {self.rank} -> {new_rank}", - ) - return new_rank - - def _test_collective_on_shrunk_group( - self, shrunk_pg, device, ranks_to_exclude, test_name="" - ): - """Test collective communication on shrunk group and verify correctness.""" - test_tensor = torch.full((1,), self.rank, device=device, dtype=torch.float32) - c10d.all_reduce(test_tensor, group=shrunk_pg) - - result = test_tensor.item() - expected_sum = sum( - r for r in range(self.world_size) if r not in ranks_to_exclude - ) - - self.assertEqual( - result, expected_sum, f"{test_name}: collective result mismatch" - ) - log_test_info( - self.rank, f"{test_name}: collective passed ({result} == {expected_sum})" - ) - return result - - def _perform_shrink_test( - self, ranks_to_exclude, test_name, shrink_flags=0, with_collective=True - ): - """Complete shrink test flow: setup, shrink, validate, test collective, cleanup. - - Consistent API: All ranks perform setup to initialize distributed environment. - ONLY non-excluded ranks call shrink_group() for both default and non-default groups. - Excluded ranks perform setup, then exit without calling shrink_group() or waiting. - """ - log_test_info(self.rank, f"{test_name} (world_size={self.world_size})") - - is_excluded = self.rank in ranks_to_exclude - log_test_info( - self.rank, - f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}", - ) - - # All ranks (including excluded ones) perform setup to initialize distributed environment - device, pg = self._setup_shrink_test(test_name.lower().replace(" ", "_")) - is_default_group = pg == c10d.distributed_c10d._get_default_group() - - if is_excluded: - log_test_info( - self.rank, - f"Excluded rank {self.rank} - setup complete, skipping shrink operation", - ) - if shrink_flags & NCCL_SHRINK_ABORT: - log_test_info(self.rank, f"Using abort for excluded rank {self.rank}") - pg._get_backend(torch.device(device)).abort() - log_test_info( - self.rank, f"cleanup resources for excluded rank {self.rank}" - ) - dist.destroy_process_group() - log_test_info(self.rank, f"Excluded rank {self.rank} - exit") - else: - log_test_info( - self.rank, f"Using regular destroy for excluded rank {self.rank}" - ) - dist.destroy_process_group() - return None - - # Only non-excluded ranks proceed with shrink - log_test_info( - self.rank, - f"Non-excluded rank calling shrink_group (default_group={is_default_group})", - ) - shrunk_pg = c10d.shrink_group(ranks_to_exclude, shrink_flags=shrink_flags) - log_test_info( - self.rank, - f"Non-excluded rank calling shrink_group (default_group={is_default_group}) done", - ) - - # Non-excluded ranks: validate and test the new group - expected_size = self.world_size - len(ranks_to_exclude) - _ = self._validate_shrunk_group(shrunk_pg, expected_size, test_name) - - if with_collective: - _ = self._test_collective_on_shrunk_group( - shrunk_pg, device, ranks_to_exclude, test_name - ) - log_test_success(self.rank, f"{test_name} successful (shrink + collective)") - else: - log_test_success(self.rank, f"{test_name} successful (shrink only)") - - dist.destroy_process_group() - return shrunk_pg - - def _get_default_ranks_to_exclude(self): - """Get default ranks to exclude based on world size.""" - if self.world_size <= 1: - return [] - return [self.world_size - 1] # Exclude last rank by default - - @requires_nccl_shrink() - @requires_world_size(3) - def test_shrink_group_vs_abort_reinit_performance(self): - """Compare performance of shrink_group vs traditional abort+reinit (simplified for reliability).""" - log_test_info(self.rank, "=== TEST 1: abort+reinit ===") - - device, pg1 = self._setup_shrink_test("_perf_reinit") - torch.cuda.synchronize(device) - - # Test 1: Traditional abort + reinit - start_time = time.perf_counter() - dist.destroy_process_group() - - device, new_pg = self._setup_shrink_test("perf_shrink_test1") - reinit_time = time.perf_counter() - start_time - - # Test collective with original rank values for fair comparison (non-blocking mode) - test_tensor = torch.full((1,), self.rank, device=device, dtype=torch.float32) - work = c10d.all_reduce(test_tensor, group=new_pg, async_op=True) - work.wait() - - torch.cuda.synchronize(device) - - # Verify correctness - expected_sum = sum(r for r in range(self.world_size)) - self.assertEqual(test_tensor.item(), expected_sum, "Reinit collective failed") - - log_test_info(self.rank, f"abort+reinit: {reinit_time:.4f}s") - dist.destroy_process_group(new_pg) - - # Test 2: shrink_group with NCCL_SHRINK_ABORT - log_test_info(self.rank, "=== TEST 2: shrink_group ===") - - ranks_to_exclude = [self.world_size - 1] - is_excluded = self.rank in ranks_to_exclude - log_test_info( - self.rank, - f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}", - ) - - device, pg1 = self._setup_shrink_test("perf_shrink_test2") # Unique suffix - - shrink_time = 0 - if not is_excluded: - torch.cuda.synchronize(device) # Ensure accurate timing - start_time = time.perf_counter() - shrunk_pg = c10d.shrink_group( - ranks_to_exclude, shrink_flags=NCCL_SHRINK_ABORT - ) - c10d.all_reduce(torch.ones(1).cuda(device), group=shrunk_pg) - shrink_time = time.perf_counter() - start_time - - # Test collective communication on shrunk group (non-blocking mode) - test_tensor = torch.full( - (1,), self.rank, device=device, dtype=torch.float32 - ) - work = c10d.all_reduce(test_tensor, group=shrunk_pg, async_op=True) - work.wait() - - # Verify correctness - expected_sum = sum( - r for r in range(self.world_size) if r not in ranks_to_exclude - ) - self.assertEqual( - test_tensor.item(), - expected_sum, - "shrink_test: collective result mismatch", - ) - - torch.cuda.synchronize(device) # Ensure operations complete - log_test_info(self.rank, f"shrink_group: {shrink_time:.4f}s") - dist.destroy_process_group() - else: - log_test_info(self.rank, "Excluded from shrink test - exiting immediately") - dist.destroy_process_group() - return - - # Performance analysis (only for participating ranks) - if shrink_time > 0 and reinit_time > 0: - speedup = reinit_time / shrink_time - time_saved = reinit_time - shrink_time - - log_test_info(self.rank, "=== PERFORMANCE RESULTS ===") - log_test_info(self.rank, f"shrink_group: {shrink_time:.4f}s") - log_test_info(self.rank, f"abort+reinit: {reinit_time:.4f}s") - log_test_info(self.rank, f"time_saved: {time_saved:+.4f}s") - log_test_info(self.rank, f"speedup: {speedup:.2f}x") - - if speedup > 1.1: - log_test_success(self.rank, "shrink_group significantly faster") - elif speedup > 0.9: - log_test_info(self.rank, "≈ comparable performance") - else: - log_test_warning(self.rank, "abort+reinit faster") - - log_test_info(self.rank, "Performance test completed") - @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_deterministic_mode_no_break(self): diff --git a/torch/csrc/distributed/c10d/Backend.hpp b/torch/csrc/distributed/c10d/Backend.hpp index 1ebf9394e064..655e0a5578c2 100644 --- a/torch/csrc/distributed/c10d/Backend.hpp +++ b/torch/csrc/distributed/c10d/Backend.hpp @@ -79,23 +79,6 @@ class TORCH_API Backend : public torch::CustomClassHolder { return false; } - virtual bool supportsShrinking() const { - return false; - } - - // Shrink the backend by excluding specified ranks. Backends that support - // communicator shrinking should override this and return a new backend - // instance representing the shrunken group. Backends may use opts_override - // to supply backend-specific options for the new group. - virtual c10::intrusive_ptr shrink( - const std::vector& /*ranks_to_exclude*/, - int /*shrink_flags*/ = 0, - const c10::intrusive_ptr& /*opts_override*/ = nullptr) { - TORCH_CHECK( - false, - c10::str("Backend ", getBackendName(), " does not support shrink")); - } - virtual void setTimeout(std::chrono::milliseconds timeout) { TORCH_CHECK( false, diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index a41f654b9ae2..8074cc98a04f 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -259,65 +259,6 @@ std::shared_ptr NCCLComm::split( } #endif -#ifdef NCCL_HAS_COMM_SHRINK -std::shared_ptr NCCLComm::shrink( - NCCLComm* source, - std::vector& ranks_to_exclude, - ncclConfig_t* config, - int shrinkFlags) { - // Preconditions are validated in ProcessGroupNCCL::shrink - - LOG(INFO) << "Rank " << source->rank_ << ": shrinking comm " << source->repr() - << " excluding " << ranks_to_exclude.size() << " ranks"; - - at::cuda::OptionalCUDAGuard gpuGuard(source->deviceIndex_); - auto comm = std::make_shared(); - - // This call will block until the source communicator is initialized - auto sourceComm = source->getNcclComm(); - - C10D_NCCL_CHECK_NONBLOCKING( - ncclCommShrink( - sourceComm, - ranks_to_exclude.data(), - ranks_to_exclude.size(), - reinterpret_cast(&(comm->ncclComm_)), - config, - shrinkFlags), - source->getNcclCommFailureReason()); - - // Wait for the child communicator to be ready - source->waitReady(true); - comm->initialized_ = true; - - // NCCL automatically assigns rank during shrink - query it efficiently - int assigned_rank; - try { - C10D_NCCL_CHECK( - ncclCommUserRank(comm->ncclComm_, &assigned_rank), std::nullopt); - comm->rank_ = assigned_rank; - } catch (const std::exception& e) { - // Fallback: if ncclCommUserRank fails, we can't determine the rank - LOG(ERROR) << "Failed to query NCCL-assigned rank: " << e.what(); - throw; - } - - // Child comm should be on the same device as parent comm - comm->deviceIndex_ = source->deviceIndex_; - if (config != nullptr) { - comm->nonBlocking_ = config->blocking == 0; - } else { - // Inherit parent behavior if no config provided - comm->nonBlocking_ = source->nonBlocking_; - } - - LOG(INFO) << "Rank " << source->rank_ << ": created shrunken comm " - << comm->repr() << " with NCCL-assigned rank " << assigned_rank; - - return comm; -} -#endif - void NCCLComm::finalize() { LockType lock(mutex_); if (aborted_) { diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index 142633b82374..fdd50f69ef3d 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -90,10 +90,6 @@ static_assert( #define NCCL_HAS_NVLS_CTAS #endif -#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 27, 0) -#define NCCL_HAS_COMM_SHRINK -#endif - // Macro to throw on a non-successful NCCL return value. #define C10D_NCCL_CHECK(cmd, failureReason) \ do { \ @@ -298,14 +294,6 @@ class NCCLComm { ncclConfig_t& config); #endif // NCCL_HAS_COMM_SPLIT -#ifdef NCCL_HAS_COMM_SHRINK - static std::shared_ptr shrink( - NCCLComm* source, - std::vector& ranks_to_exclude, - ncclConfig_t* config, - int shrinkFlags = 0); -#endif // NCCL_HAS_COMM_SHRINK - #if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP) std::unordered_map ncclCommDump(); #endif diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 1a63128f8ddf..9b615b9f16b0 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -165,7 +165,7 @@ ncclRedOpRAII getNcclReduceOp( } // Get a key string from device -inline std::string getKeyFromDevice(const at::Device& device) { +inline std::string getKeyFromDevice(at::Device& device) { return std::to_string(device.index()); } @@ -5838,139 +5838,6 @@ at::Tensor ProcessGroupNCCL::allocateTensor( return tensor; } -#ifdef NCCL_HAS_COMM_SHRINK -c10::intrusive_ptr ProcessGroupNCCL::shrink( - const std::vector& ranks_to_exclude, - int shrink_flags, - const c10::intrusive_ptr& opts_override) { - // Runtime version check with better error message - auto runtime_version = torch::cuda::nccl::version(); - TORCH_CHECK( - runtime_version >= NCCL_VERSION(2, 27, 0), - "ProcessGroupNCCL::shrink requires NCCL version 2.27.0 or later. " - "Found version: ", - runtime_version); - - // Early validation with detailed error messages - TORCH_CHECK_VALUE( - !ranks_to_exclude.empty(), "ranks_to_exclude cannot be empty"); - TORCH_CHECK_VALUE( - static_cast(ranks_to_exclude.size()) < size_, - "Cannot exclude all ranks (", - ranks_to_exclude.size(), - " >= ", - size_, - ")"); - - // Validate ranks and convert to int efficiently - std::vector int_ranks_to_exclude; - int_ranks_to_exclude.reserve(ranks_to_exclude.size()); - for (int64_t rank : ranks_to_exclude) { - TORCH_CHECK_VALUE( - rank >= 0 && rank < size_, - "Invalid rank ", - rank, - " for group size ", - size_); - int_ranks_to_exclude.push_back(static_cast(rank)); - } - - // Get primary communicator with better error context - auto primary_device_index = guessDeviceId(); - auto primary_device = at::Device(at::kCUDA, primary_device_index); - const auto primary_key = getKeyFromDevice(primary_device); - - std::shared_ptr primary_comm = getNCCLComm(primary_key); - TORCH_CHECK( - primary_comm, - "Primary NCCL communicator for device ", - primary_device, - " (key: ", - primary_key, - ") is not initialized"); - - // Cache device index before shrink operation - at::DeviceIndex parent_device_index = primary_comm->getDeviceIndex(); - - ncclConfig_t* config = nullptr; - // Default to inheriting from parent options - bool high_priority_stream = options_->is_high_priority_stream; - if (opts_override) { - auto nccl_opts = - c10::static_intrusive_pointer_cast( - opts_override); - config = &nccl_opts->config; - // If user provided override options, honor is_high_priority_stream as well - high_priority_stream = nccl_opts->is_high_priority_stream; - } - - std::shared_ptr shrunk_comm = NCCLComm::shrink( - primary_comm.get(), - int_ranks_to_exclude, - (config != nullptr ? config : &options_->config), - shrink_flags); - - // Calculate new size and get NCCL-assigned rank - int new_size = size_ - static_cast(ranks_to_exclude.size()); - int new_rank = shrunk_comm->rank_; - - // Create new ProcessGroupNCCL with optimized options cloning - auto new_store = store_->clone(); - auto new_opts = ProcessGroupNCCL::Options::create(high_priority_stream); - new_opts->timeout = options_->timeout; - if (config != nullptr) { - new_opts->config = *config; - } else { - new_opts->config = options_->config; - } - - auto new_pg = c10::make_intrusive( - new_store, new_rank, new_size, new_opts); - - // Set up the new process group with optimized device setup - new_pg->initializeDeviceStateForComm( - at::Device(at::kCUDA, parent_device_index), shrunk_comm); - - return c10::static_intrusive_pointer_cast(new_pg); -} - -#else // !NCCL_HAS_COMM_SHRINK -// Backend interface override: raise consistent error when shrink is -// unsupported. -c10::intrusive_ptr ProcessGroupNCCL::shrink( - const std::vector& /*ranks_to_exclude*/, - int /*shrink_flags*/, - const c10::intrusive_ptr& /*opts_override*/) { - TORCH_CHECK( - false, - "ProcessGroupNCCL::shrink requires NCCL version 2.27.0 or later, " - "but PyTorch was built with an older version or without NCCL shrink support."); -} - -#endif // NCCL_HAS_COMM_SHRINK - -void ProcessGroupNCCL::initializeDeviceStateForComm( - const at::Device& device, - std::shared_ptr comm) { - const auto key = getKeyFromDevice(device); - std::unique_lock lock(mutex_); - at::cuda::OptionalCUDAGuard gpuGuard(device); - - bool force_high = getCvarBool(TORCH_NCCL_HIGH_PRIORITY, false); - auto stream = at::cuda::getStreamFromPool( - options_->is_high_priority_stream || force_high); - - devNCCLCommMap_[key] = comm; - ncclStreams_.emplace(key, stream); - ncclEvents_.emplace(key, at::cuda::CUDAEvent(cudaEventDisableTiming)); - usedDeviceIdxs_.insert(device.index()); - - if (shouldAllCommunicatorsRegisterAllTensors()) { - std::lock_guard map_lock(ncclCommMemPoolMapMutex); - ncclCommMemPoolMap.emplace(std::move(comm), MemPoolSet{}); - } -} - } // namespace c10d #endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 2ead1a107394..286eab14d1a8 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -997,21 +997,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { ErrorType getError() override; - bool supportsShrinking() const override { -#ifdef NCCL_HAS_COMM_SHRINK - return true; -#else - return false; -#endif - } - - // Backend-style shrink override that returns a Backend instance. - c10::intrusive_ptr shrink( - const std::vector& ranks_to_exclude, - int shrink_flags = 0, - const c10::intrusive_ptr& opts_override = - nullptr) override; - std::shared_ptr getMemAllocator() override; // Allocate tensor from communication-optimized memory pool @@ -1080,12 +1065,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { int p2pRank = 0, bool isSendRecvSelf = false); - // Initialize device-specific state (comm, stream, event, bookkeeping) for a - // given communicator on this process group instance. - void initializeDeviceStateForComm( - const at::Device& device, - std::shared_ptr comm); - // Wrapper method which can be overridden for tests. virtual std::exception_ptr checkForNCCLErrors( std::shared_ptr& ncclComm); diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index f7d60e0cb62d..bdf2576efbe7 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -2730,23 +2730,12 @@ Arguments: "supports_time_estimate", &::c10d::Backend::supportsTimeEstimation, "(test whether the backend supports collective time estimation)") - .def_property_readonly( - "supports_shrinking", - &::c10d::Backend::supportsShrinking, - "(test whether the backend supports communicator shrinking)") .def( "set_timeout", &::c10d::Backend::setTimeout, py::arg("timeout"), py::call_guard(), R"(Sets the default timeout for all future operations.)") - .def( - "shrink", - &::c10d::Backend::shrink, - py::arg("ranks_to_exclude"), - py::arg("shrink_flags") = 0, - py::arg("opts_override") = nullptr, - py::call_guard()) .def( "broadcast", &::c10d::Backend::broadcast, diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 9156f51b6ad9..c39847176517 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -130,7 +130,6 @@ __all__ = [ "reduce_scatter_tensor", "get_node_local_rank", "split_group", - "shrink_group", ] _MPI_AVAILABLE = True @@ -5714,517 +5713,3 @@ def _get_process_group_name(pg: ProcessGroup) -> str: def _get_process_group_store(pg: ProcessGroup) -> Store: return _world.pg_map[pg][1] - - -# Shrink flags for process group backends -SHRINK_DEFAULT = 0x00 -SHRINK_ABORT = 0x01 - - -@_time_logger -def shrink_group( - ranks_to_exclude: list[int], - group: Optional[ProcessGroup] = None, - shrink_flags: int = SHRINK_DEFAULT, - pg_options: Optional[Any] = None, -) -> ProcessGroup: - """ - Shrinks a process group by excluding specified ranks. - - Creates and returns a new, smaller process group comprising only the ranks - from the original group that were not in the ``ranks_to_exclude`` list. - - Args: - ranks_to_exclude (List[int]): A list of ranks from the original - ``group`` to exclude from the new group. - group (ProcessGroup, optional): The process group to shrink. If ``None``, - the default process group is used. Defaults to ``None``. - shrink_flags (int, optional): Flags to control the shrinking behavior. - Can be ``SHRINK_DEFAULT`` (default) or ``SHRINK_ABORT``. - ``SHRINK_ABORT`` will attempt to terminate ongoing operations - in the parent communicator before shrinking. - Defaults to ``SHRINK_DEFAULT``. - pg_options (ProcessGroupOptions, optional): Backend-specific options to apply - to the shrunken process group. If provided, the backend will use - these options when creating the new group. If omitted, the new group - inherits defaults from the parent. - - Returns: - ProcessGroup: a new group comprised of the remaining ranks. If the - default group was shrunk, the returned group becomes the new default group. - - Raises: - TypeError: if the group’s backend does not support shrinking. - ValueError: if ``ranks_to_exclude`` is invalid (empty, out of bounds, - duplicates, or excludes all ranks). - RuntimeError: if an excluded rank calls this function or the backend - fails the operation. - - Notes: - - Only non-excluded ranks should call this function; excluded ranks - must not participate in the shrink operation. - - Shrinking the default group destroys all other process groups since - rank reassignment makes them inconsistent. - """ - # Step 1: Validate input parameters with comprehensive error checking - _validate_shrink_inputs(ranks_to_exclude, shrink_flags) - - # Step 2: Get target group and essential properties - target_group_info = _prepare_shrink_target_group(group) - - # Step 3: Validate backend requirements and availability - backend_impl = _validate_shrink_backend_requirements(target_group_info) - - # Step 4: Validate ranks against group and check for duplicates - excluded_ranks_set = _validate_and_process_excluded_ranks( - ranks_to_exclude, target_group_info - ) - - # Step 5: Execute the actual shrink operation (backend-specific) - new_backend = backend_impl.shrink( - sorted(excluded_ranks_set), - shrink_flags, - pg_options if pg_options is not None else None, - ) - - # Step 6: Handle cleanup and creation of new process group - target_group_info["pg_options_override"] = pg_options - return _finalize_shrunk_group(target_group_info, excluded_ranks_set, new_backend) - - -def _validate_shrink_inputs(ranks_to_exclude: list[int], shrink_flags: int) -> None: - """Validate input parameters for shrink_group.""" - if not isinstance(ranks_to_exclude, list): - raise TypeError( - f"ranks_to_exclude must be a list, but got {type(ranks_to_exclude).__name__}. " - f"Example: [1, 3, 5] to exclude ranks 1, 3, and 5." - ) - - if not ranks_to_exclude: - raise ValueError( - "ranks_to_exclude cannot be empty. To shrink a group, you must specify at least " - "one rank to exclude. Example: [failed_rank_id]" - ) - - # Validate shrink_flags with clear explanation of valid values - valid_flags = [SHRINK_DEFAULT, SHRINK_ABORT] - if not isinstance(shrink_flags, int) or shrink_flags not in valid_flags: - raise ValueError( - f"Invalid shrink_flags value: {shrink_flags}. Must be one of: " - f"SHRINK_DEFAULT ({SHRINK_DEFAULT}) or SHRINK_ABORT ({SHRINK_ABORT}). " - f"Use SHRINK_ABORT to abort ongoing operations before shrinking." - ) - - -def _prepare_shrink_target_group(group: Optional[ProcessGroup]) -> dict: - """Prepare and validate the target group for shrinking.""" - target_pg = group if group is not None else _get_default_group() - - # Cache frequently accessed properties to avoid repeated calls - group_size = int(target_pg.size()) - group_info = { - "process_group": target_pg, - "is_default_group": (target_pg == _get_default_group()), - "group_size": group_size, - "current_rank": target_pg.rank(), - "group_name": _get_process_group_name(target_pg), - } - - # Validate that we have a valid process group - if group_size <= 1: - raise ValueError( - f"Cannot shrink a process group with size {group_size}. " - f"Group must have at least 2 ranks to support shrinking." - ) - - return group_info - - -def _validate_shrink_backend_requirements(group_info: dict) -> Any: - """Return the backend implementation for the target group or raise if unsupported.""" - target_pg = group_info["process_group"] - group_name = group_info["group_name"] - - # Get the group's backend directly via ProcessGroup API. Prefer a bound device if present, - # otherwise try CUDA then fall back to CPU. - try: - preferred_device = getattr(target_pg, "bound_device_id", None) - if preferred_device is not None: - backend_impl = target_pg._get_backend(preferred_device) - else: - # Try CUDA first if available, else CPU - try: - backend_impl = target_pg._get_backend(torch.device("cuda")) - except Exception: - backend_impl = target_pg._get_backend(torch.device("cpu")) - except RuntimeError as e: - raise RuntimeError( - f"Cannot access device backend for process group '{group_name}'. " - f"Ensure the process group was initialized with a compatible device backend and devices are available." - ) from e - - try: - supports = bool(backend_impl.supports_shrinking) - except Exception: - supports = False - if not supports: - raise TypeError( - f"Process group backend for '{group_name}' does not support shrinking operations." - ) - - return backend_impl - - -def _validate_and_process_excluded_ranks( - ranks_to_exclude: list[int], group_info: dict -) -> set: - """Validate excluded ranks and convert to set for efficient operations.""" - group_size = group_info["group_size"] - current_rank = group_info["current_rank"] - - # Use set for O(1) duplicate detection and membership testing - excluded_ranks_set = set() - - # Validate each rank with detailed error messages - for i, rank in enumerate(ranks_to_exclude): - if not isinstance(rank, int): - raise TypeError( - f"All elements in ranks_to_exclude must be integers. " - f"Element at index {i} is {type(rank).__name__}: {rank}" - ) - - if not (0 <= rank < group_size): - raise ValueError( - f"Rank {rank} at index {i} is out of bounds for group size {group_size}. " - f"Valid ranks are in range [0, {group_size - 1}]." - ) - - if rank in excluded_ranks_set: - raise ValueError( - f"Duplicate rank {rank} found in ranks_to_exclude at index {i}. " - f"Each rank can only be excluded once." - ) - - excluded_ranks_set.add(rank) - - # Ensure we don't exclude all ranks - if len(excluded_ranks_set) >= group_size: - raise ValueError( - f"Cannot exclude all {group_size} ranks from process group. " - f"At least one rank must remain. Excluding {len(excluded_ranks_set)} ranks." - ) - - # Critical check: current rank should not be in excluded list - if current_rank in excluded_ranks_set: - raise RuntimeError( - f"Current rank {current_rank} is in the exclusion list and should not call shrink_group(). " - f"Only non-excluded ranks should participate in the shrinking operation. " - f"Excluded ranks should terminate their processes instead." - ) - - return excluded_ranks_set - - -def _finalize_shrunk_group( - group_info: dict, excluded_ranks_set: set, new_backend -) -> ProcessGroup: - """Clean up old group and create new shrunk process group.""" - target_pg = group_info["process_group"] - is_default_group = group_info["is_default_group"] - - # Handle default group dependencies - destroy other groups first - if is_default_group: - _destroy_all_other_groups(exclude_group=target_pg) - - # Gather original group metadata before cleanup - original_group_metadata = _extract_group_metadata(target_pg) - - # Calculate remaining ranks efficiently - original_ranks = get_process_group_ranks(target_pg) - remaining_ranks = [ - rank for rank in original_ranks if rank not in excluded_ranks_set - ] - - # Clean up the original group - _cleanup_original_group(target_pg, is_default_group) - - # Create and configure the new process group - new_pg = _create_shrunk_process_group( - new_backend, remaining_ranks, original_group_metadata, is_default_group - ) - - # Register the new group in global state - if is_default_group: - _update_default_pg(new_pg) - - # Update global state with new group information - rank_mapping = { - global_rank: group_rank - for group_rank, global_rank in enumerate(remaining_ranks) - } - _update_process_group_global_state( - pg=new_pg, - backend_name=original_group_metadata["backend_name"], - store=original_group_metadata["store"], - group_name=original_group_metadata["new_group_name"], - backend_config=original_group_metadata["backend_config"], - rank_mapping=rank_mapping, - ) - - return new_pg - - -def _extract_group_metadata(target_pg: ProcessGroup) -> dict: - """Extract metadata from the original group before cleanup.""" - original_backend_name, original_store = _world.pg_map[target_pg] - original_backend_config = _world.pg_backend_config.get(target_pg, "") - original_group_name = _get_process_group_name(target_pg) - - # Extract device binding information before cleanup to avoid accessing destroyed group - bound_device_id = None - if hasattr(target_pg, "bound_device_id"): - bound_device_id = target_pg.bound_device_id - - # Generate new group name for the shrunk group; hash for uniqueness across backends - remaining_ranks = list(get_process_group_ranks(target_pg)) - new_group_name = _process_group_name(remaining_ranks, use_hashed_name=True) - - return { - "backend_name": original_backend_name, - "store": original_store, - "backend_config": original_backend_config, - "original_group_name": original_group_name, - "new_group_name": new_group_name, - "bound_device_id": bound_device_id, # Safe to access after cleanup - } - - -def _cleanup_original_group(target_pg: ProcessGroup, is_default_group: bool) -> None: - """Clean up the original process group safely.""" - try: - destroy_process_group(target_pg) - except Exception as e: - group_type = "default" if is_default_group else "non-default" - logger.warning("Failed to destroy %s group during shrinking: %s", group_type, e) - - # Ensure global state cleanup even if destroy_process_group fails - _cleanup_process_group_global_state(target_pg) - - -def _create_shrunk_process_group( - new_backend, remaining_ranks: list[int], metadata: dict, is_default_group: bool -) -> ProcessGroup: - """Create and configure the new shrunk process group.""" - # Create new group properties - new_group_rank = new_backend.rank() - new_group_size = new_backend.size() - group_name = metadata["new_group_name"] - - # Generate descriptive group description - if is_default_group: - group_desc = "default:shrunken" - else: - group_desc = f"{metadata['original_group_name']}:shrunk" - - # Create process group with new communicator (clone the parent store like split does) - prefix_store = PrefixStore(f"{group_name}/", metadata["store"].clone()) - new_pg = ProcessGroup(prefix_store, new_group_rank, new_group_size) - - # Configure backend using the device type of the new backend's bound device if available, - # otherwise derive from the original group's bound device or fall back to CPU. - backend_device = metadata.get("bound_device_id") - if backend_device is None: - # Default to CPU if no bound device is present - backend_device = torch.device("cpu") - - # Choose backend enum based on device type - if backend_device.type == "cuda": - backend_type = ProcessGroup.BackendType.NCCL - else: - backend_type = ProcessGroup.BackendType.GLOO - - new_pg._register_backend(backend_device, backend_type, new_backend) - new_pg._set_default_backend(backend_type) - - # Inherit device binding from original group if it was bound - bound_device_id = metadata.get("bound_device_id") - if bound_device_id is not None: - new_pg.bound_device_id = bound_device_id - - # Set group metadata - new_pg._set_group_name(group_name) - new_pg._set_group_desc(group_desc) - - # Persist backend configuration overrides (if provided via shrink_group) - backend_config_override = metadata.get("backend_config") - if backend_config_override is not None: - # Store for introspection/debugging and potential backend hooks - _world.pg_backend_config[new_pg] = backend_config_override - - return new_pg - - -def _destroy_all_other_groups(exclude_group: Optional[ProcessGroup] = None) -> None: - """ - Destroy all process groups except the excluded group and clean up all global state. - - This is necessary when shrinking the default group because global ranks - are reassigned by NCCL, making all existing process groups inconsistent. - - Note: Uses abort for non-collective cleanup since excluded ranks may not - participate in collective operations. Backend cleanup is handled independently per group. - - Args: - exclude_group (ProcessGroup, optional): Process group to exclude from destruction. - If None, destroys all process groups. - """ - # Get list of groups to destroy (avoid modifying dict while iterating) - groups_to_destroy = [] - for pg in list(_world.pg_group_ranks.keys()): - if exclude_group is not None and pg == exclude_group: - continue - groups_to_destroy.append(pg) - - # Warn user about automatic destruction - if groups_to_destroy: - group_names = [_get_process_group_name(pg) for pg in groups_to_destroy] - logger.warning( - "Shrinking default group will destroy %d other process groups: %s. " - "This is necessary because shrinking the default group reassigns global ranks, " - "making existing groups inconsistent.", - len(groups_to_destroy), - ", ".join(group_names), - ) - - # Destroy each group and clean up global state - for pg in groups_to_destroy: - try: - # First call abort_process_group which handles the C++ cleanup non-collectively - _abort_process_group(pg) - except Exception as e: - # Log but don't fail - some groups might already be destroyed - logger.warning( - "Failed to abort process group %s: %s", - _get_process_group_name(pg), - e, - ) - - # Ensure all global state is cleaned up even if _abort_process_group fails - # or doesn't clean up everything - _cleanup_process_group_global_state(pg) - - -def _cleanup_process_group_global_state(pg: ProcessGroup) -> None: - """ - Clean up all global state associated with a process group. - - This function ensures complete cleanup of process group state from all - global dictionaries and registries, even if destroy_process_group fails - or doesn't clean up everything. This is critical when destroying multiple - groups to prevent inconsistent state. - - The cleanup removes the process group from: - - _world.pg_map (backend and store mapping) - - _world.pg_names (group name mapping) - - _world.pg_group_ranks (rank mappings) - - _world.pg_backend_config (backend configuration) - - _world.tags_to_pg and _world.pg_to_tag (tag mappings) - - _world.pg_coalesce_state (coalescing state) - - C++ internal registries via _unregister_process_group - - Args: - pg (ProcessGroup): The process group to clean up. - """ - try: - # Clean up main process group mappings - _world.pg_map.pop(pg, None) - _world.pg_group_ranks.pop(pg, None) - _world.pg_backend_config.pop(pg, None) - - # Clean up process group name mapping - group_name = _world.pg_names.pop(pg, None) - - # Clean up tag mappings - pg_tag = _world.pg_to_tag.pop(pg, None) - if pg_tag is not None and pg_tag in _world.tags_to_pg: - try: - _world.tags_to_pg[pg_tag].remove(pg) - # Remove the tag entry if list is empty - if not _world.tags_to_pg[pg_tag]: - _world.tags_to_pg.pop(pg_tag, None) - except (ValueError, KeyError): - # Process group was already removed from the list - pass - - # Clean up any registered process group names using C++ unregister function - if group_name is not None: - try: - _unregister_process_group(group_name) - except Exception: - # Process group name might not be registered or already unregistered - pass - - # Clean up coalesce state if present - _world.pg_coalesce_state.pop(pg, None) - - except Exception as e: - # Log cleanup failures but don't propagate - we want to continue with other cleanups - logger.warning("Failed to fully clean up global state for process group: %s", e) - - -def _update_process_group_global_state( - pg: ProcessGroup, - backend_name: str, - store: Store, - group_name: str, - backend_config: str, - rank_mapping: Optional[dict[int, int]] = None, - pg_tag: Optional[str] = None, - user_tag: Optional[str] = None, -) -> None: - """ - Update all global state dictionaries for a process group. - - This helper function consolidates the common pattern of updating multiple - global state dictionaries when creating or modifying process groups. - - Args: - pg (ProcessGroup): The process group to update state for. - backend_name (str): Backend name for pg_map. - store (Store): Store instance for pg_map. - group_name (str): Group name for pg_names and registration. - backend_config (str): Backend configuration string. - rank_mapping (Dict[int, int], optional): Global rank to group rank mapping. - If None, skips updating pg_group_ranks. - pg_tag (str, optional): Process group tag. If None, defaults to f"ptd:{group_name}". - user_tag (str, optional): User-provided tag for special tag handling. - If provided, creates "user:{user_tag}" tag and also adds to default "". - """ - # Update main process group mappings - _world.pg_map[pg] = (backend_name, store) - _world.pg_names[pg] = group_name - _world.pg_backend_config[pg] = backend_config - - # Register the process group name - _register_process_group(group_name, pg) - - # Update rank mapping if provided - if rank_mapping is not None: - _world.pg_group_ranks[pg] = rank_mapping - - # Handle tag management - if pg_tag is None: - pg_tag = f"ptd:{group_name}" - - if user_tag is not None: - # Special handling for user-provided tags - # Add to default "" tag first - _world.tags_to_pg.setdefault("", []).append(pg) - # Then create user-specific tag - user_pg_tag = f"user:{user_tag}" - _world.tags_to_pg.setdefault(user_pg_tag, []).append(pg) - _world.pg_to_tag[pg] = user_pg_tag - else: - # Standard process group tag - _world.tags_to_pg.setdefault(pg_tag, []).append(pg) - _world.pg_to_tag[pg] = pg_tag diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 91f09adf9e81..18384b311b93 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -238,47 +238,6 @@ def skip_if_lt_x_gpu(x): return decorator -def requires_world_size(n: int): - """ - Decorator to request a specific world size for a test. The test harness can - read this attribute to set the number of ranks to spawn. If there are fewer - than `n` CUDA devices available, the test should be skipped by the harness. - - Usage: - @require_world_size(3) - def test_something(self): - ... - """ - - def decorator(func): - func._required_world_size = n - available = torch.cuda.device_count() - return unittest.skipUnless( - available >= n, f"requires {n} GPUs, found {available}" - )(func) - - return decorator - - -def get_required_world_size(obj: Any, default: int) -> int: - """ - Returns the requested world size for the currently running unittest method on `obj` - if annotated via `@require_world_size(n)`, else returns `default`. - """ - try: - # Try MultiProcessTestCase helper first, then unittest fallback - test_name = ( - obj._current_test_name() # type: ignore[attr-defined] - if hasattr(obj, "_current_test_name") and callable(obj._current_test_name) - else obj._testMethodName - ) - fn = getattr(obj, test_name) - value = fn._required_world_size - return int(value) - except Exception: - return default - - # This decorator helps avoiding initializing cuda while testing other backends def nccl_skip_if_lt_x_gpu(backend, x): def decorator(func): @@ -408,13 +367,6 @@ def requires_nccl_version(version, msg): ) -def requires_nccl_shrink(): - """ - Require NCCL shrink support (NCCL available and version >= 2.27). - """ - return requires_nccl_version((2, 27), "Need NCCL 2.27+ for shrink_group") - - def requires_nccl(): return skip_but_pass_in_sandcastle_if( not c10d.is_nccl_available(),