From fa0db212e717b6cb225159cb32ea3d83baa52381 Mon Sep 17 00:00:00 2001 From: Bruce Chang Date: Sun, 19 Oct 2025 18:00:08 +0000 Subject: [PATCH] shrink_group implementation to expose ncclCommShrink API (#164518) Closes #164529 To expose the new [ncclCommShrink](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommshrink) API to PyTorch. This is useful when you need to exclude certain GPUs or nodes from a collective operation, for example in fault tolerance scenarios or when dynamically adjusting resource utilization. For more info: [Shrinking a communicator](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#shrinking-a-communicator) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164518 Approved by: https://github.com/kwen2501 --- docs/source/distributed.md | 4 + test/distributed/logging_utils.py | 43 ++ test/distributed/test_c10d_nccl.py | 640 +++++++++++++++++- torch/csrc/distributed/c10d/Backend.hpp | 17 + torch/csrc/distributed/c10d/NCCLUtils.cpp | 59 ++ torch/csrc/distributed/c10d/NCCLUtils.hpp | 12 + .../distributed/c10d/ProcessGroupNCCL.cpp | 135 +++- .../distributed/c10d/ProcessGroupNCCL.hpp | 21 + torch/csrc/distributed/c10d/init.cpp | 11 + torch/distributed/distributed_c10d.py | 515 ++++++++++++++ torch/testing/_internal/common_distributed.py | 48 ++ 11 files changed, 1503 insertions(+), 2 deletions(-) create mode 100644 test/distributed/logging_utils.py diff --git a/docs/source/distributed.md b/docs/source/distributed.md index 5da02bb8a194..69df7be1fa80 100644 --- a/docs/source/distributed.md +++ b/docs/source/distributed.md @@ -394,6 +394,10 @@ an opaque group handle that can be given as a `group` argument to all collective .. autofunction:: new_group ``` +```{eval-rst} +.. autofunction:: torch.distributed.distributed_c10d.shrink_group +``` + ```{eval-rst} .. autofunction:: get_group_rank ``` diff --git a/test/distributed/logging_utils.py b/test/distributed/logging_utils.py new file mode 100644 index 000000000000..09a0adccfd80 --- /dev/null +++ b/test/distributed/logging_utils.py @@ -0,0 +1,43 @@ +import logging +import time + + +_start_time = time.time() +_logger = logging.getLogger(__name__) + + +def _ts(): + return time.time() - _start_time + + +def configure(level=logging.INFO, force=False): + try: + logging.basicConfig( + level=level, + format="%(asctime)s %(name)s %(levelname)s: %(message)s", + force=force, + ) + except TypeError: + logging.basicConfig( + level=level, format="%(asctime)s %(name)s %(levelname)s: %(message)s" + ) + + +def log_test_info(rank, message): + _logger.info("[%7.3fs][Rank %s] %s", _ts(), rank, message) + + +def log_test_success(rank, message): + _logger.info("[%7.3fs][Rank %s] ✅ %s", _ts(), rank, message) + + +def log_test_validation(rank, message): + _logger.info("[%7.3fs][Rank %s] ✓ %s", _ts(), rank, message) + + +def log_test_warning(rank, message): + _logger.warning("[%7.3fs][Rank %s] ⚠️ %s", _ts(), rank, message) + + +def log_test_error(rank, message): + _logger.error("[%7.3fs][Rank %s] ✗ %s", _ts(), rank, message) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 7410255d27a8..149622e9445c 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -2,6 +2,7 @@ import copy import json +import logging import os import pickle import random @@ -21,6 +22,7 @@ from unittest import mock, SkipTest import torch import torch.distributed as c10d import torch.distributed._functional_collectives as _functional_collectives +from torch.distributed.distributed_c10d import SHRINK_ABORT as NCCL_SHRINK_ABORT if not c10d.is_available() or not c10d.is_nccl_available(): @@ -47,12 +49,15 @@ from torch._C._distributed_c10d import ErrorType, OpType, WorkResult from torch.nn.parallel import DistributedDataParallel from torch.testing._internal.common_cuda import _get_torch_rocm_version, TEST_MULTIGPU from torch.testing._internal.common_distributed import ( + get_required_world_size, get_timeout, init_multigpu_helper, MultiProcessTestCase, requires_multicast_support, requires_nccl, + requires_nccl_shrink, requires_nccl_version, + requires_world_size, skip_if_lt_x_gpu, skip_if_rocm_multiprocess, sm_is_or_higher_than, @@ -87,6 +92,17 @@ BFLOAT16_AVAILABLE = torch.cuda.is_available() and ( torch.version.cuda is not None or torch.version.hip is not None ) +from logging_utils import ( + configure as _log_configure, + log_test_info, + log_test_success, + log_test_validation, + log_test_warning, +) + + +_log_configure(level=logging.INFO, force=True) + class RendezvousEnvTest(TestCase): @retry_on_connect_failures @@ -317,7 +333,7 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase): @property def world_size(self): - return 2 + return get_required_world_size(self, 2) @property def rank_to_GPU(self): @@ -1255,6 +1271,628 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase): pg_2 = c10d.new_group([0, 1]) self.assertEqual(pg_2.group_desc, "undefined") + @requires_nccl_shrink() + @requires_world_size(2) + def test_shrink_group_basic(self): + """Test basic shrink_group functionality.""" + self._perform_shrink_test([1], "Basic shrink test") + + @requires_nccl_shrink() + @requires_world_size(2) + def test_shrink_group_validation(self): + """Test input validation in shrink_group.""" + device, pg = self._setup_shrink_test("validation") + + def _test_invalid_input(ranks, description, expected_exception): + """Helper to test invalid inputs.""" + try: + c10d.shrink_group(ranks) + self.fail(f"Expected {expected_exception.__name__} for {description}") + except expected_exception: + log_test_validation(self.rank, f"✓ {description}") + except Exception: + if expected_exception is Exception: # Accept any exception + log_test_validation(self.rank, f"✓ {description}") + else: + raise + + # Test cases + _test_invalid_input([], "Empty exclusion list", ValueError) + if self.world_size > 1: + _test_invalid_input([0, 0, 1], "Duplicate ranks", Exception) + _test_invalid_input([self.world_size + 1], "Out of bounds rank", Exception) + + log_test_success(self.rank, "All validation tests passed") + dist.destroy_process_group() + + @requires_nccl_shrink() + @requires_world_size(2) + def test_shrink_group_backend_properties(self): + """Test that backend properties are preserved after shrinking.""" + + test_name = "Backend Properties Test" + ranks_to_exclude = [0] + + # Reuse _setup_shrink_test for complete setup (device, environment, and process group) + device, pg = self._setup_shrink_test("backend_properties") + + # Follow _perform_shrink_test pattern from here + log_test_info(self.rank, f"{test_name} (world_size={self.world_size})") + + is_excluded = self.rank in ranks_to_exclude + log_test_info( + self.rank, + f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}", + ) + + # Store original backend property values (not references) before shrinking + original_timeout = None + original_high_priority = None + if not is_excluded: + original_backend = pg._get_backend(device) + original_timeout = original_backend.options._timeout + original_high_priority = original_backend.options.is_high_priority_stream + log_test_info( + self.rank, + f"Storing original backend properties: timeout={original_timeout}, high_priority={original_high_priority}", + ) + + if is_excluded: + log_test_info( + self.rank, + f"Excluded rank {self.rank} - setup complete, skipping shrink operation", + ) + dist.destroy_process_group() # hang without it + return + + # Only non-excluded ranks proceed with shrink (same as _perform_shrink_test) + log_test_info(self.rank, "Non-excluded rank calling shrink_group") + shrunk_pg = c10d.shrink_group(ranks_to_exclude) + + # Reuse _validate_shrunk_group helper (same as _perform_shrink_test) + expected_size = self.world_size - len(ranks_to_exclude) + _ = self._validate_shrunk_group(shrunk_pg, expected_size, test_name) + + # Add custom backend properties validation + new_backend = shrunk_pg._get_backend(device) + log_test_info(self.rank, "Validating backend properties are preserved") + + new_timeout = new_backend.options._timeout + new_high_priority = new_backend.options.is_high_priority_stream + + log_test_info( + self.rank, + f"Timeout comparison - original: {original_timeout}, new: {new_timeout}", + ) + self.assertEqual( + original_timeout, new_timeout, f"{test_name}: timeout not preserved" + ) + + log_test_info( + self.rank, + f"High priority stream comparison - original: {original_high_priority}, new: {new_high_priority}", + ) + self.assertEqual( + original_high_priority, + new_high_priority, + f"{test_name}: high_priority_stream not preserved", + ) + + log_test_validation( + self.rank, f"{test_name}: Backend properties preserved successfully" + ) + log_test_success( + self.rank, f"{test_name} successful (shrink + backend validation)" + ) + + # Cleanup (same as _perform_shrink_test) + dist.destroy_process_group() + + @requires_nccl_shrink() + @requires_world_size(2) + def test_shrink_group_multiple_comms(self): + """Test shrink_group with multiple communicators and subgroup invalidation.""" + + device, pg = self._setup_shrink_test("multiple_comms") + + # Create subgroup [0, 1] and test shrinking it + subgroup = c10d.new_group([0, 1]) + if self.rank <= 1: + # Shrink subgroup: exclude rank 1 + if self.rank == 0: # Only rank 0 remains + shrunk_subgroup = c10d.shrink_group([1], group=subgroup) + self.assertEqual(shrunk_subgroup.size(), 1) + # Test communication on shrunk subgroup + tensor = torch.full((1,), self.rank).cuda(device) + c10d.all_reduce(tensor, group=shrunk_subgroup) + self.assertEqual(tensor.item(), 0) # Only rank 0 + log_test_success(self.rank, "Subgroup shrinking successful") + + dist.barrier() # Sync before default group test + + # Shrink default group: exclude last rank + ranks_to_exclude = [self.world_size - 1] + if self.rank not in ranks_to_exclude: + shrunk_default = c10d.shrink_group(ranks_to_exclude) + expected_size = self.world_size - 1 + self.assertEqual(shrunk_default.size(), expected_size) + + # Test collective on shrunk default group + tensor = torch.full((1,), self.rank).cuda(device) + c10d.all_reduce(tensor, group=shrunk_default) + expected_sum = sum( + range(self.world_size - 1) + ) # 0 + 1 + ... + (world_size-2) + self.assertEqual(tensor.item(), expected_sum) + log_test_success(self.rank, "Default group shrinking successful") + + # Note: After shrinking default group, the old subgroup is invalid + # due to global rank reassignment + + dist.destroy_process_group() + + def _test_shrink_group_with_flag(self, shrink_flag, flag_name, rank_to_exclude): + """Helper method to test shrink_group with a specific flag.""" + if self.world_size < 2: + log_test_info(self.rank, f"Skipping (needs ≥2 GPUs, got {self.world_size})") + return + ranks_to_exclude = [rank_to_exclude] + log_test_info(self.rank, f"Using {flag_name} flag (value: {shrink_flag})") + if flag_name == "NCCL_SHRINK_ABORT": + log_test_info( + self.rank, + "ABORT flag will terminate ongoing operations before shrinking", + ) + + self._perform_shrink_test( + ranks_to_exclude, f"{flag_name} flag test", shrink_flags=shrink_flag + ) + + @requires_nccl_shrink() + @requires_world_size(2) + def test_shrink_group_flags(self): + """Test shrink_group with different shrink flags.""" + # Test ABORT flags + log_test_info(self.rank, "Testing NCCL_SHRINK_ABORT flag") + self._test_shrink_group_with_flag(NCCL_SHRINK_ABORT, "NCCL_SHRINK_ABORT", 1) + + @requires_nccl_shrink() + @requires_world_size(2) + def test_shrink_group_nccl_config(self): + """Verify that passing NCCL config via pg_options influences the shrunk group's backend options.""" + device, pg = self._setup_shrink_test("config") + if self.rank == self.world_size - 1: + # excluded rank should not call shrink_group + dist.destroy_process_group() + return + + # Prepare pg_options with NCCL config overrides + # Capture parent's current backend options to ensure we can prove override vs inherit + parent_backend = pg._get_backend(torch.device("cuda")) + parent_hp = parent_backend.options.is_high_priority_stream + parent_blocking = parent_backend.options.config.blocking + + # Choose overrides that differ from the parent (flip where possible) + override_hp = not parent_hp + if parent_blocking in (0, 1): + override_blocking = 1 - parent_blocking + else: + # If undefined or unexpected, set to 1 which is a concrete value + override_blocking = 1 + + opts = c10d.ProcessGroupNCCL.Options() + opts.is_high_priority_stream = override_hp + opts.config.blocking = override_blocking + + shrunk_pg = c10d.shrink_group([self.world_size - 1], pg_options=opts) + + # Validate backend options propagated + backend = shrunk_pg._get_backend(torch.device("cuda")) + # is_high_priority_stream should exactly match our override and differ from parent + self.assertEqual(backend.options.is_high_priority_stream, override_hp) + self.assertNotEqual(backend.options.is_high_priority_stream, parent_hp) + # config is a struct; check representative field and difference from parent when meaningful + self.assertEqual(backend.options.config.blocking, override_blocking) + if parent_blocking in (0, 1): + self.assertNotEqual(backend.options.config.blocking, parent_blocking) + + dist.destroy_process_group() + + @requires_nccl_shrink() + @requires_world_size(2) + def test_shrink_group_performance(self): + """Test shrink_group performance and regression detection.""" + import time + + ranks_to_exclude = self._get_default_ranks_to_exclude() + is_excluded = self.rank in ranks_to_exclude + + if not ranks_to_exclude: + log_test_info(self.rank, "Skipping performance test (world_size=1)") + return + + log_test_info(self.rank, f"Performance test with {self.world_size} processes") + device, pg = self._setup_shrink_test("performance") + + if not is_excluded: + log_test_info(self.rank, "Measuring shrink_group performance") + start_time = time.time() + shrunk_pg = c10d.shrink_group(ranks_to_exclude) + end_time = time.time() + + elapsed_time = end_time - start_time + log_test_info(self.rank, f"shrink_group: {elapsed_time:.3f}s") + + # Regression check: should complete within reasonable time + self.assertLess( + elapsed_time, + 30.0, + f"shrink_group took {elapsed_time:.3f}s, possible regression", + ) + + # Test collective performance + expected_size = self.world_size - len(ranks_to_exclude) + self._validate_shrunk_group(shrunk_pg, expected_size, "performance") + + collective_start = time.time() + _ = self._test_collective_on_shrunk_group( + shrunk_pg, device, ranks_to_exclude, "performance" + ) + collective_time = time.time() - collective_start + + log_test_info(self.rank, f"all_reduce: {collective_time:.3f}s") + log_test_success(self.rank, "Performance test passed") + else: + log_test_info(self.rank, "Excluded rank - waiting") + + dist.destroy_process_group() + + @requires_nccl_shrink() + @requires_world_size(4) + def test_shrink_group_multiple_exclusions(self): + """Test shrink_group with multiple ranks excluded at once.""" + # Scale exclusions with world size + ranks_to_exclude = list(range(2, self.world_size, 2)) # Every other rank from 2 + + self._perform_shrink_test(ranks_to_exclude, "Multiple exclusions test") + + @requires_nccl_shrink() + @requires_world_size(3) + def test_shrink_group_multiple_iterations(self): + """Test multiple shrink operations in sequence.""" + log_test_info( + self.rank, + f"Starting test_shrink_group_multiple_iterations with world_size={self.world_size}", + ) + + store = c10d.FileStore(self.file_name, self.world_size) + device = torch.device(f"cuda:{self.rank}") + _ = self._create_process_group_nccl(store, self.opts(), device_id=device) + + # Track current effective world size throughout shrinking operations + current_world_size = self.world_size + log_test_info(self.rank, f"Initial world_size: {current_world_size}") + + # First shrinking: exclude the last rank(s) + first_exclusion = [self.world_size - 1] + if self.world_size >= 6: + first_exclusion.append( + self.world_size - 2 + ) # Exclude last two ranks for larger sizes + + log_test_info(self.rank, f"First shrinking: excluding ranks {first_exclusion}") + + if self.rank not in first_exclusion: + # Only non-excluded ranks should call shrink_group + first_pg = c10d.shrink_group(first_exclusion) + self.assertIsNotNone(first_pg) + # IMPORTANT: Update world size after first shrinking + current_world_size = first_pg.size() + expected_first_size = self.world_size - len(first_exclusion) + log_test_info( + self.rank, + f"After first shrinking: world_size {self.world_size} -> {current_world_size}", + ) + self.assertEqual(first_pg.size(), expected_first_size) + + # Second shrinking: exclude another rank from the remaining group + # Choose a rank that's in the middle range + if current_world_size >= 3: + second_exclusion = [ + current_world_size - 1 + ] # Exclude the new "last" rank + log_test_info( + self.rank, + f"Second shrinking from group of size {current_world_size}: excluding ranks {second_exclusion}", + ) + + if self.rank not in second_exclusion: + # Only non-excluded ranks should call shrink_group for second iteration + second_pg = c10d.shrink_group(second_exclusion, group=first_pg) + self.assertIsNotNone(second_pg) + # IMPORTANT: Update world size after second shrinking + final_world_size = second_pg.size() + expected_final_size = current_world_size - len(second_exclusion) + log_test_info( + self.rank, + f"After second shrinking: world_size {current_world_size} -> {final_world_size}", + ) + self.assertEqual(second_pg.size(), expected_final_size) + + # Test collective on final group + tensor = torch.full((1,), self.rank).cuda(device) + log_test_info( + self.rank, + f"Performing all_reduce on final group (size {final_world_size}) with tensor: {tensor.item()}", + ) + c10d.all_reduce(tensor, group=second_pg) + log_test_info( + self.rank, + f"Final all_reduce completed, result: {tensor.item()}", + ) + + # Calculate expected sum of remaining ranks + all_excluded = set(first_exclusion + second_exclusion) + remaining_ranks = [ + r for r in range(self.world_size) if r not in all_excluded + ] + expected_sum = sum(remaining_ranks) + log_test_info( + self.rank, + f"Remaining ranks: {remaining_ranks}, expected sum: {expected_sum}, actual: {tensor.item()}", + ) + self.assertEqual(tensor.item(), expected_sum) + log_test_info(self.rank, "Final verification passed") + else: + log_test_info( + self.rank, + "This rank excluded in second shrinking, not calling shrink_group", + ) + else: + log_test_info( + self.rank, "Skipping second shrinking (remaining group too small)" + ) + else: + log_test_info( + self.rank, + "This rank excluded in first shrinking, not calling shrink_group", + ) + + log_test_info(self.rank, "Destroying process group") + dist.destroy_process_group() + log_test_info(self.rank, "test_shrink_group_multiple_iterations completed") + + # Helper methods for optimized shrink group tests + def _setup_shrink_test(self, test_suffix, world_size=None, warmup=True): + """Common setup for shrink group tests.""" + os.environ["TORCH_NCCL_USE_COMM_NONBLOCKING"] = "1" + world_size = world_size or self.world_size + store = c10d.FileStore(self.file_name + f"_{test_suffix}", world_size) + device = torch.device(f"cuda:{self.rank}") + c10d.init_process_group( + "nccl", + world_size=world_size, + rank=self.rank, + store=store, + pg_options=self.opts(), + device_id=device, + ) + pg = c10d.distributed_c10d._get_default_group() + + if warmup: + c10d.all_reduce(torch.ones(1).cuda(device), group=pg) + + return device, pg + + def _validate_shrunk_group(self, shrunk_pg, expected_size, test_name=""): + """Validate properties of a shrunk process group.""" + self.assertIsNotNone(shrunk_pg, f"{test_name}: shrunk_pg should not be None") + actual_size = shrunk_pg.size() + self.assertEqual( + actual_size, expected_size, f"{test_name}: group size mismatch" + ) + + new_rank = shrunk_pg.rank() + self.assertTrue( + 0 <= new_rank < expected_size, f"{test_name}: invalid new rank {new_rank}" + ) + + log_test_info( + self.rank, + f"{test_name}: world_size {self.world_size} -> {actual_size}, rank {self.rank} -> {new_rank}", + ) + return new_rank + + def _test_collective_on_shrunk_group( + self, shrunk_pg, device, ranks_to_exclude, test_name="" + ): + """Test collective communication on shrunk group and verify correctness.""" + test_tensor = torch.full((1,), self.rank, device=device, dtype=torch.float32) + c10d.all_reduce(test_tensor, group=shrunk_pg) + + result = test_tensor.item() + expected_sum = sum( + r for r in range(self.world_size) if r not in ranks_to_exclude + ) + + self.assertEqual( + result, expected_sum, f"{test_name}: collective result mismatch" + ) + log_test_info( + self.rank, f"{test_name}: collective passed ({result} == {expected_sum})" + ) + return result + + def _perform_shrink_test( + self, ranks_to_exclude, test_name, shrink_flags=0, with_collective=True + ): + """Complete shrink test flow: setup, shrink, validate, test collective, cleanup. + + Consistent API: All ranks perform setup to initialize distributed environment. + ONLY non-excluded ranks call shrink_group() for both default and non-default groups. + Excluded ranks perform setup, then exit without calling shrink_group() or waiting. + """ + log_test_info(self.rank, f"{test_name} (world_size={self.world_size})") + + is_excluded = self.rank in ranks_to_exclude + log_test_info( + self.rank, + f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}", + ) + + # All ranks (including excluded ones) perform setup to initialize distributed environment + device, pg = self._setup_shrink_test(test_name.lower().replace(" ", "_")) + is_default_group = pg == c10d.distributed_c10d._get_default_group() + + if is_excluded: + log_test_info( + self.rank, + f"Excluded rank {self.rank} - setup complete, skipping shrink operation", + ) + if shrink_flags & NCCL_SHRINK_ABORT: + log_test_info(self.rank, f"Using abort for excluded rank {self.rank}") + pg._get_backend(torch.device(device)).abort() + log_test_info( + self.rank, f"cleanup resources for excluded rank {self.rank}" + ) + dist.destroy_process_group() + log_test_info(self.rank, f"Excluded rank {self.rank} - exit") + else: + log_test_info( + self.rank, f"Using regular destroy for excluded rank {self.rank}" + ) + dist.destroy_process_group() + return None + + # Only non-excluded ranks proceed with shrink + log_test_info( + self.rank, + f"Non-excluded rank calling shrink_group (default_group={is_default_group})", + ) + shrunk_pg = c10d.shrink_group(ranks_to_exclude, shrink_flags=shrink_flags) + log_test_info( + self.rank, + f"Non-excluded rank calling shrink_group (default_group={is_default_group}) done", + ) + + # Non-excluded ranks: validate and test the new group + expected_size = self.world_size - len(ranks_to_exclude) + _ = self._validate_shrunk_group(shrunk_pg, expected_size, test_name) + + if with_collective: + _ = self._test_collective_on_shrunk_group( + shrunk_pg, device, ranks_to_exclude, test_name + ) + log_test_success(self.rank, f"{test_name} successful (shrink + collective)") + else: + log_test_success(self.rank, f"{test_name} successful (shrink only)") + + dist.destroy_process_group() + return shrunk_pg + + def _get_default_ranks_to_exclude(self): + """Get default ranks to exclude based on world size.""" + if self.world_size <= 1: + return [] + return [self.world_size - 1] # Exclude last rank by default + + @requires_nccl_shrink() + @requires_world_size(3) + def test_shrink_group_vs_abort_reinit_performance(self): + """Compare performance of shrink_group vs traditional abort+reinit (simplified for reliability).""" + log_test_info(self.rank, "=== TEST 1: abort+reinit ===") + + device, pg1 = self._setup_shrink_test("_perf_reinit") + torch.cuda.synchronize(device) + + # Test 1: Traditional abort + reinit + start_time = time.perf_counter() + dist.destroy_process_group() + + device, new_pg = self._setup_shrink_test("perf_shrink_test1") + reinit_time = time.perf_counter() - start_time + + # Test collective with original rank values for fair comparison (non-blocking mode) + test_tensor = torch.full((1,), self.rank, device=device, dtype=torch.float32) + work = c10d.all_reduce(test_tensor, group=new_pg, async_op=True) + work.wait() + + torch.cuda.synchronize(device) + + # Verify correctness + expected_sum = sum(r for r in range(self.world_size)) + self.assertEqual(test_tensor.item(), expected_sum, "Reinit collective failed") + + log_test_info(self.rank, f"abort+reinit: {reinit_time:.4f}s") + dist.destroy_process_group(new_pg) + + # Test 2: shrink_group with NCCL_SHRINK_ABORT + log_test_info(self.rank, "=== TEST 2: shrink_group ===") + + ranks_to_exclude = [self.world_size - 1] + is_excluded = self.rank in ranks_to_exclude + log_test_info( + self.rank, + f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}", + ) + + device, pg1 = self._setup_shrink_test("perf_shrink_test2") # Unique suffix + + shrink_time = 0 + if not is_excluded: + torch.cuda.synchronize(device) # Ensure accurate timing + start_time = time.perf_counter() + shrunk_pg = c10d.shrink_group( + ranks_to_exclude, shrink_flags=NCCL_SHRINK_ABORT + ) + c10d.all_reduce(torch.ones(1).cuda(device), group=shrunk_pg) + shrink_time = time.perf_counter() - start_time + + # Test collective communication on shrunk group (non-blocking mode) + test_tensor = torch.full( + (1,), self.rank, device=device, dtype=torch.float32 + ) + work = c10d.all_reduce(test_tensor, group=shrunk_pg, async_op=True) + work.wait() + + # Verify correctness + expected_sum = sum( + r for r in range(self.world_size) if r not in ranks_to_exclude + ) + self.assertEqual( + test_tensor.item(), + expected_sum, + "shrink_test: collective result mismatch", + ) + + torch.cuda.synchronize(device) # Ensure operations complete + log_test_info(self.rank, f"shrink_group: {shrink_time:.4f}s") + dist.destroy_process_group() + else: + log_test_info(self.rank, "Excluded from shrink test - exiting immediately") + dist.destroy_process_group() + return + + # Performance analysis (only for participating ranks) + if shrink_time > 0 and reinit_time > 0: + speedup = reinit_time / shrink_time + time_saved = reinit_time - shrink_time + + log_test_info(self.rank, "=== PERFORMANCE RESULTS ===") + log_test_info(self.rank, f"shrink_group: {shrink_time:.4f}s") + log_test_info(self.rank, f"abort+reinit: {reinit_time:.4f}s") + log_test_info(self.rank, f"time_saved: {time_saved:+.4f}s") + log_test_info(self.rank, f"speedup: {speedup:.2f}x") + + if speedup > 1.1: + log_test_success(self.rank, "shrink_group significantly faster") + elif speedup > 0.9: + log_test_info(self.rank, "≈ comparable performance") + else: + log_test_warning(self.rank, "abort+reinit faster") + + log_test_info(self.rank, "Performance test completed") + @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_deterministic_mode_no_break(self): diff --git a/torch/csrc/distributed/c10d/Backend.hpp b/torch/csrc/distributed/c10d/Backend.hpp index 655e0a5578c2..1ebf9394e064 100644 --- a/torch/csrc/distributed/c10d/Backend.hpp +++ b/torch/csrc/distributed/c10d/Backend.hpp @@ -79,6 +79,23 @@ class TORCH_API Backend : public torch::CustomClassHolder { return false; } + virtual bool supportsShrinking() const { + return false; + } + + // Shrink the backend by excluding specified ranks. Backends that support + // communicator shrinking should override this and return a new backend + // instance representing the shrunken group. Backends may use opts_override + // to supply backend-specific options for the new group. + virtual c10::intrusive_ptr shrink( + const std::vector& /*ranks_to_exclude*/, + int /*shrink_flags*/ = 0, + const c10::intrusive_ptr& /*opts_override*/ = nullptr) { + TORCH_CHECK( + false, + c10::str("Backend ", getBackendName(), " does not support shrink")); + } + virtual void setTimeout(std::chrono::milliseconds timeout) { TORCH_CHECK( false, diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index 8074cc98a04f..a41f654b9ae2 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -259,6 +259,65 @@ std::shared_ptr NCCLComm::split( } #endif +#ifdef NCCL_HAS_COMM_SHRINK +std::shared_ptr NCCLComm::shrink( + NCCLComm* source, + std::vector& ranks_to_exclude, + ncclConfig_t* config, + int shrinkFlags) { + // Preconditions are validated in ProcessGroupNCCL::shrink + + LOG(INFO) << "Rank " << source->rank_ << ": shrinking comm " << source->repr() + << " excluding " << ranks_to_exclude.size() << " ranks"; + + at::cuda::OptionalCUDAGuard gpuGuard(source->deviceIndex_); + auto comm = std::make_shared(); + + // This call will block until the source communicator is initialized + auto sourceComm = source->getNcclComm(); + + C10D_NCCL_CHECK_NONBLOCKING( + ncclCommShrink( + sourceComm, + ranks_to_exclude.data(), + ranks_to_exclude.size(), + reinterpret_cast(&(comm->ncclComm_)), + config, + shrinkFlags), + source->getNcclCommFailureReason()); + + // Wait for the child communicator to be ready + source->waitReady(true); + comm->initialized_ = true; + + // NCCL automatically assigns rank during shrink - query it efficiently + int assigned_rank; + try { + C10D_NCCL_CHECK( + ncclCommUserRank(comm->ncclComm_, &assigned_rank), std::nullopt); + comm->rank_ = assigned_rank; + } catch (const std::exception& e) { + // Fallback: if ncclCommUserRank fails, we can't determine the rank + LOG(ERROR) << "Failed to query NCCL-assigned rank: " << e.what(); + throw; + } + + // Child comm should be on the same device as parent comm + comm->deviceIndex_ = source->deviceIndex_; + if (config != nullptr) { + comm->nonBlocking_ = config->blocking == 0; + } else { + // Inherit parent behavior if no config provided + comm->nonBlocking_ = source->nonBlocking_; + } + + LOG(INFO) << "Rank " << source->rank_ << ": created shrunken comm " + << comm->repr() << " with NCCL-assigned rank " << assigned_rank; + + return comm; +} +#endif + void NCCLComm::finalize() { LockType lock(mutex_); if (aborted_) { diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index fdd50f69ef3d..142633b82374 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -90,6 +90,10 @@ static_assert( #define NCCL_HAS_NVLS_CTAS #endif +#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 27, 0) +#define NCCL_HAS_COMM_SHRINK +#endif + // Macro to throw on a non-successful NCCL return value. #define C10D_NCCL_CHECK(cmd, failureReason) \ do { \ @@ -294,6 +298,14 @@ class NCCLComm { ncclConfig_t& config); #endif // NCCL_HAS_COMM_SPLIT +#ifdef NCCL_HAS_COMM_SHRINK + static std::shared_ptr shrink( + NCCLComm* source, + std::vector& ranks_to_exclude, + ncclConfig_t* config, + int shrinkFlags = 0); +#endif // NCCL_HAS_COMM_SHRINK + #if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP) std::unordered_map ncclCommDump(); #endif diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 9b615b9f16b0..1a63128f8ddf 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -165,7 +165,7 @@ ncclRedOpRAII getNcclReduceOp( } // Get a key string from device -inline std::string getKeyFromDevice(at::Device& device) { +inline std::string getKeyFromDevice(const at::Device& device) { return std::to_string(device.index()); } @@ -5838,6 +5838,139 @@ at::Tensor ProcessGroupNCCL::allocateTensor( return tensor; } +#ifdef NCCL_HAS_COMM_SHRINK +c10::intrusive_ptr ProcessGroupNCCL::shrink( + const std::vector& ranks_to_exclude, + int shrink_flags, + const c10::intrusive_ptr& opts_override) { + // Runtime version check with better error message + auto runtime_version = torch::cuda::nccl::version(); + TORCH_CHECK( + runtime_version >= NCCL_VERSION(2, 27, 0), + "ProcessGroupNCCL::shrink requires NCCL version 2.27.0 or later. " + "Found version: ", + runtime_version); + + // Early validation with detailed error messages + TORCH_CHECK_VALUE( + !ranks_to_exclude.empty(), "ranks_to_exclude cannot be empty"); + TORCH_CHECK_VALUE( + static_cast(ranks_to_exclude.size()) < size_, + "Cannot exclude all ranks (", + ranks_to_exclude.size(), + " >= ", + size_, + ")"); + + // Validate ranks and convert to int efficiently + std::vector int_ranks_to_exclude; + int_ranks_to_exclude.reserve(ranks_to_exclude.size()); + for (int64_t rank : ranks_to_exclude) { + TORCH_CHECK_VALUE( + rank >= 0 && rank < size_, + "Invalid rank ", + rank, + " for group size ", + size_); + int_ranks_to_exclude.push_back(static_cast(rank)); + } + + // Get primary communicator with better error context + auto primary_device_index = guessDeviceId(); + auto primary_device = at::Device(at::kCUDA, primary_device_index); + const auto primary_key = getKeyFromDevice(primary_device); + + std::shared_ptr primary_comm = getNCCLComm(primary_key); + TORCH_CHECK( + primary_comm, + "Primary NCCL communicator for device ", + primary_device, + " (key: ", + primary_key, + ") is not initialized"); + + // Cache device index before shrink operation + at::DeviceIndex parent_device_index = primary_comm->getDeviceIndex(); + + ncclConfig_t* config = nullptr; + // Default to inheriting from parent options + bool high_priority_stream = options_->is_high_priority_stream; + if (opts_override) { + auto nccl_opts = + c10::static_intrusive_pointer_cast( + opts_override); + config = &nccl_opts->config; + // If user provided override options, honor is_high_priority_stream as well + high_priority_stream = nccl_opts->is_high_priority_stream; + } + + std::shared_ptr shrunk_comm = NCCLComm::shrink( + primary_comm.get(), + int_ranks_to_exclude, + (config != nullptr ? config : &options_->config), + shrink_flags); + + // Calculate new size and get NCCL-assigned rank + int new_size = size_ - static_cast(ranks_to_exclude.size()); + int new_rank = shrunk_comm->rank_; + + // Create new ProcessGroupNCCL with optimized options cloning + auto new_store = store_->clone(); + auto new_opts = ProcessGroupNCCL::Options::create(high_priority_stream); + new_opts->timeout = options_->timeout; + if (config != nullptr) { + new_opts->config = *config; + } else { + new_opts->config = options_->config; + } + + auto new_pg = c10::make_intrusive( + new_store, new_rank, new_size, new_opts); + + // Set up the new process group with optimized device setup + new_pg->initializeDeviceStateForComm( + at::Device(at::kCUDA, parent_device_index), shrunk_comm); + + return c10::static_intrusive_pointer_cast(new_pg); +} + +#else // !NCCL_HAS_COMM_SHRINK +// Backend interface override: raise consistent error when shrink is +// unsupported. +c10::intrusive_ptr ProcessGroupNCCL::shrink( + const std::vector& /*ranks_to_exclude*/, + int /*shrink_flags*/, + const c10::intrusive_ptr& /*opts_override*/) { + TORCH_CHECK( + false, + "ProcessGroupNCCL::shrink requires NCCL version 2.27.0 or later, " + "but PyTorch was built with an older version or without NCCL shrink support."); +} + +#endif // NCCL_HAS_COMM_SHRINK + +void ProcessGroupNCCL::initializeDeviceStateForComm( + const at::Device& device, + std::shared_ptr comm) { + const auto key = getKeyFromDevice(device); + std::unique_lock lock(mutex_); + at::cuda::OptionalCUDAGuard gpuGuard(device); + + bool force_high = getCvarBool(TORCH_NCCL_HIGH_PRIORITY, false); + auto stream = at::cuda::getStreamFromPool( + options_->is_high_priority_stream || force_high); + + devNCCLCommMap_[key] = comm; + ncclStreams_.emplace(key, stream); + ncclEvents_.emplace(key, at::cuda::CUDAEvent(cudaEventDisableTiming)); + usedDeviceIdxs_.insert(device.index()); + + if (shouldAllCommunicatorsRegisterAllTensors()) { + std::lock_guard map_lock(ncclCommMemPoolMapMutex); + ncclCommMemPoolMap.emplace(std::move(comm), MemPoolSet{}); + } +} + } // namespace c10d #endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 286eab14d1a8..2ead1a107394 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -997,6 +997,21 @@ class TORCH_API ProcessGroupNCCL : public Backend { ErrorType getError() override; + bool supportsShrinking() const override { +#ifdef NCCL_HAS_COMM_SHRINK + return true; +#else + return false; +#endif + } + + // Backend-style shrink override that returns a Backend instance. + c10::intrusive_ptr shrink( + const std::vector& ranks_to_exclude, + int shrink_flags = 0, + const c10::intrusive_ptr& opts_override = + nullptr) override; + std::shared_ptr getMemAllocator() override; // Allocate tensor from communication-optimized memory pool @@ -1065,6 +1080,12 @@ class TORCH_API ProcessGroupNCCL : public Backend { int p2pRank = 0, bool isSendRecvSelf = false); + // Initialize device-specific state (comm, stream, event, bookkeeping) for a + // given communicator on this process group instance. + void initializeDeviceStateForComm( + const at::Device& device, + std::shared_ptr comm); + // Wrapper method which can be overridden for tests. virtual std::exception_ptr checkForNCCLErrors( std::shared_ptr& ncclComm); diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index bdf2576efbe7..f7d60e0cb62d 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -2730,12 +2730,23 @@ Arguments: "supports_time_estimate", &::c10d::Backend::supportsTimeEstimation, "(test whether the backend supports collective time estimation)") + .def_property_readonly( + "supports_shrinking", + &::c10d::Backend::supportsShrinking, + "(test whether the backend supports communicator shrinking)") .def( "set_timeout", &::c10d::Backend::setTimeout, py::arg("timeout"), py::call_guard(), R"(Sets the default timeout for all future operations.)") + .def( + "shrink", + &::c10d::Backend::shrink, + py::arg("ranks_to_exclude"), + py::arg("shrink_flags") = 0, + py::arg("opts_override") = nullptr, + py::call_guard()) .def( "broadcast", &::c10d::Backend::broadcast, diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index c39847176517..9156f51b6ad9 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -130,6 +130,7 @@ __all__ = [ "reduce_scatter_tensor", "get_node_local_rank", "split_group", + "shrink_group", ] _MPI_AVAILABLE = True @@ -5713,3 +5714,517 @@ def _get_process_group_name(pg: ProcessGroup) -> str: def _get_process_group_store(pg: ProcessGroup) -> Store: return _world.pg_map[pg][1] + + +# Shrink flags for process group backends +SHRINK_DEFAULT = 0x00 +SHRINK_ABORT = 0x01 + + +@_time_logger +def shrink_group( + ranks_to_exclude: list[int], + group: Optional[ProcessGroup] = None, + shrink_flags: int = SHRINK_DEFAULT, + pg_options: Optional[Any] = None, +) -> ProcessGroup: + """ + Shrinks a process group by excluding specified ranks. + + Creates and returns a new, smaller process group comprising only the ranks + from the original group that were not in the ``ranks_to_exclude`` list. + + Args: + ranks_to_exclude (List[int]): A list of ranks from the original + ``group`` to exclude from the new group. + group (ProcessGroup, optional): The process group to shrink. If ``None``, + the default process group is used. Defaults to ``None``. + shrink_flags (int, optional): Flags to control the shrinking behavior. + Can be ``SHRINK_DEFAULT`` (default) or ``SHRINK_ABORT``. + ``SHRINK_ABORT`` will attempt to terminate ongoing operations + in the parent communicator before shrinking. + Defaults to ``SHRINK_DEFAULT``. + pg_options (ProcessGroupOptions, optional): Backend-specific options to apply + to the shrunken process group. If provided, the backend will use + these options when creating the new group. If omitted, the new group + inherits defaults from the parent. + + Returns: + ProcessGroup: a new group comprised of the remaining ranks. If the + default group was shrunk, the returned group becomes the new default group. + + Raises: + TypeError: if the group’s backend does not support shrinking. + ValueError: if ``ranks_to_exclude`` is invalid (empty, out of bounds, + duplicates, or excludes all ranks). + RuntimeError: if an excluded rank calls this function or the backend + fails the operation. + + Notes: + - Only non-excluded ranks should call this function; excluded ranks + must not participate in the shrink operation. + - Shrinking the default group destroys all other process groups since + rank reassignment makes them inconsistent. + """ + # Step 1: Validate input parameters with comprehensive error checking + _validate_shrink_inputs(ranks_to_exclude, shrink_flags) + + # Step 2: Get target group and essential properties + target_group_info = _prepare_shrink_target_group(group) + + # Step 3: Validate backend requirements and availability + backend_impl = _validate_shrink_backend_requirements(target_group_info) + + # Step 4: Validate ranks against group and check for duplicates + excluded_ranks_set = _validate_and_process_excluded_ranks( + ranks_to_exclude, target_group_info + ) + + # Step 5: Execute the actual shrink operation (backend-specific) + new_backend = backend_impl.shrink( + sorted(excluded_ranks_set), + shrink_flags, + pg_options if pg_options is not None else None, + ) + + # Step 6: Handle cleanup and creation of new process group + target_group_info["pg_options_override"] = pg_options + return _finalize_shrunk_group(target_group_info, excluded_ranks_set, new_backend) + + +def _validate_shrink_inputs(ranks_to_exclude: list[int], shrink_flags: int) -> None: + """Validate input parameters for shrink_group.""" + if not isinstance(ranks_to_exclude, list): + raise TypeError( + f"ranks_to_exclude must be a list, but got {type(ranks_to_exclude).__name__}. " + f"Example: [1, 3, 5] to exclude ranks 1, 3, and 5." + ) + + if not ranks_to_exclude: + raise ValueError( + "ranks_to_exclude cannot be empty. To shrink a group, you must specify at least " + "one rank to exclude. Example: [failed_rank_id]" + ) + + # Validate shrink_flags with clear explanation of valid values + valid_flags = [SHRINK_DEFAULT, SHRINK_ABORT] + if not isinstance(shrink_flags, int) or shrink_flags not in valid_flags: + raise ValueError( + f"Invalid shrink_flags value: {shrink_flags}. Must be one of: " + f"SHRINK_DEFAULT ({SHRINK_DEFAULT}) or SHRINK_ABORT ({SHRINK_ABORT}). " + f"Use SHRINK_ABORT to abort ongoing operations before shrinking." + ) + + +def _prepare_shrink_target_group(group: Optional[ProcessGroup]) -> dict: + """Prepare and validate the target group for shrinking.""" + target_pg = group if group is not None else _get_default_group() + + # Cache frequently accessed properties to avoid repeated calls + group_size = int(target_pg.size()) + group_info = { + "process_group": target_pg, + "is_default_group": (target_pg == _get_default_group()), + "group_size": group_size, + "current_rank": target_pg.rank(), + "group_name": _get_process_group_name(target_pg), + } + + # Validate that we have a valid process group + if group_size <= 1: + raise ValueError( + f"Cannot shrink a process group with size {group_size}. " + f"Group must have at least 2 ranks to support shrinking." + ) + + return group_info + + +def _validate_shrink_backend_requirements(group_info: dict) -> Any: + """Return the backend implementation for the target group or raise if unsupported.""" + target_pg = group_info["process_group"] + group_name = group_info["group_name"] + + # Get the group's backend directly via ProcessGroup API. Prefer a bound device if present, + # otherwise try CUDA then fall back to CPU. + try: + preferred_device = getattr(target_pg, "bound_device_id", None) + if preferred_device is not None: + backend_impl = target_pg._get_backend(preferred_device) + else: + # Try CUDA first if available, else CPU + try: + backend_impl = target_pg._get_backend(torch.device("cuda")) + except Exception: + backend_impl = target_pg._get_backend(torch.device("cpu")) + except RuntimeError as e: + raise RuntimeError( + f"Cannot access device backend for process group '{group_name}'. " + f"Ensure the process group was initialized with a compatible device backend and devices are available." + ) from e + + try: + supports = bool(backend_impl.supports_shrinking) + except Exception: + supports = False + if not supports: + raise TypeError( + f"Process group backend for '{group_name}' does not support shrinking operations." + ) + + return backend_impl + + +def _validate_and_process_excluded_ranks( + ranks_to_exclude: list[int], group_info: dict +) -> set: + """Validate excluded ranks and convert to set for efficient operations.""" + group_size = group_info["group_size"] + current_rank = group_info["current_rank"] + + # Use set for O(1) duplicate detection and membership testing + excluded_ranks_set = set() + + # Validate each rank with detailed error messages + for i, rank in enumerate(ranks_to_exclude): + if not isinstance(rank, int): + raise TypeError( + f"All elements in ranks_to_exclude must be integers. " + f"Element at index {i} is {type(rank).__name__}: {rank}" + ) + + if not (0 <= rank < group_size): + raise ValueError( + f"Rank {rank} at index {i} is out of bounds for group size {group_size}. " + f"Valid ranks are in range [0, {group_size - 1}]." + ) + + if rank in excluded_ranks_set: + raise ValueError( + f"Duplicate rank {rank} found in ranks_to_exclude at index {i}. " + f"Each rank can only be excluded once." + ) + + excluded_ranks_set.add(rank) + + # Ensure we don't exclude all ranks + if len(excluded_ranks_set) >= group_size: + raise ValueError( + f"Cannot exclude all {group_size} ranks from process group. " + f"At least one rank must remain. Excluding {len(excluded_ranks_set)} ranks." + ) + + # Critical check: current rank should not be in excluded list + if current_rank in excluded_ranks_set: + raise RuntimeError( + f"Current rank {current_rank} is in the exclusion list and should not call shrink_group(). " + f"Only non-excluded ranks should participate in the shrinking operation. " + f"Excluded ranks should terminate their processes instead." + ) + + return excluded_ranks_set + + +def _finalize_shrunk_group( + group_info: dict, excluded_ranks_set: set, new_backend +) -> ProcessGroup: + """Clean up old group and create new shrunk process group.""" + target_pg = group_info["process_group"] + is_default_group = group_info["is_default_group"] + + # Handle default group dependencies - destroy other groups first + if is_default_group: + _destroy_all_other_groups(exclude_group=target_pg) + + # Gather original group metadata before cleanup + original_group_metadata = _extract_group_metadata(target_pg) + + # Calculate remaining ranks efficiently + original_ranks = get_process_group_ranks(target_pg) + remaining_ranks = [ + rank for rank in original_ranks if rank not in excluded_ranks_set + ] + + # Clean up the original group + _cleanup_original_group(target_pg, is_default_group) + + # Create and configure the new process group + new_pg = _create_shrunk_process_group( + new_backend, remaining_ranks, original_group_metadata, is_default_group + ) + + # Register the new group in global state + if is_default_group: + _update_default_pg(new_pg) + + # Update global state with new group information + rank_mapping = { + global_rank: group_rank + for group_rank, global_rank in enumerate(remaining_ranks) + } + _update_process_group_global_state( + pg=new_pg, + backend_name=original_group_metadata["backend_name"], + store=original_group_metadata["store"], + group_name=original_group_metadata["new_group_name"], + backend_config=original_group_metadata["backend_config"], + rank_mapping=rank_mapping, + ) + + return new_pg + + +def _extract_group_metadata(target_pg: ProcessGroup) -> dict: + """Extract metadata from the original group before cleanup.""" + original_backend_name, original_store = _world.pg_map[target_pg] + original_backend_config = _world.pg_backend_config.get(target_pg, "") + original_group_name = _get_process_group_name(target_pg) + + # Extract device binding information before cleanup to avoid accessing destroyed group + bound_device_id = None + if hasattr(target_pg, "bound_device_id"): + bound_device_id = target_pg.bound_device_id + + # Generate new group name for the shrunk group; hash for uniqueness across backends + remaining_ranks = list(get_process_group_ranks(target_pg)) + new_group_name = _process_group_name(remaining_ranks, use_hashed_name=True) + + return { + "backend_name": original_backend_name, + "store": original_store, + "backend_config": original_backend_config, + "original_group_name": original_group_name, + "new_group_name": new_group_name, + "bound_device_id": bound_device_id, # Safe to access after cleanup + } + + +def _cleanup_original_group(target_pg: ProcessGroup, is_default_group: bool) -> None: + """Clean up the original process group safely.""" + try: + destroy_process_group(target_pg) + except Exception as e: + group_type = "default" if is_default_group else "non-default" + logger.warning("Failed to destroy %s group during shrinking: %s", group_type, e) + + # Ensure global state cleanup even if destroy_process_group fails + _cleanup_process_group_global_state(target_pg) + + +def _create_shrunk_process_group( + new_backend, remaining_ranks: list[int], metadata: dict, is_default_group: bool +) -> ProcessGroup: + """Create and configure the new shrunk process group.""" + # Create new group properties + new_group_rank = new_backend.rank() + new_group_size = new_backend.size() + group_name = metadata["new_group_name"] + + # Generate descriptive group description + if is_default_group: + group_desc = "default:shrunken" + else: + group_desc = f"{metadata['original_group_name']}:shrunk" + + # Create process group with new communicator (clone the parent store like split does) + prefix_store = PrefixStore(f"{group_name}/", metadata["store"].clone()) + new_pg = ProcessGroup(prefix_store, new_group_rank, new_group_size) + + # Configure backend using the device type of the new backend's bound device if available, + # otherwise derive from the original group's bound device or fall back to CPU. + backend_device = metadata.get("bound_device_id") + if backend_device is None: + # Default to CPU if no bound device is present + backend_device = torch.device("cpu") + + # Choose backend enum based on device type + if backend_device.type == "cuda": + backend_type = ProcessGroup.BackendType.NCCL + else: + backend_type = ProcessGroup.BackendType.GLOO + + new_pg._register_backend(backend_device, backend_type, new_backend) + new_pg._set_default_backend(backend_type) + + # Inherit device binding from original group if it was bound + bound_device_id = metadata.get("bound_device_id") + if bound_device_id is not None: + new_pg.bound_device_id = bound_device_id + + # Set group metadata + new_pg._set_group_name(group_name) + new_pg._set_group_desc(group_desc) + + # Persist backend configuration overrides (if provided via shrink_group) + backend_config_override = metadata.get("backend_config") + if backend_config_override is not None: + # Store for introspection/debugging and potential backend hooks + _world.pg_backend_config[new_pg] = backend_config_override + + return new_pg + + +def _destroy_all_other_groups(exclude_group: Optional[ProcessGroup] = None) -> None: + """ + Destroy all process groups except the excluded group and clean up all global state. + + This is necessary when shrinking the default group because global ranks + are reassigned by NCCL, making all existing process groups inconsistent. + + Note: Uses abort for non-collective cleanup since excluded ranks may not + participate in collective operations. Backend cleanup is handled independently per group. + + Args: + exclude_group (ProcessGroup, optional): Process group to exclude from destruction. + If None, destroys all process groups. + """ + # Get list of groups to destroy (avoid modifying dict while iterating) + groups_to_destroy = [] + for pg in list(_world.pg_group_ranks.keys()): + if exclude_group is not None and pg == exclude_group: + continue + groups_to_destroy.append(pg) + + # Warn user about automatic destruction + if groups_to_destroy: + group_names = [_get_process_group_name(pg) for pg in groups_to_destroy] + logger.warning( + "Shrinking default group will destroy %d other process groups: %s. " + "This is necessary because shrinking the default group reassigns global ranks, " + "making existing groups inconsistent.", + len(groups_to_destroy), + ", ".join(group_names), + ) + + # Destroy each group and clean up global state + for pg in groups_to_destroy: + try: + # First call abort_process_group which handles the C++ cleanup non-collectively + _abort_process_group(pg) + except Exception as e: + # Log but don't fail - some groups might already be destroyed + logger.warning( + "Failed to abort process group %s: %s", + _get_process_group_name(pg), + e, + ) + + # Ensure all global state is cleaned up even if _abort_process_group fails + # or doesn't clean up everything + _cleanup_process_group_global_state(pg) + + +def _cleanup_process_group_global_state(pg: ProcessGroup) -> None: + """ + Clean up all global state associated with a process group. + + This function ensures complete cleanup of process group state from all + global dictionaries and registries, even if destroy_process_group fails + or doesn't clean up everything. This is critical when destroying multiple + groups to prevent inconsistent state. + + The cleanup removes the process group from: + - _world.pg_map (backend and store mapping) + - _world.pg_names (group name mapping) + - _world.pg_group_ranks (rank mappings) + - _world.pg_backend_config (backend configuration) + - _world.tags_to_pg and _world.pg_to_tag (tag mappings) + - _world.pg_coalesce_state (coalescing state) + - C++ internal registries via _unregister_process_group + + Args: + pg (ProcessGroup): The process group to clean up. + """ + try: + # Clean up main process group mappings + _world.pg_map.pop(pg, None) + _world.pg_group_ranks.pop(pg, None) + _world.pg_backend_config.pop(pg, None) + + # Clean up process group name mapping + group_name = _world.pg_names.pop(pg, None) + + # Clean up tag mappings + pg_tag = _world.pg_to_tag.pop(pg, None) + if pg_tag is not None and pg_tag in _world.tags_to_pg: + try: + _world.tags_to_pg[pg_tag].remove(pg) + # Remove the tag entry if list is empty + if not _world.tags_to_pg[pg_tag]: + _world.tags_to_pg.pop(pg_tag, None) + except (ValueError, KeyError): + # Process group was already removed from the list + pass + + # Clean up any registered process group names using C++ unregister function + if group_name is not None: + try: + _unregister_process_group(group_name) + except Exception: + # Process group name might not be registered or already unregistered + pass + + # Clean up coalesce state if present + _world.pg_coalesce_state.pop(pg, None) + + except Exception as e: + # Log cleanup failures but don't propagate - we want to continue with other cleanups + logger.warning("Failed to fully clean up global state for process group: %s", e) + + +def _update_process_group_global_state( + pg: ProcessGroup, + backend_name: str, + store: Store, + group_name: str, + backend_config: str, + rank_mapping: Optional[dict[int, int]] = None, + pg_tag: Optional[str] = None, + user_tag: Optional[str] = None, +) -> None: + """ + Update all global state dictionaries for a process group. + + This helper function consolidates the common pattern of updating multiple + global state dictionaries when creating or modifying process groups. + + Args: + pg (ProcessGroup): The process group to update state for. + backend_name (str): Backend name for pg_map. + store (Store): Store instance for pg_map. + group_name (str): Group name for pg_names and registration. + backend_config (str): Backend configuration string. + rank_mapping (Dict[int, int], optional): Global rank to group rank mapping. + If None, skips updating pg_group_ranks. + pg_tag (str, optional): Process group tag. If None, defaults to f"ptd:{group_name}". + user_tag (str, optional): User-provided tag for special tag handling. + If provided, creates "user:{user_tag}" tag and also adds to default "". + """ + # Update main process group mappings + _world.pg_map[pg] = (backend_name, store) + _world.pg_names[pg] = group_name + _world.pg_backend_config[pg] = backend_config + + # Register the process group name + _register_process_group(group_name, pg) + + # Update rank mapping if provided + if rank_mapping is not None: + _world.pg_group_ranks[pg] = rank_mapping + + # Handle tag management + if pg_tag is None: + pg_tag = f"ptd:{group_name}" + + if user_tag is not None: + # Special handling for user-provided tags + # Add to default "" tag first + _world.tags_to_pg.setdefault("", []).append(pg) + # Then create user-specific tag + user_pg_tag = f"user:{user_tag}" + _world.tags_to_pg.setdefault(user_pg_tag, []).append(pg) + _world.pg_to_tag[pg] = user_pg_tag + else: + # Standard process group tag + _world.tags_to_pg.setdefault(pg_tag, []).append(pg) + _world.pg_to_tag[pg] = pg_tag diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 18384b311b93..91f09adf9e81 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -238,6 +238,47 @@ def skip_if_lt_x_gpu(x): return decorator +def requires_world_size(n: int): + """ + Decorator to request a specific world size for a test. The test harness can + read this attribute to set the number of ranks to spawn. If there are fewer + than `n` CUDA devices available, the test should be skipped by the harness. + + Usage: + @require_world_size(3) + def test_something(self): + ... + """ + + def decorator(func): + func._required_world_size = n + available = torch.cuda.device_count() + return unittest.skipUnless( + available >= n, f"requires {n} GPUs, found {available}" + )(func) + + return decorator + + +def get_required_world_size(obj: Any, default: int) -> int: + """ + Returns the requested world size for the currently running unittest method on `obj` + if annotated via `@require_world_size(n)`, else returns `default`. + """ + try: + # Try MultiProcessTestCase helper first, then unittest fallback + test_name = ( + obj._current_test_name() # type: ignore[attr-defined] + if hasattr(obj, "_current_test_name") and callable(obj._current_test_name) + else obj._testMethodName + ) + fn = getattr(obj, test_name) + value = fn._required_world_size + return int(value) + except Exception: + return default + + # This decorator helps avoiding initializing cuda while testing other backends def nccl_skip_if_lt_x_gpu(backend, x): def decorator(func): @@ -367,6 +408,13 @@ def requires_nccl_version(version, msg): ) +def requires_nccl_shrink(): + """ + Require NCCL shrink support (NCCL available and version >= 2.27). + """ + return requires_nccl_version((2, 27), "Need NCCL 2.27+ for shrink_group") + + def requires_nccl(): return skip_but_pass_in_sandcastle_if( not c10d.is_nccl_available(),