mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
shrink_group implementation to expose ncclCommShrink API (#164518)
Closes #164529 To expose the new [ncclCommShrink](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommshrink) API to PyTorch. This is useful when you need to exclude certain GPUs or nodes from a collective operation, for example in fault tolerance scenarios or when dynamically adjusting resource utilization. For more info: [Shrinking a communicator](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#shrinking-a-communicator) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164518 Approved by: https://github.com/Skylion007, https://github.com/syed-ahmed, https://github.com/kwen2501
This commit is contained in:
committed by
PyTorch MergeBot
parent
39e0a832c9
commit
a032510db3
43
test/distributed/logging_utils.py
Normal file
43
test/distributed/logging_utils.py
Normal file
@ -0,0 +1,43 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
|
||||
_start_time = time.time()
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _ts():
|
||||
return time.time() - _start_time
|
||||
|
||||
|
||||
def configure(level=logging.INFO, force=False):
|
||||
try:
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="%(asctime)s %(name)s %(levelname)s: %(message)s",
|
||||
force=force,
|
||||
)
|
||||
except TypeError:
|
||||
logging.basicConfig(
|
||||
level=level, format="%(asctime)s %(name)s %(levelname)s: %(message)s"
|
||||
)
|
||||
|
||||
|
||||
def log_test_info(rank, message):
|
||||
_logger.info("[%7.3fs][Rank %s] %s", _ts(), rank, message)
|
||||
|
||||
|
||||
def log_test_success(rank, message):
|
||||
_logger.info("[%7.3fs][Rank %s] ✅ %s", _ts(), rank, message)
|
||||
|
||||
|
||||
def log_test_validation(rank, message):
|
||||
_logger.info("[%7.3fs][Rank %s] ✓ %s", _ts(), rank, message)
|
||||
|
||||
|
||||
def log_test_warning(rank, message):
|
||||
_logger.warning("[%7.3fs][Rank %s] ⚠️ %s", _ts(), rank, message)
|
||||
|
||||
|
||||
def log_test_error(rank, message):
|
||||
_logger.error("[%7.3fs][Rank %s] ✗ %s", _ts(), rank, message)
|
@ -2,6 +2,7 @@
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
@ -21,6 +22,7 @@ from unittest import mock, SkipTest
|
||||
import torch
|
||||
import torch.distributed as c10d
|
||||
import torch.distributed._functional_collectives as _functional_collectives
|
||||
from torch.distributed.distributed_c10d import SHRINK_ABORT as NCCL_SHRINK_ABORT
|
||||
|
||||
|
||||
if not c10d.is_available() or not c10d.is_nccl_available():
|
||||
@ -47,12 +49,15 @@ from torch._C._distributed_c10d import ErrorType, OpType, WorkResult
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.testing._internal.common_cuda import _get_torch_rocm_version, TEST_MULTIGPU
|
||||
from torch.testing._internal.common_distributed import (
|
||||
get_required_world_size,
|
||||
get_timeout,
|
||||
init_multigpu_helper,
|
||||
MultiProcessTestCase,
|
||||
requires_multicast_support,
|
||||
requires_nccl,
|
||||
requires_nccl_shrink,
|
||||
requires_nccl_version,
|
||||
requires_world_size,
|
||||
skip_if_lt_x_gpu,
|
||||
skip_if_rocm_multiprocess,
|
||||
sm_is_or_higher_than,
|
||||
@ -87,6 +92,17 @@ BFLOAT16_AVAILABLE = torch.cuda.is_available() and (
|
||||
torch.version.cuda is not None or torch.version.hip is not None
|
||||
)
|
||||
|
||||
from logging_utils import (
|
||||
configure as _log_configure,
|
||||
log_test_info,
|
||||
log_test_success,
|
||||
log_test_validation,
|
||||
log_test_warning,
|
||||
)
|
||||
|
||||
|
||||
_log_configure(level=logging.INFO, force=True)
|
||||
|
||||
|
||||
class RendezvousEnvTest(TestCase):
|
||||
@retry_on_connect_failures
|
||||
@ -317,7 +333,7 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return 2
|
||||
return get_required_world_size(self, 2)
|
||||
|
||||
@property
|
||||
def rank_to_GPU(self):
|
||||
@ -1255,6 +1271,628 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||
pg_2 = c10d.new_group([0, 1])
|
||||
self.assertEqual(pg_2.group_desc, "undefined")
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_basic(self):
|
||||
"""Test basic shrink_group functionality."""
|
||||
self._perform_shrink_test([1], "Basic shrink test")
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_validation(self):
|
||||
"""Test input validation in shrink_group."""
|
||||
device, pg = self._setup_shrink_test("validation")
|
||||
|
||||
def _test_invalid_input(ranks, description, expected_exception):
|
||||
"""Helper to test invalid inputs."""
|
||||
try:
|
||||
c10d.shrink_group(ranks)
|
||||
self.fail(f"Expected {expected_exception.__name__} for {description}")
|
||||
except expected_exception:
|
||||
log_test_validation(self.rank, f"✓ {description}")
|
||||
except Exception:
|
||||
if expected_exception == Exception: # Accept any exception
|
||||
log_test_validation(self.rank, f"✓ {description}")
|
||||
else:
|
||||
raise
|
||||
|
||||
# Test cases
|
||||
_test_invalid_input([], "Empty exclusion list", ValueError)
|
||||
if self.world_size > 1:
|
||||
_test_invalid_input([0, 0, 1], "Duplicate ranks", Exception)
|
||||
_test_invalid_input([self.world_size + 1], "Out of bounds rank", Exception)
|
||||
|
||||
log_test_success(self.rank, "All validation tests passed")
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_backend_properties(self):
|
||||
"""Test that backend properties are preserved after shrinking."""
|
||||
|
||||
test_name = "Backend Properties Test"
|
||||
ranks_to_exclude = [0]
|
||||
|
||||
# Reuse _setup_shrink_test for complete setup (device, environment, and process group)
|
||||
device, pg = self._setup_shrink_test("backend_properties")
|
||||
|
||||
# Follow _perform_shrink_test pattern from here
|
||||
log_test_info(self.rank, f"{test_name} (world_size={self.world_size})")
|
||||
|
||||
is_excluded = self.rank in ranks_to_exclude
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}",
|
||||
)
|
||||
|
||||
# Store original backend property values (not references) before shrinking
|
||||
original_timeout = None
|
||||
original_high_priority = None
|
||||
if not is_excluded:
|
||||
original_backend = pg._get_backend(device)
|
||||
original_timeout = original_backend.options._timeout
|
||||
original_high_priority = original_backend.options.is_high_priority_stream
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Storing original backend properties: timeout={original_timeout}, high_priority={original_high_priority}",
|
||||
)
|
||||
|
||||
if is_excluded:
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Excluded rank {self.rank} - setup complete, skipping shrink operation",
|
||||
)
|
||||
dist.destroy_process_group() # hang without it
|
||||
return
|
||||
|
||||
# Only non-excluded ranks proceed with shrink (same as _perform_shrink_test)
|
||||
log_test_info(self.rank, "Non-excluded rank calling shrink_group")
|
||||
shrunk_pg = c10d.shrink_group(ranks_to_exclude)
|
||||
|
||||
# Reuse _validate_shrunk_group helper (same as _perform_shrink_test)
|
||||
expected_size = self.world_size - len(ranks_to_exclude)
|
||||
_ = self._validate_shrunk_group(shrunk_pg, expected_size, test_name)
|
||||
|
||||
# Add custom backend properties validation
|
||||
new_backend = shrunk_pg._get_backend(device)
|
||||
log_test_info(self.rank, "Validating backend properties are preserved")
|
||||
|
||||
new_timeout = new_backend.options._timeout
|
||||
new_high_priority = new_backend.options.is_high_priority_stream
|
||||
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Timeout comparison - original: {original_timeout}, new: {new_timeout}",
|
||||
)
|
||||
self.assertEqual(
|
||||
original_timeout, new_timeout, f"{test_name}: timeout not preserved"
|
||||
)
|
||||
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"High priority stream comparison - original: {original_high_priority}, new: {new_high_priority}",
|
||||
)
|
||||
self.assertEqual(
|
||||
original_high_priority,
|
||||
new_high_priority,
|
||||
f"{test_name}: high_priority_stream not preserved",
|
||||
)
|
||||
|
||||
log_test_validation(
|
||||
self.rank, f"{test_name}: Backend properties preserved successfully"
|
||||
)
|
||||
log_test_success(
|
||||
self.rank, f"{test_name} successful (shrink + backend validation)"
|
||||
)
|
||||
|
||||
# Cleanup (same as _perform_shrink_test)
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_multiple_comms(self):
|
||||
"""Test shrink_group with multiple communicators and subgroup invalidation."""
|
||||
|
||||
device, pg = self._setup_shrink_test("multiple_comms")
|
||||
|
||||
# Create subgroup [0, 1] and test shrinking it
|
||||
subgroup = c10d.new_group([0, 1])
|
||||
if self.rank <= 1:
|
||||
# Shrink subgroup: exclude rank 1
|
||||
if self.rank == 0: # Only rank 0 remains
|
||||
shrunk_subgroup = c10d.shrink_group([1], group=subgroup)
|
||||
self.assertEqual(shrunk_subgroup.size(), 1)
|
||||
# Test communication on shrunk subgroup
|
||||
tensor = torch.full((1,), self.rank).cuda(device)
|
||||
c10d.all_reduce(tensor, group=shrunk_subgroup)
|
||||
self.assertEqual(tensor.item(), 0) # Only rank 0
|
||||
log_test_success(self.rank, "Subgroup shrinking successful")
|
||||
|
||||
dist.barrier() # Sync before default group test
|
||||
|
||||
# Shrink default group: exclude last rank
|
||||
ranks_to_exclude = [self.world_size - 1]
|
||||
if self.rank not in ranks_to_exclude:
|
||||
shrunk_default = c10d.shrink_group(ranks_to_exclude)
|
||||
expected_size = self.world_size - 1
|
||||
self.assertEqual(shrunk_default.size(), expected_size)
|
||||
|
||||
# Test collective on shrunk default group
|
||||
tensor = torch.full((1,), self.rank).cuda(device)
|
||||
c10d.all_reduce(tensor, group=shrunk_default)
|
||||
expected_sum = sum(
|
||||
range(self.world_size - 1)
|
||||
) # 0 + 1 + ... + (world_size-2)
|
||||
self.assertEqual(tensor.item(), expected_sum)
|
||||
log_test_success(self.rank, "Default group shrinking successful")
|
||||
|
||||
# Note: After shrinking default group, the old subgroup is invalid
|
||||
# due to global rank reassignment
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
def _test_shrink_group_with_flag(self, shrink_flag, flag_name, rank_to_exclude):
|
||||
"""Helper method to test shrink_group with a specific flag."""
|
||||
if self.world_size < 2:
|
||||
log_test_info(self.rank, f"Skipping (needs ≥2 GPUs, got {self.world_size})")
|
||||
return
|
||||
ranks_to_exclude = [rank_to_exclude]
|
||||
log_test_info(self.rank, f"Using {flag_name} flag (value: {shrink_flag})")
|
||||
if flag_name == "NCCL_SHRINK_ABORT":
|
||||
log_test_info(
|
||||
self.rank,
|
||||
"ABORT flag will terminate ongoing operations before shrinking",
|
||||
)
|
||||
|
||||
self._perform_shrink_test(
|
||||
ranks_to_exclude, f"{flag_name} flag test", shrink_flags=shrink_flag
|
||||
)
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_flags(self):
|
||||
"""Test shrink_group with different shrink flags."""
|
||||
# Test ABORT flags
|
||||
log_test_info(self.rank, "Testing NCCL_SHRINK_ABORT flag")
|
||||
self._test_shrink_group_with_flag(NCCL_SHRINK_ABORT, "NCCL_SHRINK_ABORT", 1)
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_nccl_config(self):
|
||||
"""Verify that passing NCCL config via pg_options influences the shrunk group's backend options."""
|
||||
device, pg = self._setup_shrink_test("config")
|
||||
if self.rank == self.world_size - 1:
|
||||
# excluded rank should not call shrink_group
|
||||
dist.destroy_process_group()
|
||||
return
|
||||
|
||||
# Prepare pg_options with NCCL config overrides
|
||||
# Capture parent's current backend options to ensure we can prove override vs inherit
|
||||
parent_backend = pg._get_backend(torch.device("cuda"))
|
||||
parent_hp = parent_backend.options.is_high_priority_stream
|
||||
parent_blocking = parent_backend.options.config.blocking
|
||||
|
||||
# Choose overrides that differ from the parent (flip where possible)
|
||||
override_hp = not parent_hp
|
||||
if parent_blocking in (0, 1):
|
||||
override_blocking = 1 - parent_blocking
|
||||
else:
|
||||
# If undefined or unexpected, set to 1 which is a concrete value
|
||||
override_blocking = 1
|
||||
|
||||
opts = c10d.ProcessGroupNCCL.Options()
|
||||
opts.is_high_priority_stream = override_hp
|
||||
opts.config.blocking = override_blocking
|
||||
|
||||
shrunk_pg = c10d.shrink_group([self.world_size - 1], pg_options=opts)
|
||||
|
||||
# Validate backend options propagated
|
||||
backend = shrunk_pg._get_backend(torch.device("cuda"))
|
||||
# is_high_priority_stream should exactly match our override and differ from parent
|
||||
self.assertEqual(backend.options.is_high_priority_stream, override_hp)
|
||||
self.assertNotEqual(backend.options.is_high_priority_stream, parent_hp)
|
||||
# config is a struct; check representative field and difference from parent when meaningful
|
||||
self.assertEqual(backend.options.config.blocking, override_blocking)
|
||||
if parent_blocking in (0, 1):
|
||||
self.assertNotEqual(backend.options.config.blocking, parent_blocking)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_performance(self):
|
||||
"""Test shrink_group performance and regression detection."""
|
||||
import time
|
||||
|
||||
ranks_to_exclude = self._get_default_ranks_to_exclude()
|
||||
is_excluded = self.rank in ranks_to_exclude
|
||||
|
||||
if not ranks_to_exclude:
|
||||
log_test_info(self.rank, "Skipping performance test (world_size=1)")
|
||||
return
|
||||
|
||||
log_test_info(self.rank, f"Performance test with {self.world_size} processes")
|
||||
device, pg = self._setup_shrink_test("performance")
|
||||
|
||||
if not is_excluded:
|
||||
log_test_info(self.rank, "Measuring shrink_group performance")
|
||||
start_time = time.time()
|
||||
shrunk_pg = c10d.shrink_group(ranks_to_exclude)
|
||||
end_time = time.time()
|
||||
|
||||
elapsed_time = end_time - start_time
|
||||
log_test_info(self.rank, f"shrink_group: {elapsed_time:.3f}s")
|
||||
|
||||
# Regression check: should complete within reasonable time
|
||||
self.assertLess(
|
||||
elapsed_time,
|
||||
30.0,
|
||||
f"shrink_group took {elapsed_time:.3f}s, possible regression",
|
||||
)
|
||||
|
||||
# Test collective performance
|
||||
expected_size = self.world_size - len(ranks_to_exclude)
|
||||
self._validate_shrunk_group(shrunk_pg, expected_size, "performance")
|
||||
|
||||
collective_start = time.time()
|
||||
_ = self._test_collective_on_shrunk_group(
|
||||
shrunk_pg, device, ranks_to_exclude, "performance"
|
||||
)
|
||||
collective_time = time.time() - collective_start
|
||||
|
||||
log_test_info(self.rank, f"all_reduce: {collective_time:.3f}s")
|
||||
log_test_success(self.rank, "Performance test passed")
|
||||
else:
|
||||
log_test_info(self.rank, "Excluded rank - waiting")
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(4)
|
||||
def test_shrink_group_multiple_exclusions(self):
|
||||
"""Test shrink_group with multiple ranks excluded at once."""
|
||||
# Scale exclusions with world size
|
||||
ranks_to_exclude = list(range(2, self.world_size, 2)) # Every other rank from 2
|
||||
|
||||
self._perform_shrink_test(ranks_to_exclude, "Multiple exclusions test")
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(3)
|
||||
def test_shrink_group_multiple_iterations(self):
|
||||
"""Test multiple shrink operations in sequence."""
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Starting test_shrink_group_multiple_iterations with world_size={self.world_size}",
|
||||
)
|
||||
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
_ = self._create_process_group_nccl(store, self.opts(), device_id=device)
|
||||
|
||||
# Track current effective world size throughout shrinking operations
|
||||
current_world_size = self.world_size
|
||||
log_test_info(self.rank, f"Initial world_size: {current_world_size}")
|
||||
|
||||
# First shrinking: exclude the last rank(s)
|
||||
first_exclusion = [self.world_size - 1]
|
||||
if self.world_size >= 6:
|
||||
first_exclusion.append(
|
||||
self.world_size - 2
|
||||
) # Exclude last two ranks for larger sizes
|
||||
|
||||
log_test_info(self.rank, f"First shrinking: excluding ranks {first_exclusion}")
|
||||
|
||||
if self.rank not in first_exclusion:
|
||||
# Only non-excluded ranks should call shrink_group
|
||||
first_pg = c10d.shrink_group(first_exclusion)
|
||||
self.assertIsNotNone(first_pg)
|
||||
# IMPORTANT: Update world size after first shrinking
|
||||
current_world_size = first_pg.size()
|
||||
expected_first_size = self.world_size - len(first_exclusion)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"After first shrinking: world_size {self.world_size} -> {current_world_size}",
|
||||
)
|
||||
self.assertEqual(first_pg.size(), expected_first_size)
|
||||
|
||||
# Second shrinking: exclude another rank from the remaining group
|
||||
# Choose a rank that's in the middle range
|
||||
if current_world_size >= 3:
|
||||
second_exclusion = [
|
||||
current_world_size - 1
|
||||
] # Exclude the new "last" rank
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Second shrinking from group of size {current_world_size}: excluding ranks {second_exclusion}",
|
||||
)
|
||||
|
||||
if self.rank not in second_exclusion:
|
||||
# Only non-excluded ranks should call shrink_group for second iteration
|
||||
second_pg = c10d.shrink_group(second_exclusion, group=first_pg)
|
||||
self.assertIsNotNone(second_pg)
|
||||
# IMPORTANT: Update world size after second shrinking
|
||||
final_world_size = second_pg.size()
|
||||
expected_final_size = current_world_size - len(second_exclusion)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"After second shrinking: world_size {current_world_size} -> {final_world_size}",
|
||||
)
|
||||
self.assertEqual(second_pg.size(), expected_final_size)
|
||||
|
||||
# Test collective on final group
|
||||
tensor = torch.full((1,), self.rank).cuda(device)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Performing all_reduce on final group (size {final_world_size}) with tensor: {tensor.item()}",
|
||||
)
|
||||
c10d.all_reduce(tensor, group=second_pg)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Final all_reduce completed, result: {tensor.item()}",
|
||||
)
|
||||
|
||||
# Calculate expected sum of remaining ranks
|
||||
all_excluded = set(first_exclusion + second_exclusion)
|
||||
remaining_ranks = [
|
||||
r for r in range(self.world_size) if r not in all_excluded
|
||||
]
|
||||
expected_sum = sum(remaining_ranks)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Remaining ranks: {remaining_ranks}, expected sum: {expected_sum}, actual: {tensor.item()}",
|
||||
)
|
||||
self.assertEqual(tensor.item(), expected_sum)
|
||||
log_test_info(self.rank, "Final verification passed")
|
||||
else:
|
||||
log_test_info(
|
||||
self.rank,
|
||||
"This rank excluded in second shrinking, not calling shrink_group",
|
||||
)
|
||||
else:
|
||||
log_test_info(
|
||||
self.rank, "Skipping second shrinking (remaining group too small)"
|
||||
)
|
||||
else:
|
||||
log_test_info(
|
||||
self.rank,
|
||||
"This rank excluded in first shrinking, not calling shrink_group",
|
||||
)
|
||||
|
||||
log_test_info(self.rank, "Destroying process group")
|
||||
dist.destroy_process_group()
|
||||
log_test_info(self.rank, "test_shrink_group_multiple_iterations completed")
|
||||
|
||||
# Helper methods for optimized shrink group tests
|
||||
def _setup_shrink_test(self, test_suffix, world_size=None, warmup=True):
|
||||
"""Common setup for shrink group tests."""
|
||||
os.environ["TORCH_NCCL_USE_COMM_NONBLOCKING"] = "1"
|
||||
world_size = world_size or self.world_size
|
||||
store = c10d.FileStore(self.file_name + f"_{test_suffix}", world_size)
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
c10d.init_process_group(
|
||||
"nccl",
|
||||
world_size=world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
pg_options=self.opts(),
|
||||
device_id=device,
|
||||
)
|
||||
pg = c10d.distributed_c10d._get_default_group()
|
||||
|
||||
if warmup:
|
||||
c10d.all_reduce(torch.ones(1).cuda(device), group=pg)
|
||||
|
||||
return device, pg
|
||||
|
||||
def _validate_shrunk_group(self, shrunk_pg, expected_size, test_name=""):
|
||||
"""Validate properties of a shrunk process group."""
|
||||
self.assertIsNotNone(shrunk_pg, f"{test_name}: shrunk_pg should not be None")
|
||||
actual_size = shrunk_pg.size()
|
||||
self.assertEqual(
|
||||
actual_size, expected_size, f"{test_name}: group size mismatch"
|
||||
)
|
||||
|
||||
new_rank = shrunk_pg.rank()
|
||||
self.assertTrue(
|
||||
0 <= new_rank < expected_size, f"{test_name}: invalid new rank {new_rank}"
|
||||
)
|
||||
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"{test_name}: world_size {self.world_size} -> {actual_size}, rank {self.rank} -> {new_rank}",
|
||||
)
|
||||
return new_rank
|
||||
|
||||
def _test_collective_on_shrunk_group(
|
||||
self, shrunk_pg, device, ranks_to_exclude, test_name=""
|
||||
):
|
||||
"""Test collective communication on shrunk group and verify correctness."""
|
||||
test_tensor = torch.full((1,), self.rank, device=device, dtype=torch.float32)
|
||||
c10d.all_reduce(test_tensor, group=shrunk_pg)
|
||||
|
||||
result = test_tensor.item()
|
||||
expected_sum = sum(
|
||||
r for r in range(self.world_size) if r not in ranks_to_exclude
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
result, expected_sum, f"{test_name}: collective result mismatch"
|
||||
)
|
||||
log_test_info(
|
||||
self.rank, f"{test_name}: collective passed ({result} == {expected_sum})"
|
||||
)
|
||||
return result
|
||||
|
||||
def _perform_shrink_test(
|
||||
self, ranks_to_exclude, test_name, shrink_flags=0, with_collective=True
|
||||
):
|
||||
"""Complete shrink test flow: setup, shrink, validate, test collective, cleanup.
|
||||
|
||||
Consistent API: All ranks perform setup to initialize distributed environment.
|
||||
ONLY non-excluded ranks call shrink_group() for both default and non-default groups.
|
||||
Excluded ranks perform setup, then exit without calling shrink_group() or waiting.
|
||||
"""
|
||||
log_test_info(self.rank, f"{test_name} (world_size={self.world_size})")
|
||||
|
||||
is_excluded = self.rank in ranks_to_exclude
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}",
|
||||
)
|
||||
|
||||
# All ranks (including excluded ones) perform setup to initialize distributed environment
|
||||
device, pg = self._setup_shrink_test(test_name.lower().replace(" ", "_"))
|
||||
is_default_group = pg == c10d.distributed_c10d._get_default_group()
|
||||
|
||||
if is_excluded:
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Excluded rank {self.rank} - setup complete, skipping shrink operation",
|
||||
)
|
||||
if shrink_flags & NCCL_SHRINK_ABORT:
|
||||
log_test_info(self.rank, f"Using abort for excluded rank {self.rank}")
|
||||
pg._get_backend(torch.device(device)).abort()
|
||||
log_test_info(
|
||||
self.rank, f"cleanup resources for excluded rank {self.rank}"
|
||||
)
|
||||
dist.destroy_process_group()
|
||||
log_test_info(self.rank, f"Excluded rank {self.rank} - exit")
|
||||
else:
|
||||
log_test_info(
|
||||
self.rank, f"Using regular destroy for excluded rank {self.rank}"
|
||||
)
|
||||
dist.destroy_process_group()
|
||||
return None
|
||||
|
||||
# Only non-excluded ranks proceed with shrink
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Non-excluded rank calling shrink_group (default_group={is_default_group})",
|
||||
)
|
||||
shrunk_pg = c10d.shrink_group(ranks_to_exclude, shrink_flags=shrink_flags)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Non-excluded rank calling shrink_group (default_group={is_default_group}) done",
|
||||
)
|
||||
|
||||
# Non-excluded ranks: validate and test the new group
|
||||
expected_size = self.world_size - len(ranks_to_exclude)
|
||||
_ = self._validate_shrunk_group(shrunk_pg, expected_size, test_name)
|
||||
|
||||
if with_collective:
|
||||
_ = self._test_collective_on_shrunk_group(
|
||||
shrunk_pg, device, ranks_to_exclude, test_name
|
||||
)
|
||||
log_test_success(self.rank, f"{test_name} successful (shrink + collective)")
|
||||
else:
|
||||
log_test_success(self.rank, f"{test_name} successful (shrink only)")
|
||||
|
||||
dist.destroy_process_group()
|
||||
return shrunk_pg
|
||||
|
||||
def _get_default_ranks_to_exclude(self):
|
||||
"""Get default ranks to exclude based on world size."""
|
||||
if self.world_size <= 1:
|
||||
return []
|
||||
return [self.world_size - 1] # Exclude last rank by default
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(3)
|
||||
def test_shrink_group_vs_abort_reinit_performance(self):
|
||||
"""Compare performance of shrink_group vs traditional abort+reinit (simplified for reliability)."""
|
||||
log_test_info(self.rank, "=== TEST 1: abort+reinit ===")
|
||||
|
||||
device, pg1 = self._setup_shrink_test("_perf_reinit")
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
# Test 1: Traditional abort + reinit
|
||||
start_time = time.perf_counter()
|
||||
dist.destroy_process_group()
|
||||
|
||||
device, new_pg = self._setup_shrink_test("perf_shrink_test1")
|
||||
reinit_time = time.perf_counter() - start_time
|
||||
|
||||
# Test collective with original rank values for fair comparison (non-blocking mode)
|
||||
test_tensor = torch.full((1,), self.rank, device=device, dtype=torch.float32)
|
||||
work = c10d.all_reduce(test_tensor, group=new_pg, async_op=True)
|
||||
work.wait()
|
||||
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
# Verify correctness
|
||||
expected_sum = sum(r for r in range(self.world_size))
|
||||
self.assertEqual(test_tensor.item(), expected_sum, "Reinit collective failed")
|
||||
|
||||
log_test_info(self.rank, f"abort+reinit: {reinit_time:.4f}s")
|
||||
dist.destroy_process_group(new_pg)
|
||||
|
||||
# Test 2: shrink_group with NCCL_SHRINK_ABORT
|
||||
log_test_info(self.rank, "=== TEST 2: shrink_group ===")
|
||||
|
||||
ranks_to_exclude = [self.world_size - 1]
|
||||
is_excluded = self.rank in ranks_to_exclude
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}",
|
||||
)
|
||||
|
||||
device, pg1 = self._setup_shrink_test("perf_shrink_test2") # Unique suffix
|
||||
|
||||
shrink_time = 0
|
||||
if not is_excluded:
|
||||
torch.cuda.synchronize(device) # Ensure accurate timing
|
||||
start_time = time.perf_counter()
|
||||
shrunk_pg = c10d.shrink_group(
|
||||
ranks_to_exclude, shrink_flags=NCCL_SHRINK_ABORT
|
||||
)
|
||||
c10d.all_reduce(torch.ones(1).cuda(device), group=shrunk_pg)
|
||||
shrink_time = time.perf_counter() - start_time
|
||||
|
||||
# Test collective communication on shrunk group (non-blocking mode)
|
||||
test_tensor = torch.full(
|
||||
(1,), self.rank, device=device, dtype=torch.float32
|
||||
)
|
||||
work = c10d.all_reduce(test_tensor, group=shrunk_pg, async_op=True)
|
||||
work.wait()
|
||||
|
||||
# Verify correctness
|
||||
expected_sum = sum(
|
||||
r for r in range(self.world_size) if r not in ranks_to_exclude
|
||||
)
|
||||
self.assertEqual(
|
||||
test_tensor.item(),
|
||||
expected_sum,
|
||||
"shrink_test: collective result mismatch",
|
||||
)
|
||||
|
||||
torch.cuda.synchronize(device) # Ensure operations complete
|
||||
log_test_info(self.rank, f"shrink_group: {shrink_time:.4f}s")
|
||||
dist.destroy_process_group()
|
||||
else:
|
||||
log_test_info(self.rank, "Excluded from shrink test - exiting immediately")
|
||||
dist.destroy_process_group()
|
||||
return
|
||||
|
||||
# Performance analysis (only for participating ranks)
|
||||
if shrink_time > 0 and reinit_time > 0:
|
||||
speedup = reinit_time / shrink_time
|
||||
time_saved = reinit_time - shrink_time
|
||||
|
||||
log_test_info(self.rank, "=== PERFORMANCE RESULTS ===")
|
||||
log_test_info(self.rank, f"shrink_group: {shrink_time:.4f}s")
|
||||
log_test_info(self.rank, f"abort+reinit: {reinit_time:.4f}s")
|
||||
log_test_info(self.rank, f"time_saved: {time_saved:+.4f}s")
|
||||
log_test_info(self.rank, f"speedup: {speedup:.2f}x")
|
||||
|
||||
if speedup > 1.1:
|
||||
log_test_success(self.rank, "shrink_group significantly faster")
|
||||
elif speedup > 0.9:
|
||||
log_test_info(self.rank, "≈ comparable performance")
|
||||
else:
|
||||
log_test_warning(self.rank, "abort+reinit faster")
|
||||
|
||||
log_test_info(self.rank, "Performance test completed")
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_deterministic_mode_no_break(self):
|
||||
|
Reference in New Issue
Block a user