mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-23 14:59:34 +08:00
Compare commits
25 Commits
cpp-docs-d
...
ciflow/ind
Author | SHA1 | Date | |
---|---|---|---|
2cdd4617c8 | |||
13e04b57e1 | |||
1c7fe8f861 | |||
4e643422f6 | |||
3c3b278872 | |||
49bbcc5833 | |||
292d8e6caa | |||
80f4d11f12 | |||
99f9d8fee0 | |||
9de148417c | |||
9c0e3db285 | |||
6ed6905d29 | |||
08af8a99e9 | |||
69210a3ecc | |||
2be01db423 | |||
c8f6c13c96 | |||
f8926ed88c | |||
495607e655 | |||
7f603f54d1 | |||
35466f9ef9 | |||
65300600dd | |||
95d52c623d | |||
e2f2803882 | |||
8d6c701f80 | |||
eb161310a5 |
@ -52,7 +52,7 @@ struct FusedAdagradMathFunctor {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
|
||||
C10_DEVICE __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
int64_t chunk_size,
|
||||
FusedOptimizerTensorListMetadata<3>& tl,
|
||||
const float* lr_ptr,
|
||||
const double& lr,
|
||||
@ -133,4 +133,4 @@ struct FusedAdagradMathFunctor {
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace at::native
|
||||
} // namespace at::native
|
||||
|
@ -1,8 +1,8 @@
|
||||
add_loop_eager,compile_time_instruction_count,3070000000,0.1
|
||||
add_loop_eager,compile_time_instruction_count,3184000000,0.1
|
||||
|
||||
|
||||
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.1
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,4595000000,0.1
|
||||
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1048000000,0.1
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1096000000,0.1
|
||||
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ basic_modules_ListOfLinears_inductor,compile_time_instruction_count,15240000000,
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.1
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17720000000,0.1
|
||||
|
||||
|
||||
|
||||
@ -34,11 +34,11 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11090000
|
||||
|
||||
|
||||
|
||||
update_hint_regression,compile_time_instruction_count,1719000000,0.1
|
||||
update_hint_regression,compile_time_instruction_count,1645000000,0.1
|
||||
|
||||
|
||||
|
||||
sum_floordiv_regression,compile_time_instruction_count,3686995725,0.1
|
||||
sum_floordiv_regression,compile_time_instruction_count,3813000000,0.1
|
||||
|
||||
|
||||
|
||||
@ -50,31 +50,31 @@ symint_sum_loop,compile_time_instruction_count,4299000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1869000000,0.1
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1793000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5281000000,0.1
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5120000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8333000000,0.1
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,7936000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1909000000,0.1
|
||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1848000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3442000000,0.1
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3152000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9239000000,0.1
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,8301000000,0.1
|
||||
|
||||
|
||||
|
||||
mm_loop_inductor_gpu,compile_time_instruction_count,4820968837,0.1
|
||||
mm_loop_inductor_gpu,compile_time_instruction_count,4958000000,0.1
|
||||
|
||||
|
||||
|
||||
@ -82,8 +82,8 @@ mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,9051000000,0.1
|
||||
|
||||
|
||||
|
||||
basic_NestedModule_eager,compile_time_instruction_count,9554000000,0.1
|
||||
basic_NestedModule_eager,compile_time_instruction_count,9990000000,0.1
|
||||
|
||||
|
||||
|
||||
basic_InlineMod_eager,compile_time_instruction_count,7618000000,0.1
|
||||
basic_InlineMod_eager,compile_time_instruction_count,8126000000,0.1
|
||||
|
|
@ -394,6 +394,10 @@ an opaque group handle that can be given as a `group` argument to all collective
|
||||
.. autofunction:: new_group
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: torch.distributed.distributed_c10d.shrink_group
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: get_group_rank
|
||||
```
|
||||
|
43
test/distributed/logging_utils.py
Normal file
43
test/distributed/logging_utils.py
Normal file
@ -0,0 +1,43 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
|
||||
_start_time = time.time()
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _ts():
|
||||
return time.time() - _start_time
|
||||
|
||||
|
||||
def configure(level=logging.INFO, force=False):
|
||||
try:
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="%(asctime)s %(name)s %(levelname)s: %(message)s",
|
||||
force=force,
|
||||
)
|
||||
except TypeError:
|
||||
logging.basicConfig(
|
||||
level=level, format="%(asctime)s %(name)s %(levelname)s: %(message)s"
|
||||
)
|
||||
|
||||
|
||||
def log_test_info(rank, message):
|
||||
_logger.info("[%7.3fs][Rank %s] %s", _ts(), rank, message)
|
||||
|
||||
|
||||
def log_test_success(rank, message):
|
||||
_logger.info("[%7.3fs][Rank %s] ✅ %s", _ts(), rank, message)
|
||||
|
||||
|
||||
def log_test_validation(rank, message):
|
||||
_logger.info("[%7.3fs][Rank %s] ✓ %s", _ts(), rank, message)
|
||||
|
||||
|
||||
def log_test_warning(rank, message):
|
||||
_logger.warning("[%7.3fs][Rank %s] ⚠️ %s", _ts(), rank, message)
|
||||
|
||||
|
||||
def log_test_error(rank, message):
|
||||
_logger.error("[%7.3fs][Rank %s] ✗ %s", _ts(), rank, message)
|
@ -2,6 +2,7 @@
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
@ -21,6 +22,7 @@ from unittest import mock, SkipTest
|
||||
import torch
|
||||
import torch.distributed as c10d
|
||||
import torch.distributed._functional_collectives as _functional_collectives
|
||||
from torch.distributed.distributed_c10d import SHRINK_ABORT as NCCL_SHRINK_ABORT
|
||||
|
||||
|
||||
if not c10d.is_available() or not c10d.is_nccl_available():
|
||||
@ -47,12 +49,15 @@ from torch._C._distributed_c10d import ErrorType, OpType, WorkResult
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.testing._internal.common_cuda import _get_torch_rocm_version, TEST_MULTIGPU
|
||||
from torch.testing._internal.common_distributed import (
|
||||
get_required_world_size,
|
||||
get_timeout,
|
||||
init_multigpu_helper,
|
||||
MultiProcessTestCase,
|
||||
requires_multicast_support,
|
||||
requires_nccl,
|
||||
requires_nccl_shrink,
|
||||
requires_nccl_version,
|
||||
requires_world_size,
|
||||
skip_if_lt_x_gpu,
|
||||
skip_if_rocm_multiprocess,
|
||||
sm_is_or_higher_than,
|
||||
@ -87,6 +92,17 @@ BFLOAT16_AVAILABLE = torch.cuda.is_available() and (
|
||||
torch.version.cuda is not None or torch.version.hip is not None
|
||||
)
|
||||
|
||||
from logging_utils import (
|
||||
configure as _log_configure,
|
||||
log_test_info,
|
||||
log_test_success,
|
||||
log_test_validation,
|
||||
log_test_warning,
|
||||
)
|
||||
|
||||
|
||||
_log_configure(level=logging.INFO, force=True)
|
||||
|
||||
|
||||
class RendezvousEnvTest(TestCase):
|
||||
@retry_on_connect_failures
|
||||
@ -317,7 +333,7 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return 2
|
||||
return get_required_world_size(self, 2)
|
||||
|
||||
@property
|
||||
def rank_to_GPU(self):
|
||||
@ -1255,6 +1271,628 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||
pg_2 = c10d.new_group([0, 1])
|
||||
self.assertEqual(pg_2.group_desc, "undefined")
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_basic(self):
|
||||
"""Test basic shrink_group functionality."""
|
||||
self._perform_shrink_test([1], "Basic shrink test")
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_validation(self):
|
||||
"""Test input validation in shrink_group."""
|
||||
device, pg = self._setup_shrink_test("validation")
|
||||
|
||||
def _test_invalid_input(ranks, description, expected_exception):
|
||||
"""Helper to test invalid inputs."""
|
||||
try:
|
||||
c10d.shrink_group(ranks)
|
||||
self.fail(f"Expected {expected_exception.__name__} for {description}")
|
||||
except expected_exception:
|
||||
log_test_validation(self.rank, f"✓ {description}")
|
||||
except Exception:
|
||||
if expected_exception is Exception: # Accept any exception
|
||||
log_test_validation(self.rank, f"✓ {description}")
|
||||
else:
|
||||
raise
|
||||
|
||||
# Test cases
|
||||
_test_invalid_input([], "Empty exclusion list", ValueError)
|
||||
if self.world_size > 1:
|
||||
_test_invalid_input([0, 0, 1], "Duplicate ranks", Exception)
|
||||
_test_invalid_input([self.world_size + 1], "Out of bounds rank", Exception)
|
||||
|
||||
log_test_success(self.rank, "All validation tests passed")
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_backend_properties(self):
|
||||
"""Test that backend properties are preserved after shrinking."""
|
||||
|
||||
test_name = "Backend Properties Test"
|
||||
ranks_to_exclude = [0]
|
||||
|
||||
# Reuse _setup_shrink_test for complete setup (device, environment, and process group)
|
||||
device, pg = self._setup_shrink_test("backend_properties")
|
||||
|
||||
# Follow _perform_shrink_test pattern from here
|
||||
log_test_info(self.rank, f"{test_name} (world_size={self.world_size})")
|
||||
|
||||
is_excluded = self.rank in ranks_to_exclude
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}",
|
||||
)
|
||||
|
||||
# Store original backend property values (not references) before shrinking
|
||||
original_timeout = None
|
||||
original_high_priority = None
|
||||
if not is_excluded:
|
||||
original_backend = pg._get_backend(device)
|
||||
original_timeout = original_backend.options._timeout
|
||||
original_high_priority = original_backend.options.is_high_priority_stream
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Storing original backend properties: timeout={original_timeout}, high_priority={original_high_priority}",
|
||||
)
|
||||
|
||||
if is_excluded:
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Excluded rank {self.rank} - setup complete, skipping shrink operation",
|
||||
)
|
||||
dist.destroy_process_group() # hang without it
|
||||
return
|
||||
|
||||
# Only non-excluded ranks proceed with shrink (same as _perform_shrink_test)
|
||||
log_test_info(self.rank, "Non-excluded rank calling shrink_group")
|
||||
shrunk_pg = c10d.shrink_group(ranks_to_exclude)
|
||||
|
||||
# Reuse _validate_shrunk_group helper (same as _perform_shrink_test)
|
||||
expected_size = self.world_size - len(ranks_to_exclude)
|
||||
_ = self._validate_shrunk_group(shrunk_pg, expected_size, test_name)
|
||||
|
||||
# Add custom backend properties validation
|
||||
new_backend = shrunk_pg._get_backend(device)
|
||||
log_test_info(self.rank, "Validating backend properties are preserved")
|
||||
|
||||
new_timeout = new_backend.options._timeout
|
||||
new_high_priority = new_backend.options.is_high_priority_stream
|
||||
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Timeout comparison - original: {original_timeout}, new: {new_timeout}",
|
||||
)
|
||||
self.assertEqual(
|
||||
original_timeout, new_timeout, f"{test_name}: timeout not preserved"
|
||||
)
|
||||
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"High priority stream comparison - original: {original_high_priority}, new: {new_high_priority}",
|
||||
)
|
||||
self.assertEqual(
|
||||
original_high_priority,
|
||||
new_high_priority,
|
||||
f"{test_name}: high_priority_stream not preserved",
|
||||
)
|
||||
|
||||
log_test_validation(
|
||||
self.rank, f"{test_name}: Backend properties preserved successfully"
|
||||
)
|
||||
log_test_success(
|
||||
self.rank, f"{test_name} successful (shrink + backend validation)"
|
||||
)
|
||||
|
||||
# Cleanup (same as _perform_shrink_test)
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_multiple_comms(self):
|
||||
"""Test shrink_group with multiple communicators and subgroup invalidation."""
|
||||
|
||||
device, pg = self._setup_shrink_test("multiple_comms")
|
||||
|
||||
# Create subgroup [0, 1] and test shrinking it
|
||||
subgroup = c10d.new_group([0, 1])
|
||||
if self.rank <= 1:
|
||||
# Shrink subgroup: exclude rank 1
|
||||
if self.rank == 0: # Only rank 0 remains
|
||||
shrunk_subgroup = c10d.shrink_group([1], group=subgroup)
|
||||
self.assertEqual(shrunk_subgroup.size(), 1)
|
||||
# Test communication on shrunk subgroup
|
||||
tensor = torch.full((1,), self.rank).cuda(device)
|
||||
c10d.all_reduce(tensor, group=shrunk_subgroup)
|
||||
self.assertEqual(tensor.item(), 0) # Only rank 0
|
||||
log_test_success(self.rank, "Subgroup shrinking successful")
|
||||
|
||||
dist.barrier() # Sync before default group test
|
||||
|
||||
# Shrink default group: exclude last rank
|
||||
ranks_to_exclude = [self.world_size - 1]
|
||||
if self.rank not in ranks_to_exclude:
|
||||
shrunk_default = c10d.shrink_group(ranks_to_exclude)
|
||||
expected_size = self.world_size - 1
|
||||
self.assertEqual(shrunk_default.size(), expected_size)
|
||||
|
||||
# Test collective on shrunk default group
|
||||
tensor = torch.full((1,), self.rank).cuda(device)
|
||||
c10d.all_reduce(tensor, group=shrunk_default)
|
||||
expected_sum = sum(
|
||||
range(self.world_size - 1)
|
||||
) # 0 + 1 + ... + (world_size-2)
|
||||
self.assertEqual(tensor.item(), expected_sum)
|
||||
log_test_success(self.rank, "Default group shrinking successful")
|
||||
|
||||
# Note: After shrinking default group, the old subgroup is invalid
|
||||
# due to global rank reassignment
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
def _test_shrink_group_with_flag(self, shrink_flag, flag_name, rank_to_exclude):
|
||||
"""Helper method to test shrink_group with a specific flag."""
|
||||
if self.world_size < 2:
|
||||
log_test_info(self.rank, f"Skipping (needs ≥2 GPUs, got {self.world_size})")
|
||||
return
|
||||
ranks_to_exclude = [rank_to_exclude]
|
||||
log_test_info(self.rank, f"Using {flag_name} flag (value: {shrink_flag})")
|
||||
if flag_name == "NCCL_SHRINK_ABORT":
|
||||
log_test_info(
|
||||
self.rank,
|
||||
"ABORT flag will terminate ongoing operations before shrinking",
|
||||
)
|
||||
|
||||
self._perform_shrink_test(
|
||||
ranks_to_exclude, f"{flag_name} flag test", shrink_flags=shrink_flag
|
||||
)
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_flags(self):
|
||||
"""Test shrink_group with different shrink flags."""
|
||||
# Test ABORT flags
|
||||
log_test_info(self.rank, "Testing NCCL_SHRINK_ABORT flag")
|
||||
self._test_shrink_group_with_flag(NCCL_SHRINK_ABORT, "NCCL_SHRINK_ABORT", 1)
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_nccl_config(self):
|
||||
"""Verify that passing NCCL config via pg_options influences the shrunk group's backend options."""
|
||||
device, pg = self._setup_shrink_test("config")
|
||||
if self.rank == self.world_size - 1:
|
||||
# excluded rank should not call shrink_group
|
||||
dist.destroy_process_group()
|
||||
return
|
||||
|
||||
# Prepare pg_options with NCCL config overrides
|
||||
# Capture parent's current backend options to ensure we can prove override vs inherit
|
||||
parent_backend = pg._get_backend(torch.device("cuda"))
|
||||
parent_hp = parent_backend.options.is_high_priority_stream
|
||||
parent_blocking = parent_backend.options.config.blocking
|
||||
|
||||
# Choose overrides that differ from the parent (flip where possible)
|
||||
override_hp = not parent_hp
|
||||
if parent_blocking in (0, 1):
|
||||
override_blocking = 1 - parent_blocking
|
||||
else:
|
||||
# If undefined or unexpected, set to 1 which is a concrete value
|
||||
override_blocking = 1
|
||||
|
||||
opts = c10d.ProcessGroupNCCL.Options()
|
||||
opts.is_high_priority_stream = override_hp
|
||||
opts.config.blocking = override_blocking
|
||||
|
||||
shrunk_pg = c10d.shrink_group([self.world_size - 1], pg_options=opts)
|
||||
|
||||
# Validate backend options propagated
|
||||
backend = shrunk_pg._get_backend(torch.device("cuda"))
|
||||
# is_high_priority_stream should exactly match our override and differ from parent
|
||||
self.assertEqual(backend.options.is_high_priority_stream, override_hp)
|
||||
self.assertNotEqual(backend.options.is_high_priority_stream, parent_hp)
|
||||
# config is a struct; check representative field and difference from parent when meaningful
|
||||
self.assertEqual(backend.options.config.blocking, override_blocking)
|
||||
if parent_blocking in (0, 1):
|
||||
self.assertNotEqual(backend.options.config.blocking, parent_blocking)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_performance(self):
|
||||
"""Test shrink_group performance and regression detection."""
|
||||
import time
|
||||
|
||||
ranks_to_exclude = self._get_default_ranks_to_exclude()
|
||||
is_excluded = self.rank in ranks_to_exclude
|
||||
|
||||
if not ranks_to_exclude:
|
||||
log_test_info(self.rank, "Skipping performance test (world_size=1)")
|
||||
return
|
||||
|
||||
log_test_info(self.rank, f"Performance test with {self.world_size} processes")
|
||||
device, pg = self._setup_shrink_test("performance")
|
||||
|
||||
if not is_excluded:
|
||||
log_test_info(self.rank, "Measuring shrink_group performance")
|
||||
start_time = time.time()
|
||||
shrunk_pg = c10d.shrink_group(ranks_to_exclude)
|
||||
end_time = time.time()
|
||||
|
||||
elapsed_time = end_time - start_time
|
||||
log_test_info(self.rank, f"shrink_group: {elapsed_time:.3f}s")
|
||||
|
||||
# Regression check: should complete within reasonable time
|
||||
self.assertLess(
|
||||
elapsed_time,
|
||||
30.0,
|
||||
f"shrink_group took {elapsed_time:.3f}s, possible regression",
|
||||
)
|
||||
|
||||
# Test collective performance
|
||||
expected_size = self.world_size - len(ranks_to_exclude)
|
||||
self._validate_shrunk_group(shrunk_pg, expected_size, "performance")
|
||||
|
||||
collective_start = time.time()
|
||||
_ = self._test_collective_on_shrunk_group(
|
||||
shrunk_pg, device, ranks_to_exclude, "performance"
|
||||
)
|
||||
collective_time = time.time() - collective_start
|
||||
|
||||
log_test_info(self.rank, f"all_reduce: {collective_time:.3f}s")
|
||||
log_test_success(self.rank, "Performance test passed")
|
||||
else:
|
||||
log_test_info(self.rank, "Excluded rank - waiting")
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(4)
|
||||
def test_shrink_group_multiple_exclusions(self):
|
||||
"""Test shrink_group with multiple ranks excluded at once."""
|
||||
# Scale exclusions with world size
|
||||
ranks_to_exclude = list(range(2, self.world_size, 2)) # Every other rank from 2
|
||||
|
||||
self._perform_shrink_test(ranks_to_exclude, "Multiple exclusions test")
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(3)
|
||||
def test_shrink_group_multiple_iterations(self):
|
||||
"""Test multiple shrink operations in sequence."""
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Starting test_shrink_group_multiple_iterations with world_size={self.world_size}",
|
||||
)
|
||||
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
_ = self._create_process_group_nccl(store, self.opts(), device_id=device)
|
||||
|
||||
# Track current effective world size throughout shrinking operations
|
||||
current_world_size = self.world_size
|
||||
log_test_info(self.rank, f"Initial world_size: {current_world_size}")
|
||||
|
||||
# First shrinking: exclude the last rank(s)
|
||||
first_exclusion = [self.world_size - 1]
|
||||
if self.world_size >= 6:
|
||||
first_exclusion.append(
|
||||
self.world_size - 2
|
||||
) # Exclude last two ranks for larger sizes
|
||||
|
||||
log_test_info(self.rank, f"First shrinking: excluding ranks {first_exclusion}")
|
||||
|
||||
if self.rank not in first_exclusion:
|
||||
# Only non-excluded ranks should call shrink_group
|
||||
first_pg = c10d.shrink_group(first_exclusion)
|
||||
self.assertIsNotNone(first_pg)
|
||||
# IMPORTANT: Update world size after first shrinking
|
||||
current_world_size = first_pg.size()
|
||||
expected_first_size = self.world_size - len(first_exclusion)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"After first shrinking: world_size {self.world_size} -> {current_world_size}",
|
||||
)
|
||||
self.assertEqual(first_pg.size(), expected_first_size)
|
||||
|
||||
# Second shrinking: exclude another rank from the remaining group
|
||||
# Choose a rank that's in the middle range
|
||||
if current_world_size >= 3:
|
||||
second_exclusion = [
|
||||
current_world_size - 1
|
||||
] # Exclude the new "last" rank
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Second shrinking from group of size {current_world_size}: excluding ranks {second_exclusion}",
|
||||
)
|
||||
|
||||
if self.rank not in second_exclusion:
|
||||
# Only non-excluded ranks should call shrink_group for second iteration
|
||||
second_pg = c10d.shrink_group(second_exclusion, group=first_pg)
|
||||
self.assertIsNotNone(second_pg)
|
||||
# IMPORTANT: Update world size after second shrinking
|
||||
final_world_size = second_pg.size()
|
||||
expected_final_size = current_world_size - len(second_exclusion)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"After second shrinking: world_size {current_world_size} -> {final_world_size}",
|
||||
)
|
||||
self.assertEqual(second_pg.size(), expected_final_size)
|
||||
|
||||
# Test collective on final group
|
||||
tensor = torch.full((1,), self.rank).cuda(device)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Performing all_reduce on final group (size {final_world_size}) with tensor: {tensor.item()}",
|
||||
)
|
||||
c10d.all_reduce(tensor, group=second_pg)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Final all_reduce completed, result: {tensor.item()}",
|
||||
)
|
||||
|
||||
# Calculate expected sum of remaining ranks
|
||||
all_excluded = set(first_exclusion + second_exclusion)
|
||||
remaining_ranks = [
|
||||
r for r in range(self.world_size) if r not in all_excluded
|
||||
]
|
||||
expected_sum = sum(remaining_ranks)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Remaining ranks: {remaining_ranks}, expected sum: {expected_sum}, actual: {tensor.item()}",
|
||||
)
|
||||
self.assertEqual(tensor.item(), expected_sum)
|
||||
log_test_info(self.rank, "Final verification passed")
|
||||
else:
|
||||
log_test_info(
|
||||
self.rank,
|
||||
"This rank excluded in second shrinking, not calling shrink_group",
|
||||
)
|
||||
else:
|
||||
log_test_info(
|
||||
self.rank, "Skipping second shrinking (remaining group too small)"
|
||||
)
|
||||
else:
|
||||
log_test_info(
|
||||
self.rank,
|
||||
"This rank excluded in first shrinking, not calling shrink_group",
|
||||
)
|
||||
|
||||
log_test_info(self.rank, "Destroying process group")
|
||||
dist.destroy_process_group()
|
||||
log_test_info(self.rank, "test_shrink_group_multiple_iterations completed")
|
||||
|
||||
# Helper methods for optimized shrink group tests
|
||||
def _setup_shrink_test(self, test_suffix, world_size=None, warmup=True):
|
||||
"""Common setup for shrink group tests."""
|
||||
os.environ["TORCH_NCCL_USE_COMM_NONBLOCKING"] = "1"
|
||||
world_size = world_size or self.world_size
|
||||
store = c10d.FileStore(self.file_name + f"_{test_suffix}", world_size)
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
c10d.init_process_group(
|
||||
"nccl",
|
||||
world_size=world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
pg_options=self.opts(),
|
||||
device_id=device,
|
||||
)
|
||||
pg = c10d.distributed_c10d._get_default_group()
|
||||
|
||||
if warmup:
|
||||
c10d.all_reduce(torch.ones(1).cuda(device), group=pg)
|
||||
|
||||
return device, pg
|
||||
|
||||
def _validate_shrunk_group(self, shrunk_pg, expected_size, test_name=""):
|
||||
"""Validate properties of a shrunk process group."""
|
||||
self.assertIsNotNone(shrunk_pg, f"{test_name}: shrunk_pg should not be None")
|
||||
actual_size = shrunk_pg.size()
|
||||
self.assertEqual(
|
||||
actual_size, expected_size, f"{test_name}: group size mismatch"
|
||||
)
|
||||
|
||||
new_rank = shrunk_pg.rank()
|
||||
self.assertTrue(
|
||||
0 <= new_rank < expected_size, f"{test_name}: invalid new rank {new_rank}"
|
||||
)
|
||||
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"{test_name}: world_size {self.world_size} -> {actual_size}, rank {self.rank} -> {new_rank}",
|
||||
)
|
||||
return new_rank
|
||||
|
||||
def _test_collective_on_shrunk_group(
|
||||
self, shrunk_pg, device, ranks_to_exclude, test_name=""
|
||||
):
|
||||
"""Test collective communication on shrunk group and verify correctness."""
|
||||
test_tensor = torch.full((1,), self.rank, device=device, dtype=torch.float32)
|
||||
c10d.all_reduce(test_tensor, group=shrunk_pg)
|
||||
|
||||
result = test_tensor.item()
|
||||
expected_sum = sum(
|
||||
r for r in range(self.world_size) if r not in ranks_to_exclude
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
result, expected_sum, f"{test_name}: collective result mismatch"
|
||||
)
|
||||
log_test_info(
|
||||
self.rank, f"{test_name}: collective passed ({result} == {expected_sum})"
|
||||
)
|
||||
return result
|
||||
|
||||
def _perform_shrink_test(
|
||||
self, ranks_to_exclude, test_name, shrink_flags=0, with_collective=True
|
||||
):
|
||||
"""Complete shrink test flow: setup, shrink, validate, test collective, cleanup.
|
||||
|
||||
Consistent API: All ranks perform setup to initialize distributed environment.
|
||||
ONLY non-excluded ranks call shrink_group() for both default and non-default groups.
|
||||
Excluded ranks perform setup, then exit without calling shrink_group() or waiting.
|
||||
"""
|
||||
log_test_info(self.rank, f"{test_name} (world_size={self.world_size})")
|
||||
|
||||
is_excluded = self.rank in ranks_to_exclude
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}",
|
||||
)
|
||||
|
||||
# All ranks (including excluded ones) perform setup to initialize distributed environment
|
||||
device, pg = self._setup_shrink_test(test_name.lower().replace(" ", "_"))
|
||||
is_default_group = pg == c10d.distributed_c10d._get_default_group()
|
||||
|
||||
if is_excluded:
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Excluded rank {self.rank} - setup complete, skipping shrink operation",
|
||||
)
|
||||
if shrink_flags & NCCL_SHRINK_ABORT:
|
||||
log_test_info(self.rank, f"Using abort for excluded rank {self.rank}")
|
||||
pg._get_backend(torch.device(device)).abort()
|
||||
log_test_info(
|
||||
self.rank, f"cleanup resources for excluded rank {self.rank}"
|
||||
)
|
||||
dist.destroy_process_group()
|
||||
log_test_info(self.rank, f"Excluded rank {self.rank} - exit")
|
||||
else:
|
||||
log_test_info(
|
||||
self.rank, f"Using regular destroy for excluded rank {self.rank}"
|
||||
)
|
||||
dist.destroy_process_group()
|
||||
return None
|
||||
|
||||
# Only non-excluded ranks proceed with shrink
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Non-excluded rank calling shrink_group (default_group={is_default_group})",
|
||||
)
|
||||
shrunk_pg = c10d.shrink_group(ranks_to_exclude, shrink_flags=shrink_flags)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Non-excluded rank calling shrink_group (default_group={is_default_group}) done",
|
||||
)
|
||||
|
||||
# Non-excluded ranks: validate and test the new group
|
||||
expected_size = self.world_size - len(ranks_to_exclude)
|
||||
_ = self._validate_shrunk_group(shrunk_pg, expected_size, test_name)
|
||||
|
||||
if with_collective:
|
||||
_ = self._test_collective_on_shrunk_group(
|
||||
shrunk_pg, device, ranks_to_exclude, test_name
|
||||
)
|
||||
log_test_success(self.rank, f"{test_name} successful (shrink + collective)")
|
||||
else:
|
||||
log_test_success(self.rank, f"{test_name} successful (shrink only)")
|
||||
|
||||
dist.destroy_process_group()
|
||||
return shrunk_pg
|
||||
|
||||
def _get_default_ranks_to_exclude(self):
|
||||
"""Get default ranks to exclude based on world size."""
|
||||
if self.world_size <= 1:
|
||||
return []
|
||||
return [self.world_size - 1] # Exclude last rank by default
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(3)
|
||||
def test_shrink_group_vs_abort_reinit_performance(self):
|
||||
"""Compare performance of shrink_group vs traditional abort+reinit (simplified for reliability)."""
|
||||
log_test_info(self.rank, "=== TEST 1: abort+reinit ===")
|
||||
|
||||
device, pg1 = self._setup_shrink_test("_perf_reinit")
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
# Test 1: Traditional abort + reinit
|
||||
start_time = time.perf_counter()
|
||||
dist.destroy_process_group()
|
||||
|
||||
device, new_pg = self._setup_shrink_test("perf_shrink_test1")
|
||||
reinit_time = time.perf_counter() - start_time
|
||||
|
||||
# Test collective with original rank values for fair comparison (non-blocking mode)
|
||||
test_tensor = torch.full((1,), self.rank, device=device, dtype=torch.float32)
|
||||
work = c10d.all_reduce(test_tensor, group=new_pg, async_op=True)
|
||||
work.wait()
|
||||
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
# Verify correctness
|
||||
expected_sum = sum(r for r in range(self.world_size))
|
||||
self.assertEqual(test_tensor.item(), expected_sum, "Reinit collective failed")
|
||||
|
||||
log_test_info(self.rank, f"abort+reinit: {reinit_time:.4f}s")
|
||||
dist.destroy_process_group(new_pg)
|
||||
|
||||
# Test 2: shrink_group with NCCL_SHRINK_ABORT
|
||||
log_test_info(self.rank, "=== TEST 2: shrink_group ===")
|
||||
|
||||
ranks_to_exclude = [self.world_size - 1]
|
||||
is_excluded = self.rank in ranks_to_exclude
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}",
|
||||
)
|
||||
|
||||
device, pg1 = self._setup_shrink_test("perf_shrink_test2") # Unique suffix
|
||||
|
||||
shrink_time = 0
|
||||
if not is_excluded:
|
||||
torch.cuda.synchronize(device) # Ensure accurate timing
|
||||
start_time = time.perf_counter()
|
||||
shrunk_pg = c10d.shrink_group(
|
||||
ranks_to_exclude, shrink_flags=NCCL_SHRINK_ABORT
|
||||
)
|
||||
c10d.all_reduce(torch.ones(1).cuda(device), group=shrunk_pg)
|
||||
shrink_time = time.perf_counter() - start_time
|
||||
|
||||
# Test collective communication on shrunk group (non-blocking mode)
|
||||
test_tensor = torch.full(
|
||||
(1,), self.rank, device=device, dtype=torch.float32
|
||||
)
|
||||
work = c10d.all_reduce(test_tensor, group=shrunk_pg, async_op=True)
|
||||
work.wait()
|
||||
|
||||
# Verify correctness
|
||||
expected_sum = sum(
|
||||
r for r in range(self.world_size) if r not in ranks_to_exclude
|
||||
)
|
||||
self.assertEqual(
|
||||
test_tensor.item(),
|
||||
expected_sum,
|
||||
"shrink_test: collective result mismatch",
|
||||
)
|
||||
|
||||
torch.cuda.synchronize(device) # Ensure operations complete
|
||||
log_test_info(self.rank, f"shrink_group: {shrink_time:.4f}s")
|
||||
dist.destroy_process_group()
|
||||
else:
|
||||
log_test_info(self.rank, "Excluded from shrink test - exiting immediately")
|
||||
dist.destroy_process_group()
|
||||
return
|
||||
|
||||
# Performance analysis (only for participating ranks)
|
||||
if shrink_time > 0 and reinit_time > 0:
|
||||
speedup = reinit_time / shrink_time
|
||||
time_saved = reinit_time - shrink_time
|
||||
|
||||
log_test_info(self.rank, "=== PERFORMANCE RESULTS ===")
|
||||
log_test_info(self.rank, f"shrink_group: {shrink_time:.4f}s")
|
||||
log_test_info(self.rank, f"abort+reinit: {reinit_time:.4f}s")
|
||||
log_test_info(self.rank, f"time_saved: {time_saved:+.4f}s")
|
||||
log_test_info(self.rank, f"speedup: {speedup:.2f}x")
|
||||
|
||||
if speedup > 1.1:
|
||||
log_test_success(self.rank, "shrink_group significantly faster")
|
||||
elif speedup > 0.9:
|
||||
log_test_info(self.rank, "≈ comparable performance")
|
||||
else:
|
||||
log_test_warning(self.rank, "abort+reinit faster")
|
||||
|
||||
log_test_info(self.rank, "Performance test completed")
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_deterministic_mode_no_break(self):
|
||||
|
@ -153,7 +153,9 @@ def _get_custom_policy(no_recompute_list=None, must_recompute_list=None):
|
||||
return _custom_policy
|
||||
|
||||
|
||||
class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase):
|
||||
class ActivationCheckpointingViaTagsTests(
|
||||
torch._dynamo.test_case.TestCaseWithNestedGraphBreaks
|
||||
):
|
||||
def _validate(
|
||||
self,
|
||||
fn,
|
||||
|
@ -77,7 +77,7 @@ def customized_ctx_manager_with_graph_break(mode):
|
||||
torch._C._set_grad_enabled(prev)
|
||||
|
||||
|
||||
class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
||||
class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
|
||||
def test_no_grad(self):
|
||||
def fn1(a, b):
|
||||
x = a + 1
|
||||
@ -1706,7 +1706,7 @@ class GraphModule(torch.nn.Module):
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_f = torch.compile(f, backend=cnts)
|
||||
opt_f(torch.randn(2, 2, 2, 2).to(dtype=torch.float16))
|
||||
self.assertEqual(cnts.frame_count, 3)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
# test sdpa_kernel graph break with 2 arguments
|
||||
def test_sdpa_kernel_ctx_manager3(self):
|
||||
@ -1836,14 +1836,18 @@ class GraphModule(torch.nn.Module):
|
||||
self.assertEqual(len(counters["graph_break"]), 1)
|
||||
|
||||
|
||||
class ContextlibContextManagerTests(torch._dynamo.test_case.TestCase):
|
||||
class ContextlibContextManagerTests(
|
||||
torch._dynamo.test_case.TestCaseWithNestedGraphBreaks
|
||||
):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._prev = torch._dynamo.config.enable_trace_contextlib
|
||||
self._u_prev = torch._dynamo.config.enable_trace_unittest
|
||||
torch._dynamo.config.enable_trace_contextlib = True
|
||||
torch._dynamo.config.enable_trace_unittest = True
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
torch._dynamo.config.enable_trace_contextlib = self._prev
|
||||
torch._dynamo.config.enable_trace_unittest = self._u_prev
|
||||
|
||||
|
@ -17,7 +17,7 @@ from torch.testing._internal.common_utils import (
|
||||
)
|
||||
|
||||
|
||||
class GeneratorTestsBase(torch._dynamo.test_case.TestCase):
|
||||
class GeneratorTestsBase(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._old = torch._dynamo.config.enable_faithful_generator_behavior
|
||||
|
@ -1,86 +1,29 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch._dynamo import config
|
||||
from torch._dynamo.testing import make_test_cls_with_patches
|
||||
|
||||
|
||||
try:
|
||||
# from . import test_ctx_manager
|
||||
pass
|
||||
except ImportError:
|
||||
# import test_aot_autograd
|
||||
# import test_ctx_manager
|
||||
|
||||
# import test_export
|
||||
# import test_functions
|
||||
# import test_higher_order_ops
|
||||
# import test_misc
|
||||
# import test_modules
|
||||
# import test_repros
|
||||
# import test_sdpa
|
||||
# import test_subgraphs
|
||||
pass
|
||||
|
||||
|
||||
test_classes = {}
|
||||
|
||||
|
||||
def make_nested_cls(cls):
|
||||
suffix = "_nested_graph_breaks"
|
||||
|
||||
cls_prefix = "NestedGraphBreaks"
|
||||
|
||||
test_class = make_test_cls_with_patches(
|
||||
cls,
|
||||
cls_prefix,
|
||||
suffix,
|
||||
(config, "debug_force_nested_calls", True),
|
||||
(config, "debug_force_graph_break_on_leaf_return", True),
|
||||
(config, "debug_disable_compile_counter", True),
|
||||
xfail_prop="_expected_failure_nested_graph_breaks",
|
||||
)
|
||||
|
||||
test_classes[test_class.__name__] = test_class
|
||||
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
|
||||
# globals()[test_class.__name__] = test_class
|
||||
test_class.__module__ = __name__
|
||||
return test_class
|
||||
|
||||
|
||||
tests = [
|
||||
# test_ctx_manager.CtxManagerTests,
|
||||
# test_functions.FunctionTests,
|
||||
# test_misc.MiscTests,
|
||||
# test_repros.ReproTests,
|
||||
# test_modules.NNModuleTests,
|
||||
# test_subgraphs.SubGraphTests,
|
||||
# test_higher_order_ops.HigherOrderOpTests,
|
||||
# test_higher_order_ops.FuncTorchHigherOrderOpTests,
|
||||
# test_aot_autograd.AotAutogradFallbackTests,
|
||||
# test_sdpa.TestSDPA,
|
||||
]
|
||||
test = None
|
||||
for test in tests:
|
||||
make_nested_cls(test)
|
||||
del test
|
||||
|
||||
|
||||
# for use in test_side_effects_globals
|
||||
global1, global2, global3, global4 = (torch.zeros(3),) * 4
|
||||
|
||||
|
||||
class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
torch._dynamo.config.nested_graph_breaks = True
|
||||
class CustomizedCtxManager:
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
torch._dynamo.graph_break()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
torch._dynamo.config.nested_graph_breaks = False
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
pass
|
||||
|
||||
|
||||
class NestedGraphBreakTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
|
||||
def test_single_graph_break(self):
|
||||
# NOTE marking f1, f2, f3 as global
|
||||
# prevents them from being freevars
|
||||
@ -211,6 +154,31 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
self.assertEqual(cnts.op_count, 14)
|
||||
|
||||
def test_counters(self):
|
||||
global f1, f2, f3, f4
|
||||
|
||||
def f1(x):
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
return x + 2
|
||||
|
||||
def f2(x):
|
||||
return f1(x + 4) + 8
|
||||
|
||||
def f3(x):
|
||||
x = x + 16
|
||||
for _ in range(1):
|
||||
x = f2(x)
|
||||
return x + 32
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def f4(x):
|
||||
return f3(x + 64) + 128
|
||||
|
||||
self.assertEqual(f4(torch.zeros(3)), torch.zeros(3) + 255)
|
||||
self.assertEqual(len(torch._dynamo.utils.counters["graph_break"]), 2)
|
||||
breakpoint()
|
||||
|
||||
def test_supported_ctx_manager(self):
|
||||
global check, check_disabled, f1, f2, f3
|
||||
|
||||
@ -612,6 +580,211 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(cnts.op_count, 4)
|
||||
self.assertEqual(torch._dynamo.utils.counters["frames"]["total"], 3)
|
||||
|
||||
def test_generator_nested_graph_break(self):
|
||||
def gen(x):
|
||||
yield x + 1
|
||||
torch._dynamo.graph_break()
|
||||
yield x + 2
|
||||
|
||||
def fn(x):
|
||||
x = x + 4
|
||||
return list(gen(x))
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(backend=cnts)(fn)
|
||||
x = torch.zeros(3)
|
||||
res = fn(x)
|
||||
# NOTE: if we enable nested graph breaks on inlined generators, we expect
|
||||
# some sort of internal dynamo failure
|
||||
ref = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
# fn should be skipped
|
||||
self.assertEqual(cnts.frame_count, 0)
|
||||
|
||||
def outer(x):
|
||||
x = x + 8
|
||||
return fn(x)[0] + 16
|
||||
|
||||
cnts.clear()
|
||||
torch.compiler.reset()
|
||||
|
||||
opt_fn = torch._dynamo.optimize(backend=cnts)(outer)
|
||||
x = torch.zeros(3)
|
||||
res = outer(x)
|
||||
ref = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
# only outer should be traced
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
self.assertEqual(cnts.op_count, 2)
|
||||
|
||||
def test_return_after_graph_break_nested(self):
|
||||
# With improper implementation, returning immediately after a nested graph
|
||||
# break may skip the rest of the top-level frame.
|
||||
def f2(inner, x):
|
||||
x += 2
|
||||
return inner(x)
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def f3(inner, x):
|
||||
result = f2(inner, x)
|
||||
x += 4
|
||||
if result is not None:
|
||||
x += result
|
||||
return x
|
||||
|
||||
# test normal graph break
|
||||
x = torch.zeros(3)
|
||||
|
||||
def inner1(x):
|
||||
x += 1
|
||||
return torch._dynamo.graph_break()
|
||||
|
||||
ref = f3(inner1, x)
|
||||
self.assertEqual(ref, torch.zeros(3) + 7)
|
||||
|
||||
# test step graph break
|
||||
x = torch.zeros(3)
|
||||
|
||||
def inner2(x):
|
||||
x += 1
|
||||
return torch._dynamo.step_unsupported()
|
||||
|
||||
ref = f3(inner2, x)
|
||||
self.assertEqual(ref, torch.zeros(3) + 7)
|
||||
|
||||
# test store attr graph break
|
||||
# NOTE: we do this manual bytecode generation hack since the only RETURN_*
|
||||
# instruction that can follow STORE_ATTR is RETURN_CONST, which was removed in 3.14+.
|
||||
|
||||
# make sure inner3's code options are compatible with the instructions below
|
||||
def inner3(x):
|
||||
x.attr = 1000
|
||||
|
||||
new_inst = torch._dynamo.bytecode_transformation.create_instruction
|
||||
insts = [
|
||||
new_inst("LOAD_CONST", argval=1000),
|
||||
new_inst("LOAD_CONST", argval=2000),
|
||||
new_inst("LOAD_FAST", argval="x"),
|
||||
new_inst("STORE_ATTR", argval="attr"),
|
||||
new_inst("RETURN_VALUE"),
|
||||
]
|
||||
if sys.version_info >= (3, 11):
|
||||
insts = [new_inst("RESUME", arg=0)] + insts
|
||||
code_keys = torch._dynamo.bytecode_transformation.get_code_keys()
|
||||
code_options = {k: getattr(inner3.__code__, k) for k in code_keys}
|
||||
_, inner3_code = (
|
||||
torch._dynamo.bytecode_transformation.clean_and_assemble_instructions(
|
||||
insts, code_keys, code_options
|
||||
)
|
||||
)
|
||||
inner3.__code__ = inner3_code
|
||||
|
||||
x = torch.zeros(3)
|
||||
ref = f3(inner3, x)
|
||||
self.assertEqual(ref, torch.zeros(3) + 1006)
|
||||
|
||||
# dynamic branching is harder to test - the other tests should be enough cover
|
||||
|
||||
# test every function returning
|
||||
@torch.compiler.disable
|
||||
def inner5(x):
|
||||
x += 8
|
||||
return x
|
||||
|
||||
def inner4(x):
|
||||
x += 1
|
||||
return inner5(x)
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def f4(x):
|
||||
x += 4
|
||||
return f2(inner4, x)
|
||||
|
||||
x = torch.zeros(3)
|
||||
ref = f4(x)
|
||||
self.assertEqual(ref, torch.zeros(3) + 15)
|
||||
|
||||
def test_return_after_graph_break_deep_nested(self):
|
||||
@torch.compiler.disable
|
||||
def f1(x):
|
||||
return x + 1
|
||||
|
||||
def f2(x):
|
||||
return f1(x + 2)
|
||||
|
||||
def f3(x):
|
||||
return f2(x + 4)
|
||||
|
||||
def f4(x):
|
||||
x = f3(x + 8)
|
||||
return x + 16
|
||||
|
||||
def f5(x):
|
||||
return f4(x + 32)
|
||||
|
||||
def f6(x):
|
||||
return f5(x + 64)
|
||||
|
||||
def f7(x):
|
||||
x = f6(x + 128)
|
||||
return x + 256
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def f8(x):
|
||||
return f7(x + 512)
|
||||
|
||||
x = torch.zeros(3)
|
||||
ref = f8(x)
|
||||
self.assertEqual(ref, torch.zeros(3) + 1023)
|
||||
|
||||
# check that only 2 resume functions are created
|
||||
self.assertEqual(len(torch._dynamo.utils.counters["resumes"]), 2)
|
||||
for name in ("resume_in_f4", "resume_in_f7"):
|
||||
self.assertTrue(
|
||||
any(
|
||||
name in key
|
||||
for key in torch._dynamo.utils.counters["resumes"].keys()
|
||||
)
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_nested_decorated_function(self):
|
||||
# decorator must call ContextWrappingVariable.cleanup_assert to trigger this test
|
||||
def f(x):
|
||||
@torch.autocast("cpu")
|
||||
def inner(y):
|
||||
y = y + 1
|
||||
torch._dynamo.graph_break()
|
||||
return y + 1
|
||||
|
||||
return inner(x)
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(backend=cnts)(f)
|
||||
x = torch.zeros(3)
|
||||
res = f(x)
|
||||
ref = opt_fn(x)
|
||||
print(ref, res)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
self.assertEqual(cnts.op_count, 6)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_nested_graph_break_in_custom_ctx_manager_init(self):
|
||||
def f(x):
|
||||
with CustomizedCtxManager(x):
|
||||
return x + 1
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(backend=cnts)(f)
|
||||
x = torch.zeros(3)
|
||||
res = f(x)
|
||||
ref = opt_fn(x)
|
||||
print(ref, res)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
self.assertEqual(cnts.op_count, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
136
test/dynamo/test_nested_graph_breaks_wrapped.py
Normal file
136
test/dynamo/test_nested_graph_breaks_wrapped.py
Normal file
@ -0,0 +1,136 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch._dynamo.testing import make_test_cls_with_patches
|
||||
|
||||
|
||||
try:
|
||||
from . import test_activation_checkpointing, test_ctx_manager, test_misc
|
||||
except ImportError:
|
||||
import test_activation_checkpointing
|
||||
import test_ctx_manager
|
||||
import test_misc
|
||||
|
||||
|
||||
test_classes = {}
|
||||
|
||||
|
||||
def make_nested_cls(cls, strong):
|
||||
config = torch._dynamo.config
|
||||
|
||||
if strong:
|
||||
# A strong nested graph break test - will graph break at every leaf function's return
|
||||
test_class = make_test_cls_with_patches(
|
||||
cls,
|
||||
"NestedGraphBreaksStrong",
|
||||
"_nested_graph_breaks_strong",
|
||||
(config, "nested_graph_breaks", True),
|
||||
(config, "debug_force_nested_calls", True),
|
||||
(config, "debug_force_graph_break_on_leaf_return", True),
|
||||
(config, "debug_disable_compile_counter", True),
|
||||
xfail_prop="_expected_failure_nested_graph_breaks_strong",
|
||||
)
|
||||
else:
|
||||
test_class = make_test_cls_with_patches(
|
||||
cls,
|
||||
"NestedGraphBreaks",
|
||||
"_nested_graph_breaks",
|
||||
(config, "nested_graph_breaks", True),
|
||||
(config, "debug_force_nested_calls", True),
|
||||
(config, "debug_disable_compile_counter", True),
|
||||
xfail_prop="_expected_failure_nested_graph_breaks",
|
||||
)
|
||||
|
||||
test_classes[test_class.__name__] = test_class
|
||||
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
|
||||
globals()[test_class.__name__] = test_class
|
||||
test_class.__module__ = __name__
|
||||
|
||||
|
||||
tests = [
|
||||
getattr(
|
||||
test_activation_checkpointing, "ActivationCheckpointingViaTagsTestsCUDA", None
|
||||
),
|
||||
test_ctx_manager.CtxManagerTests,
|
||||
test_misc.MiscTests,
|
||||
]
|
||||
|
||||
strong_tests = []
|
||||
test = None
|
||||
for test in tests:
|
||||
if not test:
|
||||
continue
|
||||
make_nested_cls(test, False)
|
||||
|
||||
for test in strong_tests:
|
||||
make_nested_cls(test, True)
|
||||
|
||||
del test
|
||||
|
||||
xfails = [
|
||||
# multiple exit due to nested graph break in decorator
|
||||
# NestedGraphBreaksStrongCtxManagerTests.test_disable_saved_tensors_hooks_nested_graph_breaks_strong, # noqa: F821
|
||||
# NestedGraphBreaksStrongCtxManagerTests.test_disable_saved_tensors_hooks_prev_disabled_nested_graph_breaks_strong, # noqa: F821
|
||||
# NestedGraphBreaksStrongCtxManagerTests.test_disable_saved_tensors_hooks_prev_disabled_nested_nested_graph_breaks_strong, # noqa: F821
|
||||
# graph break in context manager __init__
|
||||
# NestedGraphBreaksStrongCtxManagerTests.test_generic_context_manager_CustomizedCtxManager_nested_graph_breaks_strong, # noqa: F821
|
||||
# NestedGraphBreaksStrongCtxManagerTests.test_generic_context_manager_customized_ctx_manager_nested_graph_breaks_strong, # noqa: F821
|
||||
# NestedGraphBreaksStrongCtxManagerTests.test_generic_context_manager_with_graph_break_CustomizedCtxManager_nested_graph_breaks_strong, # noqa: F821
|
||||
# NestedGraphBreaksStrongCtxManagerTests.test_generic_context_manager_with_graph_break_customized_ctx_manager_nested_graph_breaks_strong, # noqa: F821
|
||||
# NestedGraphBreaksStrongCtxManagerTests.test_generic_ctx_manager_with_graph_break_CustomizedCtxManagerWithGraphBreak_nested_graph_breaks_strong, # noqa: F821
|
||||
# NestedGraphBreaksStrongCtxManagerTests.test_generic_ctx_manager_with_graph_break_customized_ctx_manager_with_graph_break_nested_graph_breaks_strong, # noqa: F821
|
||||
# NestedGraphBreaksStrongCtxManagerTests.test_nested_generic_context_manager_CustomizedCtxManager_nested_graph_breaks_strong, # noqa: F821
|
||||
# NestedGraphBreaksStrongCtxManagerTests.test_nested_generic_context_manager_customized_ctx_manager_nested_graph_breaks_strong, # noqa: F821
|
||||
# NestedGraphBreaksStrongCtxManagerTests.test_nested_generic_context_manager_with_graph_break_CustomizedCtxManager_nested_graph_breaks_strong, # noqa: F821
|
||||
# NestedGraphBreaksStrongCtxManagerTests.test_nested_generic_context_manager_with_graph_break_customized_ctx_manager_nested_graph_breaks_strong, # noqa: F821
|
||||
# NestedGraphBreaksStrongCtxManagerTests.test_return_context_manager_nested_graph_breaks_strong, # noqa: F821
|
||||
# NestedGraphBreaksStrongCtxManagerTests.test_return_context_manager_with_graph_break_nested_graph_breaks_strong, # noqa: F821
|
||||
# recursion limit exceeded
|
||||
# NestedGraphBreaksStrongCtxManagerTests.test_cuda_stream_compared_with_constant_nested_graph_breaks_strong, # noqa: F821
|
||||
# variable naming issues
|
||||
NestedGraphBreaksMiscTests.test_flat_name_to_original_fqn_nested_graph_breaks, # noqa: F821
|
||||
NestedGraphBreaksMiscTests.test_compare_shapes_with_constant_nested_graph_breaks, # noqa: F821
|
||||
NestedGraphBreaksMiscTests.test_guard_failure_fn2_nested_graph_breaks, # noqa: F821
|
||||
NestedGraphBreaksMiscTests.test_guard_failure_fn_shape_control_nested_graph_breaks, # noqa: F821
|
||||
NestedGraphBreaksMiscTests.test_guard_failure_fn_tensor_iter_nested_graph_breaks, # noqa: F821
|
||||
NestedGraphBreaksMiscTests.test_guard_filter_fn_by_name_and_value_nested_graph_breaks, # noqa: F821
|
||||
NestedGraphBreaksMiscTests.test_guard_sym_node_fstring_when_used_nested_graph_breaks, # noqa: F821
|
||||
NestedGraphBreaksMiscTests.test_symint_as_device_kwarg_multi_gpu_nested_graph_breaks, # noqa: F821
|
||||
NestedGraphBreaksMiscTests.test_sys_modules_nested_graph_breaks, # noqa: F821
|
||||
# counters["graph_breaks"] issues
|
||||
NestedGraphBreaksMiscTests.test_data_ptr_graph_break_aten_nested_graph_breaks, # noqa: F821
|
||||
# nested graph break removes duplicate graph break
|
||||
NestedGraphBreaksMiscTests.test_duplicate_graph_break_log_nested_graph_breaks, # noqa: F821
|
||||
# doesn't work due to debug_force_nested_calls wrapping the top frame
|
||||
NestedGraphBreaksMiscTests.test_dynamo_cache_invalidate_nested_graph_breaks, # noqa: F821
|
||||
NestedGraphBreaksMiscTests.test_dynamo_cache_move_to_front_nested_graph_breaks, # noqa: F821
|
||||
NestedGraphBreaksMiscTests.test_dynamo_reset_clears_cache_nested_graph_breaks, # noqa: F821
|
||||
NestedGraphBreaksMiscTests.test_fail_on_recompile_error_message_nested_graph_breaks, # noqa: F821
|
||||
NestedGraphBreaksMiscTests.test_get_cache_entry_nested_graph_breaks, # noqa: F821
|
||||
NestedGraphBreaksMiscTests.test_getattrvariable_as_python_constant_nested_graph_breaks, # noqa: F821
|
||||
NestedGraphBreaksMiscTests.test_precompile_entries_nested_graph_breaks, # noqa: F821
|
||||
NestedGraphBreaksMiscTests.test_precompile_entry_hit_nested_graph_breaks, # noqa: F821
|
||||
NestedGraphBreaksMiscTests.test_precompile_fail_on_recompile_nested_graph_breaks, # noqa: F821
|
||||
NestedGraphBreaksMiscTests.test_torch_guards_stack_frame_register_inlining_deep_nested_graph_breaks, # noqa: F821
|
||||
NestedGraphBreaksMiscTests.test_torch_guards_stack_frame_register_inlining_nested_graph_breaks, # noqa: F821
|
||||
# differing op_count
|
||||
NestedGraphBreaksMiscTests.test_nested_closure_nested_graph_breaks, # noqa: F821
|
||||
NestedGraphBreaksMiscTests.test_return_nested_function_nested_graph_breaks, # noqa: F821
|
||||
# unknown
|
||||
NestedGraphBreaksMiscTests.test_inspect_signature_bind_non_user_function_nested_graph_breaks, # noqa: F821
|
||||
]
|
||||
|
||||
case = None
|
||||
|
||||
for case in xfails:
|
||||
unittest.expectedFailure(case)
|
||||
|
||||
del case, xfails
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
@ -1,4 +1,5 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import functools
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -548,6 +549,30 @@ class RecompileTests(torch._dynamo.test_case.TestCase):
|
||||
f(x, foo1)
|
||||
self.assertEqual(counter.frame_count, 2)
|
||||
|
||||
@dc.patch(recompile_limit=1, fail_on_recompile_limit_hit=True)
|
||||
def test_wrap_inline_recompiles(self):
|
||||
inp = torch.ones(3)
|
||||
|
||||
wrap_fns = (
|
||||
torch._dynamo.external_utils.wrap_inline,
|
||||
functools.partial(
|
||||
torch._dynamo.external_utils.wrap_inline_with_error_on_graph_break,
|
||||
error_on_graph_break=True,
|
||||
),
|
||||
functools.partial(
|
||||
torch._dynamo.external_utils.wrap_inline_with_error_on_graph_break,
|
||||
error_on_graph_break=False,
|
||||
),
|
||||
)
|
||||
|
||||
for fn in wrap_fns:
|
||||
for i in range(2):
|
||||
opt_fn = torch.compile(
|
||||
fn(lambda x: x + i),
|
||||
backend="eager",
|
||||
)
|
||||
self.assertEqual(inp + i, opt_fn(inp))
|
||||
|
||||
def test_no_recompile_over_unused_objects(self):
|
||||
# This is a regression test case that imitates
|
||||
# https://github.com/city96/ComfyUI-GGUF/blob/47bec6147569a138dd30ad3e14f190a36a3be456/ops.py#L169-L182
|
||||
|
@ -6,6 +6,7 @@ import builtins
|
||||
import collections
|
||||
import contextlib
|
||||
import copy
|
||||
import gc
|
||||
import functools
|
||||
import inspect
|
||||
import io
|
||||
@ -19,6 +20,7 @@ import traceback
|
||||
import types
|
||||
import typing
|
||||
import unittest
|
||||
import weakref
|
||||
import warnings
|
||||
from math import sqrt
|
||||
from torch.multiprocessing import Process
|
||||
@ -1624,6 +1626,25 @@ class TestFX(JitTestCase):
|
||||
|
||||
self.assertTrue(neg not in relu.users)
|
||||
|
||||
@skipIfTorchDynamo("Dynamo does not free right away")
|
||||
def test_prepend_does_not_leak(self):
|
||||
g = Graph()
|
||||
x = g.placeholder("x")
|
||||
relu = g.call_function(torch.relu, (x,))
|
||||
neg = g.call_function(torch.neg, (x,))
|
||||
|
||||
relu.prepend(neg)
|
||||
|
||||
ref = weakref.ref(neg)
|
||||
g.erase_node(neg)
|
||||
del g
|
||||
del x
|
||||
del relu
|
||||
del neg
|
||||
gc.collect()
|
||||
|
||||
self.assertIsNone(ref())
|
||||
|
||||
def test_remove_uses_with_custom_filter(self):
|
||||
g: torch.fx.Graph = Graph()
|
||||
x: torch.fx.Node = g.placeholder("x")
|
||||
|
@ -2758,6 +2758,12 @@ class _NodeBase:
|
||||
return_type: Any,
|
||||
) -> None: ...
|
||||
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...
|
||||
def _prepend(self, n: FxNode) -> None: ...
|
||||
def _remove_from_list(self) -> None: ...
|
||||
def __lt__(self, n: Self) -> _bool: ...
|
||||
def __gt__(self, n: Self) -> _bool: ...
|
||||
def __le__(self, n: Self) -> _bool: ...
|
||||
def __ge__(self, n: Self) -> _bool: ...
|
||||
|
||||
class _NodeIter(Iterator[FxNode]):
|
||||
def __init__(self, root: FxNode, reversed: _bool) -> None: ...
|
||||
|
@ -154,6 +154,7 @@ def reset() -> None:
|
||||
TensorifyState.clear()
|
||||
torch._dynamo.utils.warn_once_cache.clear()
|
||||
torch._dynamo.utils.user_obj_id_to_weakref.clear()
|
||||
torch._dynamo.symbolic_convert._debug_force_graph_break_on_leaf_return_disable_codes.clear()
|
||||
torch._C._autograd._saved_tensors_hooks_set_tracing(False)
|
||||
|
||||
|
||||
|
@ -500,7 +500,8 @@ issue_3_13_0_warning = True
|
||||
# traced FX graph is empty when RETURN_* is traced.
|
||||
allow_empty_graphs = False
|
||||
|
||||
# Used for testing - forces all top-level functions to be nested when traced with Dynamo
|
||||
# Used for testing - forces all top-level functions to be nested when traced with Dynamo.
|
||||
# There are slight differences between this config and wrap_top_frame.
|
||||
debug_force_nested_calls = False
|
||||
|
||||
# Used for testing - forces a graph break when a function
|
||||
|
@ -837,8 +837,12 @@ class _TorchDynamoContext:
|
||||
filename = inspect.getsourcefile(fn)
|
||||
except TypeError:
|
||||
filename = None
|
||||
if config.debug_force_nested_calls:
|
||||
if config.debug_force_nested_calls and filename not in DONT_WRAP_FILES:
|
||||
fn = external_utils.wrap_inline(fn)
|
||||
# Create a new code object for `fn` so that functions have different
|
||||
# recompilation caches.
|
||||
# Copy hack since deepcopy doesn't actually give a new code object
|
||||
fn.__code__ = fn.__code__.replace(co_varnames=fn.__code__.co_varnames) # type: ignore[attr-defined]
|
||||
elif config.wrap_top_frame or (
|
||||
(filename is None or trace_rules.check(fn))
|
||||
and (
|
||||
|
@ -21,7 +21,9 @@ Key functionality groups:
|
||||
"""
|
||||
|
||||
import functools
|
||||
import types
|
||||
import warnings
|
||||
import weakref
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import deprecated, ParamSpec
|
||||
|
||||
@ -58,15 +60,32 @@ else:
|
||||
return torch.compiler.is_compiling()
|
||||
|
||||
|
||||
def deepcopy_code(code: types.CodeType) -> types.CodeType:
|
||||
# copy hack since deepcopy doesn't actually give a new code object
|
||||
return code.replace(co_varnames=code.co_varnames) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
_wrap_inline_cache: weakref.WeakKeyDictionary[Callable[..., Any], types.CodeType] = (
|
||||
weakref.WeakKeyDictionary()
|
||||
)
|
||||
|
||||
|
||||
def wrap_inline(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
"""
|
||||
Create an extra frame around fn that is not in skipfiles.
|
||||
|
||||
This extra frame has its own per-fn code object so that we don't recompile
|
||||
due to multiple calls to wrap_inline with different fn's.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
if fn not in _wrap_inline_cache:
|
||||
_wrap_inline_cache[fn] = deepcopy_code(inner.__code__) # type: ignore[attr-defined]
|
||||
|
||||
inner.__code__ = _wrap_inline_cache[fn] # type: ignore[attr-defined]
|
||||
return inner
|
||||
|
||||
|
||||
@ -231,6 +250,11 @@ def call_accumulate_grad(
|
||||
variable.grad = updated_grad[0]
|
||||
|
||||
|
||||
_wrap_inline_with_error_on_graph_break_cache: weakref.WeakKeyDictionary[
|
||||
Callable[..., Any], dict[bool, types.CodeType]
|
||||
] = weakref.WeakKeyDictionary()
|
||||
|
||||
|
||||
def wrap_inline_with_error_on_graph_break(
|
||||
fn: Callable[_P, _R], error_on_graph_break: bool
|
||||
) -> Callable[_P, _R]:
|
||||
@ -249,6 +273,14 @@ def wrap_inline_with_error_on_graph_break(
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
if fn not in _wrap_inline_with_error_on_graph_break_cache:
|
||||
_wrap_inline_with_error_on_graph_break_cache[fn] = {}
|
||||
cache = _wrap_inline_with_error_on_graph_break_cache[fn]
|
||||
if error_on_graph_break not in cache:
|
||||
cache[error_on_graph_break] = deepcopy_code(wrapper.__code__) # type: ignore[attr-defined]
|
||||
|
||||
wrapper.__code__ = cache[error_on_graph_break] # type: ignore[attr-defined]
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
|
@ -47,7 +47,6 @@ from collections import deque
|
||||
from traceback import StackSummary
|
||||
from typing import Any, Callable, cast, NoReturn, Optional, TYPE_CHECKING, Union
|
||||
from typing_extensions import TypeAlias, TypeIs
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch._logging
|
||||
@ -80,7 +79,6 @@ from .bytecode_transformation import (
|
||||
create_dup_top,
|
||||
create_instruction,
|
||||
create_jump_absolute,
|
||||
create_load_const,
|
||||
create_rot_n,
|
||||
create_swap,
|
||||
get_code_keys,
|
||||
@ -677,18 +675,14 @@ def generic_jump(
|
||||
)
|
||||
self.pop()
|
||||
|
||||
if_next = self.codegen_fix_leaf_stack(
|
||||
all_stack_locals_metadata[0], self.next_instruction
|
||||
) + self.create_call_resume_at(
|
||||
if_next = self.create_call_resume_at(
|
||||
self.next_instruction,
|
||||
all_stack_locals_metadata,
|
||||
)
|
||||
if push:
|
||||
self.push(value)
|
||||
assert inst.target is not None
|
||||
if_jump = self.codegen_fix_leaf_stack(
|
||||
all_stack_locals_metadata[0], inst.target
|
||||
) + self.create_call_resume_at(
|
||||
if_jump = self.create_call_resume_at(
|
||||
inst.target,
|
||||
all_stack_locals_metadata,
|
||||
)
|
||||
@ -1024,10 +1018,7 @@ def break_graph_if_unsupported(
|
||||
for _ in range(push):
|
||||
self.push(UnknownVariable())
|
||||
self.output.add_output_instructions(
|
||||
self.codegen_fix_leaf_stack(
|
||||
all_stack_locals_metadata[0], self.next_instruction
|
||||
)
|
||||
+ self.create_call_resume_at(
|
||||
self.create_call_resume_at(
|
||||
self.next_instruction,
|
||||
all_stack_locals_metadata,
|
||||
)
|
||||
@ -1163,6 +1154,11 @@ class ExceptionStack:
|
||||
__repr__ = __str__
|
||||
|
||||
|
||||
_debug_force_graph_break_on_leaf_return_disable_codes: weakref.WeakSet[
|
||||
types.CodeType
|
||||
] = weakref.WeakSet()
|
||||
|
||||
|
||||
class InstructionTranslatorBase(
|
||||
metaclass=BytecodeDispatchTableMeta,
|
||||
):
|
||||
@ -1526,27 +1522,69 @@ class InstructionTranslatorBase(
|
||||
# frame 1 stack + locals,
|
||||
# ], leaf_resume result
|
||||
|
||||
# pop frame N cells and locals
|
||||
cg.extend_output(
|
||||
[
|
||||
*create_copy(2),
|
||||
cg.create_load_const(0),
|
||||
create_instruction("DELETE_SUBSCR"),
|
||||
*create_copy(3),
|
||||
cg.create_load_const(0),
|
||||
create_instruction("DELETE_SUBSCR"),
|
||||
]
|
||||
)
|
||||
|
||||
# add the leaf_resume result to frame N-1 stack
|
||||
num_stack = all_stack_locals_metadata[1].num_stack
|
||||
cg.extend_output(
|
||||
[
|
||||
create_instruction("BUILD_LIST", arg=1),
|
||||
*create_copy(2),
|
||||
cg.create_load_const(1),
|
||||
cg.create_load_const(0),
|
||||
cg.create_binary_subscr(),
|
||||
*create_binary_slice(num_stack, num_stack, True),
|
||||
]
|
||||
)
|
||||
self.parent.push(UnknownVariable())
|
||||
all_stack_locals_metadata[1].num_stack += 1
|
||||
|
||||
# pop frame N cells and locals
|
||||
# current frame state
|
||||
# cells, frame_values
|
||||
# extract frame N-1 stack to stack
|
||||
cg.extend_output(
|
||||
[
|
||||
*create_copy(1),
|
||||
create_dup_top(),
|
||||
cg.create_load_const(0),
|
||||
create_instruction("DELETE_SUBSCR"),
|
||||
cg.create_binary_subscr(),
|
||||
*create_binary_slice(0, num_stack + 1),
|
||||
]
|
||||
)
|
||||
|
||||
# current frame state
|
||||
# cells, frame_values, frame N-1 stack + leaf_resume result
|
||||
# remove frame N-1 stack from frame_values
|
||||
cg.extend_output(
|
||||
# frame_values[0] = frame_values[0][num_stack + 1:]
|
||||
[
|
||||
*create_copy(2),
|
||||
cg.create_load_const(0),
|
||||
create_instruction("DELETE_SUBSCR"),
|
||||
cg.create_binary_subscr(),
|
||||
create_dup_top(),
|
||||
*create_binary_slice(num_stack + 1, None),
|
||||
*create_swap(2),
|
||||
cg.create_load_const(0),
|
||||
create_instruction("STORE_SUBSCR"),
|
||||
]
|
||||
)
|
||||
|
||||
# current frame state
|
||||
# cells, frame_values, frame N-1 stack + leaf_resume result
|
||||
# unpack the stack (need to unpack twice since UNPACK_SEQUENCE unpacks in reverse order)
|
||||
cg.extend_output(
|
||||
[
|
||||
create_instruction("UNPACK_SEQUENCE", arg=num_stack + 1),
|
||||
create_instruction("BUILD_LIST", arg=num_stack + 1),
|
||||
create_instruction("UNPACK_SEQUENCE", arg=num_stack + 1),
|
||||
]
|
||||
)
|
||||
|
||||
@ -1554,12 +1592,11 @@ class InstructionTranslatorBase(
|
||||
# current frame state
|
||||
# [frame N-1 cells, ..., frame 1 cells],
|
||||
# [
|
||||
# frame N-1 stack (including leaf_resume result) + locals,
|
||||
# frame N-1 locals,
|
||||
# frame N-2 stack + locals,
|
||||
# ...,
|
||||
# frame 1 stack + locals,
|
||||
# ],
|
||||
self.parent.push(UnknownVariable())
|
||||
all_stack_locals_metadata[1].num_stack += 1
|
||||
# ], *(frame N-1 stack), leaf_resume result
|
||||
self.output.add_output_instructions(
|
||||
cg.get_instructions()
|
||||
+ self.parent.create_call_resume_at(
|
||||
@ -2641,10 +2678,7 @@ class InstructionTranslatorBase(
|
||||
self.output.add_output_instructions([copy.copy(inst)])
|
||||
self.popn(2)
|
||||
self.output.add_output_instructions(
|
||||
self.codegen_fix_leaf_stack(
|
||||
all_stack_locals_metadata[0], self.next_instruction
|
||||
)
|
||||
+ self.create_call_resume_at(
|
||||
self.create_call_resume_at(
|
||||
self.next_instruction,
|
||||
all_stack_locals_metadata,
|
||||
)
|
||||
@ -2690,47 +2724,6 @@ class InstructionTranslatorBase(
|
||||
)
|
||||
return insts
|
||||
|
||||
def codegen_fix_leaf_stack(
|
||||
self, meta: StackLocalsMetadata, resume_inst: Instruction
|
||||
) -> list[Instruction]:
|
||||
"""
|
||||
Fixes the stack values of the current/leaf frame (self).
|
||||
|
||||
Expects the TOS to be:
|
||||
[
|
||||
frame N locals,
|
||||
frame N-1 stack + locals,
|
||||
...,
|
||||
frame 1 stack + locals
|
||||
], *(frame N stack (post-unsupported instruction))
|
||||
|
||||
Rearranges the TOS to become:
|
||||
[
|
||||
frame N stack + locals,
|
||||
...,
|
||||
frame 1 stack + locals
|
||||
]
|
||||
|
||||
Args:
|
||||
- meta: metadata for the leaf frame returned from OutputGraph.compile_subgraph
|
||||
- resume_inst: if the resume instruction is a return instruction, then don't return any instructions
|
||||
"""
|
||||
if resume_inst.opname in ("RETURN_VALUE", "RETURN_CONST"):
|
||||
return []
|
||||
# move frame N stack to the frame values list
|
||||
current_num_stack = len(self.stack) - len(meta.stack_null_idxes)
|
||||
meta.num_stack = current_num_stack
|
||||
return [
|
||||
create_instruction("BUILD_LIST", arg=current_num_stack),
|
||||
*create_copy(2),
|
||||
# frame_values, frame N stack, frame_values
|
||||
create_load_const(0),
|
||||
create_instruction("BINARY_SUBSCR"),
|
||||
*create_binary_slice(0, 0, True),
|
||||
# frame_values[0][0:0] = frame N stack
|
||||
# frame_values left on top of stack
|
||||
]
|
||||
|
||||
def create_resume(
|
||||
self,
|
||||
idx: int,
|
||||
@ -2918,6 +2911,8 @@ class InstructionTranslatorBase(
|
||||
new_code, self.f_globals["__name__"], package_name
|
||||
)
|
||||
|
||||
counters["resumes"][new_code.co_name] += 1
|
||||
|
||||
return new_code, resume_name
|
||||
|
||||
def create_call_resume_at(
|
||||
@ -2926,20 +2921,20 @@ class InstructionTranslatorBase(
|
||||
all_stack_locals_metadata: list[StackLocalsMetadata],
|
||||
) -> list[Instruction]:
|
||||
"""
|
||||
Codegen all resume function(s) from the frame stack starting at `self` and call them.
|
||||
Codegen all resume function(s) from the frame stack starting at `self`, call them,
|
||||
and return the result.
|
||||
Assumes that the unsupported instruction has already been run.
|
||||
|
||||
Expects the stack to be in the state:
|
||||
[frame N cells, ..., frame 1 cells],
|
||||
Expects the TOS to be:
|
||||
[
|
||||
frame N stack + locals,
|
||||
frame N locals,
|
||||
frame N-1 stack + locals,
|
||||
...,
|
||||
frame 1 stack + locals
|
||||
]
|
||||
], *(frame N stack (post-unsupported instruction))
|
||||
|
||||
Pops the cells and frame values list from the stack.
|
||||
Also includes a return instruction (stack expected to be empty after return).
|
||||
Leaves the result of calling the resume functions on the stack and returns it
|
||||
(empty stack after return).
|
||||
|
||||
Args:
|
||||
- inst: the instruction of the current (deepest) frame to resume at
|
||||
@ -2949,31 +2944,141 @@ class InstructionTranslatorBase(
|
||||
|
||||
self.instruction_pointer = None
|
||||
|
||||
current_num_stack = len(self.stack) - len(
|
||||
all_stack_locals_metadata[0].stack_null_idxes
|
||||
)
|
||||
all_stack_locals_metadata[0].num_stack = current_num_stack
|
||||
|
||||
if inst.opname in ("RETURN_VALUE", "RETURN_CONST"):
|
||||
return self.codegen_return_with_pops(
|
||||
inst, all_stack_locals_metadata[0].num_stack
|
||||
)
|
||||
|
||||
cg = PyCodegen(self.output.root_tx)
|
||||
|
||||
# NOTE: We do not need to codegen frames whose resume instruction is RETURN_VALUE
|
||||
# We could also do something similar for RETURN_CONST, but a lot more code is necessary
|
||||
# since we would need to track RETURN_CONST values and inject the constant in the right places.
|
||||
|
||||
# Filter out tx'es that are resuming on RETURN_*.
|
||||
txes: list[InstructionTranslatorBase] = []
|
||||
idxes: list[int] = []
|
||||
resume_insts: list[Instruction] = []
|
||||
cur_tx: Optional[InstructionTranslatorBase] = self
|
||||
idx = 0
|
||||
resume_codes: list[types.CodeType] = []
|
||||
resume_names = []
|
||||
while cur_tx is not None:
|
||||
if cur_tx is self:
|
||||
resume_inst = inst
|
||||
else:
|
||||
resume_inst = cur_tx.next_instruction
|
||||
if (
|
||||
not (
|
||||
config.debug_force_graph_break_on_leaf_return
|
||||
and self.current_instruction.opname == "NOP"
|
||||
and self.current_instruction.argval == "GRAPH_BREAK_IF_LEAF"
|
||||
and cur_tx is self
|
||||
)
|
||||
and resume_inst.opname != "RETURN_VALUE"
|
||||
):
|
||||
txes.append(cur_tx)
|
||||
idxes.append(idx)
|
||||
resume_insts.append(resume_inst)
|
||||
|
||||
cur_tx = cur_tx.parent
|
||||
idx += 1
|
||||
|
||||
current_num_stack = len(self.stack) - len(
|
||||
all_stack_locals_metadata[0].stack_null_idxes
|
||||
)
|
||||
|
||||
# Every tx is returning - no need to call a resume function.
|
||||
if not txes:
|
||||
# Pop everything but TOS, then return the TOS.
|
||||
# Frame N's stack must have length >= 1 since it's about to RETURN_VALUE.
|
||||
# Frame N actually should have stack length == 1, because debug CPython expects
|
||||
# empty stacks after return, but there is no guarantee written down anywhere.
|
||||
assert current_num_stack >= 1
|
||||
cg.extend_output(create_swap(current_num_stack + 2))
|
||||
for _ in range(current_num_stack + 1):
|
||||
cg.append_output(create_instruction("POP_TOP"))
|
||||
cg.append_output(create_instruction("RETURN_VALUE"))
|
||||
|
||||
return cg.get_instructions()
|
||||
|
||||
# Let frame k be the deepest frame where the resume function is not RETURN_VALUE
|
||||
# - If k == N, then the frame N stack is prepended to the frame N locals.
|
||||
# - If k != N, then frame N's TOS is added to frame k's stack.
|
||||
|
||||
# Rearrange the TOS to be compatible with create_resume and codegen_call_resume:
|
||||
# [
|
||||
# frame N stack + locals,
|
||||
# ...,
|
||||
# frame 1 stack + locals
|
||||
# ]
|
||||
|
||||
# create the stack values that should be moved
|
||||
if txes[0] is self:
|
||||
# Frame N is non-returning, pack all of frame N's stack to
|
||||
# be moved to frame N's frame values
|
||||
cg.append_output(create_instruction("BUILD_LIST", arg=current_num_stack))
|
||||
# frame N stack is not yet on the frame N's frame values
|
||||
stack_insert_idx = 0
|
||||
all_stack_locals_metadata[0].num_stack = current_num_stack
|
||||
else:
|
||||
# Frame N is returning. Let frame k be the deepest non-returning frame.
|
||||
# Add frame N's TOS to frame k's stack.
|
||||
# pop frame N stack except TOS
|
||||
cg.extend_output(create_swap(current_num_stack))
|
||||
for _ in range(current_num_stack - 1):
|
||||
cg.append_output(create_instruction("POP_TOP"))
|
||||
cg.append_output(create_instruction("BUILD_LIST", arg=1))
|
||||
# frame k stack is already on frame k's frame values
|
||||
stack_insert_idx = all_stack_locals_metadata[idxes[0]].num_stack
|
||||
all_stack_locals_metadata[idxes[0]].num_stack += 1
|
||||
txes[0].push(UnknownVariable())
|
||||
|
||||
# move the predetermined stack value(s) to the deepest non-returning frame
|
||||
cg.extend_output(
|
||||
[
|
||||
*create_copy(2),
|
||||
# frame_values, return_const, frame_values
|
||||
cg.create_load_const(idxes[0]),
|
||||
cg.create_binary_subscr(),
|
||||
*create_binary_slice(stack_insert_idx, stack_insert_idx, True),
|
||||
# frame_values[idxes[0]][stack_insert_idx:stack_insert_idx] = frame N stack/[return_const/TOS]
|
||||
# frame_values left on top of stack
|
||||
]
|
||||
)
|
||||
|
||||
# filter out frame values of skipped tx'es
|
||||
filter_insts = []
|
||||
for idx in idxes:
|
||||
filter_insts.extend(
|
||||
[
|
||||
create_dup_top(),
|
||||
cg.create_load_const(idx),
|
||||
cg.create_binary_subscr(),
|
||||
*create_swap(2),
|
||||
]
|
||||
)
|
||||
# TOS: cells, frame_values[idxes[0]], ..., frame_values[idxes[...]], frame_values
|
||||
filter_insts.extend(
|
||||
[
|
||||
create_instruction("POP_TOP"),
|
||||
create_instruction("BUILD_LIST", arg=len(idxes)),
|
||||
]
|
||||
)
|
||||
# TOS: cells, filtered frame_values
|
||||
|
||||
cg.extend_output(filter_insts)
|
||||
# filter out cells of skipped tx'es using the same instructions in filter_insts,
|
||||
# but with cells as TOS instead of frame values
|
||||
cg.extend_output(
|
||||
[
|
||||
*create_swap(2),
|
||||
*copy.deepcopy(filter_insts),
|
||||
*create_swap(2),
|
||||
]
|
||||
)
|
||||
# TOS: filtered cells, filtered frame_values
|
||||
|
||||
resume_codes: list[types.CodeType] = []
|
||||
resume_names = []
|
||||
for i, cur_tx in enumerate(txes):
|
||||
resume_code, resume_name = cur_tx.create_resume(
|
||||
idx,
|
||||
resume_inst,
|
||||
all_stack_locals_metadata[idx],
|
||||
i,
|
||||
resume_insts[i],
|
||||
all_stack_locals_metadata[idxes[i]],
|
||||
resume_codes,
|
||||
cg,
|
||||
cur_tx is self,
|
||||
@ -2982,11 +3087,17 @@ class InstructionTranslatorBase(
|
||||
resume_codes.append(resume_code)
|
||||
resume_names.append(resume_name)
|
||||
|
||||
cur_tx = cur_tx.parent
|
||||
idx += 1
|
||||
if (
|
||||
config.debug_force_graph_break_on_leaf_return
|
||||
and self.current_instruction.opname == "NOP"
|
||||
and self.current_instruction.argval == "GRAPH_BREAK_IF_LEAF"
|
||||
):
|
||||
_debug_force_graph_break_on_leaf_return_disable_codes.add(resume_codes[0])
|
||||
|
||||
self.codegen_call_resume(resume_codes, resume_names, cg)
|
||||
return cg.get_instructions() + [create_instruction("RETURN_VALUE")]
|
||||
cg.append_output(create_instruction("RETURN_VALUE"))
|
||||
|
||||
return cg.get_instructions()
|
||||
|
||||
@staticmethod
|
||||
def codegen_call_resume(
|
||||
@ -3119,9 +3230,14 @@ class InstructionTranslatorBase(
|
||||
return (
|
||||
all(b.can_restore() for b in self.block_stack)
|
||||
and not self.one_graph
|
||||
and not self.error_on_graph_break
|
||||
and (
|
||||
not self.error_on_graph_break
|
||||
or config.debug_force_graph_break_on_leaf_return
|
||||
)
|
||||
and not self.is_tracing_resume_prologue
|
||||
and not self.active_generic_context_managers
|
||||
# Do not allow nested graph breaks in HOPs
|
||||
and self.output.current_tracer.parent is None
|
||||
)
|
||||
|
||||
@break_graph_if_unsupported(push=0)
|
||||
@ -3361,7 +3477,10 @@ class InstructionTranslatorBase(
|
||||
|
||||
def NOP(self, inst: Instruction) -> None:
|
||||
# Dynamo-specific testing behavior
|
||||
if inst.argval == "GRAPH_BREAK_IF_LEAF":
|
||||
if (
|
||||
self.f_code not in _debug_force_graph_break_on_leaf_return_disable_codes
|
||||
and inst.argval == "GRAPH_BREAK_IF_LEAF"
|
||||
):
|
||||
self.graph_break_on_leaf_function(inst)
|
||||
|
||||
def POP_TOP(self, inst: Instruction) -> None:
|
||||
@ -4480,9 +4599,8 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
|
||||
@classmethod
|
||||
def inline_call(cls, parent: Any, func: Any, args: Any, kwargs: Any) -> Any:
|
||||
with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
|
||||
tracer = cls.build_inline_tracer(parent, func, args, kwargs)
|
||||
return tracer.inline_call_()
|
||||
tracer = cls.build_inline_tracer(parent, func, args, kwargs)
|
||||
return tracer.inline_call_()
|
||||
|
||||
@staticmethod
|
||||
def check_inlineable(func: Any) -> trace_rules.SkipResult:
|
||||
@ -4941,6 +5059,10 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
|
||||
self.generator_exhausted = False
|
||||
self.is_generator_from_ctx_manager = False
|
||||
|
||||
def should_compile_partial_graph(self) -> bool:
|
||||
# resuming on graph break on inlined generator not supported
|
||||
return False
|
||||
|
||||
def YIELD_VALUE(self, inst: Instruction) -> None:
|
||||
top = self.pop()
|
||||
self.generated_items.append(top)
|
||||
|
@ -102,16 +102,36 @@ class TestCase(TorchTestCase):
|
||||
torch.set_grad_enabled(self._prior_is_grad_enabled)
|
||||
|
||||
def assertEqual(self, x: Any, y: Any, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
|
||||
if (
|
||||
config.debug_disable_compile_counter
|
||||
and isinstance(x, utils.CompileCounterInt)
|
||||
or isinstance(y, utils.CompileCounterInt)
|
||||
):
|
||||
return
|
||||
if config.debug_disable_compile_counter:
|
||||
if isinstance(x, utils.CompileCounterInt) or isinstance(
|
||||
y, utils.CompileCounterInt
|
||||
):
|
||||
return
|
||||
# skip checks like self.assertEqual(len(counters["graph_break"]), 1)
|
||||
if (
|
||||
(cur_frame := inspect.currentframe())
|
||||
and (upper_frame := cur_frame.f_back)
|
||||
and (upper_code := inspect.getframeinfo(upper_frame).code_context)
|
||||
and "counters" in upper_code[0]
|
||||
):
|
||||
return
|
||||
return super().assertEqual(x, y, *args, **kwargs)
|
||||
|
||||
# assertExpectedInline might also need to be disabled for wrapped nested
|
||||
# graph break tests
|
||||
def assertExpectedInline(self, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
|
||||
if config.debug_disable_compile_counter:
|
||||
return
|
||||
return super().assertExpectedInline(*args, **kwargs)
|
||||
|
||||
|
||||
class TestCaseWithNestedGraphBreaks(TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.prev_nested_graph_breaks = torch._dynamo.config.nested_graph_breaks
|
||||
torch._dynamo.config.nested_graph_breaks = True
|
||||
|
||||
def tearDown(self) -> None:
|
||||
super().tearDown()
|
||||
torch._dynamo.config.nested_graph_breaks = self.prev_nested_graph_breaks
|
||||
|
||||
|
||||
class CPythonTestCase(TestCase):
|
||||
|
@ -33,7 +33,7 @@ from torch import fx
|
||||
from torch._dynamo.backends.debugging import aot_eager
|
||||
from torch._dynamo.output_graph import OutputGraph
|
||||
|
||||
from . import config, eval_frame, optimize_assert, reset
|
||||
from . import config, eval_frame, optimize, reset
|
||||
from .bytecode_transformation import (
|
||||
create_instruction,
|
||||
debug_checks,
|
||||
@ -379,7 +379,7 @@ def standard_test(
|
||||
correct1 = fn(*args1)
|
||||
correct2 = fn(*args2)
|
||||
reset()
|
||||
opt_fn = optimize_assert(actual)(fn)
|
||||
opt_fn = optimize(actual, error_on_graph_break=True)(fn)
|
||||
val1a = opt_fn(*args1)
|
||||
val2a = opt_fn(*args2)
|
||||
val1b = opt_fn(*args1)
|
||||
|
@ -35,7 +35,6 @@ from collections.abc import Sequence
|
||||
from types import FunctionType
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar
|
||||
from typing_extensions import Never
|
||||
from unittest.mock import patch
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
import torch
|
||||
@ -69,7 +68,6 @@ from ..utils import (
|
||||
check_constant_args,
|
||||
check_unspec_or_constant_args,
|
||||
cmp_name_to_op_mapping,
|
||||
counters,
|
||||
identity,
|
||||
is_function,
|
||||
is_wrapper_or_member_descriptor,
|
||||
@ -712,8 +710,7 @@ class LocalGeneratorObjectVariable(VariableTracker):
|
||||
# Hierarchically, tx can be seen as the parent of the inline tracer
|
||||
# created on call_function. Any exception needs to be propagated to tx
|
||||
# for Dynamo to behave correctly
|
||||
with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
|
||||
return tracer.inline_call_()
|
||||
return tracer.inline_call_()
|
||||
except ObservedException as e:
|
||||
tracer.generator_exhausted = True
|
||||
raise e
|
||||
@ -723,8 +720,6 @@ class LocalGeneratorObjectVariable(VariableTracker):
|
||||
except Unsupported as e:
|
||||
torch._dynamo.eval_frame.skip_code(self.get_code())
|
||||
raise SkipFrame from e
|
||||
finally:
|
||||
counters["unimplemented"] |= counters["inline_call"]
|
||||
|
||||
def call_obj_hasattr(self, tx, name):
|
||||
if name in self.python_type().__dict__:
|
||||
|
@ -79,6 +79,23 @@ class TORCH_API Backend : public torch::CustomClassHolder {
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual bool supportsShrinking() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Shrink the backend by excluding specified ranks. Backends that support
|
||||
// communicator shrinking should override this and return a new backend
|
||||
// instance representing the shrunken group. Backends may use opts_override
|
||||
// to supply backend-specific options for the new group.
|
||||
virtual c10::intrusive_ptr<Backend> shrink(
|
||||
const std::vector<int64_t>& /*ranks_to_exclude*/,
|
||||
int /*shrink_flags*/ = 0,
|
||||
const c10::intrusive_ptr<Options>& /*opts_override*/ = nullptr) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
c10::str("Backend ", getBackendName(), " does not support shrink"));
|
||||
}
|
||||
|
||||
virtual void setTimeout(std::chrono::milliseconds timeout) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
|
@ -259,6 +259,65 @@ std::shared_ptr<NCCLComm> NCCLComm::split(
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef NCCL_HAS_COMM_SHRINK
|
||||
std::shared_ptr<NCCLComm> NCCLComm::shrink(
|
||||
NCCLComm* source,
|
||||
std::vector<int>& ranks_to_exclude,
|
||||
ncclConfig_t* config,
|
||||
int shrinkFlags) {
|
||||
// Preconditions are validated in ProcessGroupNCCL::shrink
|
||||
|
||||
LOG(INFO) << "Rank " << source->rank_ << ": shrinking comm " << source->repr()
|
||||
<< " excluding " << ranks_to_exclude.size() << " ranks";
|
||||
|
||||
at::cuda::OptionalCUDAGuard gpuGuard(source->deviceIndex_);
|
||||
auto comm = std::make_shared<NCCLComm>();
|
||||
|
||||
// This call will block until the source communicator is initialized
|
||||
auto sourceComm = source->getNcclComm();
|
||||
|
||||
C10D_NCCL_CHECK_NONBLOCKING(
|
||||
ncclCommShrink(
|
||||
sourceComm,
|
||||
ranks_to_exclude.data(),
|
||||
ranks_to_exclude.size(),
|
||||
reinterpret_cast<ncclComm_t*>(&(comm->ncclComm_)),
|
||||
config,
|
||||
shrinkFlags),
|
||||
source->getNcclCommFailureReason());
|
||||
|
||||
// Wait for the child communicator to be ready
|
||||
source->waitReady(true);
|
||||
comm->initialized_ = true;
|
||||
|
||||
// NCCL automatically assigns rank during shrink - query it efficiently
|
||||
int assigned_rank;
|
||||
try {
|
||||
C10D_NCCL_CHECK(
|
||||
ncclCommUserRank(comm->ncclComm_, &assigned_rank), std::nullopt);
|
||||
comm->rank_ = assigned_rank;
|
||||
} catch (const std::exception& e) {
|
||||
// Fallback: if ncclCommUserRank fails, we can't determine the rank
|
||||
LOG(ERROR) << "Failed to query NCCL-assigned rank: " << e.what();
|
||||
throw;
|
||||
}
|
||||
|
||||
// Child comm should be on the same device as parent comm
|
||||
comm->deviceIndex_ = source->deviceIndex_;
|
||||
if (config != nullptr) {
|
||||
comm->nonBlocking_ = config->blocking == 0;
|
||||
} else {
|
||||
// Inherit parent behavior if no config provided
|
||||
comm->nonBlocking_ = source->nonBlocking_;
|
||||
}
|
||||
|
||||
LOG(INFO) << "Rank " << source->rank_ << ": created shrunken comm "
|
||||
<< comm->repr() << " with NCCL-assigned rank " << assigned_rank;
|
||||
|
||||
return comm;
|
||||
}
|
||||
#endif
|
||||
|
||||
void NCCLComm::finalize() {
|
||||
LockType lock(mutex_);
|
||||
if (aborted_) {
|
||||
|
@ -90,6 +90,10 @@ static_assert(
|
||||
#define NCCL_HAS_NVLS_CTAS
|
||||
#endif
|
||||
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 27, 0)
|
||||
#define NCCL_HAS_COMM_SHRINK
|
||||
#endif
|
||||
|
||||
// Macro to throw on a non-successful NCCL return value.
|
||||
#define C10D_NCCL_CHECK(cmd, failureReason) \
|
||||
do { \
|
||||
@ -294,6 +298,14 @@ class NCCLComm {
|
||||
ncclConfig_t& config);
|
||||
#endif // NCCL_HAS_COMM_SPLIT
|
||||
|
||||
#ifdef NCCL_HAS_COMM_SHRINK
|
||||
static std::shared_ptr<NCCLComm> shrink(
|
||||
NCCLComm* source,
|
||||
std::vector<int>& ranks_to_exclude,
|
||||
ncclConfig_t* config,
|
||||
int shrinkFlags = 0);
|
||||
#endif // NCCL_HAS_COMM_SHRINK
|
||||
|
||||
#if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP)
|
||||
std::unordered_map<std::string, std::string> ncclCommDump();
|
||||
#endif
|
||||
|
@ -165,7 +165,7 @@ ncclRedOpRAII getNcclReduceOp(
|
||||
}
|
||||
|
||||
// Get a key string from device
|
||||
inline std::string getKeyFromDevice(at::Device& device) {
|
||||
inline std::string getKeyFromDevice(const at::Device& device) {
|
||||
return std::to_string(device.index());
|
||||
}
|
||||
|
||||
@ -5838,6 +5838,139 @@ at::Tensor ProcessGroupNCCL::allocateTensor(
|
||||
return tensor;
|
||||
}
|
||||
|
||||
#ifdef NCCL_HAS_COMM_SHRINK
|
||||
c10::intrusive_ptr<Backend> ProcessGroupNCCL::shrink(
|
||||
const std::vector<int64_t>& ranks_to_exclude,
|
||||
int shrink_flags,
|
||||
const c10::intrusive_ptr<Backend::Options>& opts_override) {
|
||||
// Runtime version check with better error message
|
||||
auto runtime_version = torch::cuda::nccl::version();
|
||||
TORCH_CHECK(
|
||||
runtime_version >= NCCL_VERSION(2, 27, 0),
|
||||
"ProcessGroupNCCL::shrink requires NCCL version 2.27.0 or later. "
|
||||
"Found version: ",
|
||||
runtime_version);
|
||||
|
||||
// Early validation with detailed error messages
|
||||
TORCH_CHECK_VALUE(
|
||||
!ranks_to_exclude.empty(), "ranks_to_exclude cannot be empty");
|
||||
TORCH_CHECK_VALUE(
|
||||
static_cast<int>(ranks_to_exclude.size()) < size_,
|
||||
"Cannot exclude all ranks (",
|
||||
ranks_to_exclude.size(),
|
||||
" >= ",
|
||||
size_,
|
||||
")");
|
||||
|
||||
// Validate ranks and convert to int efficiently
|
||||
std::vector<int> int_ranks_to_exclude;
|
||||
int_ranks_to_exclude.reserve(ranks_to_exclude.size());
|
||||
for (int64_t rank : ranks_to_exclude) {
|
||||
TORCH_CHECK_VALUE(
|
||||
rank >= 0 && rank < size_,
|
||||
"Invalid rank ",
|
||||
rank,
|
||||
" for group size ",
|
||||
size_);
|
||||
int_ranks_to_exclude.push_back(static_cast<int>(rank));
|
||||
}
|
||||
|
||||
// Get primary communicator with better error context
|
||||
auto primary_device_index = guessDeviceId();
|
||||
auto primary_device = at::Device(at::kCUDA, primary_device_index);
|
||||
const auto primary_key = getKeyFromDevice(primary_device);
|
||||
|
||||
std::shared_ptr<NCCLComm> primary_comm = getNCCLComm(primary_key);
|
||||
TORCH_CHECK(
|
||||
primary_comm,
|
||||
"Primary NCCL communicator for device ",
|
||||
primary_device,
|
||||
" (key: ",
|
||||
primary_key,
|
||||
") is not initialized");
|
||||
|
||||
// Cache device index before shrink operation
|
||||
at::DeviceIndex parent_device_index = primary_comm->getDeviceIndex();
|
||||
|
||||
ncclConfig_t* config = nullptr;
|
||||
// Default to inheriting from parent options
|
||||
bool high_priority_stream = options_->is_high_priority_stream;
|
||||
if (opts_override) {
|
||||
auto nccl_opts =
|
||||
c10::static_intrusive_pointer_cast<ProcessGroupNCCL::Options>(
|
||||
opts_override);
|
||||
config = &nccl_opts->config;
|
||||
// If user provided override options, honor is_high_priority_stream as well
|
||||
high_priority_stream = nccl_opts->is_high_priority_stream;
|
||||
}
|
||||
|
||||
std::shared_ptr<NCCLComm> shrunk_comm = NCCLComm::shrink(
|
||||
primary_comm.get(),
|
||||
int_ranks_to_exclude,
|
||||
(config != nullptr ? config : &options_->config),
|
||||
shrink_flags);
|
||||
|
||||
// Calculate new size and get NCCL-assigned rank
|
||||
int new_size = size_ - static_cast<int>(ranks_to_exclude.size());
|
||||
int new_rank = shrunk_comm->rank_;
|
||||
|
||||
// Create new ProcessGroupNCCL with optimized options cloning
|
||||
auto new_store = store_->clone();
|
||||
auto new_opts = ProcessGroupNCCL::Options::create(high_priority_stream);
|
||||
new_opts->timeout = options_->timeout;
|
||||
if (config != nullptr) {
|
||||
new_opts->config = *config;
|
||||
} else {
|
||||
new_opts->config = options_->config;
|
||||
}
|
||||
|
||||
auto new_pg = c10::make_intrusive<ProcessGroupNCCL>(
|
||||
new_store, new_rank, new_size, new_opts);
|
||||
|
||||
// Set up the new process group with optimized device setup
|
||||
new_pg->initializeDeviceStateForComm(
|
||||
at::Device(at::kCUDA, parent_device_index), shrunk_comm);
|
||||
|
||||
return c10::static_intrusive_pointer_cast<Backend>(new_pg);
|
||||
}
|
||||
|
||||
#else // !NCCL_HAS_COMM_SHRINK
|
||||
// Backend interface override: raise consistent error when shrink is
|
||||
// unsupported.
|
||||
c10::intrusive_ptr<Backend> ProcessGroupNCCL::shrink(
|
||||
const std::vector<int64_t>& /*ranks_to_exclude*/,
|
||||
int /*shrink_flags*/,
|
||||
const c10::intrusive_ptr<Backend::Options>& /*opts_override*/) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"ProcessGroupNCCL::shrink requires NCCL version 2.27.0 or later, "
|
||||
"but PyTorch was built with an older version or without NCCL shrink support.");
|
||||
}
|
||||
|
||||
#endif // NCCL_HAS_COMM_SHRINK
|
||||
|
||||
void ProcessGroupNCCL::initializeDeviceStateForComm(
|
||||
const at::Device& device,
|
||||
std::shared_ptr<NCCLComm> comm) {
|
||||
const auto key = getKeyFromDevice(device);
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
at::cuda::OptionalCUDAGuard gpuGuard(device);
|
||||
|
||||
bool force_high = getCvarBool(TORCH_NCCL_HIGH_PRIORITY, false);
|
||||
auto stream = at::cuda::getStreamFromPool(
|
||||
options_->is_high_priority_stream || force_high);
|
||||
|
||||
devNCCLCommMap_[key] = comm;
|
||||
ncclStreams_.emplace(key, stream);
|
||||
ncclEvents_.emplace(key, at::cuda::CUDAEvent(cudaEventDisableTiming));
|
||||
usedDeviceIdxs_.insert(device.index());
|
||||
|
||||
if (shouldAllCommunicatorsRegisterAllTensors()) {
|
||||
std::lock_guard<std::mutex> map_lock(ncclCommMemPoolMapMutex);
|
||||
ncclCommMemPoolMap.emplace(std::move(comm), MemPoolSet{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace c10d
|
||||
|
||||
#endif // USE_C10D_NCCL
|
||||
|
@ -997,6 +997,21 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
ErrorType getError() override;
|
||||
|
||||
bool supportsShrinking() const override {
|
||||
#ifdef NCCL_HAS_COMM_SHRINK
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Backend-style shrink override that returns a Backend instance.
|
||||
c10::intrusive_ptr<Backend> shrink(
|
||||
const std::vector<int64_t>& ranks_to_exclude,
|
||||
int shrink_flags = 0,
|
||||
const c10::intrusive_ptr<Backend::Options>& opts_override =
|
||||
nullptr) override;
|
||||
|
||||
std::shared_ptr<c10::Allocator> getMemAllocator() override;
|
||||
|
||||
// Allocate tensor from communication-optimized memory pool
|
||||
@ -1065,6 +1080,12 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
int p2pRank = 0,
|
||||
bool isSendRecvSelf = false);
|
||||
|
||||
// Initialize device-specific state (comm, stream, event, bookkeeping) for a
|
||||
// given communicator on this process group instance.
|
||||
void initializeDeviceStateForComm(
|
||||
const at::Device& device,
|
||||
std::shared_ptr<NCCLComm> comm);
|
||||
|
||||
// Wrapper method which can be overridden for tests.
|
||||
virtual std::exception_ptr checkForNCCLErrors(
|
||||
std::shared_ptr<NCCLComm>& ncclComm);
|
||||
|
@ -2734,12 +2734,23 @@ Arguments:
|
||||
"supports_time_estimate",
|
||||
&::c10d::Backend::supportsTimeEstimation,
|
||||
"(test whether the backend supports collective time estimation)")
|
||||
.def_property_readonly(
|
||||
"supports_shrinking",
|
||||
&::c10d::Backend::supportsShrinking,
|
||||
"(test whether the backend supports communicator shrinking)")
|
||||
.def(
|
||||
"set_timeout",
|
||||
&::c10d::Backend::setTimeout,
|
||||
py::arg("timeout"),
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
R"(Sets the default timeout for all future operations.)")
|
||||
.def(
|
||||
"shrink",
|
||||
&::c10d::Backend::shrink,
|
||||
py::arg("ranks_to_exclude"),
|
||||
py::arg("shrink_flags") = 0,
|
||||
py::arg("opts_override") = nullptr,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"broadcast",
|
||||
&::c10d::Backend::broadcast,
|
||||
|
@ -1,11 +1,15 @@
|
||||
#include <torch/csrc/fx/node.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
#include <structmember.h>
|
||||
#include <torch/csrc/utils/object_ptr.h>
|
||||
#include <torch/csrc/utils/pythoncapi_compat.h>
|
||||
#include <algorithm>
|
||||
|
||||
namespace {
|
||||
|
||||
using NodeSortKey = c10::SmallVector<int64_t, 4>;
|
||||
struct NodeBase;
|
||||
|
||||
// Thrown to exit out of a C++ function and return an error to Python.
|
||||
@ -163,7 +167,41 @@ struct NodeBase {
|
||||
PyObject* users;
|
||||
PyObject* _repr_fn;
|
||||
PyObject* meta;
|
||||
PyObject* _sort_key;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
alignas(NodeSortKey) char sort_key_buf[sizeof(NodeSortKey)];
|
||||
|
||||
inline NodeSortKey& sort_key() {
|
||||
return *reinterpret_cast<NodeSortKey*>(sort_key_buf);
|
||||
}
|
||||
|
||||
inline void set_prev(NodeBase* value) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(value);
|
||||
Py_INCREF(reinterpret_cast<PyObject*>(value));
|
||||
NodeBase* old = _prev;
|
||||
_prev = value;
|
||||
Py_DECREF(reinterpret_cast<PyObject*>(old));
|
||||
}
|
||||
|
||||
inline void set_next(NodeBase* value) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(value);
|
||||
Py_INCREF(reinterpret_cast<PyObject*>(value));
|
||||
NodeBase* old = _next;
|
||||
_next = value;
|
||||
Py_DECREF(reinterpret_cast<PyObject*>(old));
|
||||
}
|
||||
|
||||
// Equivalent to:
|
||||
// p, n = self._prev, self._next
|
||||
// p._next, n._prev = n, p
|
||||
inline void remove_from_list() {
|
||||
if (this->_prev == this && this->_next == this) {
|
||||
return;
|
||||
}
|
||||
NodeBase* p = this->_prev;
|
||||
NodeBase* n = this->_next;
|
||||
p->set_next(n);
|
||||
n->set_prev(p);
|
||||
}
|
||||
};
|
||||
|
||||
static PyObject* NodeBase_new(
|
||||
@ -173,6 +211,8 @@ static PyObject* NodeBase_new(
|
||||
PyObject* self = type->tp_alloc(type, 0);
|
||||
if (!self)
|
||||
return nullptr;
|
||||
new (reinterpret_cast<NodeBase*>(self)->sort_key_buf)
|
||||
NodeSortKey(); // placement new does not allocate
|
||||
return self;
|
||||
}
|
||||
|
||||
@ -201,7 +241,6 @@ static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
|
||||
self->users = PyDict_New();
|
||||
self->_repr_fn = Py_NewRef(Py_None);
|
||||
self->meta = PyDict_New();
|
||||
self->_sort_key = PyTuple_New(0);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -221,7 +260,6 @@ static struct PyMemberDef NodeBase_members[] = {
|
||||
{"users", T_OBJECT_EX, offsetof(NodeBase, users), 0, nullptr},
|
||||
{"_repr_fn", T_OBJECT_EX, offsetof(NodeBase, _repr_fn), 0, nullptr},
|
||||
{"meta", T_OBJECT_EX, offsetof(NodeBase, meta), 0, nullptr},
|
||||
{"_sort_key", T_OBJECT_EX, offsetof(NodeBase, _sort_key), 0, nullptr},
|
||||
{nullptr} /* Sentinel */
|
||||
};
|
||||
|
||||
@ -239,7 +277,6 @@ static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
|
||||
Py_VISIT(self->users);
|
||||
Py_VISIT(self->_repr_fn);
|
||||
Py_VISIT(self->meta);
|
||||
Py_VISIT(self->_sort_key);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -257,12 +294,12 @@ static int NodeBase_clear(NodeBase* self) {
|
||||
Py_CLEAR(self->users);
|
||||
Py_CLEAR(self->_repr_fn);
|
||||
Py_CLEAR(self->meta);
|
||||
Py_CLEAR(self->_sort_key);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void NodeBase_dealloc(PyObject* self) {
|
||||
PyObject_GC_UnTrack(self);
|
||||
reinterpret_cast<NodeBase*>(self)->sort_key().~NodeSortKey();
|
||||
(void)NodeBase_clear((NodeBase*)self);
|
||||
Py_TYPE(self)->tp_free(self);
|
||||
}
|
||||
@ -321,15 +358,195 @@ static PyObject* NodeBase__update_args_kwargs(
|
||||
}
|
||||
}
|
||||
|
||||
static PyObject* NodeBase__remove_from_list(
|
||||
PyObject* self,
|
||||
PyObject* _ignored) {
|
||||
reinterpret_cast<NodeBase*>(self)->remove_from_list();
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* NodeBase__prepend(PyObject* self_, PyObject* arg) {
|
||||
if (self_ == arg) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
if (!is_node(arg)) {
|
||||
PyErr_SetString(PyExc_TypeError, "_prepend() argument must be a Node");
|
||||
return nullptr;
|
||||
}
|
||||
NodeBase* self = reinterpret_cast<NodeBase*>(self_);
|
||||
NodeBase* x = reinterpret_cast<NodeBase*>(arg);
|
||||
if (self->graph != x->graph) {
|
||||
PyErr_SetString(
|
||||
PyExc_AssertionError,
|
||||
"Attempting to move a Node into a different Graph");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
x->remove_from_list();
|
||||
NodeBase* p = self->_prev;
|
||||
p->set_next(x);
|
||||
x->set_prev(p);
|
||||
x->set_next(self);
|
||||
self->set_prev(x);
|
||||
|
||||
// Now compute x.sort_key()
|
||||
const NodeSortKey& psk = x->_prev->sort_key();
|
||||
const NodeSortKey& nsk = x->_next->sort_key();
|
||||
if (psk.size() > nsk.size()) {
|
||||
// prefix = psk[: len(nsk)+1]
|
||||
size_t slice_len = nsk.size() + 1;
|
||||
NodeSortKey prefix(psk.begin(), psk.begin() + slice_len);
|
||||
// last element is idx => increment by 1
|
||||
prefix.back()++;
|
||||
x->sort_key() = std::move(prefix);
|
||||
} else if (psk.size() < nsk.size()) {
|
||||
// prefix = nsk[: len(psk)+1]
|
||||
size_t slice_len = psk.size() + 1;
|
||||
NodeSortKey prefix(nsk.begin(), nsk.begin() + slice_len);
|
||||
// last element is idx => decrement by 1
|
||||
prefix.back()--;
|
||||
x->sort_key() = std::move(prefix);
|
||||
} else {
|
||||
// same length => add a 0
|
||||
x->sort_key() = psk;
|
||||
x->sort_key().emplace_back(0);
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
// __lt__(self, other): Return self.sort_key < other.sort_key
|
||||
static PyObject* NodeBase___lt__(PyObject* self, PyObject* other) {
|
||||
// METH_O => one argument: 'other'
|
||||
if (!is_node(other)) {
|
||||
Py_RETURN_NOTIMPLEMENTED;
|
||||
}
|
||||
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
|
||||
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
|
||||
bool less = std::lexicographical_compare(
|
||||
lhs.begin(), lhs.end(), rhs.begin(), rhs.end());
|
||||
if (less)
|
||||
Py_RETURN_TRUE;
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
|
||||
// __gt__(self, other): Return self.sort_key() > other.sort_key
|
||||
static PyObject* NodeBase___gt__(PyObject* self, PyObject* other) {
|
||||
if (!is_node(other)) {
|
||||
Py_RETURN_NOTIMPLEMENTED;
|
||||
}
|
||||
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
|
||||
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
|
||||
// "a > b" is equivalent to "b < a"
|
||||
bool greater = std::lexicographical_compare(
|
||||
rhs.begin(), rhs.end(), lhs.begin(), lhs.end());
|
||||
if (greater)
|
||||
Py_RETURN_TRUE;
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
|
||||
static PyObject* NodeBase___ge__(PyObject* self, PyObject* other) {
|
||||
if (self == other) {
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
return NodeBase___gt__(self, other);
|
||||
}
|
||||
|
||||
// __le__(self, other): Return not (self > other)
|
||||
static PyObject* NodeBase___le__(PyObject* self, PyObject* other) {
|
||||
if (self == other) {
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
return NodeBase___lt__(self, other);
|
||||
}
|
||||
|
||||
// Convert the NodeBase::sort_key vector<long> into a Python tuple of ints
|
||||
// Only used by pickle/__getstate__
|
||||
static PyObject* NodeBase_get_sort_key(PyObject* self, void* /*closure*/) {
|
||||
NodeBase* node = reinterpret_cast<NodeBase*>(self);
|
||||
const NodeSortKey& vec = node->sort_key();
|
||||
Py_ssize_t n = static_cast<Py_ssize_t>(vec.size());
|
||||
THPObjectPtr tuple(PyTuple_New(n));
|
||||
if (!tuple) {
|
||||
return nullptr; // Out of memory
|
||||
}
|
||||
for (Py_ssize_t i = 0; i < n; i++) {
|
||||
PyObject* value = PyLong_FromSsize_t(vec[i]);
|
||||
if (!value) {
|
||||
return nullptr;
|
||||
}
|
||||
PyTuple_SET_ITEM(tuple.get(), i, value);
|
||||
}
|
||||
return tuple.release();
|
||||
}
|
||||
|
||||
// Setter for NodeBase::sort_key: expects a Python tuple of ints, e.g.
|
||||
// node._sort_key = (1,2,3) Only used by pickle/__setstate__
|
||||
static int NodeBase_set_sort_key(
|
||||
PyObject* self,
|
||||
PyObject* value,
|
||||
void* /*closure*/) {
|
||||
NodeBase* node = reinterpret_cast<NodeBase*>(self);
|
||||
if (!PyTuple_Check(value)) {
|
||||
PyErr_SetString(PyExc_TypeError, "_sort_key must be an tuple of ints");
|
||||
return -1;
|
||||
}
|
||||
Py_ssize_t size = PyTuple_GET_SIZE(value);
|
||||
NodeSortKey new_vec;
|
||||
new_vec.reserve(size);
|
||||
for (Py_ssize_t i = 0; i < size; i++) {
|
||||
int64_t val = PyLong_AsSsize_t(PyTuple_GET_ITEM(value, i));
|
||||
if (val == -1 && PyErr_Occurred()) {
|
||||
return -1;
|
||||
}
|
||||
new_vec.emplace_back(val);
|
||||
}
|
||||
node->sort_key() = std::move(new_vec);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
static PyMethodDef NodeBase_methods[] = {
|
||||
{"_update_args_kwargs",
|
||||
(PyCFunction)(void*)(NodeBase__update_args_kwargs),
|
||||
METH_FASTCALL,
|
||||
"Internal method: do not call directly."},
|
||||
{"_remove_from_list",
|
||||
(PyCFunction)(void*)(NodeBase__remove_from_list),
|
||||
METH_NOARGS,
|
||||
"Internal method: do not call directly."},
|
||||
{"_prepend",
|
||||
(PyCFunction)(void*)(NodeBase__prepend),
|
||||
METH_O,
|
||||
"Internal method: do not call directly."},
|
||||
{"__lt__",
|
||||
(PyCFunction)(void*)NodeBase___lt__,
|
||||
METH_O,
|
||||
"Return True if self.sort_key < other.sort_key"},
|
||||
{"__gt__",
|
||||
(PyCFunction)(void*)NodeBase___gt__,
|
||||
METH_O,
|
||||
"Return True if self.sort_key > other.sort_key"},
|
||||
{"__ge__",
|
||||
(PyCFunction)(void*)NodeBase___ge__,
|
||||
METH_O,
|
||||
"Return True if self.sort_key >= other.sort_key"},
|
||||
{"__le__",
|
||||
(PyCFunction)(void*)NodeBase___le__,
|
||||
METH_O,
|
||||
"Return True if self.sort_key <= other.sort_key"},
|
||||
{nullptr, nullptr, 0, nullptr} // Sentinel
|
||||
};
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
static PyGetSetDef NodeBase_getset[] = {
|
||||
{"_sort_key", // attribute name in Python
|
||||
(getter)NodeBase_get_sort_key, // C getter function
|
||||
(setter)NodeBase_set_sort_key, // C setter function
|
||||
(char*)"The sort key as a tuple of ints", // docstring
|
||||
nullptr},
|
||||
{nullptr, nullptr, nullptr, nullptr, nullptr} // Sentinel
|
||||
};
|
||||
|
||||
PyTypeObject NodeBaseType = {
|
||||
PyVarObject_HEAD_INIT(nullptr, 0)
|
||||
"torch._C._NodeBase", /* tp_name */
|
||||
@ -361,7 +578,7 @@ PyTypeObject NodeBaseType = {
|
||||
nullptr, /* tp_iternext */
|
||||
NodeBase_methods, /* tp_methods */
|
||||
NodeBase_members, /* tp_members */
|
||||
nullptr, /* tp_getset */
|
||||
NodeBase_getset, /* tp_getset */
|
||||
nullptr, /* tp_base */
|
||||
nullptr, /* tp_dict */
|
||||
nullptr, /* tp_descr_get */
|
||||
|
@ -130,6 +130,7 @@ __all__ = [
|
||||
"reduce_scatter_tensor",
|
||||
"get_node_local_rank",
|
||||
"split_group",
|
||||
"shrink_group",
|
||||
]
|
||||
|
||||
_MPI_AVAILABLE = True
|
||||
@ -5713,3 +5714,521 @@ def _get_process_group_name(pg: ProcessGroup) -> str:
|
||||
|
||||
def _get_process_group_store(pg: ProcessGroup) -> Store:
|
||||
return _world.pg_map[pg][1]
|
||||
|
||||
|
||||
# Shrink flags for process group backends
|
||||
SHRINK_DEFAULT = 0x00
|
||||
SHRINK_ABORT = 0x01
|
||||
|
||||
|
||||
@_time_logger
|
||||
def shrink_group(
|
||||
ranks_to_exclude: list[int],
|
||||
group: Optional[ProcessGroup] = None,
|
||||
shrink_flags: int = SHRINK_DEFAULT,
|
||||
pg_options: Optional[Any] = None,
|
||||
) -> ProcessGroup:
|
||||
"""
|
||||
Shrinks a process group by excluding specified ranks.
|
||||
|
||||
Creates and returns a new, smaller process group comprising only the ranks
|
||||
from the original group that were not in the ``ranks_to_exclude`` list.
|
||||
|
||||
Args:
|
||||
ranks_to_exclude (List[int]): A list of ranks from the original
|
||||
``group`` to exclude from the new group.
|
||||
group (ProcessGroup, optional): The process group to shrink. If ``None``,
|
||||
the default process group is used. Defaults to ``None``.
|
||||
shrink_flags (int, optional): Flags to control the shrinking behavior.
|
||||
Can be ``SHRINK_DEFAULT`` (default) or ``SHRINK_ABORT``.
|
||||
``SHRINK_ABORT`` will attempt to terminate ongoing operations
|
||||
in the parent communicator before shrinking.
|
||||
Defaults to ``SHRINK_DEFAULT``.
|
||||
pg_options (ProcessGroupOptions, optional): Backend-specific options to apply
|
||||
to the shrunken process group. If provided, the backend will use
|
||||
these options when creating the new group. If omitted, the new group
|
||||
inherits defaults from the parent.
|
||||
|
||||
Returns:
|
||||
ProcessGroup: a new group comprised of the remaining ranks. If the
|
||||
default group was shrunk, the returned group becomes the new default group.
|
||||
|
||||
Raises:
|
||||
TypeError: if the group’s backend does not support shrinking.
|
||||
ValueError: if ``ranks_to_exclude`` is invalid (empty, out of bounds,
|
||||
duplicates, or excludes all ranks).
|
||||
RuntimeError: if an excluded rank calls this function or the backend
|
||||
fails the operation.
|
||||
|
||||
Notes:
|
||||
- Only non-excluded ranks should call this function; excluded ranks
|
||||
must not participate in the shrink operation.
|
||||
- Shrinking the default group destroys all other process groups since
|
||||
rank reassignment makes them inconsistent.
|
||||
"""
|
||||
# Step 1: Validate input parameters with comprehensive error checking
|
||||
_validate_shrink_inputs(ranks_to_exclude, shrink_flags)
|
||||
|
||||
# Step 2: Get target group and essential properties
|
||||
target_group_info = _prepare_shrink_target_group(group)
|
||||
|
||||
# Step 3: Validate backend requirements and availability
|
||||
backend_impl = _validate_shrink_backend_requirements(target_group_info)
|
||||
|
||||
# Step 4: Validate ranks against group and check for duplicates
|
||||
excluded_ranks_set = _validate_and_process_excluded_ranks(
|
||||
ranks_to_exclude, target_group_info
|
||||
)
|
||||
|
||||
# Step 5: Execute the actual shrink operation (backend-specific)
|
||||
new_backend = backend_impl.shrink(
|
||||
sorted(excluded_ranks_set),
|
||||
shrink_flags,
|
||||
pg_options if pg_options is not None else None,
|
||||
)
|
||||
|
||||
# Step 6: Handle cleanup and creation of new process group
|
||||
target_group_info["pg_options_override"] = pg_options
|
||||
return _finalize_shrunk_group(target_group_info, excluded_ranks_set, new_backend)
|
||||
|
||||
|
||||
def _validate_shrink_inputs(ranks_to_exclude: list[int], shrink_flags: int) -> None:
|
||||
"""Validate input parameters for shrink_group."""
|
||||
if not isinstance(ranks_to_exclude, list):
|
||||
raise TypeError(
|
||||
f"ranks_to_exclude must be a list, but got {type(ranks_to_exclude).__name__}. "
|
||||
f"Example: [1, 3, 5] to exclude ranks 1, 3, and 5."
|
||||
)
|
||||
|
||||
if not ranks_to_exclude:
|
||||
raise ValueError(
|
||||
"ranks_to_exclude cannot be empty. To shrink a group, you must specify at least "
|
||||
"one rank to exclude. Example: [failed_rank_id]"
|
||||
)
|
||||
|
||||
# Validate shrink_flags with clear explanation of valid values
|
||||
valid_flags = [SHRINK_DEFAULT, SHRINK_ABORT]
|
||||
if not isinstance(shrink_flags, int) or shrink_flags not in valid_flags:
|
||||
raise ValueError(
|
||||
f"Invalid shrink_flags value: {shrink_flags}. Must be one of: "
|
||||
f"SHRINK_DEFAULT ({SHRINK_DEFAULT}) or SHRINK_ABORT ({SHRINK_ABORT}). "
|
||||
f"Use SHRINK_ABORT to abort ongoing operations before shrinking."
|
||||
)
|
||||
|
||||
|
||||
def _prepare_shrink_target_group(group: Optional[ProcessGroup]) -> dict:
|
||||
"""Prepare and validate the target group for shrinking."""
|
||||
target_pg = group if group is not None else _get_default_group()
|
||||
|
||||
# Cache frequently accessed properties to avoid repeated calls
|
||||
group_size = int(target_pg.size())
|
||||
group_info = {
|
||||
"process_group": target_pg,
|
||||
"is_default_group": (target_pg == _get_default_group()),
|
||||
"group_size": group_size,
|
||||
"current_rank": target_pg.rank(),
|
||||
"group_name": _get_process_group_name(target_pg),
|
||||
}
|
||||
|
||||
# Validate that we have a valid process group
|
||||
if group_size <= 1:
|
||||
raise ValueError(
|
||||
f"Cannot shrink a process group with size {group_size}. "
|
||||
f"Group must have at least 2 ranks to support shrinking."
|
||||
)
|
||||
|
||||
return group_info
|
||||
|
||||
|
||||
def _validate_shrink_backend_requirements(group_info: dict) -> Any:
|
||||
"""Return the backend implementation for the target group or raise if unsupported."""
|
||||
target_pg = group_info["process_group"]
|
||||
group_name = group_info["group_name"]
|
||||
|
||||
# Get the group's backend directly via ProcessGroup API. Prefer a bound device if present,
|
||||
# otherwise try CUDA then fall back to CPU.
|
||||
try:
|
||||
preferred_device = getattr(target_pg, "bound_device_id", None)
|
||||
if preferred_device is not None:
|
||||
backend_impl = target_pg._get_backend(preferred_device)
|
||||
else:
|
||||
# Try CUDA first if available, else CPU
|
||||
try:
|
||||
backend_impl = target_pg._get_backend(torch.device("cuda"))
|
||||
except Exception:
|
||||
backend_impl = target_pg._get_backend(torch.device("cpu"))
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(
|
||||
f"Cannot access device backend for process group '{group_name}'. "
|
||||
f"Ensure the process group was initialized with a compatible device backend and devices are available."
|
||||
) from e
|
||||
|
||||
try:
|
||||
supports = bool(backend_impl.supports_shrinking)
|
||||
except Exception:
|
||||
supports = False
|
||||
if not supports:
|
||||
raise TypeError(
|
||||
f"Process group backend for '{group_name}' does not support shrinking operations."
|
||||
)
|
||||
|
||||
return backend_impl
|
||||
|
||||
|
||||
def _validate_and_process_excluded_ranks(
|
||||
ranks_to_exclude: list[int], group_info: dict
|
||||
) -> set:
|
||||
"""Validate excluded ranks and convert to set for efficient operations."""
|
||||
group_size = group_info["group_size"]
|
||||
current_rank = group_info["current_rank"]
|
||||
|
||||
# Use set for O(1) duplicate detection and membership testing
|
||||
excluded_ranks_set = set()
|
||||
|
||||
# Validate each rank with detailed error messages
|
||||
for i, rank in enumerate(ranks_to_exclude):
|
||||
if not isinstance(rank, int):
|
||||
raise TypeError(
|
||||
f"All elements in ranks_to_exclude must be integers. "
|
||||
f"Element at index {i} is {type(rank).__name__}: {rank}"
|
||||
)
|
||||
|
||||
if not (0 <= rank < group_size):
|
||||
raise ValueError(
|
||||
f"Rank {rank} at index {i} is out of bounds for group size {group_size}. "
|
||||
f"Valid ranks are in range [0, {group_size - 1}]."
|
||||
)
|
||||
|
||||
if rank in excluded_ranks_set:
|
||||
raise ValueError(
|
||||
f"Duplicate rank {rank} found in ranks_to_exclude at index {i}. "
|
||||
f"Each rank can only be excluded once."
|
||||
)
|
||||
|
||||
excluded_ranks_set.add(rank)
|
||||
|
||||
# Ensure we don't exclude all ranks
|
||||
if len(excluded_ranks_set) >= group_size:
|
||||
raise ValueError(
|
||||
f"Cannot exclude all {group_size} ranks from process group. "
|
||||
f"At least one rank must remain. Excluding {len(excluded_ranks_set)} ranks."
|
||||
)
|
||||
|
||||
# Critical check: current rank should not be in excluded list
|
||||
if current_rank in excluded_ranks_set:
|
||||
raise RuntimeError(
|
||||
f"Current rank {current_rank} is in the exclusion list and should not call shrink_group(). "
|
||||
f"Only non-excluded ranks should participate in the shrinking operation. "
|
||||
f"Excluded ranks should terminate their processes instead."
|
||||
)
|
||||
|
||||
return excluded_ranks_set
|
||||
|
||||
|
||||
def _finalize_shrunk_group(
|
||||
group_info: dict, excluded_ranks_set: set, new_backend
|
||||
) -> ProcessGroup:
|
||||
"""Clean up old group and create new shrunk process group."""
|
||||
target_pg = group_info["process_group"]
|
||||
is_default_group = group_info["is_default_group"]
|
||||
|
||||
# Handle default group dependencies - destroy other groups first
|
||||
if is_default_group:
|
||||
_destroy_all_other_groups(exclude_group=target_pg)
|
||||
|
||||
# Gather original group metadata before cleanup
|
||||
original_group_metadata = _extract_group_metadata(target_pg)
|
||||
|
||||
# Calculate remaining ranks efficiently
|
||||
original_ranks = get_process_group_ranks(target_pg)
|
||||
remaining_ranks = [
|
||||
rank for rank in original_ranks if rank not in excluded_ranks_set
|
||||
]
|
||||
|
||||
# Clean up the original group
|
||||
_cleanup_original_group(target_pg, is_default_group)
|
||||
|
||||
# Create and configure the new process group
|
||||
new_pg = _create_shrunk_process_group(
|
||||
new_backend, remaining_ranks, original_group_metadata, is_default_group
|
||||
)
|
||||
|
||||
# Register the new group in global state
|
||||
if is_default_group:
|
||||
_update_default_pg(new_pg)
|
||||
|
||||
# Update global state with new group information
|
||||
rank_mapping = {
|
||||
global_rank: group_rank
|
||||
for group_rank, global_rank in enumerate(remaining_ranks)
|
||||
}
|
||||
_update_process_group_global_state(
|
||||
pg=new_pg,
|
||||
backend_name=original_group_metadata["backend_name"],
|
||||
store=original_group_metadata["store"],
|
||||
group_name=original_group_metadata["new_group_name"],
|
||||
backend_config=original_group_metadata["backend_config"],
|
||||
rank_mapping=rank_mapping,
|
||||
)
|
||||
|
||||
return new_pg
|
||||
|
||||
|
||||
def _extract_group_metadata(target_pg: ProcessGroup) -> dict:
|
||||
"""Extract metadata from the original group before cleanup."""
|
||||
original_backend_name, original_store = _world.pg_map[target_pg]
|
||||
original_backend_config = _world.pg_backend_config.get(target_pg, "")
|
||||
original_group_name = _get_process_group_name(target_pg)
|
||||
|
||||
# Extract device binding information before cleanup to avoid accessing destroyed group
|
||||
bound_device_id = None
|
||||
if hasattr(target_pg, "bound_device_id"):
|
||||
bound_device_id = target_pg.bound_device_id
|
||||
|
||||
# Generate new group name for the shrunk group; hash for uniqueness across backends
|
||||
remaining_ranks = list(get_process_group_ranks(target_pg))
|
||||
new_group_name = _process_group_name(remaining_ranks, use_hashed_name=True)
|
||||
|
||||
return {
|
||||
"backend_name": original_backend_name,
|
||||
"store": original_store,
|
||||
"backend_config": original_backend_config,
|
||||
"original_group_name": original_group_name,
|
||||
"new_group_name": new_group_name,
|
||||
"bound_device_id": bound_device_id, # Safe to access after cleanup
|
||||
}
|
||||
|
||||
|
||||
def _cleanup_original_group(target_pg: ProcessGroup, is_default_group: bool) -> None:
|
||||
"""Clean up the original process group safely."""
|
||||
try:
|
||||
destroy_process_group(target_pg)
|
||||
except Exception as e:
|
||||
group_type = "default" if is_default_group else "non-default"
|
||||
logger.warning(
|
||||
"Failed to destroy %s group during shrinking: %s", group_type, str(e)
|
||||
)
|
||||
|
||||
# Ensure global state cleanup even if destroy_process_group fails
|
||||
_cleanup_process_group_global_state(target_pg)
|
||||
|
||||
|
||||
def _create_shrunk_process_group(
|
||||
new_backend, remaining_ranks: list[int], metadata: dict, is_default_group: bool
|
||||
) -> ProcessGroup:
|
||||
"""Create and configure the new shrunk process group."""
|
||||
# Create new group properties
|
||||
new_group_rank = new_backend.rank()
|
||||
new_group_size = new_backend.size()
|
||||
group_name = metadata["new_group_name"]
|
||||
|
||||
# Generate descriptive group description
|
||||
if is_default_group:
|
||||
group_desc = "default:shrunken"
|
||||
else:
|
||||
group_desc = f"{metadata['original_group_name']}:shrunk"
|
||||
|
||||
# Create process group with new communicator (clone the parent store like split does)
|
||||
prefix_store = PrefixStore(f"{group_name}/", metadata["store"].clone())
|
||||
new_pg = ProcessGroup(prefix_store, new_group_rank, new_group_size)
|
||||
|
||||
# Configure backend using the device type of the new backend's bound device if available,
|
||||
# otherwise derive from the original group's bound device or fall back to CPU.
|
||||
backend_device = metadata.get("bound_device_id")
|
||||
if backend_device is None:
|
||||
# Default to CPU if no bound device is present
|
||||
backend_device = torch.device("cpu")
|
||||
|
||||
# Choose backend enum based on device type
|
||||
if backend_device.type == "cuda":
|
||||
backend_type = ProcessGroup.BackendType.NCCL
|
||||
else:
|
||||
backend_type = ProcessGroup.BackendType.GLOO
|
||||
|
||||
new_pg._register_backend(backend_device, backend_type, new_backend)
|
||||
new_pg._set_default_backend(backend_type)
|
||||
|
||||
# Inherit device binding from original group if it was bound
|
||||
bound_device_id = metadata.get("bound_device_id")
|
||||
if bound_device_id is not None:
|
||||
new_pg.bound_device_id = bound_device_id
|
||||
|
||||
# Set group metadata
|
||||
new_pg._set_group_name(group_name)
|
||||
new_pg._set_group_desc(group_desc)
|
||||
|
||||
# Persist backend configuration overrides (if provided via shrink_group)
|
||||
backend_config_override = metadata.get("backend_config")
|
||||
if backend_config_override is not None:
|
||||
# Store for introspection/debugging and potential backend hooks
|
||||
_world.pg_backend_config[new_pg] = backend_config_override
|
||||
|
||||
return new_pg
|
||||
|
||||
|
||||
def _destroy_all_other_groups(exclude_group: Optional[ProcessGroup] = None) -> None:
|
||||
"""
|
||||
Destroy all process groups except the excluded group and clean up all global state.
|
||||
|
||||
This is necessary when shrinking the default group because global ranks
|
||||
are reassigned by NCCL, making all existing process groups inconsistent.
|
||||
|
||||
Note: Uses abort for non-collective cleanup since excluded ranks may not
|
||||
participate in collective operations. Backend cleanup is handled independently per group.
|
||||
|
||||
Args:
|
||||
exclude_group (ProcessGroup, optional): Process group to exclude from destruction.
|
||||
If None, destroys all process groups.
|
||||
"""
|
||||
# Get list of groups to destroy (avoid modifying dict while iterating)
|
||||
groups_to_destroy = []
|
||||
for pg in list(_world.pg_group_ranks.keys()):
|
||||
if exclude_group is not None and pg == exclude_group:
|
||||
continue
|
||||
groups_to_destroy.append(pg)
|
||||
|
||||
# Warn user about automatic destruction
|
||||
if groups_to_destroy:
|
||||
group_names = [_get_process_group_name(pg) for pg in groups_to_destroy]
|
||||
logger.warning(
|
||||
"Shrinking default group will destroy %d other process groups: %s. "
|
||||
"This is necessary because shrinking the default group reassigns global ranks, "
|
||||
"making existing groups inconsistent.",
|
||||
len(groups_to_destroy),
|
||||
", ".join(group_names),
|
||||
)
|
||||
|
||||
# Destroy each group and clean up global state
|
||||
for pg in groups_to_destroy:
|
||||
try:
|
||||
# First call abort_process_group which handles the C++ cleanup non-collectively
|
||||
_abort_process_group(pg)
|
||||
except Exception as e:
|
||||
# Log but don't fail - some groups might already be destroyed
|
||||
logger.warning(
|
||||
"Failed to abort process group %s: %s",
|
||||
_get_process_group_name(pg),
|
||||
str(e),
|
||||
)
|
||||
|
||||
# Ensure all global state is cleaned up even if _abort_process_group fails
|
||||
# or doesn't clean up everything
|
||||
_cleanup_process_group_global_state(pg)
|
||||
|
||||
|
||||
def _cleanup_process_group_global_state(pg: ProcessGroup) -> None:
|
||||
"""
|
||||
Clean up all global state associated with a process group.
|
||||
|
||||
This function ensures complete cleanup of process group state from all
|
||||
global dictionaries and registries, even if destroy_process_group fails
|
||||
or doesn't clean up everything. This is critical when destroying multiple
|
||||
groups to prevent inconsistent state.
|
||||
|
||||
The cleanup removes the process group from:
|
||||
- _world.pg_map (backend and store mapping)
|
||||
- _world.pg_names (group name mapping)
|
||||
- _world.pg_group_ranks (rank mappings)
|
||||
- _world.pg_backend_config (backend configuration)
|
||||
- _world.tags_to_pg and _world.pg_to_tag (tag mappings)
|
||||
- _world.pg_coalesce_state (coalescing state)
|
||||
- C++ internal registries via _unregister_process_group
|
||||
|
||||
Args:
|
||||
pg (ProcessGroup): The process group to clean up.
|
||||
"""
|
||||
try:
|
||||
# Clean up main process group mappings
|
||||
_world.pg_map.pop(pg, None)
|
||||
_world.pg_group_ranks.pop(pg, None)
|
||||
_world.pg_backend_config.pop(pg, None)
|
||||
|
||||
# Clean up process group name mapping
|
||||
group_name = _world.pg_names.pop(pg, None)
|
||||
|
||||
# Clean up tag mappings
|
||||
pg_tag = _world.pg_to_tag.pop(pg, None)
|
||||
if pg_tag is not None and pg_tag in _world.tags_to_pg:
|
||||
try:
|
||||
_world.tags_to_pg[pg_tag].remove(pg)
|
||||
# Remove the tag entry if list is empty
|
||||
if not _world.tags_to_pg[pg_tag]:
|
||||
_world.tags_to_pg.pop(pg_tag, None)
|
||||
except (ValueError, KeyError):
|
||||
# Process group was already removed from the list
|
||||
pass
|
||||
|
||||
# Clean up any registered process group names using C++ unregister function
|
||||
if group_name is not None:
|
||||
try:
|
||||
_unregister_process_group(group_name)
|
||||
except Exception:
|
||||
# Process group name might not be registered or already unregistered
|
||||
pass
|
||||
|
||||
# Clean up coalesce state if present
|
||||
_world.pg_coalesce_state.pop(pg, None)
|
||||
|
||||
except Exception as e:
|
||||
# Log cleanup failures but don't propagate - we want to continue with other cleanups
|
||||
logger.warning(
|
||||
"Failed to fully clean up global state for process group: %s", str(e)
|
||||
)
|
||||
|
||||
|
||||
def _update_process_group_global_state(
|
||||
pg: ProcessGroup,
|
||||
backend_name: str,
|
||||
store: Store,
|
||||
group_name: str,
|
||||
backend_config: str,
|
||||
rank_mapping: Optional[dict[int, int]] = None,
|
||||
pg_tag: Optional[str] = None,
|
||||
user_tag: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Update all global state dictionaries for a process group.
|
||||
|
||||
This helper function consolidates the common pattern of updating multiple
|
||||
global state dictionaries when creating or modifying process groups.
|
||||
|
||||
Args:
|
||||
pg (ProcessGroup): The process group to update state for.
|
||||
backend_name (str): Backend name for pg_map.
|
||||
store (Store): Store instance for pg_map.
|
||||
group_name (str): Group name for pg_names and registration.
|
||||
backend_config (str): Backend configuration string.
|
||||
rank_mapping (Dict[int, int], optional): Global rank to group rank mapping.
|
||||
If None, skips updating pg_group_ranks.
|
||||
pg_tag (str, optional): Process group tag. If None, defaults to f"ptd:{group_name}".
|
||||
user_tag (str, optional): User-provided tag for special tag handling.
|
||||
If provided, creates "user:{user_tag}" tag and also adds to default "".
|
||||
"""
|
||||
# Update main process group mappings
|
||||
_world.pg_map[pg] = (backend_name, store)
|
||||
_world.pg_names[pg] = group_name
|
||||
_world.pg_backend_config[pg] = backend_config
|
||||
|
||||
# Register the process group name
|
||||
_register_process_group(group_name, pg)
|
||||
|
||||
# Update rank mapping if provided
|
||||
if rank_mapping is not None:
|
||||
_world.pg_group_ranks[pg] = rank_mapping
|
||||
|
||||
# Handle tag management
|
||||
if pg_tag is None:
|
||||
pg_tag = f"ptd:{group_name}"
|
||||
|
||||
if user_tag is not None:
|
||||
# Special handling for user-provided tags
|
||||
# Add to default "" tag first
|
||||
_world.tags_to_pg.setdefault("", []).append(pg)
|
||||
# Then create user-specific tag
|
||||
user_pg_tag = f"user:{user_tag}"
|
||||
_world.tags_to_pg.setdefault(user_pg_tag, []).append(pg)
|
||||
_world.pg_to_tag[pg] = user_pg_tag
|
||||
else:
|
||||
# Standard process group tag
|
||||
_world.tags_to_pg.setdefault(pg_tag, []).append(pg)
|
||||
_world.pg_to_tag[pg] = pg_tag
|
||||
|
@ -385,41 +385,7 @@ class Node(_NodeBase):
|
||||
Args:
|
||||
x (Node): The node to put before this node. Must be a member of the same graph.
|
||||
"""
|
||||
assert self.graph == x.graph, "Attempting to move a Node into a different Graph"
|
||||
if self == x:
|
||||
log.debug(
|
||||
"Trying to prepend a node to itself. This behavior has no effect on the graph."
|
||||
)
|
||||
return
|
||||
x._remove_from_list()
|
||||
p = self._prev
|
||||
p._next, x._prev = x, p
|
||||
x._next, self._prev = self, x
|
||||
|
||||
# compute x._sort_key
|
||||
psk = x._prev._sort_key
|
||||
nsk = x._next._sort_key
|
||||
if len(psk) > len(nsk):
|
||||
idx: int
|
||||
*prefix, idx = psk[: len(nsk) + 1]
|
||||
x._sort_key = (*prefix, idx + 1)
|
||||
elif len(psk) < len(nsk):
|
||||
*prefix, idx = nsk[: len(psk) + 1]
|
||||
x._sort_key = (*prefix, idx - 1)
|
||||
else: # same length, increase length by 1
|
||||
x._sort_key = (*psk, 0)
|
||||
|
||||
def __gt__(self, other: "Node") -> bool:
|
||||
return self._sort_key > other._sort_key
|
||||
|
||||
def __lt__(self, other: "Node") -> bool:
|
||||
return self._sort_key < other._sort_key
|
||||
|
||||
def __ge__(self, other: "Node") -> bool:
|
||||
return self > other or self == other
|
||||
|
||||
def __le__(self, other: "Node") -> bool:
|
||||
return self < other or self == other
|
||||
self._prepend(x)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def append(self, x: "Node") -> None:
|
||||
@ -430,11 +396,7 @@ class Node(_NodeBase):
|
||||
Args:
|
||||
x (Node): The node to put after this node. Must be a member of the same graph.
|
||||
"""
|
||||
self._next.prepend(x)
|
||||
|
||||
def _remove_from_list(self) -> None:
|
||||
p, n = self._prev, self._next
|
||||
p._next, n._prev = n, p
|
||||
self._next._prepend(x)
|
||||
|
||||
@property
|
||||
def args(self) -> tuple[Argument, ...]:
|
||||
|
@ -238,6 +238,47 @@ def skip_if_lt_x_gpu(x):
|
||||
return decorator
|
||||
|
||||
|
||||
def requires_world_size(n: int):
|
||||
"""
|
||||
Decorator to request a specific world size for a test. The test harness can
|
||||
read this attribute to set the number of ranks to spawn. If there are fewer
|
||||
than `n` CUDA devices available, the test should be skipped by the harness.
|
||||
|
||||
Usage:
|
||||
@require_world_size(3)
|
||||
def test_something(self):
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
func._required_world_size = n
|
||||
available = torch.cuda.device_count()
|
||||
return unittest.skipUnless(
|
||||
available >= n, f"requires {n} GPUs, found {available}"
|
||||
)(func)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_required_world_size(obj: Any, default: int) -> int:
|
||||
"""
|
||||
Returns the requested world size for the currently running unittest method on `obj`
|
||||
if annotated via `@require_world_size(n)`, else returns `default`.
|
||||
"""
|
||||
try:
|
||||
# Try MultiProcessTestCase helper first, then unittest fallback
|
||||
test_name = (
|
||||
obj._current_test_name() # type: ignore[attr-defined]
|
||||
if hasattr(obj, "_current_test_name") and callable(obj._current_test_name)
|
||||
else obj._testMethodName
|
||||
)
|
||||
fn = getattr(obj, test_name)
|
||||
value = fn._required_world_size
|
||||
return int(value)
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
|
||||
# This decorator helps avoiding initializing cuda while testing other backends
|
||||
def nccl_skip_if_lt_x_gpu(backend, x):
|
||||
def decorator(func):
|
||||
@ -367,6 +408,13 @@ def requires_nccl_version(version, msg):
|
||||
)
|
||||
|
||||
|
||||
def requires_nccl_shrink():
|
||||
"""
|
||||
Require NCCL shrink support (NCCL available and version >= 2.27).
|
||||
"""
|
||||
return requires_nccl_version((2, 27), "Need NCCL 2.27+ for shrink_group")
|
||||
|
||||
|
||||
def requires_nccl():
|
||||
return skip_but_pass_in_sandcastle_if(
|
||||
not c10d.is_nccl_available(),
|
||||
|
@ -1836,7 +1836,7 @@ def make_dynamo_test(
|
||||
Decorator function to create a dynamo test case. A function annotate with
|
||||
this decorator takes as input a unittest object.
|
||||
"""
|
||||
from torch._dynamo.testing import CompileCounter, reset, optimize_assert
|
||||
from torch._dynamo.testing import CompileCounter, reset, optimize
|
||||
if fn is None:
|
||||
return lambda fn: make_dynamo_test(fn)
|
||||
|
||||
@ -1852,7 +1852,7 @@ def make_dynamo_test(
|
||||
|
||||
dummy()
|
||||
reset()
|
||||
opt_fn = optimize_assert(actual)(dummy)
|
||||
opt_fn = optimize(actual, error_on_graph_break=True)(dummy)
|
||||
opt_fn()
|
||||
reset()
|
||||
|
||||
|
Reference in New Issue
Block a user