mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "shrink_group implementation to expose ncclCommShrink API (#164518)"
This reverts commit fa0db212e717b6cb225159cb32ea3d83baa52381. Reverted https://github.com/pytorch/pytorch/pull/164518 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/164518#issuecomment-3419893217))
This commit is contained in:
@ -394,10 +394,6 @@ an opaque group handle that can be given as a `group` argument to all collective
|
|||||||
.. autofunction:: new_group
|
.. autofunction:: new_group
|
||||||
```
|
```
|
||||||
|
|
||||||
```{eval-rst}
|
|
||||||
.. autofunction:: torch.distributed.distributed_c10d.shrink_group
|
|
||||||
```
|
|
||||||
|
|
||||||
```{eval-rst}
|
```{eval-rst}
|
||||||
.. autofunction:: get_group_rank
|
.. autofunction:: get_group_rank
|
||||||
```
|
```
|
||||||
|
@ -1,43 +0,0 @@
|
|||||||
import logging
|
|
||||||
import time
|
|
||||||
|
|
||||||
|
|
||||||
_start_time = time.time()
|
|
||||||
_logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def _ts():
|
|
||||||
return time.time() - _start_time
|
|
||||||
|
|
||||||
|
|
||||||
def configure(level=logging.INFO, force=False):
|
|
||||||
try:
|
|
||||||
logging.basicConfig(
|
|
||||||
level=level,
|
|
||||||
format="%(asctime)s %(name)s %(levelname)s: %(message)s",
|
|
||||||
force=force,
|
|
||||||
)
|
|
||||||
except TypeError:
|
|
||||||
logging.basicConfig(
|
|
||||||
level=level, format="%(asctime)s %(name)s %(levelname)s: %(message)s"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def log_test_info(rank, message):
|
|
||||||
_logger.info("[%7.3fs][Rank %s] %s", _ts(), rank, message)
|
|
||||||
|
|
||||||
|
|
||||||
def log_test_success(rank, message):
|
|
||||||
_logger.info("[%7.3fs][Rank %s] ✅ %s", _ts(), rank, message)
|
|
||||||
|
|
||||||
|
|
||||||
def log_test_validation(rank, message):
|
|
||||||
_logger.info("[%7.3fs][Rank %s] ✓ %s", _ts(), rank, message)
|
|
||||||
|
|
||||||
|
|
||||||
def log_test_warning(rank, message):
|
|
||||||
_logger.warning("[%7.3fs][Rank %s] ⚠️ %s", _ts(), rank, message)
|
|
||||||
|
|
||||||
|
|
||||||
def log_test_error(rank, message):
|
|
||||||
_logger.error("[%7.3fs][Rank %s] ✗ %s", _ts(), rank, message)
|
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import random
|
import random
|
||||||
@ -22,7 +21,6 @@ from unittest import mock, SkipTest
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as c10d
|
import torch.distributed as c10d
|
||||||
import torch.distributed._functional_collectives as _functional_collectives
|
import torch.distributed._functional_collectives as _functional_collectives
|
||||||
from torch.distributed.distributed_c10d import SHRINK_ABORT as NCCL_SHRINK_ABORT
|
|
||||||
|
|
||||||
|
|
||||||
if not c10d.is_available() or not c10d.is_nccl_available():
|
if not c10d.is_available() or not c10d.is_nccl_available():
|
||||||
@ -49,15 +47,12 @@ from torch._C._distributed_c10d import ErrorType, OpType, WorkResult
|
|||||||
from torch.nn.parallel import DistributedDataParallel
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
from torch.testing._internal.common_cuda import _get_torch_rocm_version, TEST_MULTIGPU
|
from torch.testing._internal.common_cuda import _get_torch_rocm_version, TEST_MULTIGPU
|
||||||
from torch.testing._internal.common_distributed import (
|
from torch.testing._internal.common_distributed import (
|
||||||
get_required_world_size,
|
|
||||||
get_timeout,
|
get_timeout,
|
||||||
init_multigpu_helper,
|
init_multigpu_helper,
|
||||||
MultiProcessTestCase,
|
MultiProcessTestCase,
|
||||||
requires_multicast_support,
|
requires_multicast_support,
|
||||||
requires_nccl,
|
requires_nccl,
|
||||||
requires_nccl_shrink,
|
|
||||||
requires_nccl_version,
|
requires_nccl_version,
|
||||||
requires_world_size,
|
|
||||||
skip_if_lt_x_gpu,
|
skip_if_lt_x_gpu,
|
||||||
skip_if_rocm_multiprocess,
|
skip_if_rocm_multiprocess,
|
||||||
sm_is_or_higher_than,
|
sm_is_or_higher_than,
|
||||||
@ -92,17 +87,6 @@ BFLOAT16_AVAILABLE = torch.cuda.is_available() and (
|
|||||||
torch.version.cuda is not None or torch.version.hip is not None
|
torch.version.cuda is not None or torch.version.hip is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
from logging_utils import (
|
|
||||||
configure as _log_configure,
|
|
||||||
log_test_info,
|
|
||||||
log_test_success,
|
|
||||||
log_test_validation,
|
|
||||||
log_test_warning,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
_log_configure(level=logging.INFO, force=True)
|
|
||||||
|
|
||||||
|
|
||||||
class RendezvousEnvTest(TestCase):
|
class RendezvousEnvTest(TestCase):
|
||||||
@retry_on_connect_failures
|
@retry_on_connect_failures
|
||||||
@ -333,7 +317,7 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def world_size(self):
|
def world_size(self):
|
||||||
return get_required_world_size(self, 2)
|
return 2
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def rank_to_GPU(self):
|
def rank_to_GPU(self):
|
||||||
@ -1271,628 +1255,6 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
|||||||
pg_2 = c10d.new_group([0, 1])
|
pg_2 = c10d.new_group([0, 1])
|
||||||
self.assertEqual(pg_2.group_desc, "undefined")
|
self.assertEqual(pg_2.group_desc, "undefined")
|
||||||
|
|
||||||
@requires_nccl_shrink()
|
|
||||||
@requires_world_size(2)
|
|
||||||
def test_shrink_group_basic(self):
|
|
||||||
"""Test basic shrink_group functionality."""
|
|
||||||
self._perform_shrink_test([1], "Basic shrink test")
|
|
||||||
|
|
||||||
@requires_nccl_shrink()
|
|
||||||
@requires_world_size(2)
|
|
||||||
def test_shrink_group_validation(self):
|
|
||||||
"""Test input validation in shrink_group."""
|
|
||||||
device, pg = self._setup_shrink_test("validation")
|
|
||||||
|
|
||||||
def _test_invalid_input(ranks, description, expected_exception):
|
|
||||||
"""Helper to test invalid inputs."""
|
|
||||||
try:
|
|
||||||
c10d.shrink_group(ranks)
|
|
||||||
self.fail(f"Expected {expected_exception.__name__} for {description}")
|
|
||||||
except expected_exception:
|
|
||||||
log_test_validation(self.rank, f"✓ {description}")
|
|
||||||
except Exception:
|
|
||||||
if expected_exception is Exception: # Accept any exception
|
|
||||||
log_test_validation(self.rank, f"✓ {description}")
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Test cases
|
|
||||||
_test_invalid_input([], "Empty exclusion list", ValueError)
|
|
||||||
if self.world_size > 1:
|
|
||||||
_test_invalid_input([0, 0, 1], "Duplicate ranks", Exception)
|
|
||||||
_test_invalid_input([self.world_size + 1], "Out of bounds rank", Exception)
|
|
||||||
|
|
||||||
log_test_success(self.rank, "All validation tests passed")
|
|
||||||
dist.destroy_process_group()
|
|
||||||
|
|
||||||
@requires_nccl_shrink()
|
|
||||||
@requires_world_size(2)
|
|
||||||
def test_shrink_group_backend_properties(self):
|
|
||||||
"""Test that backend properties are preserved after shrinking."""
|
|
||||||
|
|
||||||
test_name = "Backend Properties Test"
|
|
||||||
ranks_to_exclude = [0]
|
|
||||||
|
|
||||||
# Reuse _setup_shrink_test for complete setup (device, environment, and process group)
|
|
||||||
device, pg = self._setup_shrink_test("backend_properties")
|
|
||||||
|
|
||||||
# Follow _perform_shrink_test pattern from here
|
|
||||||
log_test_info(self.rank, f"{test_name} (world_size={self.world_size})")
|
|
||||||
|
|
||||||
is_excluded = self.rank in ranks_to_exclude
|
|
||||||
log_test_info(
|
|
||||||
self.rank,
|
|
||||||
f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store original backend property values (not references) before shrinking
|
|
||||||
original_timeout = None
|
|
||||||
original_high_priority = None
|
|
||||||
if not is_excluded:
|
|
||||||
original_backend = pg._get_backend(device)
|
|
||||||
original_timeout = original_backend.options._timeout
|
|
||||||
original_high_priority = original_backend.options.is_high_priority_stream
|
|
||||||
log_test_info(
|
|
||||||
self.rank,
|
|
||||||
f"Storing original backend properties: timeout={original_timeout}, high_priority={original_high_priority}",
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_excluded:
|
|
||||||
log_test_info(
|
|
||||||
self.rank,
|
|
||||||
f"Excluded rank {self.rank} - setup complete, skipping shrink operation",
|
|
||||||
)
|
|
||||||
dist.destroy_process_group() # hang without it
|
|
||||||
return
|
|
||||||
|
|
||||||
# Only non-excluded ranks proceed with shrink (same as _perform_shrink_test)
|
|
||||||
log_test_info(self.rank, "Non-excluded rank calling shrink_group")
|
|
||||||
shrunk_pg = c10d.shrink_group(ranks_to_exclude)
|
|
||||||
|
|
||||||
# Reuse _validate_shrunk_group helper (same as _perform_shrink_test)
|
|
||||||
expected_size = self.world_size - len(ranks_to_exclude)
|
|
||||||
_ = self._validate_shrunk_group(shrunk_pg, expected_size, test_name)
|
|
||||||
|
|
||||||
# Add custom backend properties validation
|
|
||||||
new_backend = shrunk_pg._get_backend(device)
|
|
||||||
log_test_info(self.rank, "Validating backend properties are preserved")
|
|
||||||
|
|
||||||
new_timeout = new_backend.options._timeout
|
|
||||||
new_high_priority = new_backend.options.is_high_priority_stream
|
|
||||||
|
|
||||||
log_test_info(
|
|
||||||
self.rank,
|
|
||||||
f"Timeout comparison - original: {original_timeout}, new: {new_timeout}",
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
original_timeout, new_timeout, f"{test_name}: timeout not preserved"
|
|
||||||
)
|
|
||||||
|
|
||||||
log_test_info(
|
|
||||||
self.rank,
|
|
||||||
f"High priority stream comparison - original: {original_high_priority}, new: {new_high_priority}",
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
original_high_priority,
|
|
||||||
new_high_priority,
|
|
||||||
f"{test_name}: high_priority_stream not preserved",
|
|
||||||
)
|
|
||||||
|
|
||||||
log_test_validation(
|
|
||||||
self.rank, f"{test_name}: Backend properties preserved successfully"
|
|
||||||
)
|
|
||||||
log_test_success(
|
|
||||||
self.rank, f"{test_name} successful (shrink + backend validation)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Cleanup (same as _perform_shrink_test)
|
|
||||||
dist.destroy_process_group()
|
|
||||||
|
|
||||||
@requires_nccl_shrink()
|
|
||||||
@requires_world_size(2)
|
|
||||||
def test_shrink_group_multiple_comms(self):
|
|
||||||
"""Test shrink_group with multiple communicators and subgroup invalidation."""
|
|
||||||
|
|
||||||
device, pg = self._setup_shrink_test("multiple_comms")
|
|
||||||
|
|
||||||
# Create subgroup [0, 1] and test shrinking it
|
|
||||||
subgroup = c10d.new_group([0, 1])
|
|
||||||
if self.rank <= 1:
|
|
||||||
# Shrink subgroup: exclude rank 1
|
|
||||||
if self.rank == 0: # Only rank 0 remains
|
|
||||||
shrunk_subgroup = c10d.shrink_group([1], group=subgroup)
|
|
||||||
self.assertEqual(shrunk_subgroup.size(), 1)
|
|
||||||
# Test communication on shrunk subgroup
|
|
||||||
tensor = torch.full((1,), self.rank).cuda(device)
|
|
||||||
c10d.all_reduce(tensor, group=shrunk_subgroup)
|
|
||||||
self.assertEqual(tensor.item(), 0) # Only rank 0
|
|
||||||
log_test_success(self.rank, "Subgroup shrinking successful")
|
|
||||||
|
|
||||||
dist.barrier() # Sync before default group test
|
|
||||||
|
|
||||||
# Shrink default group: exclude last rank
|
|
||||||
ranks_to_exclude = [self.world_size - 1]
|
|
||||||
if self.rank not in ranks_to_exclude:
|
|
||||||
shrunk_default = c10d.shrink_group(ranks_to_exclude)
|
|
||||||
expected_size = self.world_size - 1
|
|
||||||
self.assertEqual(shrunk_default.size(), expected_size)
|
|
||||||
|
|
||||||
# Test collective on shrunk default group
|
|
||||||
tensor = torch.full((1,), self.rank).cuda(device)
|
|
||||||
c10d.all_reduce(tensor, group=shrunk_default)
|
|
||||||
expected_sum = sum(
|
|
||||||
range(self.world_size - 1)
|
|
||||||
) # 0 + 1 + ... + (world_size-2)
|
|
||||||
self.assertEqual(tensor.item(), expected_sum)
|
|
||||||
log_test_success(self.rank, "Default group shrinking successful")
|
|
||||||
|
|
||||||
# Note: After shrinking default group, the old subgroup is invalid
|
|
||||||
# due to global rank reassignment
|
|
||||||
|
|
||||||
dist.destroy_process_group()
|
|
||||||
|
|
||||||
def _test_shrink_group_with_flag(self, shrink_flag, flag_name, rank_to_exclude):
|
|
||||||
"""Helper method to test shrink_group with a specific flag."""
|
|
||||||
if self.world_size < 2:
|
|
||||||
log_test_info(self.rank, f"Skipping (needs ≥2 GPUs, got {self.world_size})")
|
|
||||||
return
|
|
||||||
ranks_to_exclude = [rank_to_exclude]
|
|
||||||
log_test_info(self.rank, f"Using {flag_name} flag (value: {shrink_flag})")
|
|
||||||
if flag_name == "NCCL_SHRINK_ABORT":
|
|
||||||
log_test_info(
|
|
||||||
self.rank,
|
|
||||||
"ABORT flag will terminate ongoing operations before shrinking",
|
|
||||||
)
|
|
||||||
|
|
||||||
self._perform_shrink_test(
|
|
||||||
ranks_to_exclude, f"{flag_name} flag test", shrink_flags=shrink_flag
|
|
||||||
)
|
|
||||||
|
|
||||||
@requires_nccl_shrink()
|
|
||||||
@requires_world_size(2)
|
|
||||||
def test_shrink_group_flags(self):
|
|
||||||
"""Test shrink_group with different shrink flags."""
|
|
||||||
# Test ABORT flags
|
|
||||||
log_test_info(self.rank, "Testing NCCL_SHRINK_ABORT flag")
|
|
||||||
self._test_shrink_group_with_flag(NCCL_SHRINK_ABORT, "NCCL_SHRINK_ABORT", 1)
|
|
||||||
|
|
||||||
@requires_nccl_shrink()
|
|
||||||
@requires_world_size(2)
|
|
||||||
def test_shrink_group_nccl_config(self):
|
|
||||||
"""Verify that passing NCCL config via pg_options influences the shrunk group's backend options."""
|
|
||||||
device, pg = self._setup_shrink_test("config")
|
|
||||||
if self.rank == self.world_size - 1:
|
|
||||||
# excluded rank should not call shrink_group
|
|
||||||
dist.destroy_process_group()
|
|
||||||
return
|
|
||||||
|
|
||||||
# Prepare pg_options with NCCL config overrides
|
|
||||||
# Capture parent's current backend options to ensure we can prove override vs inherit
|
|
||||||
parent_backend = pg._get_backend(torch.device("cuda"))
|
|
||||||
parent_hp = parent_backend.options.is_high_priority_stream
|
|
||||||
parent_blocking = parent_backend.options.config.blocking
|
|
||||||
|
|
||||||
# Choose overrides that differ from the parent (flip where possible)
|
|
||||||
override_hp = not parent_hp
|
|
||||||
if parent_blocking in (0, 1):
|
|
||||||
override_blocking = 1 - parent_blocking
|
|
||||||
else:
|
|
||||||
# If undefined or unexpected, set to 1 which is a concrete value
|
|
||||||
override_blocking = 1
|
|
||||||
|
|
||||||
opts = c10d.ProcessGroupNCCL.Options()
|
|
||||||
opts.is_high_priority_stream = override_hp
|
|
||||||
opts.config.blocking = override_blocking
|
|
||||||
|
|
||||||
shrunk_pg = c10d.shrink_group([self.world_size - 1], pg_options=opts)
|
|
||||||
|
|
||||||
# Validate backend options propagated
|
|
||||||
backend = shrunk_pg._get_backend(torch.device("cuda"))
|
|
||||||
# is_high_priority_stream should exactly match our override and differ from parent
|
|
||||||
self.assertEqual(backend.options.is_high_priority_stream, override_hp)
|
|
||||||
self.assertNotEqual(backend.options.is_high_priority_stream, parent_hp)
|
|
||||||
# config is a struct; check representative field and difference from parent when meaningful
|
|
||||||
self.assertEqual(backend.options.config.blocking, override_blocking)
|
|
||||||
if parent_blocking in (0, 1):
|
|
||||||
self.assertNotEqual(backend.options.config.blocking, parent_blocking)
|
|
||||||
|
|
||||||
dist.destroy_process_group()
|
|
||||||
|
|
||||||
@requires_nccl_shrink()
|
|
||||||
@requires_world_size(2)
|
|
||||||
def test_shrink_group_performance(self):
|
|
||||||
"""Test shrink_group performance and regression detection."""
|
|
||||||
import time
|
|
||||||
|
|
||||||
ranks_to_exclude = self._get_default_ranks_to_exclude()
|
|
||||||
is_excluded = self.rank in ranks_to_exclude
|
|
||||||
|
|
||||||
if not ranks_to_exclude:
|
|
||||||
log_test_info(self.rank, "Skipping performance test (world_size=1)")
|
|
||||||
return
|
|
||||||
|
|
||||||
log_test_info(self.rank, f"Performance test with {self.world_size} processes")
|
|
||||||
device, pg = self._setup_shrink_test("performance")
|
|
||||||
|
|
||||||
if not is_excluded:
|
|
||||||
log_test_info(self.rank, "Measuring shrink_group performance")
|
|
||||||
start_time = time.time()
|
|
||||||
shrunk_pg = c10d.shrink_group(ranks_to_exclude)
|
|
||||||
end_time = time.time()
|
|
||||||
|
|
||||||
elapsed_time = end_time - start_time
|
|
||||||
log_test_info(self.rank, f"shrink_group: {elapsed_time:.3f}s")
|
|
||||||
|
|
||||||
# Regression check: should complete within reasonable time
|
|
||||||
self.assertLess(
|
|
||||||
elapsed_time,
|
|
||||||
30.0,
|
|
||||||
f"shrink_group took {elapsed_time:.3f}s, possible regression",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test collective performance
|
|
||||||
expected_size = self.world_size - len(ranks_to_exclude)
|
|
||||||
self._validate_shrunk_group(shrunk_pg, expected_size, "performance")
|
|
||||||
|
|
||||||
collective_start = time.time()
|
|
||||||
_ = self._test_collective_on_shrunk_group(
|
|
||||||
shrunk_pg, device, ranks_to_exclude, "performance"
|
|
||||||
)
|
|
||||||
collective_time = time.time() - collective_start
|
|
||||||
|
|
||||||
log_test_info(self.rank, f"all_reduce: {collective_time:.3f}s")
|
|
||||||
log_test_success(self.rank, "Performance test passed")
|
|
||||||
else:
|
|
||||||
log_test_info(self.rank, "Excluded rank - waiting")
|
|
||||||
|
|
||||||
dist.destroy_process_group()
|
|
||||||
|
|
||||||
@requires_nccl_shrink()
|
|
||||||
@requires_world_size(4)
|
|
||||||
def test_shrink_group_multiple_exclusions(self):
|
|
||||||
"""Test shrink_group with multiple ranks excluded at once."""
|
|
||||||
# Scale exclusions with world size
|
|
||||||
ranks_to_exclude = list(range(2, self.world_size, 2)) # Every other rank from 2
|
|
||||||
|
|
||||||
self._perform_shrink_test(ranks_to_exclude, "Multiple exclusions test")
|
|
||||||
|
|
||||||
@requires_nccl_shrink()
|
|
||||||
@requires_world_size(3)
|
|
||||||
def test_shrink_group_multiple_iterations(self):
|
|
||||||
"""Test multiple shrink operations in sequence."""
|
|
||||||
log_test_info(
|
|
||||||
self.rank,
|
|
||||||
f"Starting test_shrink_group_multiple_iterations with world_size={self.world_size}",
|
|
||||||
)
|
|
||||||
|
|
||||||
store = c10d.FileStore(self.file_name, self.world_size)
|
|
||||||
device = torch.device(f"cuda:{self.rank}")
|
|
||||||
_ = self._create_process_group_nccl(store, self.opts(), device_id=device)
|
|
||||||
|
|
||||||
# Track current effective world size throughout shrinking operations
|
|
||||||
current_world_size = self.world_size
|
|
||||||
log_test_info(self.rank, f"Initial world_size: {current_world_size}")
|
|
||||||
|
|
||||||
# First shrinking: exclude the last rank(s)
|
|
||||||
first_exclusion = [self.world_size - 1]
|
|
||||||
if self.world_size >= 6:
|
|
||||||
first_exclusion.append(
|
|
||||||
self.world_size - 2
|
|
||||||
) # Exclude last two ranks for larger sizes
|
|
||||||
|
|
||||||
log_test_info(self.rank, f"First shrinking: excluding ranks {first_exclusion}")
|
|
||||||
|
|
||||||
if self.rank not in first_exclusion:
|
|
||||||
# Only non-excluded ranks should call shrink_group
|
|
||||||
first_pg = c10d.shrink_group(first_exclusion)
|
|
||||||
self.assertIsNotNone(first_pg)
|
|
||||||
# IMPORTANT: Update world size after first shrinking
|
|
||||||
current_world_size = first_pg.size()
|
|
||||||
expected_first_size = self.world_size - len(first_exclusion)
|
|
||||||
log_test_info(
|
|
||||||
self.rank,
|
|
||||||
f"After first shrinking: world_size {self.world_size} -> {current_world_size}",
|
|
||||||
)
|
|
||||||
self.assertEqual(first_pg.size(), expected_first_size)
|
|
||||||
|
|
||||||
# Second shrinking: exclude another rank from the remaining group
|
|
||||||
# Choose a rank that's in the middle range
|
|
||||||
if current_world_size >= 3:
|
|
||||||
second_exclusion = [
|
|
||||||
current_world_size - 1
|
|
||||||
] # Exclude the new "last" rank
|
|
||||||
log_test_info(
|
|
||||||
self.rank,
|
|
||||||
f"Second shrinking from group of size {current_world_size}: excluding ranks {second_exclusion}",
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.rank not in second_exclusion:
|
|
||||||
# Only non-excluded ranks should call shrink_group for second iteration
|
|
||||||
second_pg = c10d.shrink_group(second_exclusion, group=first_pg)
|
|
||||||
self.assertIsNotNone(second_pg)
|
|
||||||
# IMPORTANT: Update world size after second shrinking
|
|
||||||
final_world_size = second_pg.size()
|
|
||||||
expected_final_size = current_world_size - len(second_exclusion)
|
|
||||||
log_test_info(
|
|
||||||
self.rank,
|
|
||||||
f"After second shrinking: world_size {current_world_size} -> {final_world_size}",
|
|
||||||
)
|
|
||||||
self.assertEqual(second_pg.size(), expected_final_size)
|
|
||||||
|
|
||||||
# Test collective on final group
|
|
||||||
tensor = torch.full((1,), self.rank).cuda(device)
|
|
||||||
log_test_info(
|
|
||||||
self.rank,
|
|
||||||
f"Performing all_reduce on final group (size {final_world_size}) with tensor: {tensor.item()}",
|
|
||||||
)
|
|
||||||
c10d.all_reduce(tensor, group=second_pg)
|
|
||||||
log_test_info(
|
|
||||||
self.rank,
|
|
||||||
f"Final all_reduce completed, result: {tensor.item()}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate expected sum of remaining ranks
|
|
||||||
all_excluded = set(first_exclusion + second_exclusion)
|
|
||||||
remaining_ranks = [
|
|
||||||
r for r in range(self.world_size) if r not in all_excluded
|
|
||||||
]
|
|
||||||
expected_sum = sum(remaining_ranks)
|
|
||||||
log_test_info(
|
|
||||||
self.rank,
|
|
||||||
f"Remaining ranks: {remaining_ranks}, expected sum: {expected_sum}, actual: {tensor.item()}",
|
|
||||||
)
|
|
||||||
self.assertEqual(tensor.item(), expected_sum)
|
|
||||||
log_test_info(self.rank, "Final verification passed")
|
|
||||||
else:
|
|
||||||
log_test_info(
|
|
||||||
self.rank,
|
|
||||||
"This rank excluded in second shrinking, not calling shrink_group",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
log_test_info(
|
|
||||||
self.rank, "Skipping second shrinking (remaining group too small)"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
log_test_info(
|
|
||||||
self.rank,
|
|
||||||
"This rank excluded in first shrinking, not calling shrink_group",
|
|
||||||
)
|
|
||||||
|
|
||||||
log_test_info(self.rank, "Destroying process group")
|
|
||||||
dist.destroy_process_group()
|
|
||||||
log_test_info(self.rank, "test_shrink_group_multiple_iterations completed")
|
|
||||||
|
|
||||||
# Helper methods for optimized shrink group tests
|
|
||||||
def _setup_shrink_test(self, test_suffix, world_size=None, warmup=True):
|
|
||||||
"""Common setup for shrink group tests."""
|
|
||||||
os.environ["TORCH_NCCL_USE_COMM_NONBLOCKING"] = "1"
|
|
||||||
world_size = world_size or self.world_size
|
|
||||||
store = c10d.FileStore(self.file_name + f"_{test_suffix}", world_size)
|
|
||||||
device = torch.device(f"cuda:{self.rank}")
|
|
||||||
c10d.init_process_group(
|
|
||||||
"nccl",
|
|
||||||
world_size=world_size,
|
|
||||||
rank=self.rank,
|
|
||||||
store=store,
|
|
||||||
pg_options=self.opts(),
|
|
||||||
device_id=device,
|
|
||||||
)
|
|
||||||
pg = c10d.distributed_c10d._get_default_group()
|
|
||||||
|
|
||||||
if warmup:
|
|
||||||
c10d.all_reduce(torch.ones(1).cuda(device), group=pg)
|
|
||||||
|
|
||||||
return device, pg
|
|
||||||
|
|
||||||
def _validate_shrunk_group(self, shrunk_pg, expected_size, test_name=""):
|
|
||||||
"""Validate properties of a shrunk process group."""
|
|
||||||
self.assertIsNotNone(shrunk_pg, f"{test_name}: shrunk_pg should not be None")
|
|
||||||
actual_size = shrunk_pg.size()
|
|
||||||
self.assertEqual(
|
|
||||||
actual_size, expected_size, f"{test_name}: group size mismatch"
|
|
||||||
)
|
|
||||||
|
|
||||||
new_rank = shrunk_pg.rank()
|
|
||||||
self.assertTrue(
|
|
||||||
0 <= new_rank < expected_size, f"{test_name}: invalid new rank {new_rank}"
|
|
||||||
)
|
|
||||||
|
|
||||||
log_test_info(
|
|
||||||
self.rank,
|
|
||||||
f"{test_name}: world_size {self.world_size} -> {actual_size}, rank {self.rank} -> {new_rank}",
|
|
||||||
)
|
|
||||||
return new_rank
|
|
||||||
|
|
||||||
def _test_collective_on_shrunk_group(
|
|
||||||
self, shrunk_pg, device, ranks_to_exclude, test_name=""
|
|
||||||
):
|
|
||||||
"""Test collective communication on shrunk group and verify correctness."""
|
|
||||||
test_tensor = torch.full((1,), self.rank, device=device, dtype=torch.float32)
|
|
||||||
c10d.all_reduce(test_tensor, group=shrunk_pg)
|
|
||||||
|
|
||||||
result = test_tensor.item()
|
|
||||||
expected_sum = sum(
|
|
||||||
r for r in range(self.world_size) if r not in ranks_to_exclude
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
result, expected_sum, f"{test_name}: collective result mismatch"
|
|
||||||
)
|
|
||||||
log_test_info(
|
|
||||||
self.rank, f"{test_name}: collective passed ({result} == {expected_sum})"
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _perform_shrink_test(
|
|
||||||
self, ranks_to_exclude, test_name, shrink_flags=0, with_collective=True
|
|
||||||
):
|
|
||||||
"""Complete shrink test flow: setup, shrink, validate, test collective, cleanup.
|
|
||||||
|
|
||||||
Consistent API: All ranks perform setup to initialize distributed environment.
|
|
||||||
ONLY non-excluded ranks call shrink_group() for both default and non-default groups.
|
|
||||||
Excluded ranks perform setup, then exit without calling shrink_group() or waiting.
|
|
||||||
"""
|
|
||||||
log_test_info(self.rank, f"{test_name} (world_size={self.world_size})")
|
|
||||||
|
|
||||||
is_excluded = self.rank in ranks_to_exclude
|
|
||||||
log_test_info(
|
|
||||||
self.rank,
|
|
||||||
f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# All ranks (including excluded ones) perform setup to initialize distributed environment
|
|
||||||
device, pg = self._setup_shrink_test(test_name.lower().replace(" ", "_"))
|
|
||||||
is_default_group = pg == c10d.distributed_c10d._get_default_group()
|
|
||||||
|
|
||||||
if is_excluded:
|
|
||||||
log_test_info(
|
|
||||||
self.rank,
|
|
||||||
f"Excluded rank {self.rank} - setup complete, skipping shrink operation",
|
|
||||||
)
|
|
||||||
if shrink_flags & NCCL_SHRINK_ABORT:
|
|
||||||
log_test_info(self.rank, f"Using abort for excluded rank {self.rank}")
|
|
||||||
pg._get_backend(torch.device(device)).abort()
|
|
||||||
log_test_info(
|
|
||||||
self.rank, f"cleanup resources for excluded rank {self.rank}"
|
|
||||||
)
|
|
||||||
dist.destroy_process_group()
|
|
||||||
log_test_info(self.rank, f"Excluded rank {self.rank} - exit")
|
|
||||||
else:
|
|
||||||
log_test_info(
|
|
||||||
self.rank, f"Using regular destroy for excluded rank {self.rank}"
|
|
||||||
)
|
|
||||||
dist.destroy_process_group()
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Only non-excluded ranks proceed with shrink
|
|
||||||
log_test_info(
|
|
||||||
self.rank,
|
|
||||||
f"Non-excluded rank calling shrink_group (default_group={is_default_group})",
|
|
||||||
)
|
|
||||||
shrunk_pg = c10d.shrink_group(ranks_to_exclude, shrink_flags=shrink_flags)
|
|
||||||
log_test_info(
|
|
||||||
self.rank,
|
|
||||||
f"Non-excluded rank calling shrink_group (default_group={is_default_group}) done",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Non-excluded ranks: validate and test the new group
|
|
||||||
expected_size = self.world_size - len(ranks_to_exclude)
|
|
||||||
_ = self._validate_shrunk_group(shrunk_pg, expected_size, test_name)
|
|
||||||
|
|
||||||
if with_collective:
|
|
||||||
_ = self._test_collective_on_shrunk_group(
|
|
||||||
shrunk_pg, device, ranks_to_exclude, test_name
|
|
||||||
)
|
|
||||||
log_test_success(self.rank, f"{test_name} successful (shrink + collective)")
|
|
||||||
else:
|
|
||||||
log_test_success(self.rank, f"{test_name} successful (shrink only)")
|
|
||||||
|
|
||||||
dist.destroy_process_group()
|
|
||||||
return shrunk_pg
|
|
||||||
|
|
||||||
def _get_default_ranks_to_exclude(self):
|
|
||||||
"""Get default ranks to exclude based on world size."""
|
|
||||||
if self.world_size <= 1:
|
|
||||||
return []
|
|
||||||
return [self.world_size - 1] # Exclude last rank by default
|
|
||||||
|
|
||||||
@requires_nccl_shrink()
|
|
||||||
@requires_world_size(3)
|
|
||||||
def test_shrink_group_vs_abort_reinit_performance(self):
|
|
||||||
"""Compare performance of shrink_group vs traditional abort+reinit (simplified for reliability)."""
|
|
||||||
log_test_info(self.rank, "=== TEST 1: abort+reinit ===")
|
|
||||||
|
|
||||||
device, pg1 = self._setup_shrink_test("_perf_reinit")
|
|
||||||
torch.cuda.synchronize(device)
|
|
||||||
|
|
||||||
# Test 1: Traditional abort + reinit
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
dist.destroy_process_group()
|
|
||||||
|
|
||||||
device, new_pg = self._setup_shrink_test("perf_shrink_test1")
|
|
||||||
reinit_time = time.perf_counter() - start_time
|
|
||||||
|
|
||||||
# Test collective with original rank values for fair comparison (non-blocking mode)
|
|
||||||
test_tensor = torch.full((1,), self.rank, device=device, dtype=torch.float32)
|
|
||||||
work = c10d.all_reduce(test_tensor, group=new_pg, async_op=True)
|
|
||||||
work.wait()
|
|
||||||
|
|
||||||
torch.cuda.synchronize(device)
|
|
||||||
|
|
||||||
# Verify correctness
|
|
||||||
expected_sum = sum(r for r in range(self.world_size))
|
|
||||||
self.assertEqual(test_tensor.item(), expected_sum, "Reinit collective failed")
|
|
||||||
|
|
||||||
log_test_info(self.rank, f"abort+reinit: {reinit_time:.4f}s")
|
|
||||||
dist.destroy_process_group(new_pg)
|
|
||||||
|
|
||||||
# Test 2: shrink_group with NCCL_SHRINK_ABORT
|
|
||||||
log_test_info(self.rank, "=== TEST 2: shrink_group ===")
|
|
||||||
|
|
||||||
ranks_to_exclude = [self.world_size - 1]
|
|
||||||
is_excluded = self.rank in ranks_to_exclude
|
|
||||||
log_test_info(
|
|
||||||
self.rank,
|
|
||||||
f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}",
|
|
||||||
)
|
|
||||||
|
|
||||||
device, pg1 = self._setup_shrink_test("perf_shrink_test2") # Unique suffix
|
|
||||||
|
|
||||||
shrink_time = 0
|
|
||||||
if not is_excluded:
|
|
||||||
torch.cuda.synchronize(device) # Ensure accurate timing
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
shrunk_pg = c10d.shrink_group(
|
|
||||||
ranks_to_exclude, shrink_flags=NCCL_SHRINK_ABORT
|
|
||||||
)
|
|
||||||
c10d.all_reduce(torch.ones(1).cuda(device), group=shrunk_pg)
|
|
||||||
shrink_time = time.perf_counter() - start_time
|
|
||||||
|
|
||||||
# Test collective communication on shrunk group (non-blocking mode)
|
|
||||||
test_tensor = torch.full(
|
|
||||||
(1,), self.rank, device=device, dtype=torch.float32
|
|
||||||
)
|
|
||||||
work = c10d.all_reduce(test_tensor, group=shrunk_pg, async_op=True)
|
|
||||||
work.wait()
|
|
||||||
|
|
||||||
# Verify correctness
|
|
||||||
expected_sum = sum(
|
|
||||||
r for r in range(self.world_size) if r not in ranks_to_exclude
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
test_tensor.item(),
|
|
||||||
expected_sum,
|
|
||||||
"shrink_test: collective result mismatch",
|
|
||||||
)
|
|
||||||
|
|
||||||
torch.cuda.synchronize(device) # Ensure operations complete
|
|
||||||
log_test_info(self.rank, f"shrink_group: {shrink_time:.4f}s")
|
|
||||||
dist.destroy_process_group()
|
|
||||||
else:
|
|
||||||
log_test_info(self.rank, "Excluded from shrink test - exiting immediately")
|
|
||||||
dist.destroy_process_group()
|
|
||||||
return
|
|
||||||
|
|
||||||
# Performance analysis (only for participating ranks)
|
|
||||||
if shrink_time > 0 and reinit_time > 0:
|
|
||||||
speedup = reinit_time / shrink_time
|
|
||||||
time_saved = reinit_time - shrink_time
|
|
||||||
|
|
||||||
log_test_info(self.rank, "=== PERFORMANCE RESULTS ===")
|
|
||||||
log_test_info(self.rank, f"shrink_group: {shrink_time:.4f}s")
|
|
||||||
log_test_info(self.rank, f"abort+reinit: {reinit_time:.4f}s")
|
|
||||||
log_test_info(self.rank, f"time_saved: {time_saved:+.4f}s")
|
|
||||||
log_test_info(self.rank, f"speedup: {speedup:.2f}x")
|
|
||||||
|
|
||||||
if speedup > 1.1:
|
|
||||||
log_test_success(self.rank, "shrink_group significantly faster")
|
|
||||||
elif speedup > 0.9:
|
|
||||||
log_test_info(self.rank, "≈ comparable performance")
|
|
||||||
else:
|
|
||||||
log_test_warning(self.rank, "abort+reinit faster")
|
|
||||||
|
|
||||||
log_test_info(self.rank, "Performance test completed")
|
|
||||||
|
|
||||||
@requires_nccl()
|
@requires_nccl()
|
||||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||||
def test_deterministic_mode_no_break(self):
|
def test_deterministic_mode_no_break(self):
|
||||||
|
@ -79,23 +79,6 @@ class TORCH_API Backend : public torch::CustomClassHolder {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual bool supportsShrinking() const {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Shrink the backend by excluding specified ranks. Backends that support
|
|
||||||
// communicator shrinking should override this and return a new backend
|
|
||||||
// instance representing the shrunken group. Backends may use opts_override
|
|
||||||
// to supply backend-specific options for the new group.
|
|
||||||
virtual c10::intrusive_ptr<Backend> shrink(
|
|
||||||
const std::vector<int64_t>& /*ranks_to_exclude*/,
|
|
||||||
int /*shrink_flags*/ = 0,
|
|
||||||
const c10::intrusive_ptr<Options>& /*opts_override*/ = nullptr) {
|
|
||||||
TORCH_CHECK(
|
|
||||||
false,
|
|
||||||
c10::str("Backend ", getBackendName(), " does not support shrink"));
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual void setTimeout(std::chrono::milliseconds timeout) {
|
virtual void setTimeout(std::chrono::milliseconds timeout) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
false,
|
false,
|
||||||
|
@ -259,65 +259,6 @@ std::shared_ptr<NCCLComm> NCCLComm::split(
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef NCCL_HAS_COMM_SHRINK
|
|
||||||
std::shared_ptr<NCCLComm> NCCLComm::shrink(
|
|
||||||
NCCLComm* source,
|
|
||||||
std::vector<int>& ranks_to_exclude,
|
|
||||||
ncclConfig_t* config,
|
|
||||||
int shrinkFlags) {
|
|
||||||
// Preconditions are validated in ProcessGroupNCCL::shrink
|
|
||||||
|
|
||||||
LOG(INFO) << "Rank " << source->rank_ << ": shrinking comm " << source->repr()
|
|
||||||
<< " excluding " << ranks_to_exclude.size() << " ranks";
|
|
||||||
|
|
||||||
at::cuda::OptionalCUDAGuard gpuGuard(source->deviceIndex_);
|
|
||||||
auto comm = std::make_shared<NCCLComm>();
|
|
||||||
|
|
||||||
// This call will block until the source communicator is initialized
|
|
||||||
auto sourceComm = source->getNcclComm();
|
|
||||||
|
|
||||||
C10D_NCCL_CHECK_NONBLOCKING(
|
|
||||||
ncclCommShrink(
|
|
||||||
sourceComm,
|
|
||||||
ranks_to_exclude.data(),
|
|
||||||
ranks_to_exclude.size(),
|
|
||||||
reinterpret_cast<ncclComm_t*>(&(comm->ncclComm_)),
|
|
||||||
config,
|
|
||||||
shrinkFlags),
|
|
||||||
source->getNcclCommFailureReason());
|
|
||||||
|
|
||||||
// Wait for the child communicator to be ready
|
|
||||||
source->waitReady(true);
|
|
||||||
comm->initialized_ = true;
|
|
||||||
|
|
||||||
// NCCL automatically assigns rank during shrink - query it efficiently
|
|
||||||
int assigned_rank;
|
|
||||||
try {
|
|
||||||
C10D_NCCL_CHECK(
|
|
||||||
ncclCommUserRank(comm->ncclComm_, &assigned_rank), std::nullopt);
|
|
||||||
comm->rank_ = assigned_rank;
|
|
||||||
} catch (const std::exception& e) {
|
|
||||||
// Fallback: if ncclCommUserRank fails, we can't determine the rank
|
|
||||||
LOG(ERROR) << "Failed to query NCCL-assigned rank: " << e.what();
|
|
||||||
throw;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Child comm should be on the same device as parent comm
|
|
||||||
comm->deviceIndex_ = source->deviceIndex_;
|
|
||||||
if (config != nullptr) {
|
|
||||||
comm->nonBlocking_ = config->blocking == 0;
|
|
||||||
} else {
|
|
||||||
// Inherit parent behavior if no config provided
|
|
||||||
comm->nonBlocking_ = source->nonBlocking_;
|
|
||||||
}
|
|
||||||
|
|
||||||
LOG(INFO) << "Rank " << source->rank_ << ": created shrunken comm "
|
|
||||||
<< comm->repr() << " with NCCL-assigned rank " << assigned_rank;
|
|
||||||
|
|
||||||
return comm;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
void NCCLComm::finalize() {
|
void NCCLComm::finalize() {
|
||||||
LockType lock(mutex_);
|
LockType lock(mutex_);
|
||||||
if (aborted_) {
|
if (aborted_) {
|
||||||
|
@ -90,10 +90,6 @@ static_assert(
|
|||||||
#define NCCL_HAS_NVLS_CTAS
|
#define NCCL_HAS_NVLS_CTAS
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 27, 0)
|
|
||||||
#define NCCL_HAS_COMM_SHRINK
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Macro to throw on a non-successful NCCL return value.
|
// Macro to throw on a non-successful NCCL return value.
|
||||||
#define C10D_NCCL_CHECK(cmd, failureReason) \
|
#define C10D_NCCL_CHECK(cmd, failureReason) \
|
||||||
do { \
|
do { \
|
||||||
@ -298,14 +294,6 @@ class NCCLComm {
|
|||||||
ncclConfig_t& config);
|
ncclConfig_t& config);
|
||||||
#endif // NCCL_HAS_COMM_SPLIT
|
#endif // NCCL_HAS_COMM_SPLIT
|
||||||
|
|
||||||
#ifdef NCCL_HAS_COMM_SHRINK
|
|
||||||
static std::shared_ptr<NCCLComm> shrink(
|
|
||||||
NCCLComm* source,
|
|
||||||
std::vector<int>& ranks_to_exclude,
|
|
||||||
ncclConfig_t* config,
|
|
||||||
int shrinkFlags = 0);
|
|
||||||
#endif // NCCL_HAS_COMM_SHRINK
|
|
||||||
|
|
||||||
#if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP)
|
#if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP)
|
||||||
std::unordered_map<std::string, std::string> ncclCommDump();
|
std::unordered_map<std::string, std::string> ncclCommDump();
|
||||||
#endif
|
#endif
|
||||||
|
@ -165,7 +165,7 @@ ncclRedOpRAII getNcclReduceOp(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get a key string from device
|
// Get a key string from device
|
||||||
inline std::string getKeyFromDevice(const at::Device& device) {
|
inline std::string getKeyFromDevice(at::Device& device) {
|
||||||
return std::to_string(device.index());
|
return std::to_string(device.index());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -5838,139 +5838,6 @@ at::Tensor ProcessGroupNCCL::allocateTensor(
|
|||||||
return tensor;
|
return tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef NCCL_HAS_COMM_SHRINK
|
|
||||||
c10::intrusive_ptr<Backend> ProcessGroupNCCL::shrink(
|
|
||||||
const std::vector<int64_t>& ranks_to_exclude,
|
|
||||||
int shrink_flags,
|
|
||||||
const c10::intrusive_ptr<Backend::Options>& opts_override) {
|
|
||||||
// Runtime version check with better error message
|
|
||||||
auto runtime_version = torch::cuda::nccl::version();
|
|
||||||
TORCH_CHECK(
|
|
||||||
runtime_version >= NCCL_VERSION(2, 27, 0),
|
|
||||||
"ProcessGroupNCCL::shrink requires NCCL version 2.27.0 or later. "
|
|
||||||
"Found version: ",
|
|
||||||
runtime_version);
|
|
||||||
|
|
||||||
// Early validation with detailed error messages
|
|
||||||
TORCH_CHECK_VALUE(
|
|
||||||
!ranks_to_exclude.empty(), "ranks_to_exclude cannot be empty");
|
|
||||||
TORCH_CHECK_VALUE(
|
|
||||||
static_cast<int>(ranks_to_exclude.size()) < size_,
|
|
||||||
"Cannot exclude all ranks (",
|
|
||||||
ranks_to_exclude.size(),
|
|
||||||
" >= ",
|
|
||||||
size_,
|
|
||||||
")");
|
|
||||||
|
|
||||||
// Validate ranks and convert to int efficiently
|
|
||||||
std::vector<int> int_ranks_to_exclude;
|
|
||||||
int_ranks_to_exclude.reserve(ranks_to_exclude.size());
|
|
||||||
for (int64_t rank : ranks_to_exclude) {
|
|
||||||
TORCH_CHECK_VALUE(
|
|
||||||
rank >= 0 && rank < size_,
|
|
||||||
"Invalid rank ",
|
|
||||||
rank,
|
|
||||||
" for group size ",
|
|
||||||
size_);
|
|
||||||
int_ranks_to_exclude.push_back(static_cast<int>(rank));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get primary communicator with better error context
|
|
||||||
auto primary_device_index = guessDeviceId();
|
|
||||||
auto primary_device = at::Device(at::kCUDA, primary_device_index);
|
|
||||||
const auto primary_key = getKeyFromDevice(primary_device);
|
|
||||||
|
|
||||||
std::shared_ptr<NCCLComm> primary_comm = getNCCLComm(primary_key);
|
|
||||||
TORCH_CHECK(
|
|
||||||
primary_comm,
|
|
||||||
"Primary NCCL communicator for device ",
|
|
||||||
primary_device,
|
|
||||||
" (key: ",
|
|
||||||
primary_key,
|
|
||||||
") is not initialized");
|
|
||||||
|
|
||||||
// Cache device index before shrink operation
|
|
||||||
at::DeviceIndex parent_device_index = primary_comm->getDeviceIndex();
|
|
||||||
|
|
||||||
ncclConfig_t* config = nullptr;
|
|
||||||
// Default to inheriting from parent options
|
|
||||||
bool high_priority_stream = options_->is_high_priority_stream;
|
|
||||||
if (opts_override) {
|
|
||||||
auto nccl_opts =
|
|
||||||
c10::static_intrusive_pointer_cast<ProcessGroupNCCL::Options>(
|
|
||||||
opts_override);
|
|
||||||
config = &nccl_opts->config;
|
|
||||||
// If user provided override options, honor is_high_priority_stream as well
|
|
||||||
high_priority_stream = nccl_opts->is_high_priority_stream;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<NCCLComm> shrunk_comm = NCCLComm::shrink(
|
|
||||||
primary_comm.get(),
|
|
||||||
int_ranks_to_exclude,
|
|
||||||
(config != nullptr ? config : &options_->config),
|
|
||||||
shrink_flags);
|
|
||||||
|
|
||||||
// Calculate new size and get NCCL-assigned rank
|
|
||||||
int new_size = size_ - static_cast<int>(ranks_to_exclude.size());
|
|
||||||
int new_rank = shrunk_comm->rank_;
|
|
||||||
|
|
||||||
// Create new ProcessGroupNCCL with optimized options cloning
|
|
||||||
auto new_store = store_->clone();
|
|
||||||
auto new_opts = ProcessGroupNCCL::Options::create(high_priority_stream);
|
|
||||||
new_opts->timeout = options_->timeout;
|
|
||||||
if (config != nullptr) {
|
|
||||||
new_opts->config = *config;
|
|
||||||
} else {
|
|
||||||
new_opts->config = options_->config;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto new_pg = c10::make_intrusive<ProcessGroupNCCL>(
|
|
||||||
new_store, new_rank, new_size, new_opts);
|
|
||||||
|
|
||||||
// Set up the new process group with optimized device setup
|
|
||||||
new_pg->initializeDeviceStateForComm(
|
|
||||||
at::Device(at::kCUDA, parent_device_index), shrunk_comm);
|
|
||||||
|
|
||||||
return c10::static_intrusive_pointer_cast<Backend>(new_pg);
|
|
||||||
}
|
|
||||||
|
|
||||||
#else // !NCCL_HAS_COMM_SHRINK
|
|
||||||
// Backend interface override: raise consistent error when shrink is
|
|
||||||
// unsupported.
|
|
||||||
c10::intrusive_ptr<Backend> ProcessGroupNCCL::shrink(
|
|
||||||
const std::vector<int64_t>& /*ranks_to_exclude*/,
|
|
||||||
int /*shrink_flags*/,
|
|
||||||
const c10::intrusive_ptr<Backend::Options>& /*opts_override*/) {
|
|
||||||
TORCH_CHECK(
|
|
||||||
false,
|
|
||||||
"ProcessGroupNCCL::shrink requires NCCL version 2.27.0 or later, "
|
|
||||||
"but PyTorch was built with an older version or without NCCL shrink support.");
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif // NCCL_HAS_COMM_SHRINK
|
|
||||||
|
|
||||||
void ProcessGroupNCCL::initializeDeviceStateForComm(
|
|
||||||
const at::Device& device,
|
|
||||||
std::shared_ptr<NCCLComm> comm) {
|
|
||||||
const auto key = getKeyFromDevice(device);
|
|
||||||
std::unique_lock<std::mutex> lock(mutex_);
|
|
||||||
at::cuda::OptionalCUDAGuard gpuGuard(device);
|
|
||||||
|
|
||||||
bool force_high = getCvarBool(TORCH_NCCL_HIGH_PRIORITY, false);
|
|
||||||
auto stream = at::cuda::getStreamFromPool(
|
|
||||||
options_->is_high_priority_stream || force_high);
|
|
||||||
|
|
||||||
devNCCLCommMap_[key] = comm;
|
|
||||||
ncclStreams_.emplace(key, stream);
|
|
||||||
ncclEvents_.emplace(key, at::cuda::CUDAEvent(cudaEventDisableTiming));
|
|
||||||
usedDeviceIdxs_.insert(device.index());
|
|
||||||
|
|
||||||
if (shouldAllCommunicatorsRegisterAllTensors()) {
|
|
||||||
std::lock_guard<std::mutex> map_lock(ncclCommMemPoolMapMutex);
|
|
||||||
ncclCommMemPoolMap.emplace(std::move(comm), MemPoolSet{});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace c10d
|
} // namespace c10d
|
||||||
|
|
||||||
#endif // USE_C10D_NCCL
|
#endif // USE_C10D_NCCL
|
||||||
|
@ -997,21 +997,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
|||||||
|
|
||||||
ErrorType getError() override;
|
ErrorType getError() override;
|
||||||
|
|
||||||
bool supportsShrinking() const override {
|
|
||||||
#ifdef NCCL_HAS_COMM_SHRINK
|
|
||||||
return true;
|
|
||||||
#else
|
|
||||||
return false;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
// Backend-style shrink override that returns a Backend instance.
|
|
||||||
c10::intrusive_ptr<Backend> shrink(
|
|
||||||
const std::vector<int64_t>& ranks_to_exclude,
|
|
||||||
int shrink_flags = 0,
|
|
||||||
const c10::intrusive_ptr<Backend::Options>& opts_override =
|
|
||||||
nullptr) override;
|
|
||||||
|
|
||||||
std::shared_ptr<c10::Allocator> getMemAllocator() override;
|
std::shared_ptr<c10::Allocator> getMemAllocator() override;
|
||||||
|
|
||||||
// Allocate tensor from communication-optimized memory pool
|
// Allocate tensor from communication-optimized memory pool
|
||||||
@ -1080,12 +1065,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
|||||||
int p2pRank = 0,
|
int p2pRank = 0,
|
||||||
bool isSendRecvSelf = false);
|
bool isSendRecvSelf = false);
|
||||||
|
|
||||||
// Initialize device-specific state (comm, stream, event, bookkeeping) for a
|
|
||||||
// given communicator on this process group instance.
|
|
||||||
void initializeDeviceStateForComm(
|
|
||||||
const at::Device& device,
|
|
||||||
std::shared_ptr<NCCLComm> comm);
|
|
||||||
|
|
||||||
// Wrapper method which can be overridden for tests.
|
// Wrapper method which can be overridden for tests.
|
||||||
virtual std::exception_ptr checkForNCCLErrors(
|
virtual std::exception_ptr checkForNCCLErrors(
|
||||||
std::shared_ptr<NCCLComm>& ncclComm);
|
std::shared_ptr<NCCLComm>& ncclComm);
|
||||||
|
@ -2730,23 +2730,12 @@ Arguments:
|
|||||||
"supports_time_estimate",
|
"supports_time_estimate",
|
||||||
&::c10d::Backend::supportsTimeEstimation,
|
&::c10d::Backend::supportsTimeEstimation,
|
||||||
"(test whether the backend supports collective time estimation)")
|
"(test whether the backend supports collective time estimation)")
|
||||||
.def_property_readonly(
|
|
||||||
"supports_shrinking",
|
|
||||||
&::c10d::Backend::supportsShrinking,
|
|
||||||
"(test whether the backend supports communicator shrinking)")
|
|
||||||
.def(
|
.def(
|
||||||
"set_timeout",
|
"set_timeout",
|
||||||
&::c10d::Backend::setTimeout,
|
&::c10d::Backend::setTimeout,
|
||||||
py::arg("timeout"),
|
py::arg("timeout"),
|
||||||
py::call_guard<py::gil_scoped_release>(),
|
py::call_guard<py::gil_scoped_release>(),
|
||||||
R"(Sets the default timeout for all future operations.)")
|
R"(Sets the default timeout for all future operations.)")
|
||||||
.def(
|
|
||||||
"shrink",
|
|
||||||
&::c10d::Backend::shrink,
|
|
||||||
py::arg("ranks_to_exclude"),
|
|
||||||
py::arg("shrink_flags") = 0,
|
|
||||||
py::arg("opts_override") = nullptr,
|
|
||||||
py::call_guard<py::gil_scoped_release>())
|
|
||||||
.def(
|
.def(
|
||||||
"broadcast",
|
"broadcast",
|
||||||
&::c10d::Backend::broadcast,
|
&::c10d::Backend::broadcast,
|
||||||
|
@ -130,7 +130,6 @@ __all__ = [
|
|||||||
"reduce_scatter_tensor",
|
"reduce_scatter_tensor",
|
||||||
"get_node_local_rank",
|
"get_node_local_rank",
|
||||||
"split_group",
|
"split_group",
|
||||||
"shrink_group",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
_MPI_AVAILABLE = True
|
_MPI_AVAILABLE = True
|
||||||
@ -5714,517 +5713,3 @@ def _get_process_group_name(pg: ProcessGroup) -> str:
|
|||||||
|
|
||||||
def _get_process_group_store(pg: ProcessGroup) -> Store:
|
def _get_process_group_store(pg: ProcessGroup) -> Store:
|
||||||
return _world.pg_map[pg][1]
|
return _world.pg_map[pg][1]
|
||||||
|
|
||||||
|
|
||||||
# Shrink flags for process group backends
|
|
||||||
SHRINK_DEFAULT = 0x00
|
|
||||||
SHRINK_ABORT = 0x01
|
|
||||||
|
|
||||||
|
|
||||||
@_time_logger
|
|
||||||
def shrink_group(
|
|
||||||
ranks_to_exclude: list[int],
|
|
||||||
group: Optional[ProcessGroup] = None,
|
|
||||||
shrink_flags: int = SHRINK_DEFAULT,
|
|
||||||
pg_options: Optional[Any] = None,
|
|
||||||
) -> ProcessGroup:
|
|
||||||
"""
|
|
||||||
Shrinks a process group by excluding specified ranks.
|
|
||||||
|
|
||||||
Creates and returns a new, smaller process group comprising only the ranks
|
|
||||||
from the original group that were not in the ``ranks_to_exclude`` list.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ranks_to_exclude (List[int]): A list of ranks from the original
|
|
||||||
``group`` to exclude from the new group.
|
|
||||||
group (ProcessGroup, optional): The process group to shrink. If ``None``,
|
|
||||||
the default process group is used. Defaults to ``None``.
|
|
||||||
shrink_flags (int, optional): Flags to control the shrinking behavior.
|
|
||||||
Can be ``SHRINK_DEFAULT`` (default) or ``SHRINK_ABORT``.
|
|
||||||
``SHRINK_ABORT`` will attempt to terminate ongoing operations
|
|
||||||
in the parent communicator before shrinking.
|
|
||||||
Defaults to ``SHRINK_DEFAULT``.
|
|
||||||
pg_options (ProcessGroupOptions, optional): Backend-specific options to apply
|
|
||||||
to the shrunken process group. If provided, the backend will use
|
|
||||||
these options when creating the new group. If omitted, the new group
|
|
||||||
inherits defaults from the parent.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ProcessGroup: a new group comprised of the remaining ranks. If the
|
|
||||||
default group was shrunk, the returned group becomes the new default group.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: if the group’s backend does not support shrinking.
|
|
||||||
ValueError: if ``ranks_to_exclude`` is invalid (empty, out of bounds,
|
|
||||||
duplicates, or excludes all ranks).
|
|
||||||
RuntimeError: if an excluded rank calls this function or the backend
|
|
||||||
fails the operation.
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
- Only non-excluded ranks should call this function; excluded ranks
|
|
||||||
must not participate in the shrink operation.
|
|
||||||
- Shrinking the default group destroys all other process groups since
|
|
||||||
rank reassignment makes them inconsistent.
|
|
||||||
"""
|
|
||||||
# Step 1: Validate input parameters with comprehensive error checking
|
|
||||||
_validate_shrink_inputs(ranks_to_exclude, shrink_flags)
|
|
||||||
|
|
||||||
# Step 2: Get target group and essential properties
|
|
||||||
target_group_info = _prepare_shrink_target_group(group)
|
|
||||||
|
|
||||||
# Step 3: Validate backend requirements and availability
|
|
||||||
backend_impl = _validate_shrink_backend_requirements(target_group_info)
|
|
||||||
|
|
||||||
# Step 4: Validate ranks against group and check for duplicates
|
|
||||||
excluded_ranks_set = _validate_and_process_excluded_ranks(
|
|
||||||
ranks_to_exclude, target_group_info
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 5: Execute the actual shrink operation (backend-specific)
|
|
||||||
new_backend = backend_impl.shrink(
|
|
||||||
sorted(excluded_ranks_set),
|
|
||||||
shrink_flags,
|
|
||||||
pg_options if pg_options is not None else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 6: Handle cleanup and creation of new process group
|
|
||||||
target_group_info["pg_options_override"] = pg_options
|
|
||||||
return _finalize_shrunk_group(target_group_info, excluded_ranks_set, new_backend)
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_shrink_inputs(ranks_to_exclude: list[int], shrink_flags: int) -> None:
|
|
||||||
"""Validate input parameters for shrink_group."""
|
|
||||||
if not isinstance(ranks_to_exclude, list):
|
|
||||||
raise TypeError(
|
|
||||||
f"ranks_to_exclude must be a list, but got {type(ranks_to_exclude).__name__}. "
|
|
||||||
f"Example: [1, 3, 5] to exclude ranks 1, 3, and 5."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not ranks_to_exclude:
|
|
||||||
raise ValueError(
|
|
||||||
"ranks_to_exclude cannot be empty. To shrink a group, you must specify at least "
|
|
||||||
"one rank to exclude. Example: [failed_rank_id]"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate shrink_flags with clear explanation of valid values
|
|
||||||
valid_flags = [SHRINK_DEFAULT, SHRINK_ABORT]
|
|
||||||
if not isinstance(shrink_flags, int) or shrink_flags not in valid_flags:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid shrink_flags value: {shrink_flags}. Must be one of: "
|
|
||||||
f"SHRINK_DEFAULT ({SHRINK_DEFAULT}) or SHRINK_ABORT ({SHRINK_ABORT}). "
|
|
||||||
f"Use SHRINK_ABORT to abort ongoing operations before shrinking."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_shrink_target_group(group: Optional[ProcessGroup]) -> dict:
|
|
||||||
"""Prepare and validate the target group for shrinking."""
|
|
||||||
target_pg = group if group is not None else _get_default_group()
|
|
||||||
|
|
||||||
# Cache frequently accessed properties to avoid repeated calls
|
|
||||||
group_size = int(target_pg.size())
|
|
||||||
group_info = {
|
|
||||||
"process_group": target_pg,
|
|
||||||
"is_default_group": (target_pg == _get_default_group()),
|
|
||||||
"group_size": group_size,
|
|
||||||
"current_rank": target_pg.rank(),
|
|
||||||
"group_name": _get_process_group_name(target_pg),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Validate that we have a valid process group
|
|
||||||
if group_size <= 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"Cannot shrink a process group with size {group_size}. "
|
|
||||||
f"Group must have at least 2 ranks to support shrinking."
|
|
||||||
)
|
|
||||||
|
|
||||||
return group_info
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_shrink_backend_requirements(group_info: dict) -> Any:
|
|
||||||
"""Return the backend implementation for the target group or raise if unsupported."""
|
|
||||||
target_pg = group_info["process_group"]
|
|
||||||
group_name = group_info["group_name"]
|
|
||||||
|
|
||||||
# Get the group's backend directly via ProcessGroup API. Prefer a bound device if present,
|
|
||||||
# otherwise try CUDA then fall back to CPU.
|
|
||||||
try:
|
|
||||||
preferred_device = getattr(target_pg, "bound_device_id", None)
|
|
||||||
if preferred_device is not None:
|
|
||||||
backend_impl = target_pg._get_backend(preferred_device)
|
|
||||||
else:
|
|
||||||
# Try CUDA first if available, else CPU
|
|
||||||
try:
|
|
||||||
backend_impl = target_pg._get_backend(torch.device("cuda"))
|
|
||||||
except Exception:
|
|
||||||
backend_impl = target_pg._get_backend(torch.device("cpu"))
|
|
||||||
except RuntimeError as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Cannot access device backend for process group '{group_name}'. "
|
|
||||||
f"Ensure the process group was initialized with a compatible device backend and devices are available."
|
|
||||||
) from e
|
|
||||||
|
|
||||||
try:
|
|
||||||
supports = bool(backend_impl.supports_shrinking)
|
|
||||||
except Exception:
|
|
||||||
supports = False
|
|
||||||
if not supports:
|
|
||||||
raise TypeError(
|
|
||||||
f"Process group backend for '{group_name}' does not support shrinking operations."
|
|
||||||
)
|
|
||||||
|
|
||||||
return backend_impl
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_and_process_excluded_ranks(
|
|
||||||
ranks_to_exclude: list[int], group_info: dict
|
|
||||||
) -> set:
|
|
||||||
"""Validate excluded ranks and convert to set for efficient operations."""
|
|
||||||
group_size = group_info["group_size"]
|
|
||||||
current_rank = group_info["current_rank"]
|
|
||||||
|
|
||||||
# Use set for O(1) duplicate detection and membership testing
|
|
||||||
excluded_ranks_set = set()
|
|
||||||
|
|
||||||
# Validate each rank with detailed error messages
|
|
||||||
for i, rank in enumerate(ranks_to_exclude):
|
|
||||||
if not isinstance(rank, int):
|
|
||||||
raise TypeError(
|
|
||||||
f"All elements in ranks_to_exclude must be integers. "
|
|
||||||
f"Element at index {i} is {type(rank).__name__}: {rank}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not (0 <= rank < group_size):
|
|
||||||
raise ValueError(
|
|
||||||
f"Rank {rank} at index {i} is out of bounds for group size {group_size}. "
|
|
||||||
f"Valid ranks are in range [0, {group_size - 1}]."
|
|
||||||
)
|
|
||||||
|
|
||||||
if rank in excluded_ranks_set:
|
|
||||||
raise ValueError(
|
|
||||||
f"Duplicate rank {rank} found in ranks_to_exclude at index {i}. "
|
|
||||||
f"Each rank can only be excluded once."
|
|
||||||
)
|
|
||||||
|
|
||||||
excluded_ranks_set.add(rank)
|
|
||||||
|
|
||||||
# Ensure we don't exclude all ranks
|
|
||||||
if len(excluded_ranks_set) >= group_size:
|
|
||||||
raise ValueError(
|
|
||||||
f"Cannot exclude all {group_size} ranks from process group. "
|
|
||||||
f"At least one rank must remain. Excluding {len(excluded_ranks_set)} ranks."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Critical check: current rank should not be in excluded list
|
|
||||||
if current_rank in excluded_ranks_set:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Current rank {current_rank} is in the exclusion list and should not call shrink_group(). "
|
|
||||||
f"Only non-excluded ranks should participate in the shrinking operation. "
|
|
||||||
f"Excluded ranks should terminate their processes instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
return excluded_ranks_set
|
|
||||||
|
|
||||||
|
|
||||||
def _finalize_shrunk_group(
|
|
||||||
group_info: dict, excluded_ranks_set: set, new_backend
|
|
||||||
) -> ProcessGroup:
|
|
||||||
"""Clean up old group and create new shrunk process group."""
|
|
||||||
target_pg = group_info["process_group"]
|
|
||||||
is_default_group = group_info["is_default_group"]
|
|
||||||
|
|
||||||
# Handle default group dependencies - destroy other groups first
|
|
||||||
if is_default_group:
|
|
||||||
_destroy_all_other_groups(exclude_group=target_pg)
|
|
||||||
|
|
||||||
# Gather original group metadata before cleanup
|
|
||||||
original_group_metadata = _extract_group_metadata(target_pg)
|
|
||||||
|
|
||||||
# Calculate remaining ranks efficiently
|
|
||||||
original_ranks = get_process_group_ranks(target_pg)
|
|
||||||
remaining_ranks = [
|
|
||||||
rank for rank in original_ranks if rank not in excluded_ranks_set
|
|
||||||
]
|
|
||||||
|
|
||||||
# Clean up the original group
|
|
||||||
_cleanup_original_group(target_pg, is_default_group)
|
|
||||||
|
|
||||||
# Create and configure the new process group
|
|
||||||
new_pg = _create_shrunk_process_group(
|
|
||||||
new_backend, remaining_ranks, original_group_metadata, is_default_group
|
|
||||||
)
|
|
||||||
|
|
||||||
# Register the new group in global state
|
|
||||||
if is_default_group:
|
|
||||||
_update_default_pg(new_pg)
|
|
||||||
|
|
||||||
# Update global state with new group information
|
|
||||||
rank_mapping = {
|
|
||||||
global_rank: group_rank
|
|
||||||
for group_rank, global_rank in enumerate(remaining_ranks)
|
|
||||||
}
|
|
||||||
_update_process_group_global_state(
|
|
||||||
pg=new_pg,
|
|
||||||
backend_name=original_group_metadata["backend_name"],
|
|
||||||
store=original_group_metadata["store"],
|
|
||||||
group_name=original_group_metadata["new_group_name"],
|
|
||||||
backend_config=original_group_metadata["backend_config"],
|
|
||||||
rank_mapping=rank_mapping,
|
|
||||||
)
|
|
||||||
|
|
||||||
return new_pg
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_group_metadata(target_pg: ProcessGroup) -> dict:
|
|
||||||
"""Extract metadata from the original group before cleanup."""
|
|
||||||
original_backend_name, original_store = _world.pg_map[target_pg]
|
|
||||||
original_backend_config = _world.pg_backend_config.get(target_pg, "")
|
|
||||||
original_group_name = _get_process_group_name(target_pg)
|
|
||||||
|
|
||||||
# Extract device binding information before cleanup to avoid accessing destroyed group
|
|
||||||
bound_device_id = None
|
|
||||||
if hasattr(target_pg, "bound_device_id"):
|
|
||||||
bound_device_id = target_pg.bound_device_id
|
|
||||||
|
|
||||||
# Generate new group name for the shrunk group; hash for uniqueness across backends
|
|
||||||
remaining_ranks = list(get_process_group_ranks(target_pg))
|
|
||||||
new_group_name = _process_group_name(remaining_ranks, use_hashed_name=True)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"backend_name": original_backend_name,
|
|
||||||
"store": original_store,
|
|
||||||
"backend_config": original_backend_config,
|
|
||||||
"original_group_name": original_group_name,
|
|
||||||
"new_group_name": new_group_name,
|
|
||||||
"bound_device_id": bound_device_id, # Safe to access after cleanup
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _cleanup_original_group(target_pg: ProcessGroup, is_default_group: bool) -> None:
|
|
||||||
"""Clean up the original process group safely."""
|
|
||||||
try:
|
|
||||||
destroy_process_group(target_pg)
|
|
||||||
except Exception as e:
|
|
||||||
group_type = "default" if is_default_group else "non-default"
|
|
||||||
logger.warning("Failed to destroy %s group during shrinking: %s", group_type, e)
|
|
||||||
|
|
||||||
# Ensure global state cleanup even if destroy_process_group fails
|
|
||||||
_cleanup_process_group_global_state(target_pg)
|
|
||||||
|
|
||||||
|
|
||||||
def _create_shrunk_process_group(
|
|
||||||
new_backend, remaining_ranks: list[int], metadata: dict, is_default_group: bool
|
|
||||||
) -> ProcessGroup:
|
|
||||||
"""Create and configure the new shrunk process group."""
|
|
||||||
# Create new group properties
|
|
||||||
new_group_rank = new_backend.rank()
|
|
||||||
new_group_size = new_backend.size()
|
|
||||||
group_name = metadata["new_group_name"]
|
|
||||||
|
|
||||||
# Generate descriptive group description
|
|
||||||
if is_default_group:
|
|
||||||
group_desc = "default:shrunken"
|
|
||||||
else:
|
|
||||||
group_desc = f"{metadata['original_group_name']}:shrunk"
|
|
||||||
|
|
||||||
# Create process group with new communicator (clone the parent store like split does)
|
|
||||||
prefix_store = PrefixStore(f"{group_name}/", metadata["store"].clone())
|
|
||||||
new_pg = ProcessGroup(prefix_store, new_group_rank, new_group_size)
|
|
||||||
|
|
||||||
# Configure backend using the device type of the new backend's bound device if available,
|
|
||||||
# otherwise derive from the original group's bound device or fall back to CPU.
|
|
||||||
backend_device = metadata.get("bound_device_id")
|
|
||||||
if backend_device is None:
|
|
||||||
# Default to CPU if no bound device is present
|
|
||||||
backend_device = torch.device("cpu")
|
|
||||||
|
|
||||||
# Choose backend enum based on device type
|
|
||||||
if backend_device.type == "cuda":
|
|
||||||
backend_type = ProcessGroup.BackendType.NCCL
|
|
||||||
else:
|
|
||||||
backend_type = ProcessGroup.BackendType.GLOO
|
|
||||||
|
|
||||||
new_pg._register_backend(backend_device, backend_type, new_backend)
|
|
||||||
new_pg._set_default_backend(backend_type)
|
|
||||||
|
|
||||||
# Inherit device binding from original group if it was bound
|
|
||||||
bound_device_id = metadata.get("bound_device_id")
|
|
||||||
if bound_device_id is not None:
|
|
||||||
new_pg.bound_device_id = bound_device_id
|
|
||||||
|
|
||||||
# Set group metadata
|
|
||||||
new_pg._set_group_name(group_name)
|
|
||||||
new_pg._set_group_desc(group_desc)
|
|
||||||
|
|
||||||
# Persist backend configuration overrides (if provided via shrink_group)
|
|
||||||
backend_config_override = metadata.get("backend_config")
|
|
||||||
if backend_config_override is not None:
|
|
||||||
# Store for introspection/debugging and potential backend hooks
|
|
||||||
_world.pg_backend_config[new_pg] = backend_config_override
|
|
||||||
|
|
||||||
return new_pg
|
|
||||||
|
|
||||||
|
|
||||||
def _destroy_all_other_groups(exclude_group: Optional[ProcessGroup] = None) -> None:
|
|
||||||
"""
|
|
||||||
Destroy all process groups except the excluded group and clean up all global state.
|
|
||||||
|
|
||||||
This is necessary when shrinking the default group because global ranks
|
|
||||||
are reassigned by NCCL, making all existing process groups inconsistent.
|
|
||||||
|
|
||||||
Note: Uses abort for non-collective cleanup since excluded ranks may not
|
|
||||||
participate in collective operations. Backend cleanup is handled independently per group.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
exclude_group (ProcessGroup, optional): Process group to exclude from destruction.
|
|
||||||
If None, destroys all process groups.
|
|
||||||
"""
|
|
||||||
# Get list of groups to destroy (avoid modifying dict while iterating)
|
|
||||||
groups_to_destroy = []
|
|
||||||
for pg in list(_world.pg_group_ranks.keys()):
|
|
||||||
if exclude_group is not None and pg == exclude_group:
|
|
||||||
continue
|
|
||||||
groups_to_destroy.append(pg)
|
|
||||||
|
|
||||||
# Warn user about automatic destruction
|
|
||||||
if groups_to_destroy:
|
|
||||||
group_names = [_get_process_group_name(pg) for pg in groups_to_destroy]
|
|
||||||
logger.warning(
|
|
||||||
"Shrinking default group will destroy %d other process groups: %s. "
|
|
||||||
"This is necessary because shrinking the default group reassigns global ranks, "
|
|
||||||
"making existing groups inconsistent.",
|
|
||||||
len(groups_to_destroy),
|
|
||||||
", ".join(group_names),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Destroy each group and clean up global state
|
|
||||||
for pg in groups_to_destroy:
|
|
||||||
try:
|
|
||||||
# First call abort_process_group which handles the C++ cleanup non-collectively
|
|
||||||
_abort_process_group(pg)
|
|
||||||
except Exception as e:
|
|
||||||
# Log but don't fail - some groups might already be destroyed
|
|
||||||
logger.warning(
|
|
||||||
"Failed to abort process group %s: %s",
|
|
||||||
_get_process_group_name(pg),
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ensure all global state is cleaned up even if _abort_process_group fails
|
|
||||||
# or doesn't clean up everything
|
|
||||||
_cleanup_process_group_global_state(pg)
|
|
||||||
|
|
||||||
|
|
||||||
def _cleanup_process_group_global_state(pg: ProcessGroup) -> None:
|
|
||||||
"""
|
|
||||||
Clean up all global state associated with a process group.
|
|
||||||
|
|
||||||
This function ensures complete cleanup of process group state from all
|
|
||||||
global dictionaries and registries, even if destroy_process_group fails
|
|
||||||
or doesn't clean up everything. This is critical when destroying multiple
|
|
||||||
groups to prevent inconsistent state.
|
|
||||||
|
|
||||||
The cleanup removes the process group from:
|
|
||||||
- _world.pg_map (backend and store mapping)
|
|
||||||
- _world.pg_names (group name mapping)
|
|
||||||
- _world.pg_group_ranks (rank mappings)
|
|
||||||
- _world.pg_backend_config (backend configuration)
|
|
||||||
- _world.tags_to_pg and _world.pg_to_tag (tag mappings)
|
|
||||||
- _world.pg_coalesce_state (coalescing state)
|
|
||||||
- C++ internal registries via _unregister_process_group
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pg (ProcessGroup): The process group to clean up.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Clean up main process group mappings
|
|
||||||
_world.pg_map.pop(pg, None)
|
|
||||||
_world.pg_group_ranks.pop(pg, None)
|
|
||||||
_world.pg_backend_config.pop(pg, None)
|
|
||||||
|
|
||||||
# Clean up process group name mapping
|
|
||||||
group_name = _world.pg_names.pop(pg, None)
|
|
||||||
|
|
||||||
# Clean up tag mappings
|
|
||||||
pg_tag = _world.pg_to_tag.pop(pg, None)
|
|
||||||
if pg_tag is not None and pg_tag in _world.tags_to_pg:
|
|
||||||
try:
|
|
||||||
_world.tags_to_pg[pg_tag].remove(pg)
|
|
||||||
# Remove the tag entry if list is empty
|
|
||||||
if not _world.tags_to_pg[pg_tag]:
|
|
||||||
_world.tags_to_pg.pop(pg_tag, None)
|
|
||||||
except (ValueError, KeyError):
|
|
||||||
# Process group was already removed from the list
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Clean up any registered process group names using C++ unregister function
|
|
||||||
if group_name is not None:
|
|
||||||
try:
|
|
||||||
_unregister_process_group(group_name)
|
|
||||||
except Exception:
|
|
||||||
# Process group name might not be registered or already unregistered
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Clean up coalesce state if present
|
|
||||||
_world.pg_coalesce_state.pop(pg, None)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Log cleanup failures but don't propagate - we want to continue with other cleanups
|
|
||||||
logger.warning("Failed to fully clean up global state for process group: %s", e)
|
|
||||||
|
|
||||||
|
|
||||||
def _update_process_group_global_state(
|
|
||||||
pg: ProcessGroup,
|
|
||||||
backend_name: str,
|
|
||||||
store: Store,
|
|
||||||
group_name: str,
|
|
||||||
backend_config: str,
|
|
||||||
rank_mapping: Optional[dict[int, int]] = None,
|
|
||||||
pg_tag: Optional[str] = None,
|
|
||||||
user_tag: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Update all global state dictionaries for a process group.
|
|
||||||
|
|
||||||
This helper function consolidates the common pattern of updating multiple
|
|
||||||
global state dictionaries when creating or modifying process groups.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pg (ProcessGroup): The process group to update state for.
|
|
||||||
backend_name (str): Backend name for pg_map.
|
|
||||||
store (Store): Store instance for pg_map.
|
|
||||||
group_name (str): Group name for pg_names and registration.
|
|
||||||
backend_config (str): Backend configuration string.
|
|
||||||
rank_mapping (Dict[int, int], optional): Global rank to group rank mapping.
|
|
||||||
If None, skips updating pg_group_ranks.
|
|
||||||
pg_tag (str, optional): Process group tag. If None, defaults to f"ptd:{group_name}".
|
|
||||||
user_tag (str, optional): User-provided tag for special tag handling.
|
|
||||||
If provided, creates "user:{user_tag}" tag and also adds to default "".
|
|
||||||
"""
|
|
||||||
# Update main process group mappings
|
|
||||||
_world.pg_map[pg] = (backend_name, store)
|
|
||||||
_world.pg_names[pg] = group_name
|
|
||||||
_world.pg_backend_config[pg] = backend_config
|
|
||||||
|
|
||||||
# Register the process group name
|
|
||||||
_register_process_group(group_name, pg)
|
|
||||||
|
|
||||||
# Update rank mapping if provided
|
|
||||||
if rank_mapping is not None:
|
|
||||||
_world.pg_group_ranks[pg] = rank_mapping
|
|
||||||
|
|
||||||
# Handle tag management
|
|
||||||
if pg_tag is None:
|
|
||||||
pg_tag = f"ptd:{group_name}"
|
|
||||||
|
|
||||||
if user_tag is not None:
|
|
||||||
# Special handling for user-provided tags
|
|
||||||
# Add to default "" tag first
|
|
||||||
_world.tags_to_pg.setdefault("", []).append(pg)
|
|
||||||
# Then create user-specific tag
|
|
||||||
user_pg_tag = f"user:{user_tag}"
|
|
||||||
_world.tags_to_pg.setdefault(user_pg_tag, []).append(pg)
|
|
||||||
_world.pg_to_tag[pg] = user_pg_tag
|
|
||||||
else:
|
|
||||||
# Standard process group tag
|
|
||||||
_world.tags_to_pg.setdefault(pg_tag, []).append(pg)
|
|
||||||
_world.pg_to_tag[pg] = pg_tag
|
|
||||||
|
@ -238,47 +238,6 @@ def skip_if_lt_x_gpu(x):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def requires_world_size(n: int):
|
|
||||||
"""
|
|
||||||
Decorator to request a specific world size for a test. The test harness can
|
|
||||||
read this attribute to set the number of ranks to spawn. If there are fewer
|
|
||||||
than `n` CUDA devices available, the test should be skipped by the harness.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
@require_world_size(3)
|
|
||||||
def test_something(self):
|
|
||||||
...
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator(func):
|
|
||||||
func._required_world_size = n
|
|
||||||
available = torch.cuda.device_count()
|
|
||||||
return unittest.skipUnless(
|
|
||||||
available >= n, f"requires {n} GPUs, found {available}"
|
|
||||||
)(func)
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def get_required_world_size(obj: Any, default: int) -> int:
|
|
||||||
"""
|
|
||||||
Returns the requested world size for the currently running unittest method on `obj`
|
|
||||||
if annotated via `@require_world_size(n)`, else returns `default`.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Try MultiProcessTestCase helper first, then unittest fallback
|
|
||||||
test_name = (
|
|
||||||
obj._current_test_name() # type: ignore[attr-defined]
|
|
||||||
if hasattr(obj, "_current_test_name") and callable(obj._current_test_name)
|
|
||||||
else obj._testMethodName
|
|
||||||
)
|
|
||||||
fn = getattr(obj, test_name)
|
|
||||||
value = fn._required_world_size
|
|
||||||
return int(value)
|
|
||||||
except Exception:
|
|
||||||
return default
|
|
||||||
|
|
||||||
|
|
||||||
# This decorator helps avoiding initializing cuda while testing other backends
|
# This decorator helps avoiding initializing cuda while testing other backends
|
||||||
def nccl_skip_if_lt_x_gpu(backend, x):
|
def nccl_skip_if_lt_x_gpu(backend, x):
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
@ -408,13 +367,6 @@ def requires_nccl_version(version, msg):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def requires_nccl_shrink():
|
|
||||||
"""
|
|
||||||
Require NCCL shrink support (NCCL available and version >= 2.27).
|
|
||||||
"""
|
|
||||||
return requires_nccl_version((2, 27), "Need NCCL 2.27+ for shrink_group")
|
|
||||||
|
|
||||||
|
|
||||||
def requires_nccl():
|
def requires_nccl():
|
||||||
return skip_but_pass_in_sandcastle_if(
|
return skip_but_pass_in_sandcastle_if(
|
||||||
not c10d.is_nccl_available(),
|
not c10d.is_nccl_available(),
|
||||||
|
Reference in New Issue
Block a user