mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
shrink_group implementation to expose ncclCommShrink API (#164518)
Closes #164529 To expose the new [ncclCommShrink](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommshrink) API to PyTorch. This is useful when you need to exclude certain GPUs or nodes from a collective operation, for example in fault tolerance scenarios or when dynamically adjusting resource utilization. For more info: [Shrinking a communicator](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#shrinking-a-communicator) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164518 Approved by: https://github.com/kwen2501
This commit is contained in:
committed by
PyTorch MergeBot
parent
15ff1cd28b
commit
fa0db212e7
@ -394,6 +394,10 @@ an opaque group handle that can be given as a `group` argument to all collective
|
||||
.. autofunction:: new_group
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: torch.distributed.distributed_c10d.shrink_group
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: get_group_rank
|
||||
```
|
||||
|
43
test/distributed/logging_utils.py
Normal file
43
test/distributed/logging_utils.py
Normal file
@ -0,0 +1,43 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
|
||||
_start_time = time.time()
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _ts():
|
||||
return time.time() - _start_time
|
||||
|
||||
|
||||
def configure(level=logging.INFO, force=False):
|
||||
try:
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="%(asctime)s %(name)s %(levelname)s: %(message)s",
|
||||
force=force,
|
||||
)
|
||||
except TypeError:
|
||||
logging.basicConfig(
|
||||
level=level, format="%(asctime)s %(name)s %(levelname)s: %(message)s"
|
||||
)
|
||||
|
||||
|
||||
def log_test_info(rank, message):
|
||||
_logger.info("[%7.3fs][Rank %s] %s", _ts(), rank, message)
|
||||
|
||||
|
||||
def log_test_success(rank, message):
|
||||
_logger.info("[%7.3fs][Rank %s] ✅ %s", _ts(), rank, message)
|
||||
|
||||
|
||||
def log_test_validation(rank, message):
|
||||
_logger.info("[%7.3fs][Rank %s] ✓ %s", _ts(), rank, message)
|
||||
|
||||
|
||||
def log_test_warning(rank, message):
|
||||
_logger.warning("[%7.3fs][Rank %s] ⚠️ %s", _ts(), rank, message)
|
||||
|
||||
|
||||
def log_test_error(rank, message):
|
||||
_logger.error("[%7.3fs][Rank %s] ✗ %s", _ts(), rank, message)
|
@ -2,6 +2,7 @@
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
@ -21,6 +22,7 @@ from unittest import mock, SkipTest
|
||||
import torch
|
||||
import torch.distributed as c10d
|
||||
import torch.distributed._functional_collectives as _functional_collectives
|
||||
from torch.distributed.distributed_c10d import SHRINK_ABORT as NCCL_SHRINK_ABORT
|
||||
|
||||
|
||||
if not c10d.is_available() or not c10d.is_nccl_available():
|
||||
@ -47,12 +49,15 @@ from torch._C._distributed_c10d import ErrorType, OpType, WorkResult
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.testing._internal.common_cuda import _get_torch_rocm_version, TEST_MULTIGPU
|
||||
from torch.testing._internal.common_distributed import (
|
||||
get_required_world_size,
|
||||
get_timeout,
|
||||
init_multigpu_helper,
|
||||
MultiProcessTestCase,
|
||||
requires_multicast_support,
|
||||
requires_nccl,
|
||||
requires_nccl_shrink,
|
||||
requires_nccl_version,
|
||||
requires_world_size,
|
||||
skip_if_lt_x_gpu,
|
||||
skip_if_rocm_multiprocess,
|
||||
sm_is_or_higher_than,
|
||||
@ -87,6 +92,17 @@ BFLOAT16_AVAILABLE = torch.cuda.is_available() and (
|
||||
torch.version.cuda is not None or torch.version.hip is not None
|
||||
)
|
||||
|
||||
from logging_utils import (
|
||||
configure as _log_configure,
|
||||
log_test_info,
|
||||
log_test_success,
|
||||
log_test_validation,
|
||||
log_test_warning,
|
||||
)
|
||||
|
||||
|
||||
_log_configure(level=logging.INFO, force=True)
|
||||
|
||||
|
||||
class RendezvousEnvTest(TestCase):
|
||||
@retry_on_connect_failures
|
||||
@ -317,7 +333,7 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return 2
|
||||
return get_required_world_size(self, 2)
|
||||
|
||||
@property
|
||||
def rank_to_GPU(self):
|
||||
@ -1255,6 +1271,628 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||
pg_2 = c10d.new_group([0, 1])
|
||||
self.assertEqual(pg_2.group_desc, "undefined")
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_basic(self):
|
||||
"""Test basic shrink_group functionality."""
|
||||
self._perform_shrink_test([1], "Basic shrink test")
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_validation(self):
|
||||
"""Test input validation in shrink_group."""
|
||||
device, pg = self._setup_shrink_test("validation")
|
||||
|
||||
def _test_invalid_input(ranks, description, expected_exception):
|
||||
"""Helper to test invalid inputs."""
|
||||
try:
|
||||
c10d.shrink_group(ranks)
|
||||
self.fail(f"Expected {expected_exception.__name__} for {description}")
|
||||
except expected_exception:
|
||||
log_test_validation(self.rank, f"✓ {description}")
|
||||
except Exception:
|
||||
if expected_exception is Exception: # Accept any exception
|
||||
log_test_validation(self.rank, f"✓ {description}")
|
||||
else:
|
||||
raise
|
||||
|
||||
# Test cases
|
||||
_test_invalid_input([], "Empty exclusion list", ValueError)
|
||||
if self.world_size > 1:
|
||||
_test_invalid_input([0, 0, 1], "Duplicate ranks", Exception)
|
||||
_test_invalid_input([self.world_size + 1], "Out of bounds rank", Exception)
|
||||
|
||||
log_test_success(self.rank, "All validation tests passed")
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_backend_properties(self):
|
||||
"""Test that backend properties are preserved after shrinking."""
|
||||
|
||||
test_name = "Backend Properties Test"
|
||||
ranks_to_exclude = [0]
|
||||
|
||||
# Reuse _setup_shrink_test for complete setup (device, environment, and process group)
|
||||
device, pg = self._setup_shrink_test("backend_properties")
|
||||
|
||||
# Follow _perform_shrink_test pattern from here
|
||||
log_test_info(self.rank, f"{test_name} (world_size={self.world_size})")
|
||||
|
||||
is_excluded = self.rank in ranks_to_exclude
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}",
|
||||
)
|
||||
|
||||
# Store original backend property values (not references) before shrinking
|
||||
original_timeout = None
|
||||
original_high_priority = None
|
||||
if not is_excluded:
|
||||
original_backend = pg._get_backend(device)
|
||||
original_timeout = original_backend.options._timeout
|
||||
original_high_priority = original_backend.options.is_high_priority_stream
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Storing original backend properties: timeout={original_timeout}, high_priority={original_high_priority}",
|
||||
)
|
||||
|
||||
if is_excluded:
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Excluded rank {self.rank} - setup complete, skipping shrink operation",
|
||||
)
|
||||
dist.destroy_process_group() # hang without it
|
||||
return
|
||||
|
||||
# Only non-excluded ranks proceed with shrink (same as _perform_shrink_test)
|
||||
log_test_info(self.rank, "Non-excluded rank calling shrink_group")
|
||||
shrunk_pg = c10d.shrink_group(ranks_to_exclude)
|
||||
|
||||
# Reuse _validate_shrunk_group helper (same as _perform_shrink_test)
|
||||
expected_size = self.world_size - len(ranks_to_exclude)
|
||||
_ = self._validate_shrunk_group(shrunk_pg, expected_size, test_name)
|
||||
|
||||
# Add custom backend properties validation
|
||||
new_backend = shrunk_pg._get_backend(device)
|
||||
log_test_info(self.rank, "Validating backend properties are preserved")
|
||||
|
||||
new_timeout = new_backend.options._timeout
|
||||
new_high_priority = new_backend.options.is_high_priority_stream
|
||||
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Timeout comparison - original: {original_timeout}, new: {new_timeout}",
|
||||
)
|
||||
self.assertEqual(
|
||||
original_timeout, new_timeout, f"{test_name}: timeout not preserved"
|
||||
)
|
||||
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"High priority stream comparison - original: {original_high_priority}, new: {new_high_priority}",
|
||||
)
|
||||
self.assertEqual(
|
||||
original_high_priority,
|
||||
new_high_priority,
|
||||
f"{test_name}: high_priority_stream not preserved",
|
||||
)
|
||||
|
||||
log_test_validation(
|
||||
self.rank, f"{test_name}: Backend properties preserved successfully"
|
||||
)
|
||||
log_test_success(
|
||||
self.rank, f"{test_name} successful (shrink + backend validation)"
|
||||
)
|
||||
|
||||
# Cleanup (same as _perform_shrink_test)
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_multiple_comms(self):
|
||||
"""Test shrink_group with multiple communicators and subgroup invalidation."""
|
||||
|
||||
device, pg = self._setup_shrink_test("multiple_comms")
|
||||
|
||||
# Create subgroup [0, 1] and test shrinking it
|
||||
subgroup = c10d.new_group([0, 1])
|
||||
if self.rank <= 1:
|
||||
# Shrink subgroup: exclude rank 1
|
||||
if self.rank == 0: # Only rank 0 remains
|
||||
shrunk_subgroup = c10d.shrink_group([1], group=subgroup)
|
||||
self.assertEqual(shrunk_subgroup.size(), 1)
|
||||
# Test communication on shrunk subgroup
|
||||
tensor = torch.full((1,), self.rank).cuda(device)
|
||||
c10d.all_reduce(tensor, group=shrunk_subgroup)
|
||||
self.assertEqual(tensor.item(), 0) # Only rank 0
|
||||
log_test_success(self.rank, "Subgroup shrinking successful")
|
||||
|
||||
dist.barrier() # Sync before default group test
|
||||
|
||||
# Shrink default group: exclude last rank
|
||||
ranks_to_exclude = [self.world_size - 1]
|
||||
if self.rank not in ranks_to_exclude:
|
||||
shrunk_default = c10d.shrink_group(ranks_to_exclude)
|
||||
expected_size = self.world_size - 1
|
||||
self.assertEqual(shrunk_default.size(), expected_size)
|
||||
|
||||
# Test collective on shrunk default group
|
||||
tensor = torch.full((1,), self.rank).cuda(device)
|
||||
c10d.all_reduce(tensor, group=shrunk_default)
|
||||
expected_sum = sum(
|
||||
range(self.world_size - 1)
|
||||
) # 0 + 1 + ... + (world_size-2)
|
||||
self.assertEqual(tensor.item(), expected_sum)
|
||||
log_test_success(self.rank, "Default group shrinking successful")
|
||||
|
||||
# Note: After shrinking default group, the old subgroup is invalid
|
||||
# due to global rank reassignment
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
def _test_shrink_group_with_flag(self, shrink_flag, flag_name, rank_to_exclude):
|
||||
"""Helper method to test shrink_group with a specific flag."""
|
||||
if self.world_size < 2:
|
||||
log_test_info(self.rank, f"Skipping (needs ≥2 GPUs, got {self.world_size})")
|
||||
return
|
||||
ranks_to_exclude = [rank_to_exclude]
|
||||
log_test_info(self.rank, f"Using {flag_name} flag (value: {shrink_flag})")
|
||||
if flag_name == "NCCL_SHRINK_ABORT":
|
||||
log_test_info(
|
||||
self.rank,
|
||||
"ABORT flag will terminate ongoing operations before shrinking",
|
||||
)
|
||||
|
||||
self._perform_shrink_test(
|
||||
ranks_to_exclude, f"{flag_name} flag test", shrink_flags=shrink_flag
|
||||
)
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_flags(self):
|
||||
"""Test shrink_group with different shrink flags."""
|
||||
# Test ABORT flags
|
||||
log_test_info(self.rank, "Testing NCCL_SHRINK_ABORT flag")
|
||||
self._test_shrink_group_with_flag(NCCL_SHRINK_ABORT, "NCCL_SHRINK_ABORT", 1)
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_nccl_config(self):
|
||||
"""Verify that passing NCCL config via pg_options influences the shrunk group's backend options."""
|
||||
device, pg = self._setup_shrink_test("config")
|
||||
if self.rank == self.world_size - 1:
|
||||
# excluded rank should not call shrink_group
|
||||
dist.destroy_process_group()
|
||||
return
|
||||
|
||||
# Prepare pg_options with NCCL config overrides
|
||||
# Capture parent's current backend options to ensure we can prove override vs inherit
|
||||
parent_backend = pg._get_backend(torch.device("cuda"))
|
||||
parent_hp = parent_backend.options.is_high_priority_stream
|
||||
parent_blocking = parent_backend.options.config.blocking
|
||||
|
||||
# Choose overrides that differ from the parent (flip where possible)
|
||||
override_hp = not parent_hp
|
||||
if parent_blocking in (0, 1):
|
||||
override_blocking = 1 - parent_blocking
|
||||
else:
|
||||
# If undefined or unexpected, set to 1 which is a concrete value
|
||||
override_blocking = 1
|
||||
|
||||
opts = c10d.ProcessGroupNCCL.Options()
|
||||
opts.is_high_priority_stream = override_hp
|
||||
opts.config.blocking = override_blocking
|
||||
|
||||
shrunk_pg = c10d.shrink_group([self.world_size - 1], pg_options=opts)
|
||||
|
||||
# Validate backend options propagated
|
||||
backend = shrunk_pg._get_backend(torch.device("cuda"))
|
||||
# is_high_priority_stream should exactly match our override and differ from parent
|
||||
self.assertEqual(backend.options.is_high_priority_stream, override_hp)
|
||||
self.assertNotEqual(backend.options.is_high_priority_stream, parent_hp)
|
||||
# config is a struct; check representative field and difference from parent when meaningful
|
||||
self.assertEqual(backend.options.config.blocking, override_blocking)
|
||||
if parent_blocking in (0, 1):
|
||||
self.assertNotEqual(backend.options.config.blocking, parent_blocking)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_performance(self):
|
||||
"""Test shrink_group performance and regression detection."""
|
||||
import time
|
||||
|
||||
ranks_to_exclude = self._get_default_ranks_to_exclude()
|
||||
is_excluded = self.rank in ranks_to_exclude
|
||||
|
||||
if not ranks_to_exclude:
|
||||
log_test_info(self.rank, "Skipping performance test (world_size=1)")
|
||||
return
|
||||
|
||||
log_test_info(self.rank, f"Performance test with {self.world_size} processes")
|
||||
device, pg = self._setup_shrink_test("performance")
|
||||
|
||||
if not is_excluded:
|
||||
log_test_info(self.rank, "Measuring shrink_group performance")
|
||||
start_time = time.time()
|
||||
shrunk_pg = c10d.shrink_group(ranks_to_exclude)
|
||||
end_time = time.time()
|
||||
|
||||
elapsed_time = end_time - start_time
|
||||
log_test_info(self.rank, f"shrink_group: {elapsed_time:.3f}s")
|
||||
|
||||
# Regression check: should complete within reasonable time
|
||||
self.assertLess(
|
||||
elapsed_time,
|
||||
30.0,
|
||||
f"shrink_group took {elapsed_time:.3f}s, possible regression",
|
||||
)
|
||||
|
||||
# Test collective performance
|
||||
expected_size = self.world_size - len(ranks_to_exclude)
|
||||
self._validate_shrunk_group(shrunk_pg, expected_size, "performance")
|
||||
|
||||
collective_start = time.time()
|
||||
_ = self._test_collective_on_shrunk_group(
|
||||
shrunk_pg, device, ranks_to_exclude, "performance"
|
||||
)
|
||||
collective_time = time.time() - collective_start
|
||||
|
||||
log_test_info(self.rank, f"all_reduce: {collective_time:.3f}s")
|
||||
log_test_success(self.rank, "Performance test passed")
|
||||
else:
|
||||
log_test_info(self.rank, "Excluded rank - waiting")
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(4)
|
||||
def test_shrink_group_multiple_exclusions(self):
|
||||
"""Test shrink_group with multiple ranks excluded at once."""
|
||||
# Scale exclusions with world size
|
||||
ranks_to_exclude = list(range(2, self.world_size, 2)) # Every other rank from 2
|
||||
|
||||
self._perform_shrink_test(ranks_to_exclude, "Multiple exclusions test")
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(3)
|
||||
def test_shrink_group_multiple_iterations(self):
|
||||
"""Test multiple shrink operations in sequence."""
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Starting test_shrink_group_multiple_iterations with world_size={self.world_size}",
|
||||
)
|
||||
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
_ = self._create_process_group_nccl(store, self.opts(), device_id=device)
|
||||
|
||||
# Track current effective world size throughout shrinking operations
|
||||
current_world_size = self.world_size
|
||||
log_test_info(self.rank, f"Initial world_size: {current_world_size}")
|
||||
|
||||
# First shrinking: exclude the last rank(s)
|
||||
first_exclusion = [self.world_size - 1]
|
||||
if self.world_size >= 6:
|
||||
first_exclusion.append(
|
||||
self.world_size - 2
|
||||
) # Exclude last two ranks for larger sizes
|
||||
|
||||
log_test_info(self.rank, f"First shrinking: excluding ranks {first_exclusion}")
|
||||
|
||||
if self.rank not in first_exclusion:
|
||||
# Only non-excluded ranks should call shrink_group
|
||||
first_pg = c10d.shrink_group(first_exclusion)
|
||||
self.assertIsNotNone(first_pg)
|
||||
# IMPORTANT: Update world size after first shrinking
|
||||
current_world_size = first_pg.size()
|
||||
expected_first_size = self.world_size - len(first_exclusion)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"After first shrinking: world_size {self.world_size} -> {current_world_size}",
|
||||
)
|
||||
self.assertEqual(first_pg.size(), expected_first_size)
|
||||
|
||||
# Second shrinking: exclude another rank from the remaining group
|
||||
# Choose a rank that's in the middle range
|
||||
if current_world_size >= 3:
|
||||
second_exclusion = [
|
||||
current_world_size - 1
|
||||
] # Exclude the new "last" rank
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Second shrinking from group of size {current_world_size}: excluding ranks {second_exclusion}",
|
||||
)
|
||||
|
||||
if self.rank not in second_exclusion:
|
||||
# Only non-excluded ranks should call shrink_group for second iteration
|
||||
second_pg = c10d.shrink_group(second_exclusion, group=first_pg)
|
||||
self.assertIsNotNone(second_pg)
|
||||
# IMPORTANT: Update world size after second shrinking
|
||||
final_world_size = second_pg.size()
|
||||
expected_final_size = current_world_size - len(second_exclusion)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"After second shrinking: world_size {current_world_size} -> {final_world_size}",
|
||||
)
|
||||
self.assertEqual(second_pg.size(), expected_final_size)
|
||||
|
||||
# Test collective on final group
|
||||
tensor = torch.full((1,), self.rank).cuda(device)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Performing all_reduce on final group (size {final_world_size}) with tensor: {tensor.item()}",
|
||||
)
|
||||
c10d.all_reduce(tensor, group=second_pg)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Final all_reduce completed, result: {tensor.item()}",
|
||||
)
|
||||
|
||||
# Calculate expected sum of remaining ranks
|
||||
all_excluded = set(first_exclusion + second_exclusion)
|
||||
remaining_ranks = [
|
||||
r for r in range(self.world_size) if r not in all_excluded
|
||||
]
|
||||
expected_sum = sum(remaining_ranks)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Remaining ranks: {remaining_ranks}, expected sum: {expected_sum}, actual: {tensor.item()}",
|
||||
)
|
||||
self.assertEqual(tensor.item(), expected_sum)
|
||||
log_test_info(self.rank, "Final verification passed")
|
||||
else:
|
||||
log_test_info(
|
||||
self.rank,
|
||||
"This rank excluded in second shrinking, not calling shrink_group",
|
||||
)
|
||||
else:
|
||||
log_test_info(
|
||||
self.rank, "Skipping second shrinking (remaining group too small)"
|
||||
)
|
||||
else:
|
||||
log_test_info(
|
||||
self.rank,
|
||||
"This rank excluded in first shrinking, not calling shrink_group",
|
||||
)
|
||||
|
||||
log_test_info(self.rank, "Destroying process group")
|
||||
dist.destroy_process_group()
|
||||
log_test_info(self.rank, "test_shrink_group_multiple_iterations completed")
|
||||
|
||||
# Helper methods for optimized shrink group tests
|
||||
def _setup_shrink_test(self, test_suffix, world_size=None, warmup=True):
|
||||
"""Common setup for shrink group tests."""
|
||||
os.environ["TORCH_NCCL_USE_COMM_NONBLOCKING"] = "1"
|
||||
world_size = world_size or self.world_size
|
||||
store = c10d.FileStore(self.file_name + f"_{test_suffix}", world_size)
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
c10d.init_process_group(
|
||||
"nccl",
|
||||
world_size=world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
pg_options=self.opts(),
|
||||
device_id=device,
|
||||
)
|
||||
pg = c10d.distributed_c10d._get_default_group()
|
||||
|
||||
if warmup:
|
||||
c10d.all_reduce(torch.ones(1).cuda(device), group=pg)
|
||||
|
||||
return device, pg
|
||||
|
||||
def _validate_shrunk_group(self, shrunk_pg, expected_size, test_name=""):
|
||||
"""Validate properties of a shrunk process group."""
|
||||
self.assertIsNotNone(shrunk_pg, f"{test_name}: shrunk_pg should not be None")
|
||||
actual_size = shrunk_pg.size()
|
||||
self.assertEqual(
|
||||
actual_size, expected_size, f"{test_name}: group size mismatch"
|
||||
)
|
||||
|
||||
new_rank = shrunk_pg.rank()
|
||||
self.assertTrue(
|
||||
0 <= new_rank < expected_size, f"{test_name}: invalid new rank {new_rank}"
|
||||
)
|
||||
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"{test_name}: world_size {self.world_size} -> {actual_size}, rank {self.rank} -> {new_rank}",
|
||||
)
|
||||
return new_rank
|
||||
|
||||
def _test_collective_on_shrunk_group(
|
||||
self, shrunk_pg, device, ranks_to_exclude, test_name=""
|
||||
):
|
||||
"""Test collective communication on shrunk group and verify correctness."""
|
||||
test_tensor = torch.full((1,), self.rank, device=device, dtype=torch.float32)
|
||||
c10d.all_reduce(test_tensor, group=shrunk_pg)
|
||||
|
||||
result = test_tensor.item()
|
||||
expected_sum = sum(
|
||||
r for r in range(self.world_size) if r not in ranks_to_exclude
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
result, expected_sum, f"{test_name}: collective result mismatch"
|
||||
)
|
||||
log_test_info(
|
||||
self.rank, f"{test_name}: collective passed ({result} == {expected_sum})"
|
||||
)
|
||||
return result
|
||||
|
||||
def _perform_shrink_test(
|
||||
self, ranks_to_exclude, test_name, shrink_flags=0, with_collective=True
|
||||
):
|
||||
"""Complete shrink test flow: setup, shrink, validate, test collective, cleanup.
|
||||
|
||||
Consistent API: All ranks perform setup to initialize distributed environment.
|
||||
ONLY non-excluded ranks call shrink_group() for both default and non-default groups.
|
||||
Excluded ranks perform setup, then exit without calling shrink_group() or waiting.
|
||||
"""
|
||||
log_test_info(self.rank, f"{test_name} (world_size={self.world_size})")
|
||||
|
||||
is_excluded = self.rank in ranks_to_exclude
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}",
|
||||
)
|
||||
|
||||
# All ranks (including excluded ones) perform setup to initialize distributed environment
|
||||
device, pg = self._setup_shrink_test(test_name.lower().replace(" ", "_"))
|
||||
is_default_group = pg == c10d.distributed_c10d._get_default_group()
|
||||
|
||||
if is_excluded:
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Excluded rank {self.rank} - setup complete, skipping shrink operation",
|
||||
)
|
||||
if shrink_flags & NCCL_SHRINK_ABORT:
|
||||
log_test_info(self.rank, f"Using abort for excluded rank {self.rank}")
|
||||
pg._get_backend(torch.device(device)).abort()
|
||||
log_test_info(
|
||||
self.rank, f"cleanup resources for excluded rank {self.rank}"
|
||||
)
|
||||
dist.destroy_process_group()
|
||||
log_test_info(self.rank, f"Excluded rank {self.rank} - exit")
|
||||
else:
|
||||
log_test_info(
|
||||
self.rank, f"Using regular destroy for excluded rank {self.rank}"
|
||||
)
|
||||
dist.destroy_process_group()
|
||||
return None
|
||||
|
||||
# Only non-excluded ranks proceed with shrink
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Non-excluded rank calling shrink_group (default_group={is_default_group})",
|
||||
)
|
||||
shrunk_pg = c10d.shrink_group(ranks_to_exclude, shrink_flags=shrink_flags)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Non-excluded rank calling shrink_group (default_group={is_default_group}) done",
|
||||
)
|
||||
|
||||
# Non-excluded ranks: validate and test the new group
|
||||
expected_size = self.world_size - len(ranks_to_exclude)
|
||||
_ = self._validate_shrunk_group(shrunk_pg, expected_size, test_name)
|
||||
|
||||
if with_collective:
|
||||
_ = self._test_collective_on_shrunk_group(
|
||||
shrunk_pg, device, ranks_to_exclude, test_name
|
||||
)
|
||||
log_test_success(self.rank, f"{test_name} successful (shrink + collective)")
|
||||
else:
|
||||
log_test_success(self.rank, f"{test_name} successful (shrink only)")
|
||||
|
||||
dist.destroy_process_group()
|
||||
return shrunk_pg
|
||||
|
||||
def _get_default_ranks_to_exclude(self):
|
||||
"""Get default ranks to exclude based on world size."""
|
||||
if self.world_size <= 1:
|
||||
return []
|
||||
return [self.world_size - 1] # Exclude last rank by default
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(3)
|
||||
def test_shrink_group_vs_abort_reinit_performance(self):
|
||||
"""Compare performance of shrink_group vs traditional abort+reinit (simplified for reliability)."""
|
||||
log_test_info(self.rank, "=== TEST 1: abort+reinit ===")
|
||||
|
||||
device, pg1 = self._setup_shrink_test("_perf_reinit")
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
# Test 1: Traditional abort + reinit
|
||||
start_time = time.perf_counter()
|
||||
dist.destroy_process_group()
|
||||
|
||||
device, new_pg = self._setup_shrink_test("perf_shrink_test1")
|
||||
reinit_time = time.perf_counter() - start_time
|
||||
|
||||
# Test collective with original rank values for fair comparison (non-blocking mode)
|
||||
test_tensor = torch.full((1,), self.rank, device=device, dtype=torch.float32)
|
||||
work = c10d.all_reduce(test_tensor, group=new_pg, async_op=True)
|
||||
work.wait()
|
||||
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
# Verify correctness
|
||||
expected_sum = sum(r for r in range(self.world_size))
|
||||
self.assertEqual(test_tensor.item(), expected_sum, "Reinit collective failed")
|
||||
|
||||
log_test_info(self.rank, f"abort+reinit: {reinit_time:.4f}s")
|
||||
dist.destroy_process_group(new_pg)
|
||||
|
||||
# Test 2: shrink_group with NCCL_SHRINK_ABORT
|
||||
log_test_info(self.rank, "=== TEST 2: shrink_group ===")
|
||||
|
||||
ranks_to_exclude = [self.world_size - 1]
|
||||
is_excluded = self.rank in ranks_to_exclude
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}",
|
||||
)
|
||||
|
||||
device, pg1 = self._setup_shrink_test("perf_shrink_test2") # Unique suffix
|
||||
|
||||
shrink_time = 0
|
||||
if not is_excluded:
|
||||
torch.cuda.synchronize(device) # Ensure accurate timing
|
||||
start_time = time.perf_counter()
|
||||
shrunk_pg = c10d.shrink_group(
|
||||
ranks_to_exclude, shrink_flags=NCCL_SHRINK_ABORT
|
||||
)
|
||||
c10d.all_reduce(torch.ones(1).cuda(device), group=shrunk_pg)
|
||||
shrink_time = time.perf_counter() - start_time
|
||||
|
||||
# Test collective communication on shrunk group (non-blocking mode)
|
||||
test_tensor = torch.full(
|
||||
(1,), self.rank, device=device, dtype=torch.float32
|
||||
)
|
||||
work = c10d.all_reduce(test_tensor, group=shrunk_pg, async_op=True)
|
||||
work.wait()
|
||||
|
||||
# Verify correctness
|
||||
expected_sum = sum(
|
||||
r for r in range(self.world_size) if r not in ranks_to_exclude
|
||||
)
|
||||
self.assertEqual(
|
||||
test_tensor.item(),
|
||||
expected_sum,
|
||||
"shrink_test: collective result mismatch",
|
||||
)
|
||||
|
||||
torch.cuda.synchronize(device) # Ensure operations complete
|
||||
log_test_info(self.rank, f"shrink_group: {shrink_time:.4f}s")
|
||||
dist.destroy_process_group()
|
||||
else:
|
||||
log_test_info(self.rank, "Excluded from shrink test - exiting immediately")
|
||||
dist.destroy_process_group()
|
||||
return
|
||||
|
||||
# Performance analysis (only for participating ranks)
|
||||
if shrink_time > 0 and reinit_time > 0:
|
||||
speedup = reinit_time / shrink_time
|
||||
time_saved = reinit_time - shrink_time
|
||||
|
||||
log_test_info(self.rank, "=== PERFORMANCE RESULTS ===")
|
||||
log_test_info(self.rank, f"shrink_group: {shrink_time:.4f}s")
|
||||
log_test_info(self.rank, f"abort+reinit: {reinit_time:.4f}s")
|
||||
log_test_info(self.rank, f"time_saved: {time_saved:+.4f}s")
|
||||
log_test_info(self.rank, f"speedup: {speedup:.2f}x")
|
||||
|
||||
if speedup > 1.1:
|
||||
log_test_success(self.rank, "shrink_group significantly faster")
|
||||
elif speedup > 0.9:
|
||||
log_test_info(self.rank, "≈ comparable performance")
|
||||
else:
|
||||
log_test_warning(self.rank, "abort+reinit faster")
|
||||
|
||||
log_test_info(self.rank, "Performance test completed")
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_deterministic_mode_no_break(self):
|
||||
|
@ -79,6 +79,23 @@ class TORCH_API Backend : public torch::CustomClassHolder {
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual bool supportsShrinking() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Shrink the backend by excluding specified ranks. Backends that support
|
||||
// communicator shrinking should override this and return a new backend
|
||||
// instance representing the shrunken group. Backends may use opts_override
|
||||
// to supply backend-specific options for the new group.
|
||||
virtual c10::intrusive_ptr<Backend> shrink(
|
||||
const std::vector<int64_t>& /*ranks_to_exclude*/,
|
||||
int /*shrink_flags*/ = 0,
|
||||
const c10::intrusive_ptr<Options>& /*opts_override*/ = nullptr) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
c10::str("Backend ", getBackendName(), " does not support shrink"));
|
||||
}
|
||||
|
||||
virtual void setTimeout(std::chrono::milliseconds timeout) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
|
@ -259,6 +259,65 @@ std::shared_ptr<NCCLComm> NCCLComm::split(
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef NCCL_HAS_COMM_SHRINK
|
||||
std::shared_ptr<NCCLComm> NCCLComm::shrink(
|
||||
NCCLComm* source,
|
||||
std::vector<int>& ranks_to_exclude,
|
||||
ncclConfig_t* config,
|
||||
int shrinkFlags) {
|
||||
// Preconditions are validated in ProcessGroupNCCL::shrink
|
||||
|
||||
LOG(INFO) << "Rank " << source->rank_ << ": shrinking comm " << source->repr()
|
||||
<< " excluding " << ranks_to_exclude.size() << " ranks";
|
||||
|
||||
at::cuda::OptionalCUDAGuard gpuGuard(source->deviceIndex_);
|
||||
auto comm = std::make_shared<NCCLComm>();
|
||||
|
||||
// This call will block until the source communicator is initialized
|
||||
auto sourceComm = source->getNcclComm();
|
||||
|
||||
C10D_NCCL_CHECK_NONBLOCKING(
|
||||
ncclCommShrink(
|
||||
sourceComm,
|
||||
ranks_to_exclude.data(),
|
||||
ranks_to_exclude.size(),
|
||||
reinterpret_cast<ncclComm_t*>(&(comm->ncclComm_)),
|
||||
config,
|
||||
shrinkFlags),
|
||||
source->getNcclCommFailureReason());
|
||||
|
||||
// Wait for the child communicator to be ready
|
||||
source->waitReady(true);
|
||||
comm->initialized_ = true;
|
||||
|
||||
// NCCL automatically assigns rank during shrink - query it efficiently
|
||||
int assigned_rank;
|
||||
try {
|
||||
C10D_NCCL_CHECK(
|
||||
ncclCommUserRank(comm->ncclComm_, &assigned_rank), std::nullopt);
|
||||
comm->rank_ = assigned_rank;
|
||||
} catch (const std::exception& e) {
|
||||
// Fallback: if ncclCommUserRank fails, we can't determine the rank
|
||||
LOG(ERROR) << "Failed to query NCCL-assigned rank: " << e.what();
|
||||
throw;
|
||||
}
|
||||
|
||||
// Child comm should be on the same device as parent comm
|
||||
comm->deviceIndex_ = source->deviceIndex_;
|
||||
if (config != nullptr) {
|
||||
comm->nonBlocking_ = config->blocking == 0;
|
||||
} else {
|
||||
// Inherit parent behavior if no config provided
|
||||
comm->nonBlocking_ = source->nonBlocking_;
|
||||
}
|
||||
|
||||
LOG(INFO) << "Rank " << source->rank_ << ": created shrunken comm "
|
||||
<< comm->repr() << " with NCCL-assigned rank " << assigned_rank;
|
||||
|
||||
return comm;
|
||||
}
|
||||
#endif
|
||||
|
||||
void NCCLComm::finalize() {
|
||||
LockType lock(mutex_);
|
||||
if (aborted_) {
|
||||
|
@ -90,6 +90,10 @@ static_assert(
|
||||
#define NCCL_HAS_NVLS_CTAS
|
||||
#endif
|
||||
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 27, 0)
|
||||
#define NCCL_HAS_COMM_SHRINK
|
||||
#endif
|
||||
|
||||
// Macro to throw on a non-successful NCCL return value.
|
||||
#define C10D_NCCL_CHECK(cmd, failureReason) \
|
||||
do { \
|
||||
@ -294,6 +298,14 @@ class NCCLComm {
|
||||
ncclConfig_t& config);
|
||||
#endif // NCCL_HAS_COMM_SPLIT
|
||||
|
||||
#ifdef NCCL_HAS_COMM_SHRINK
|
||||
static std::shared_ptr<NCCLComm> shrink(
|
||||
NCCLComm* source,
|
||||
std::vector<int>& ranks_to_exclude,
|
||||
ncclConfig_t* config,
|
||||
int shrinkFlags = 0);
|
||||
#endif // NCCL_HAS_COMM_SHRINK
|
||||
|
||||
#if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP)
|
||||
std::unordered_map<std::string, std::string> ncclCommDump();
|
||||
#endif
|
||||
|
@ -165,7 +165,7 @@ ncclRedOpRAII getNcclReduceOp(
|
||||
}
|
||||
|
||||
// Get a key string from device
|
||||
inline std::string getKeyFromDevice(at::Device& device) {
|
||||
inline std::string getKeyFromDevice(const at::Device& device) {
|
||||
return std::to_string(device.index());
|
||||
}
|
||||
|
||||
@ -5838,6 +5838,139 @@ at::Tensor ProcessGroupNCCL::allocateTensor(
|
||||
return tensor;
|
||||
}
|
||||
|
||||
#ifdef NCCL_HAS_COMM_SHRINK
|
||||
c10::intrusive_ptr<Backend> ProcessGroupNCCL::shrink(
|
||||
const std::vector<int64_t>& ranks_to_exclude,
|
||||
int shrink_flags,
|
||||
const c10::intrusive_ptr<Backend::Options>& opts_override) {
|
||||
// Runtime version check with better error message
|
||||
auto runtime_version = torch::cuda::nccl::version();
|
||||
TORCH_CHECK(
|
||||
runtime_version >= NCCL_VERSION(2, 27, 0),
|
||||
"ProcessGroupNCCL::shrink requires NCCL version 2.27.0 or later. "
|
||||
"Found version: ",
|
||||
runtime_version);
|
||||
|
||||
// Early validation with detailed error messages
|
||||
TORCH_CHECK_VALUE(
|
||||
!ranks_to_exclude.empty(), "ranks_to_exclude cannot be empty");
|
||||
TORCH_CHECK_VALUE(
|
||||
static_cast<int>(ranks_to_exclude.size()) < size_,
|
||||
"Cannot exclude all ranks (",
|
||||
ranks_to_exclude.size(),
|
||||
" >= ",
|
||||
size_,
|
||||
")");
|
||||
|
||||
// Validate ranks and convert to int efficiently
|
||||
std::vector<int> int_ranks_to_exclude;
|
||||
int_ranks_to_exclude.reserve(ranks_to_exclude.size());
|
||||
for (int64_t rank : ranks_to_exclude) {
|
||||
TORCH_CHECK_VALUE(
|
||||
rank >= 0 && rank < size_,
|
||||
"Invalid rank ",
|
||||
rank,
|
||||
" for group size ",
|
||||
size_);
|
||||
int_ranks_to_exclude.push_back(static_cast<int>(rank));
|
||||
}
|
||||
|
||||
// Get primary communicator with better error context
|
||||
auto primary_device_index = guessDeviceId();
|
||||
auto primary_device = at::Device(at::kCUDA, primary_device_index);
|
||||
const auto primary_key = getKeyFromDevice(primary_device);
|
||||
|
||||
std::shared_ptr<NCCLComm> primary_comm = getNCCLComm(primary_key);
|
||||
TORCH_CHECK(
|
||||
primary_comm,
|
||||
"Primary NCCL communicator for device ",
|
||||
primary_device,
|
||||
" (key: ",
|
||||
primary_key,
|
||||
") is not initialized");
|
||||
|
||||
// Cache device index before shrink operation
|
||||
at::DeviceIndex parent_device_index = primary_comm->getDeviceIndex();
|
||||
|
||||
ncclConfig_t* config = nullptr;
|
||||
// Default to inheriting from parent options
|
||||
bool high_priority_stream = options_->is_high_priority_stream;
|
||||
if (opts_override) {
|
||||
auto nccl_opts =
|
||||
c10::static_intrusive_pointer_cast<ProcessGroupNCCL::Options>(
|
||||
opts_override);
|
||||
config = &nccl_opts->config;
|
||||
// If user provided override options, honor is_high_priority_stream as well
|
||||
high_priority_stream = nccl_opts->is_high_priority_stream;
|
||||
}
|
||||
|
||||
std::shared_ptr<NCCLComm> shrunk_comm = NCCLComm::shrink(
|
||||
primary_comm.get(),
|
||||
int_ranks_to_exclude,
|
||||
(config != nullptr ? config : &options_->config),
|
||||
shrink_flags);
|
||||
|
||||
// Calculate new size and get NCCL-assigned rank
|
||||
int new_size = size_ - static_cast<int>(ranks_to_exclude.size());
|
||||
int new_rank = shrunk_comm->rank_;
|
||||
|
||||
// Create new ProcessGroupNCCL with optimized options cloning
|
||||
auto new_store = store_->clone();
|
||||
auto new_opts = ProcessGroupNCCL::Options::create(high_priority_stream);
|
||||
new_opts->timeout = options_->timeout;
|
||||
if (config != nullptr) {
|
||||
new_opts->config = *config;
|
||||
} else {
|
||||
new_opts->config = options_->config;
|
||||
}
|
||||
|
||||
auto new_pg = c10::make_intrusive<ProcessGroupNCCL>(
|
||||
new_store, new_rank, new_size, new_opts);
|
||||
|
||||
// Set up the new process group with optimized device setup
|
||||
new_pg->initializeDeviceStateForComm(
|
||||
at::Device(at::kCUDA, parent_device_index), shrunk_comm);
|
||||
|
||||
return c10::static_intrusive_pointer_cast<Backend>(new_pg);
|
||||
}
|
||||
|
||||
#else // !NCCL_HAS_COMM_SHRINK
|
||||
// Backend interface override: raise consistent error when shrink is
|
||||
// unsupported.
|
||||
c10::intrusive_ptr<Backend> ProcessGroupNCCL::shrink(
|
||||
const std::vector<int64_t>& /*ranks_to_exclude*/,
|
||||
int /*shrink_flags*/,
|
||||
const c10::intrusive_ptr<Backend::Options>& /*opts_override*/) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"ProcessGroupNCCL::shrink requires NCCL version 2.27.0 or later, "
|
||||
"but PyTorch was built with an older version or without NCCL shrink support.");
|
||||
}
|
||||
|
||||
#endif // NCCL_HAS_COMM_SHRINK
|
||||
|
||||
void ProcessGroupNCCL::initializeDeviceStateForComm(
|
||||
const at::Device& device,
|
||||
std::shared_ptr<NCCLComm> comm) {
|
||||
const auto key = getKeyFromDevice(device);
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
at::cuda::OptionalCUDAGuard gpuGuard(device);
|
||||
|
||||
bool force_high = getCvarBool(TORCH_NCCL_HIGH_PRIORITY, false);
|
||||
auto stream = at::cuda::getStreamFromPool(
|
||||
options_->is_high_priority_stream || force_high);
|
||||
|
||||
devNCCLCommMap_[key] = comm;
|
||||
ncclStreams_.emplace(key, stream);
|
||||
ncclEvents_.emplace(key, at::cuda::CUDAEvent(cudaEventDisableTiming));
|
||||
usedDeviceIdxs_.insert(device.index());
|
||||
|
||||
if (shouldAllCommunicatorsRegisterAllTensors()) {
|
||||
std::lock_guard<std::mutex> map_lock(ncclCommMemPoolMapMutex);
|
||||
ncclCommMemPoolMap.emplace(std::move(comm), MemPoolSet{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace c10d
|
||||
|
||||
#endif // USE_C10D_NCCL
|
||||
|
@ -997,6 +997,21 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
ErrorType getError() override;
|
||||
|
||||
bool supportsShrinking() const override {
|
||||
#ifdef NCCL_HAS_COMM_SHRINK
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Backend-style shrink override that returns a Backend instance.
|
||||
c10::intrusive_ptr<Backend> shrink(
|
||||
const std::vector<int64_t>& ranks_to_exclude,
|
||||
int shrink_flags = 0,
|
||||
const c10::intrusive_ptr<Backend::Options>& opts_override =
|
||||
nullptr) override;
|
||||
|
||||
std::shared_ptr<c10::Allocator> getMemAllocator() override;
|
||||
|
||||
// Allocate tensor from communication-optimized memory pool
|
||||
@ -1065,6 +1080,12 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
int p2pRank = 0,
|
||||
bool isSendRecvSelf = false);
|
||||
|
||||
// Initialize device-specific state (comm, stream, event, bookkeeping) for a
|
||||
// given communicator on this process group instance.
|
||||
void initializeDeviceStateForComm(
|
||||
const at::Device& device,
|
||||
std::shared_ptr<NCCLComm> comm);
|
||||
|
||||
// Wrapper method which can be overridden for tests.
|
||||
virtual std::exception_ptr checkForNCCLErrors(
|
||||
std::shared_ptr<NCCLComm>& ncclComm);
|
||||
|
@ -2730,12 +2730,23 @@ Arguments:
|
||||
"supports_time_estimate",
|
||||
&::c10d::Backend::supportsTimeEstimation,
|
||||
"(test whether the backend supports collective time estimation)")
|
||||
.def_property_readonly(
|
||||
"supports_shrinking",
|
||||
&::c10d::Backend::supportsShrinking,
|
||||
"(test whether the backend supports communicator shrinking)")
|
||||
.def(
|
||||
"set_timeout",
|
||||
&::c10d::Backend::setTimeout,
|
||||
py::arg("timeout"),
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
R"(Sets the default timeout for all future operations.)")
|
||||
.def(
|
||||
"shrink",
|
||||
&::c10d::Backend::shrink,
|
||||
py::arg("ranks_to_exclude"),
|
||||
py::arg("shrink_flags") = 0,
|
||||
py::arg("opts_override") = nullptr,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"broadcast",
|
||||
&::c10d::Backend::broadcast,
|
||||
|
@ -130,6 +130,7 @@ __all__ = [
|
||||
"reduce_scatter_tensor",
|
||||
"get_node_local_rank",
|
||||
"split_group",
|
||||
"shrink_group",
|
||||
]
|
||||
|
||||
_MPI_AVAILABLE = True
|
||||
@ -5713,3 +5714,517 @@ def _get_process_group_name(pg: ProcessGroup) -> str:
|
||||
|
||||
def _get_process_group_store(pg: ProcessGroup) -> Store:
|
||||
return _world.pg_map[pg][1]
|
||||
|
||||
|
||||
# Shrink flags for process group backends
|
||||
SHRINK_DEFAULT = 0x00
|
||||
SHRINK_ABORT = 0x01
|
||||
|
||||
|
||||
@_time_logger
|
||||
def shrink_group(
|
||||
ranks_to_exclude: list[int],
|
||||
group: Optional[ProcessGroup] = None,
|
||||
shrink_flags: int = SHRINK_DEFAULT,
|
||||
pg_options: Optional[Any] = None,
|
||||
) -> ProcessGroup:
|
||||
"""
|
||||
Shrinks a process group by excluding specified ranks.
|
||||
|
||||
Creates and returns a new, smaller process group comprising only the ranks
|
||||
from the original group that were not in the ``ranks_to_exclude`` list.
|
||||
|
||||
Args:
|
||||
ranks_to_exclude (List[int]): A list of ranks from the original
|
||||
``group`` to exclude from the new group.
|
||||
group (ProcessGroup, optional): The process group to shrink. If ``None``,
|
||||
the default process group is used. Defaults to ``None``.
|
||||
shrink_flags (int, optional): Flags to control the shrinking behavior.
|
||||
Can be ``SHRINK_DEFAULT`` (default) or ``SHRINK_ABORT``.
|
||||
``SHRINK_ABORT`` will attempt to terminate ongoing operations
|
||||
in the parent communicator before shrinking.
|
||||
Defaults to ``SHRINK_DEFAULT``.
|
||||
pg_options (ProcessGroupOptions, optional): Backend-specific options to apply
|
||||
to the shrunken process group. If provided, the backend will use
|
||||
these options when creating the new group. If omitted, the new group
|
||||
inherits defaults from the parent.
|
||||
|
||||
Returns:
|
||||
ProcessGroup: a new group comprised of the remaining ranks. If the
|
||||
default group was shrunk, the returned group becomes the new default group.
|
||||
|
||||
Raises:
|
||||
TypeError: if the group’s backend does not support shrinking.
|
||||
ValueError: if ``ranks_to_exclude`` is invalid (empty, out of bounds,
|
||||
duplicates, or excludes all ranks).
|
||||
RuntimeError: if an excluded rank calls this function or the backend
|
||||
fails the operation.
|
||||
|
||||
Notes:
|
||||
- Only non-excluded ranks should call this function; excluded ranks
|
||||
must not participate in the shrink operation.
|
||||
- Shrinking the default group destroys all other process groups since
|
||||
rank reassignment makes them inconsistent.
|
||||
"""
|
||||
# Step 1: Validate input parameters with comprehensive error checking
|
||||
_validate_shrink_inputs(ranks_to_exclude, shrink_flags)
|
||||
|
||||
# Step 2: Get target group and essential properties
|
||||
target_group_info = _prepare_shrink_target_group(group)
|
||||
|
||||
# Step 3: Validate backend requirements and availability
|
||||
backend_impl = _validate_shrink_backend_requirements(target_group_info)
|
||||
|
||||
# Step 4: Validate ranks against group and check for duplicates
|
||||
excluded_ranks_set = _validate_and_process_excluded_ranks(
|
||||
ranks_to_exclude, target_group_info
|
||||
)
|
||||
|
||||
# Step 5: Execute the actual shrink operation (backend-specific)
|
||||
new_backend = backend_impl.shrink(
|
||||
sorted(excluded_ranks_set),
|
||||
shrink_flags,
|
||||
pg_options if pg_options is not None else None,
|
||||
)
|
||||
|
||||
# Step 6: Handle cleanup and creation of new process group
|
||||
target_group_info["pg_options_override"] = pg_options
|
||||
return _finalize_shrunk_group(target_group_info, excluded_ranks_set, new_backend)
|
||||
|
||||
|
||||
def _validate_shrink_inputs(ranks_to_exclude: list[int], shrink_flags: int) -> None:
|
||||
"""Validate input parameters for shrink_group."""
|
||||
if not isinstance(ranks_to_exclude, list):
|
||||
raise TypeError(
|
||||
f"ranks_to_exclude must be a list, but got {type(ranks_to_exclude).__name__}. "
|
||||
f"Example: [1, 3, 5] to exclude ranks 1, 3, and 5."
|
||||
)
|
||||
|
||||
if not ranks_to_exclude:
|
||||
raise ValueError(
|
||||
"ranks_to_exclude cannot be empty. To shrink a group, you must specify at least "
|
||||
"one rank to exclude. Example: [failed_rank_id]"
|
||||
)
|
||||
|
||||
# Validate shrink_flags with clear explanation of valid values
|
||||
valid_flags = [SHRINK_DEFAULT, SHRINK_ABORT]
|
||||
if not isinstance(shrink_flags, int) or shrink_flags not in valid_flags:
|
||||
raise ValueError(
|
||||
f"Invalid shrink_flags value: {shrink_flags}. Must be one of: "
|
||||
f"SHRINK_DEFAULT ({SHRINK_DEFAULT}) or SHRINK_ABORT ({SHRINK_ABORT}). "
|
||||
f"Use SHRINK_ABORT to abort ongoing operations before shrinking."
|
||||
)
|
||||
|
||||
|
||||
def _prepare_shrink_target_group(group: Optional[ProcessGroup]) -> dict:
|
||||
"""Prepare and validate the target group for shrinking."""
|
||||
target_pg = group if group is not None else _get_default_group()
|
||||
|
||||
# Cache frequently accessed properties to avoid repeated calls
|
||||
group_size = int(target_pg.size())
|
||||
group_info = {
|
||||
"process_group": target_pg,
|
||||
"is_default_group": (target_pg == _get_default_group()),
|
||||
"group_size": group_size,
|
||||
"current_rank": target_pg.rank(),
|
||||
"group_name": _get_process_group_name(target_pg),
|
||||
}
|
||||
|
||||
# Validate that we have a valid process group
|
||||
if group_size <= 1:
|
||||
raise ValueError(
|
||||
f"Cannot shrink a process group with size {group_size}. "
|
||||
f"Group must have at least 2 ranks to support shrinking."
|
||||
)
|
||||
|
||||
return group_info
|
||||
|
||||
|
||||
def _validate_shrink_backend_requirements(group_info: dict) -> Any:
|
||||
"""Return the backend implementation for the target group or raise if unsupported."""
|
||||
target_pg = group_info["process_group"]
|
||||
group_name = group_info["group_name"]
|
||||
|
||||
# Get the group's backend directly via ProcessGroup API. Prefer a bound device if present,
|
||||
# otherwise try CUDA then fall back to CPU.
|
||||
try:
|
||||
preferred_device = getattr(target_pg, "bound_device_id", None)
|
||||
if preferred_device is not None:
|
||||
backend_impl = target_pg._get_backend(preferred_device)
|
||||
else:
|
||||
# Try CUDA first if available, else CPU
|
||||
try:
|
||||
backend_impl = target_pg._get_backend(torch.device("cuda"))
|
||||
except Exception:
|
||||
backend_impl = target_pg._get_backend(torch.device("cpu"))
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(
|
||||
f"Cannot access device backend for process group '{group_name}'. "
|
||||
f"Ensure the process group was initialized with a compatible device backend and devices are available."
|
||||
) from e
|
||||
|
||||
try:
|
||||
supports = bool(backend_impl.supports_shrinking)
|
||||
except Exception:
|
||||
supports = False
|
||||
if not supports:
|
||||
raise TypeError(
|
||||
f"Process group backend for '{group_name}' does not support shrinking operations."
|
||||
)
|
||||
|
||||
return backend_impl
|
||||
|
||||
|
||||
def _validate_and_process_excluded_ranks(
|
||||
ranks_to_exclude: list[int], group_info: dict
|
||||
) -> set:
|
||||
"""Validate excluded ranks and convert to set for efficient operations."""
|
||||
group_size = group_info["group_size"]
|
||||
current_rank = group_info["current_rank"]
|
||||
|
||||
# Use set for O(1) duplicate detection and membership testing
|
||||
excluded_ranks_set = set()
|
||||
|
||||
# Validate each rank with detailed error messages
|
||||
for i, rank in enumerate(ranks_to_exclude):
|
||||
if not isinstance(rank, int):
|
||||
raise TypeError(
|
||||
f"All elements in ranks_to_exclude must be integers. "
|
||||
f"Element at index {i} is {type(rank).__name__}: {rank}"
|
||||
)
|
||||
|
||||
if not (0 <= rank < group_size):
|
||||
raise ValueError(
|
||||
f"Rank {rank} at index {i} is out of bounds for group size {group_size}. "
|
||||
f"Valid ranks are in range [0, {group_size - 1}]."
|
||||
)
|
||||
|
||||
if rank in excluded_ranks_set:
|
||||
raise ValueError(
|
||||
f"Duplicate rank {rank} found in ranks_to_exclude at index {i}. "
|
||||
f"Each rank can only be excluded once."
|
||||
)
|
||||
|
||||
excluded_ranks_set.add(rank)
|
||||
|
||||
# Ensure we don't exclude all ranks
|
||||
if len(excluded_ranks_set) >= group_size:
|
||||
raise ValueError(
|
||||
f"Cannot exclude all {group_size} ranks from process group. "
|
||||
f"At least one rank must remain. Excluding {len(excluded_ranks_set)} ranks."
|
||||
)
|
||||
|
||||
# Critical check: current rank should not be in excluded list
|
||||
if current_rank in excluded_ranks_set:
|
||||
raise RuntimeError(
|
||||
f"Current rank {current_rank} is in the exclusion list and should not call shrink_group(). "
|
||||
f"Only non-excluded ranks should participate in the shrinking operation. "
|
||||
f"Excluded ranks should terminate their processes instead."
|
||||
)
|
||||
|
||||
return excluded_ranks_set
|
||||
|
||||
|
||||
def _finalize_shrunk_group(
|
||||
group_info: dict, excluded_ranks_set: set, new_backend
|
||||
) -> ProcessGroup:
|
||||
"""Clean up old group and create new shrunk process group."""
|
||||
target_pg = group_info["process_group"]
|
||||
is_default_group = group_info["is_default_group"]
|
||||
|
||||
# Handle default group dependencies - destroy other groups first
|
||||
if is_default_group:
|
||||
_destroy_all_other_groups(exclude_group=target_pg)
|
||||
|
||||
# Gather original group metadata before cleanup
|
||||
original_group_metadata = _extract_group_metadata(target_pg)
|
||||
|
||||
# Calculate remaining ranks efficiently
|
||||
original_ranks = get_process_group_ranks(target_pg)
|
||||
remaining_ranks = [
|
||||
rank for rank in original_ranks if rank not in excluded_ranks_set
|
||||
]
|
||||
|
||||
# Clean up the original group
|
||||
_cleanup_original_group(target_pg, is_default_group)
|
||||
|
||||
# Create and configure the new process group
|
||||
new_pg = _create_shrunk_process_group(
|
||||
new_backend, remaining_ranks, original_group_metadata, is_default_group
|
||||
)
|
||||
|
||||
# Register the new group in global state
|
||||
if is_default_group:
|
||||
_update_default_pg(new_pg)
|
||||
|
||||
# Update global state with new group information
|
||||
rank_mapping = {
|
||||
global_rank: group_rank
|
||||
for group_rank, global_rank in enumerate(remaining_ranks)
|
||||
}
|
||||
_update_process_group_global_state(
|
||||
pg=new_pg,
|
||||
backend_name=original_group_metadata["backend_name"],
|
||||
store=original_group_metadata["store"],
|
||||
group_name=original_group_metadata["new_group_name"],
|
||||
backend_config=original_group_metadata["backend_config"],
|
||||
rank_mapping=rank_mapping,
|
||||
)
|
||||
|
||||
return new_pg
|
||||
|
||||
|
||||
def _extract_group_metadata(target_pg: ProcessGroup) -> dict:
|
||||
"""Extract metadata from the original group before cleanup."""
|
||||
original_backend_name, original_store = _world.pg_map[target_pg]
|
||||
original_backend_config = _world.pg_backend_config.get(target_pg, "")
|
||||
original_group_name = _get_process_group_name(target_pg)
|
||||
|
||||
# Extract device binding information before cleanup to avoid accessing destroyed group
|
||||
bound_device_id = None
|
||||
if hasattr(target_pg, "bound_device_id"):
|
||||
bound_device_id = target_pg.bound_device_id
|
||||
|
||||
# Generate new group name for the shrunk group; hash for uniqueness across backends
|
||||
remaining_ranks = list(get_process_group_ranks(target_pg))
|
||||
new_group_name = _process_group_name(remaining_ranks, use_hashed_name=True)
|
||||
|
||||
return {
|
||||
"backend_name": original_backend_name,
|
||||
"store": original_store,
|
||||
"backend_config": original_backend_config,
|
||||
"original_group_name": original_group_name,
|
||||
"new_group_name": new_group_name,
|
||||
"bound_device_id": bound_device_id, # Safe to access after cleanup
|
||||
}
|
||||
|
||||
|
||||
def _cleanup_original_group(target_pg: ProcessGroup, is_default_group: bool) -> None:
|
||||
"""Clean up the original process group safely."""
|
||||
try:
|
||||
destroy_process_group(target_pg)
|
||||
except Exception as e:
|
||||
group_type = "default" if is_default_group else "non-default"
|
||||
logger.warning("Failed to destroy %s group during shrinking: %s", group_type, e)
|
||||
|
||||
# Ensure global state cleanup even if destroy_process_group fails
|
||||
_cleanup_process_group_global_state(target_pg)
|
||||
|
||||
|
||||
def _create_shrunk_process_group(
|
||||
new_backend, remaining_ranks: list[int], metadata: dict, is_default_group: bool
|
||||
) -> ProcessGroup:
|
||||
"""Create and configure the new shrunk process group."""
|
||||
# Create new group properties
|
||||
new_group_rank = new_backend.rank()
|
||||
new_group_size = new_backend.size()
|
||||
group_name = metadata["new_group_name"]
|
||||
|
||||
# Generate descriptive group description
|
||||
if is_default_group:
|
||||
group_desc = "default:shrunken"
|
||||
else:
|
||||
group_desc = f"{metadata['original_group_name']}:shrunk"
|
||||
|
||||
# Create process group with new communicator (clone the parent store like split does)
|
||||
prefix_store = PrefixStore(f"{group_name}/", metadata["store"].clone())
|
||||
new_pg = ProcessGroup(prefix_store, new_group_rank, new_group_size)
|
||||
|
||||
# Configure backend using the device type of the new backend's bound device if available,
|
||||
# otherwise derive from the original group's bound device or fall back to CPU.
|
||||
backend_device = metadata.get("bound_device_id")
|
||||
if backend_device is None:
|
||||
# Default to CPU if no bound device is present
|
||||
backend_device = torch.device("cpu")
|
||||
|
||||
# Choose backend enum based on device type
|
||||
if backend_device.type == "cuda":
|
||||
backend_type = ProcessGroup.BackendType.NCCL
|
||||
else:
|
||||
backend_type = ProcessGroup.BackendType.GLOO
|
||||
|
||||
new_pg._register_backend(backend_device, backend_type, new_backend)
|
||||
new_pg._set_default_backend(backend_type)
|
||||
|
||||
# Inherit device binding from original group if it was bound
|
||||
bound_device_id = metadata.get("bound_device_id")
|
||||
if bound_device_id is not None:
|
||||
new_pg.bound_device_id = bound_device_id
|
||||
|
||||
# Set group metadata
|
||||
new_pg._set_group_name(group_name)
|
||||
new_pg._set_group_desc(group_desc)
|
||||
|
||||
# Persist backend configuration overrides (if provided via shrink_group)
|
||||
backend_config_override = metadata.get("backend_config")
|
||||
if backend_config_override is not None:
|
||||
# Store for introspection/debugging and potential backend hooks
|
||||
_world.pg_backend_config[new_pg] = backend_config_override
|
||||
|
||||
return new_pg
|
||||
|
||||
|
||||
def _destroy_all_other_groups(exclude_group: Optional[ProcessGroup] = None) -> None:
|
||||
"""
|
||||
Destroy all process groups except the excluded group and clean up all global state.
|
||||
|
||||
This is necessary when shrinking the default group because global ranks
|
||||
are reassigned by NCCL, making all existing process groups inconsistent.
|
||||
|
||||
Note: Uses abort for non-collective cleanup since excluded ranks may not
|
||||
participate in collective operations. Backend cleanup is handled independently per group.
|
||||
|
||||
Args:
|
||||
exclude_group (ProcessGroup, optional): Process group to exclude from destruction.
|
||||
If None, destroys all process groups.
|
||||
"""
|
||||
# Get list of groups to destroy (avoid modifying dict while iterating)
|
||||
groups_to_destroy = []
|
||||
for pg in list(_world.pg_group_ranks.keys()):
|
||||
if exclude_group is not None and pg == exclude_group:
|
||||
continue
|
||||
groups_to_destroy.append(pg)
|
||||
|
||||
# Warn user about automatic destruction
|
||||
if groups_to_destroy:
|
||||
group_names = [_get_process_group_name(pg) for pg in groups_to_destroy]
|
||||
logger.warning(
|
||||
"Shrinking default group will destroy %d other process groups: %s. "
|
||||
"This is necessary because shrinking the default group reassigns global ranks, "
|
||||
"making existing groups inconsistent.",
|
||||
len(groups_to_destroy),
|
||||
", ".join(group_names),
|
||||
)
|
||||
|
||||
# Destroy each group and clean up global state
|
||||
for pg in groups_to_destroy:
|
||||
try:
|
||||
# First call abort_process_group which handles the C++ cleanup non-collectively
|
||||
_abort_process_group(pg)
|
||||
except Exception as e:
|
||||
# Log but don't fail - some groups might already be destroyed
|
||||
logger.warning(
|
||||
"Failed to abort process group %s: %s",
|
||||
_get_process_group_name(pg),
|
||||
e,
|
||||
)
|
||||
|
||||
# Ensure all global state is cleaned up even if _abort_process_group fails
|
||||
# or doesn't clean up everything
|
||||
_cleanup_process_group_global_state(pg)
|
||||
|
||||
|
||||
def _cleanup_process_group_global_state(pg: ProcessGroup) -> None:
|
||||
"""
|
||||
Clean up all global state associated with a process group.
|
||||
|
||||
This function ensures complete cleanup of process group state from all
|
||||
global dictionaries and registries, even if destroy_process_group fails
|
||||
or doesn't clean up everything. This is critical when destroying multiple
|
||||
groups to prevent inconsistent state.
|
||||
|
||||
The cleanup removes the process group from:
|
||||
- _world.pg_map (backend and store mapping)
|
||||
- _world.pg_names (group name mapping)
|
||||
- _world.pg_group_ranks (rank mappings)
|
||||
- _world.pg_backend_config (backend configuration)
|
||||
- _world.tags_to_pg and _world.pg_to_tag (tag mappings)
|
||||
- _world.pg_coalesce_state (coalescing state)
|
||||
- C++ internal registries via _unregister_process_group
|
||||
|
||||
Args:
|
||||
pg (ProcessGroup): The process group to clean up.
|
||||
"""
|
||||
try:
|
||||
# Clean up main process group mappings
|
||||
_world.pg_map.pop(pg, None)
|
||||
_world.pg_group_ranks.pop(pg, None)
|
||||
_world.pg_backend_config.pop(pg, None)
|
||||
|
||||
# Clean up process group name mapping
|
||||
group_name = _world.pg_names.pop(pg, None)
|
||||
|
||||
# Clean up tag mappings
|
||||
pg_tag = _world.pg_to_tag.pop(pg, None)
|
||||
if pg_tag is not None and pg_tag in _world.tags_to_pg:
|
||||
try:
|
||||
_world.tags_to_pg[pg_tag].remove(pg)
|
||||
# Remove the tag entry if list is empty
|
||||
if not _world.tags_to_pg[pg_tag]:
|
||||
_world.tags_to_pg.pop(pg_tag, None)
|
||||
except (ValueError, KeyError):
|
||||
# Process group was already removed from the list
|
||||
pass
|
||||
|
||||
# Clean up any registered process group names using C++ unregister function
|
||||
if group_name is not None:
|
||||
try:
|
||||
_unregister_process_group(group_name)
|
||||
except Exception:
|
||||
# Process group name might not be registered or already unregistered
|
||||
pass
|
||||
|
||||
# Clean up coalesce state if present
|
||||
_world.pg_coalesce_state.pop(pg, None)
|
||||
|
||||
except Exception as e:
|
||||
# Log cleanup failures but don't propagate - we want to continue with other cleanups
|
||||
logger.warning("Failed to fully clean up global state for process group: %s", e)
|
||||
|
||||
|
||||
def _update_process_group_global_state(
|
||||
pg: ProcessGroup,
|
||||
backend_name: str,
|
||||
store: Store,
|
||||
group_name: str,
|
||||
backend_config: str,
|
||||
rank_mapping: Optional[dict[int, int]] = None,
|
||||
pg_tag: Optional[str] = None,
|
||||
user_tag: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Update all global state dictionaries for a process group.
|
||||
|
||||
This helper function consolidates the common pattern of updating multiple
|
||||
global state dictionaries when creating or modifying process groups.
|
||||
|
||||
Args:
|
||||
pg (ProcessGroup): The process group to update state for.
|
||||
backend_name (str): Backend name for pg_map.
|
||||
store (Store): Store instance for pg_map.
|
||||
group_name (str): Group name for pg_names and registration.
|
||||
backend_config (str): Backend configuration string.
|
||||
rank_mapping (Dict[int, int], optional): Global rank to group rank mapping.
|
||||
If None, skips updating pg_group_ranks.
|
||||
pg_tag (str, optional): Process group tag. If None, defaults to f"ptd:{group_name}".
|
||||
user_tag (str, optional): User-provided tag for special tag handling.
|
||||
If provided, creates "user:{user_tag}" tag and also adds to default "".
|
||||
"""
|
||||
# Update main process group mappings
|
||||
_world.pg_map[pg] = (backend_name, store)
|
||||
_world.pg_names[pg] = group_name
|
||||
_world.pg_backend_config[pg] = backend_config
|
||||
|
||||
# Register the process group name
|
||||
_register_process_group(group_name, pg)
|
||||
|
||||
# Update rank mapping if provided
|
||||
if rank_mapping is not None:
|
||||
_world.pg_group_ranks[pg] = rank_mapping
|
||||
|
||||
# Handle tag management
|
||||
if pg_tag is None:
|
||||
pg_tag = f"ptd:{group_name}"
|
||||
|
||||
if user_tag is not None:
|
||||
# Special handling for user-provided tags
|
||||
# Add to default "" tag first
|
||||
_world.tags_to_pg.setdefault("", []).append(pg)
|
||||
# Then create user-specific tag
|
||||
user_pg_tag = f"user:{user_tag}"
|
||||
_world.tags_to_pg.setdefault(user_pg_tag, []).append(pg)
|
||||
_world.pg_to_tag[pg] = user_pg_tag
|
||||
else:
|
||||
# Standard process group tag
|
||||
_world.tags_to_pg.setdefault(pg_tag, []).append(pg)
|
||||
_world.pg_to_tag[pg] = pg_tag
|
||||
|
@ -238,6 +238,47 @@ def skip_if_lt_x_gpu(x):
|
||||
return decorator
|
||||
|
||||
|
||||
def requires_world_size(n: int):
|
||||
"""
|
||||
Decorator to request a specific world size for a test. The test harness can
|
||||
read this attribute to set the number of ranks to spawn. If there are fewer
|
||||
than `n` CUDA devices available, the test should be skipped by the harness.
|
||||
|
||||
Usage:
|
||||
@require_world_size(3)
|
||||
def test_something(self):
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
func._required_world_size = n
|
||||
available = torch.cuda.device_count()
|
||||
return unittest.skipUnless(
|
||||
available >= n, f"requires {n} GPUs, found {available}"
|
||||
)(func)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_required_world_size(obj: Any, default: int) -> int:
|
||||
"""
|
||||
Returns the requested world size for the currently running unittest method on `obj`
|
||||
if annotated via `@require_world_size(n)`, else returns `default`.
|
||||
"""
|
||||
try:
|
||||
# Try MultiProcessTestCase helper first, then unittest fallback
|
||||
test_name = (
|
||||
obj._current_test_name() # type: ignore[attr-defined]
|
||||
if hasattr(obj, "_current_test_name") and callable(obj._current_test_name)
|
||||
else obj._testMethodName
|
||||
)
|
||||
fn = getattr(obj, test_name)
|
||||
value = fn._required_world_size
|
||||
return int(value)
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
|
||||
# This decorator helps avoiding initializing cuda while testing other backends
|
||||
def nccl_skip_if_lt_x_gpu(backend, x):
|
||||
def decorator(func):
|
||||
@ -367,6 +408,13 @@ def requires_nccl_version(version, msg):
|
||||
)
|
||||
|
||||
|
||||
def requires_nccl_shrink():
|
||||
"""
|
||||
Require NCCL shrink support (NCCL available and version >= 2.27).
|
||||
"""
|
||||
return requires_nccl_version((2, 27), "Need NCCL 2.27+ for shrink_group")
|
||||
|
||||
|
||||
def requires_nccl():
|
||||
return skip_but_pass_in_sandcastle_if(
|
||||
not c10d.is_nccl_available(),
|
||||
|
Reference in New Issue
Block a user