Compare commits

...

25 Commits

Author SHA1 Message Date
2cdd4617c8 Update on "[dynamo] prevent recompilation limit exceeded on external_utils.wrap_inline"
This will prevent recompilations (i.e. logs to TORCH_LOGS="recompiles") due to wrap_inline's guard on `fn`. If we call wrap_inline on enough different functions, we would hit the recompile limit.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-21 13:30:06 -07:00
13e04b57e1 Update base for Update on "[dynamo] prevent recompilation limit exceeded on external_utils.wrap_inline"
This will prevent recompilations (i.e. logs to TORCH_LOGS="recompiles") due to wrap_inline's guard on `fn`. If we call wrap_inline on enough different functions, we would hit the recompile limit.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-21 13:30:05 -07:00
1c7fe8f861 [BugFix] chunk_size should always be int64_t (#165971)
aspired by https://github.com/pytorch/pytorch/pull/156872
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165971
Approved by: https://github.com/albanD
2025-10-21 19:52:47 +00:00
4e643422f6 shrink_group implementation to expose ncclCommShrink API (#164518)
Closes #164529

To expose the new [ncclCommShrink](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommshrink) API to PyTorch.

This is useful when you need to exclude certain GPUs or nodes from a collective operation, for example in fault tolerance scenarios or when dynamically adjusting resource utilization.

For more info:  [Shrinking a communicator](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#shrinking-a-communicator)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164518
Approved by: https://github.com/kwen2501
2025-10-21 19:47:33 +00:00
3c3b278872 [reland][fx] Move Node._prepend/Node._remove_from_list to C++ (#165882)
Relands #148261 that was reverted by #150542

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165882
Approved by: https://github.com/ezyang
2025-10-21 19:43:55 +00:00
49bbcc5833 Update on "[dynamo] prevent recompilation limit exceeded on external_utils.wrap_inline"
This will prevent recompilations (i.e. logs to TORCH_LOGS="recompiles") due to wrap_inline's guard on `fn`. If we call wrap_inline on enough different functions, we would hit the recompile limit.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-21 11:59:20 -07:00
292d8e6caa Update base for Update on "[dynamo] prevent recompilation limit exceeded on external_utils.wrap_inline"
This will prevent recompilations (i.e. logs to TORCH_LOGS="recompiles") due to wrap_inline's guard on `fn`. If we call wrap_inline on enough different functions, we would hit the recompile limit.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-21 11:59:20 -07:00
80f4d11f12 Update on "[dynamo] prevent recompilation limit exceeded on external_utils.wrap_inline"
This will prevent recompilations (i.e. logs to TORCH_LOGS="recompiles") due to wrap_inline's guard on `fn`. If we call wrap_inline on enough different functions, we would hit the recompile limit.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-20 15:23:59 -07:00
99f9d8fee0 Update base for Update on "[dynamo] prevent recompilation limit exceeded on external_utils.wrap_inline"
This will prevent recompilations (i.e. logs to TORCH_LOGS="recompiles") due to wrap_inline's guard on `fn`. If we call wrap_inline on enough different functions, we would hit the recompile limit.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-20 15:23:59 -07:00
9de148417c Update on "[dynamo] prevent recompilation limit exceeded on external_utils.wrap_inline"
This will prevent recompilations (i.e. logs to TORCH_LOGS="recompiles") due to wrap_inline's guard on `fn`. If we call wrap_inline on enough different functions, we would hit the recompile limit.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-20 11:38:18 -07:00
9c0e3db285 Update base for Update on "[dynamo] prevent recompilation limit exceeded on external_utils.wrap_inline"
This will prevent recompilations (i.e. logs to TORCH_LOGS="recompiles") due to wrap_inline's guard on `fn`. If we call wrap_inline on enough different functions, we would hit the recompile limit.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-20 11:38:18 -07:00
6ed6905d29 Update on "[dynamo] prevent recompilation limit exceeded on external_utils.wrap_inline"
This will prevent recompilations (i.e. logs to TORCH_LOGS="recompiles") due to wrap_inline's guard on `fn`. If we call wrap_inline on enough different functions, we would hit the recompile limit.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-17 17:41:57 -07:00
08af8a99e9 Update base for Update on "[dynamo] prevent recompilation limit exceeded on external_utils.wrap_inline"
This will prevent recompilations (i.e. logs to TORCH_LOGS="recompiles") due to wrap_inline's guard on `fn`. If we call wrap_inline on enough different functions, we would hit the recompile limit.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-17 17:41:57 -07:00
69210a3ecc Update on "[dynamo] prevent recompilation limit exceeded on external_utils.wrap_inline"
This will prevent recompilations (i.e. logs to TORCH_LOGS="recompiles") due to wrap_inline's guard on `fn`. If we call wrap_inline on enough different functions, we would hit the recompile limit.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-15 17:02:55 -07:00
2be01db423 Update base for Update on "[dynamo] prevent recompilation limit exceeded on external_utils.wrap_inline"
This will prevent recompilations (i.e. logs to TORCH_LOGS="recompiles") due to wrap_inline's guard on `fn`. If we call wrap_inline on enough different functions, we would hit the recompile limit.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-15 17:02:55 -07:00
c8f6c13c96 Update on "[dynamo] prevent recompilation limit exceeded on external_utils.wrap_inline"
This will prevent recompilations (i.e. logs to TORCH_LOGS="recompiles") due to wrap_inline's guard on `fn`. If we call wrap_inline on enough different functions, we would hit the recompile limit.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-14 15:19:54 -07:00
f8926ed88c Update base for Update on "[dynamo] prevent recompilation limit exceeded on external_utils.wrap_inline"
This will prevent recompilations (i.e. logs to TORCH_LOGS="recompiles") due to wrap_inline's guard on `fn`. If we call wrap_inline on enough different functions, we would hit the recompile limit.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-14 15:19:54 -07:00
495607e655 Update on "[dynamo] prevent recompilation limit exceeded on external_utils.wrap_inline"
This will prevent recompilations (i.e. logs to TORCH_LOGS="recompiles") due to wrap_inline's guard on `fn`. If we call wrap_inline on enough different functions, we would hit the recompile limit.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-08 10:50:05 -07:00
7f603f54d1 Update base for Update on "[dynamo] prevent recompilation limit exceeded on external_utils.wrap_inline"
This will prevent recompilations (i.e. logs to TORCH_LOGS="recompiles") due to wrap_inline's guard on `fn`. If we call wrap_inline on enough different functions, we would hit the recompile limit.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-08 10:50:05 -07:00
35466f9ef9 Update on "[dynamo] prevent recompilation limit exceeded on external_utils.wrap_inline"
This will prevent recompilations (i.e. logs to TORCH_LOGS="recompiles") due to wrap_inline's guard on `fn`. If we call wrap_inline on enough different functions, we would hit the recompile limit.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-07 16:04:04 -07:00
65300600dd Update base for Update on "[dynamo] prevent recompilation limit exceeded on external_utils.wrap_inline"
This will prevent recompilations (i.e. logs to TORCH_LOGS="recompiles") due to wrap_inline's guard on `fn`. If we call wrap_inline on enough different functions, we would hit the recompile limit.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-10-07 16:04:04 -07:00
95d52c623d Update on "[dynamo] prevent recompilation limit exceeded on external_utils.wrap_inline"
This will prevent recompilations (i.e. logs to TORCH_LOGS="recompiles") due to wrap_inline's guard on `fn`. If we call wrap_inline on enough different functions, we would hit the recompile limit.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-09-26 18:20:49 -07:00
e2f2803882 Update base for Update on "[dynamo] prevent recompilation limit exceeded on external_utils.wrap_inline"
This will prevent recompilations (i.e. logs to TORCH_LOGS="recompiles") due to wrap_inline's guard on `fn`. If we call wrap_inline on enough different functions, we would hit the recompile limit.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-09-26 18:20:49 -07:00
8d6c701f80 Update on "[dynamo] prevent recompilations on external_utils.wrap_inline"
This will prevent recompilations due to wrap_inline's guard on `fn` from hitting the recompilation limit.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
2025-09-22 12:33:14 -07:00
eb161310a5 [dynamo] prevent recompilations on external_utils.wrap_inline
[ghstack-poisoned]
2025-09-22 03:44:24 -07:00
33 changed files with 2482 additions and 256 deletions

View File

@ -52,7 +52,7 @@ struct FusedAdagradMathFunctor {
using opmath_t = at::opmath_type<scalar_t>;
C10_DEVICE __forceinline__ void operator()(
int chunk_size,
int64_t chunk_size,
FusedOptimizerTensorListMetadata<3>& tl,
const float* lr_ptr,
const double& lr,
@ -133,4 +133,4 @@ struct FusedAdagradMathFunctor {
} // namespace
} // namespace at::native
} // namespace at::native

View File

@ -1,8 +1,8 @@
add_loop_eager,compile_time_instruction_count,3070000000,0.1
add_loop_eager,compile_time_instruction_count,3184000000,0.1
add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.1
add_loop_eager_dynamic,compile_time_instruction_count,4595000000,0.1
@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1048000000,0.1
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1096000000,0.1
@ -26,7 +26,7 @@ basic_modules_ListOfLinears_inductor,compile_time_instruction_count,15240000000,
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.1
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17720000000,0.1
@ -34,11 +34,11 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11090000
update_hint_regression,compile_time_instruction_count,1719000000,0.1
update_hint_regression,compile_time_instruction_count,1645000000,0.1
sum_floordiv_regression,compile_time_instruction_count,3686995725,0.1
sum_floordiv_regression,compile_time_instruction_count,3813000000,0.1
@ -50,31 +50,31 @@ symint_sum_loop,compile_time_instruction_count,4299000000,0.1
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1869000000,0.1
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1793000000,0.1
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5281000000,0.1
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5120000000,0.1
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8333000000,0.1
aotdispatcher_partitioner_cpu,compile_time_instruction_count,7936000000,0.1
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1909000000,0.1
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1848000000,0.1
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3442000000,0.1
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3152000000,0.1
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9239000000,0.1
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,8301000000,0.1
mm_loop_inductor_gpu,compile_time_instruction_count,4820968837,0.1
mm_loop_inductor_gpu,compile_time_instruction_count,4958000000,0.1
@ -82,8 +82,8 @@ mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,9051000000,0.1
basic_NestedModule_eager,compile_time_instruction_count,9554000000,0.1
basic_NestedModule_eager,compile_time_instruction_count,9990000000,0.1
basic_InlineMod_eager,compile_time_instruction_count,7618000000,0.1
basic_InlineMod_eager,compile_time_instruction_count,8126000000,0.1

1 add_loop_eager compile_time_instruction_count 3070000000 3184000000 0.1
2 add_loop_eager_dynamic compile_time_instruction_count 4432000000 4595000000 0.1
3 add_loop_inductor compile_time_instruction_count 29660000000 29660000000 0.1
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 39910000000 39910000000 0.1
5 add_loop_inductor_gpu compile_time_instruction_count 26800000000 26800000000 0.1
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 1048000000 1096000000 0.1
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 15240000000 15240000000 0.1
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 17020000000 17720000000 0.1
18 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3442000000 3152000000 0.1
19 aotdispatcher_training_subclass_cpu compile_time_instruction_count 9239000000 8301000000 0.1
20 mm_loop_inductor_gpu compile_time_instruction_count 4820968837 4958000000 0.1
21 mm_loop_inductor_dynamic_gpu compile_time_instruction_count 9051000000 9051000000 0.1
22 basic_NestedModule_eager compile_time_instruction_count 9554000000 9990000000 0.1
23 basic_InlineMod_eager compile_time_instruction_count 7618000000 8126000000 0.1
24
26
27
28
29
30
31
32
34
35
36
37
38
39
40
41
42
43
44
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
82
83
84
85
86
87
88
89

View File

@ -394,6 +394,10 @@ an opaque group handle that can be given as a `group` argument to all collective
.. autofunction:: new_group
```
```{eval-rst}
.. autofunction:: torch.distributed.distributed_c10d.shrink_group
```
```{eval-rst}
.. autofunction:: get_group_rank
```

View File

@ -0,0 +1,43 @@
import logging
import time
_start_time = time.time()
_logger = logging.getLogger(__name__)
def _ts():
return time.time() - _start_time
def configure(level=logging.INFO, force=False):
try:
logging.basicConfig(
level=level,
format="%(asctime)s %(name)s %(levelname)s: %(message)s",
force=force,
)
except TypeError:
logging.basicConfig(
level=level, format="%(asctime)s %(name)s %(levelname)s: %(message)s"
)
def log_test_info(rank, message):
_logger.info("[%7.3fs][Rank %s] %s", _ts(), rank, message)
def log_test_success(rank, message):
_logger.info("[%7.3fs][Rank %s] ✅ %s", _ts(), rank, message)
def log_test_validation(rank, message):
_logger.info("[%7.3fs][Rank %s] ✓ %s", _ts(), rank, message)
def log_test_warning(rank, message):
_logger.warning("[%7.3fs][Rank %s] ⚠️ %s", _ts(), rank, message)
def log_test_error(rank, message):
_logger.error("[%7.3fs][Rank %s] ✗ %s", _ts(), rank, message)

View File

@ -2,6 +2,7 @@
import copy
import json
import logging
import os
import pickle
import random
@ -21,6 +22,7 @@ from unittest import mock, SkipTest
import torch
import torch.distributed as c10d
import torch.distributed._functional_collectives as _functional_collectives
from torch.distributed.distributed_c10d import SHRINK_ABORT as NCCL_SHRINK_ABORT
if not c10d.is_available() or not c10d.is_nccl_available():
@ -47,12 +49,15 @@ from torch._C._distributed_c10d import ErrorType, OpType, WorkResult
from torch.nn.parallel import DistributedDataParallel
from torch.testing._internal.common_cuda import _get_torch_rocm_version, TEST_MULTIGPU
from torch.testing._internal.common_distributed import (
get_required_world_size,
get_timeout,
init_multigpu_helper,
MultiProcessTestCase,
requires_multicast_support,
requires_nccl,
requires_nccl_shrink,
requires_nccl_version,
requires_world_size,
skip_if_lt_x_gpu,
skip_if_rocm_multiprocess,
sm_is_or_higher_than,
@ -87,6 +92,17 @@ BFLOAT16_AVAILABLE = torch.cuda.is_available() and (
torch.version.cuda is not None or torch.version.hip is not None
)
from logging_utils import (
configure as _log_configure,
log_test_info,
log_test_success,
log_test_validation,
log_test_warning,
)
_log_configure(level=logging.INFO, force=True)
class RendezvousEnvTest(TestCase):
@retry_on_connect_failures
@ -317,7 +333,7 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
@property
def world_size(self):
return 2
return get_required_world_size(self, 2)
@property
def rank_to_GPU(self):
@ -1255,6 +1271,628 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
pg_2 = c10d.new_group([0, 1])
self.assertEqual(pg_2.group_desc, "undefined")
@requires_nccl_shrink()
@requires_world_size(2)
def test_shrink_group_basic(self):
"""Test basic shrink_group functionality."""
self._perform_shrink_test([1], "Basic shrink test")
@requires_nccl_shrink()
@requires_world_size(2)
def test_shrink_group_validation(self):
"""Test input validation in shrink_group."""
device, pg = self._setup_shrink_test("validation")
def _test_invalid_input(ranks, description, expected_exception):
"""Helper to test invalid inputs."""
try:
c10d.shrink_group(ranks)
self.fail(f"Expected {expected_exception.__name__} for {description}")
except expected_exception:
log_test_validation(self.rank, f"{description}")
except Exception:
if expected_exception is Exception: # Accept any exception
log_test_validation(self.rank, f"{description}")
else:
raise
# Test cases
_test_invalid_input([], "Empty exclusion list", ValueError)
if self.world_size > 1:
_test_invalid_input([0, 0, 1], "Duplicate ranks", Exception)
_test_invalid_input([self.world_size + 1], "Out of bounds rank", Exception)
log_test_success(self.rank, "All validation tests passed")
dist.destroy_process_group()
@requires_nccl_shrink()
@requires_world_size(2)
def test_shrink_group_backend_properties(self):
"""Test that backend properties are preserved after shrinking."""
test_name = "Backend Properties Test"
ranks_to_exclude = [0]
# Reuse _setup_shrink_test for complete setup (device, environment, and process group)
device, pg = self._setup_shrink_test("backend_properties")
# Follow _perform_shrink_test pattern from here
log_test_info(self.rank, f"{test_name} (world_size={self.world_size})")
is_excluded = self.rank in ranks_to_exclude
log_test_info(
self.rank,
f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}",
)
# Store original backend property values (not references) before shrinking
original_timeout = None
original_high_priority = None
if not is_excluded:
original_backend = pg._get_backend(device)
original_timeout = original_backend.options._timeout
original_high_priority = original_backend.options.is_high_priority_stream
log_test_info(
self.rank,
f"Storing original backend properties: timeout={original_timeout}, high_priority={original_high_priority}",
)
if is_excluded:
log_test_info(
self.rank,
f"Excluded rank {self.rank} - setup complete, skipping shrink operation",
)
dist.destroy_process_group() # hang without it
return
# Only non-excluded ranks proceed with shrink (same as _perform_shrink_test)
log_test_info(self.rank, "Non-excluded rank calling shrink_group")
shrunk_pg = c10d.shrink_group(ranks_to_exclude)
# Reuse _validate_shrunk_group helper (same as _perform_shrink_test)
expected_size = self.world_size - len(ranks_to_exclude)
_ = self._validate_shrunk_group(shrunk_pg, expected_size, test_name)
# Add custom backend properties validation
new_backend = shrunk_pg._get_backend(device)
log_test_info(self.rank, "Validating backend properties are preserved")
new_timeout = new_backend.options._timeout
new_high_priority = new_backend.options.is_high_priority_stream
log_test_info(
self.rank,
f"Timeout comparison - original: {original_timeout}, new: {new_timeout}",
)
self.assertEqual(
original_timeout, new_timeout, f"{test_name}: timeout not preserved"
)
log_test_info(
self.rank,
f"High priority stream comparison - original: {original_high_priority}, new: {new_high_priority}",
)
self.assertEqual(
original_high_priority,
new_high_priority,
f"{test_name}: high_priority_stream not preserved",
)
log_test_validation(
self.rank, f"{test_name}: Backend properties preserved successfully"
)
log_test_success(
self.rank, f"{test_name} successful (shrink + backend validation)"
)
# Cleanup (same as _perform_shrink_test)
dist.destroy_process_group()
@requires_nccl_shrink()
@requires_world_size(2)
def test_shrink_group_multiple_comms(self):
"""Test shrink_group with multiple communicators and subgroup invalidation."""
device, pg = self._setup_shrink_test("multiple_comms")
# Create subgroup [0, 1] and test shrinking it
subgroup = c10d.new_group([0, 1])
if self.rank <= 1:
# Shrink subgroup: exclude rank 1
if self.rank == 0: # Only rank 0 remains
shrunk_subgroup = c10d.shrink_group([1], group=subgroup)
self.assertEqual(shrunk_subgroup.size(), 1)
# Test communication on shrunk subgroup
tensor = torch.full((1,), self.rank).cuda(device)
c10d.all_reduce(tensor, group=shrunk_subgroup)
self.assertEqual(tensor.item(), 0) # Only rank 0
log_test_success(self.rank, "Subgroup shrinking successful")
dist.barrier() # Sync before default group test
# Shrink default group: exclude last rank
ranks_to_exclude = [self.world_size - 1]
if self.rank not in ranks_to_exclude:
shrunk_default = c10d.shrink_group(ranks_to_exclude)
expected_size = self.world_size - 1
self.assertEqual(shrunk_default.size(), expected_size)
# Test collective on shrunk default group
tensor = torch.full((1,), self.rank).cuda(device)
c10d.all_reduce(tensor, group=shrunk_default)
expected_sum = sum(
range(self.world_size - 1)
) # 0 + 1 + ... + (world_size-2)
self.assertEqual(tensor.item(), expected_sum)
log_test_success(self.rank, "Default group shrinking successful")
# Note: After shrinking default group, the old subgroup is invalid
# due to global rank reassignment
dist.destroy_process_group()
def _test_shrink_group_with_flag(self, shrink_flag, flag_name, rank_to_exclude):
"""Helper method to test shrink_group with a specific flag."""
if self.world_size < 2:
log_test_info(self.rank, f"Skipping (needs ≥2 GPUs, got {self.world_size})")
return
ranks_to_exclude = [rank_to_exclude]
log_test_info(self.rank, f"Using {flag_name} flag (value: {shrink_flag})")
if flag_name == "NCCL_SHRINK_ABORT":
log_test_info(
self.rank,
"ABORT flag will terminate ongoing operations before shrinking",
)
self._perform_shrink_test(
ranks_to_exclude, f"{flag_name} flag test", shrink_flags=shrink_flag
)
@requires_nccl_shrink()
@requires_world_size(2)
def test_shrink_group_flags(self):
"""Test shrink_group with different shrink flags."""
# Test ABORT flags
log_test_info(self.rank, "Testing NCCL_SHRINK_ABORT flag")
self._test_shrink_group_with_flag(NCCL_SHRINK_ABORT, "NCCL_SHRINK_ABORT", 1)
@requires_nccl_shrink()
@requires_world_size(2)
def test_shrink_group_nccl_config(self):
"""Verify that passing NCCL config via pg_options influences the shrunk group's backend options."""
device, pg = self._setup_shrink_test("config")
if self.rank == self.world_size - 1:
# excluded rank should not call shrink_group
dist.destroy_process_group()
return
# Prepare pg_options with NCCL config overrides
# Capture parent's current backend options to ensure we can prove override vs inherit
parent_backend = pg._get_backend(torch.device("cuda"))
parent_hp = parent_backend.options.is_high_priority_stream
parent_blocking = parent_backend.options.config.blocking
# Choose overrides that differ from the parent (flip where possible)
override_hp = not parent_hp
if parent_blocking in (0, 1):
override_blocking = 1 - parent_blocking
else:
# If undefined or unexpected, set to 1 which is a concrete value
override_blocking = 1
opts = c10d.ProcessGroupNCCL.Options()
opts.is_high_priority_stream = override_hp
opts.config.blocking = override_blocking
shrunk_pg = c10d.shrink_group([self.world_size - 1], pg_options=opts)
# Validate backend options propagated
backend = shrunk_pg._get_backend(torch.device("cuda"))
# is_high_priority_stream should exactly match our override and differ from parent
self.assertEqual(backend.options.is_high_priority_stream, override_hp)
self.assertNotEqual(backend.options.is_high_priority_stream, parent_hp)
# config is a struct; check representative field and difference from parent when meaningful
self.assertEqual(backend.options.config.blocking, override_blocking)
if parent_blocking in (0, 1):
self.assertNotEqual(backend.options.config.blocking, parent_blocking)
dist.destroy_process_group()
@requires_nccl_shrink()
@requires_world_size(2)
def test_shrink_group_performance(self):
"""Test shrink_group performance and regression detection."""
import time
ranks_to_exclude = self._get_default_ranks_to_exclude()
is_excluded = self.rank in ranks_to_exclude
if not ranks_to_exclude:
log_test_info(self.rank, "Skipping performance test (world_size=1)")
return
log_test_info(self.rank, f"Performance test with {self.world_size} processes")
device, pg = self._setup_shrink_test("performance")
if not is_excluded:
log_test_info(self.rank, "Measuring shrink_group performance")
start_time = time.time()
shrunk_pg = c10d.shrink_group(ranks_to_exclude)
end_time = time.time()
elapsed_time = end_time - start_time
log_test_info(self.rank, f"shrink_group: {elapsed_time:.3f}s")
# Regression check: should complete within reasonable time
self.assertLess(
elapsed_time,
30.0,
f"shrink_group took {elapsed_time:.3f}s, possible regression",
)
# Test collective performance
expected_size = self.world_size - len(ranks_to_exclude)
self._validate_shrunk_group(shrunk_pg, expected_size, "performance")
collective_start = time.time()
_ = self._test_collective_on_shrunk_group(
shrunk_pg, device, ranks_to_exclude, "performance"
)
collective_time = time.time() - collective_start
log_test_info(self.rank, f"all_reduce: {collective_time:.3f}s")
log_test_success(self.rank, "Performance test passed")
else:
log_test_info(self.rank, "Excluded rank - waiting")
dist.destroy_process_group()
@requires_nccl_shrink()
@requires_world_size(4)
def test_shrink_group_multiple_exclusions(self):
"""Test shrink_group with multiple ranks excluded at once."""
# Scale exclusions with world size
ranks_to_exclude = list(range(2, self.world_size, 2)) # Every other rank from 2
self._perform_shrink_test(ranks_to_exclude, "Multiple exclusions test")
@requires_nccl_shrink()
@requires_world_size(3)
def test_shrink_group_multiple_iterations(self):
"""Test multiple shrink operations in sequence."""
log_test_info(
self.rank,
f"Starting test_shrink_group_multiple_iterations with world_size={self.world_size}",
)
store = c10d.FileStore(self.file_name, self.world_size)
device = torch.device(f"cuda:{self.rank}")
_ = self._create_process_group_nccl(store, self.opts(), device_id=device)
# Track current effective world size throughout shrinking operations
current_world_size = self.world_size
log_test_info(self.rank, f"Initial world_size: {current_world_size}")
# First shrinking: exclude the last rank(s)
first_exclusion = [self.world_size - 1]
if self.world_size >= 6:
first_exclusion.append(
self.world_size - 2
) # Exclude last two ranks for larger sizes
log_test_info(self.rank, f"First shrinking: excluding ranks {first_exclusion}")
if self.rank not in first_exclusion:
# Only non-excluded ranks should call shrink_group
first_pg = c10d.shrink_group(first_exclusion)
self.assertIsNotNone(first_pg)
# IMPORTANT: Update world size after first shrinking
current_world_size = first_pg.size()
expected_first_size = self.world_size - len(first_exclusion)
log_test_info(
self.rank,
f"After first shrinking: world_size {self.world_size} -> {current_world_size}",
)
self.assertEqual(first_pg.size(), expected_first_size)
# Second shrinking: exclude another rank from the remaining group
# Choose a rank that's in the middle range
if current_world_size >= 3:
second_exclusion = [
current_world_size - 1
] # Exclude the new "last" rank
log_test_info(
self.rank,
f"Second shrinking from group of size {current_world_size}: excluding ranks {second_exclusion}",
)
if self.rank not in second_exclusion:
# Only non-excluded ranks should call shrink_group for second iteration
second_pg = c10d.shrink_group(second_exclusion, group=first_pg)
self.assertIsNotNone(second_pg)
# IMPORTANT: Update world size after second shrinking
final_world_size = second_pg.size()
expected_final_size = current_world_size - len(second_exclusion)
log_test_info(
self.rank,
f"After second shrinking: world_size {current_world_size} -> {final_world_size}",
)
self.assertEqual(second_pg.size(), expected_final_size)
# Test collective on final group
tensor = torch.full((1,), self.rank).cuda(device)
log_test_info(
self.rank,
f"Performing all_reduce on final group (size {final_world_size}) with tensor: {tensor.item()}",
)
c10d.all_reduce(tensor, group=second_pg)
log_test_info(
self.rank,
f"Final all_reduce completed, result: {tensor.item()}",
)
# Calculate expected sum of remaining ranks
all_excluded = set(first_exclusion + second_exclusion)
remaining_ranks = [
r for r in range(self.world_size) if r not in all_excluded
]
expected_sum = sum(remaining_ranks)
log_test_info(
self.rank,
f"Remaining ranks: {remaining_ranks}, expected sum: {expected_sum}, actual: {tensor.item()}",
)
self.assertEqual(tensor.item(), expected_sum)
log_test_info(self.rank, "Final verification passed")
else:
log_test_info(
self.rank,
"This rank excluded in second shrinking, not calling shrink_group",
)
else:
log_test_info(
self.rank, "Skipping second shrinking (remaining group too small)"
)
else:
log_test_info(
self.rank,
"This rank excluded in first shrinking, not calling shrink_group",
)
log_test_info(self.rank, "Destroying process group")
dist.destroy_process_group()
log_test_info(self.rank, "test_shrink_group_multiple_iterations completed")
# Helper methods for optimized shrink group tests
def _setup_shrink_test(self, test_suffix, world_size=None, warmup=True):
"""Common setup for shrink group tests."""
os.environ["TORCH_NCCL_USE_COMM_NONBLOCKING"] = "1"
world_size = world_size or self.world_size
store = c10d.FileStore(self.file_name + f"_{test_suffix}", world_size)
device = torch.device(f"cuda:{self.rank}")
c10d.init_process_group(
"nccl",
world_size=world_size,
rank=self.rank,
store=store,
pg_options=self.opts(),
device_id=device,
)
pg = c10d.distributed_c10d._get_default_group()
if warmup:
c10d.all_reduce(torch.ones(1).cuda(device), group=pg)
return device, pg
def _validate_shrunk_group(self, shrunk_pg, expected_size, test_name=""):
"""Validate properties of a shrunk process group."""
self.assertIsNotNone(shrunk_pg, f"{test_name}: shrunk_pg should not be None")
actual_size = shrunk_pg.size()
self.assertEqual(
actual_size, expected_size, f"{test_name}: group size mismatch"
)
new_rank = shrunk_pg.rank()
self.assertTrue(
0 <= new_rank < expected_size, f"{test_name}: invalid new rank {new_rank}"
)
log_test_info(
self.rank,
f"{test_name}: world_size {self.world_size} -> {actual_size}, rank {self.rank} -> {new_rank}",
)
return new_rank
def _test_collective_on_shrunk_group(
self, shrunk_pg, device, ranks_to_exclude, test_name=""
):
"""Test collective communication on shrunk group and verify correctness."""
test_tensor = torch.full((1,), self.rank, device=device, dtype=torch.float32)
c10d.all_reduce(test_tensor, group=shrunk_pg)
result = test_tensor.item()
expected_sum = sum(
r for r in range(self.world_size) if r not in ranks_to_exclude
)
self.assertEqual(
result, expected_sum, f"{test_name}: collective result mismatch"
)
log_test_info(
self.rank, f"{test_name}: collective passed ({result} == {expected_sum})"
)
return result
def _perform_shrink_test(
self, ranks_to_exclude, test_name, shrink_flags=0, with_collective=True
):
"""Complete shrink test flow: setup, shrink, validate, test collective, cleanup.
Consistent API: All ranks perform setup to initialize distributed environment.
ONLY non-excluded ranks call shrink_group() for both default and non-default groups.
Excluded ranks perform setup, then exit without calling shrink_group() or waiting.
"""
log_test_info(self.rank, f"{test_name} (world_size={self.world_size})")
is_excluded = self.rank in ranks_to_exclude
log_test_info(
self.rank,
f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}",
)
# All ranks (including excluded ones) perform setup to initialize distributed environment
device, pg = self._setup_shrink_test(test_name.lower().replace(" ", "_"))
is_default_group = pg == c10d.distributed_c10d._get_default_group()
if is_excluded:
log_test_info(
self.rank,
f"Excluded rank {self.rank} - setup complete, skipping shrink operation",
)
if shrink_flags & NCCL_SHRINK_ABORT:
log_test_info(self.rank, f"Using abort for excluded rank {self.rank}")
pg._get_backend(torch.device(device)).abort()
log_test_info(
self.rank, f"cleanup resources for excluded rank {self.rank}"
)
dist.destroy_process_group()
log_test_info(self.rank, f"Excluded rank {self.rank} - exit")
else:
log_test_info(
self.rank, f"Using regular destroy for excluded rank {self.rank}"
)
dist.destroy_process_group()
return None
# Only non-excluded ranks proceed with shrink
log_test_info(
self.rank,
f"Non-excluded rank calling shrink_group (default_group={is_default_group})",
)
shrunk_pg = c10d.shrink_group(ranks_to_exclude, shrink_flags=shrink_flags)
log_test_info(
self.rank,
f"Non-excluded rank calling shrink_group (default_group={is_default_group}) done",
)
# Non-excluded ranks: validate and test the new group
expected_size = self.world_size - len(ranks_to_exclude)
_ = self._validate_shrunk_group(shrunk_pg, expected_size, test_name)
if with_collective:
_ = self._test_collective_on_shrunk_group(
shrunk_pg, device, ranks_to_exclude, test_name
)
log_test_success(self.rank, f"{test_name} successful (shrink + collective)")
else:
log_test_success(self.rank, f"{test_name} successful (shrink only)")
dist.destroy_process_group()
return shrunk_pg
def _get_default_ranks_to_exclude(self):
"""Get default ranks to exclude based on world size."""
if self.world_size <= 1:
return []
return [self.world_size - 1] # Exclude last rank by default
@requires_nccl_shrink()
@requires_world_size(3)
def test_shrink_group_vs_abort_reinit_performance(self):
"""Compare performance of shrink_group vs traditional abort+reinit (simplified for reliability)."""
log_test_info(self.rank, "=== TEST 1: abort+reinit ===")
device, pg1 = self._setup_shrink_test("_perf_reinit")
torch.cuda.synchronize(device)
# Test 1: Traditional abort + reinit
start_time = time.perf_counter()
dist.destroy_process_group()
device, new_pg = self._setup_shrink_test("perf_shrink_test1")
reinit_time = time.perf_counter() - start_time
# Test collective with original rank values for fair comparison (non-blocking mode)
test_tensor = torch.full((1,), self.rank, device=device, dtype=torch.float32)
work = c10d.all_reduce(test_tensor, group=new_pg, async_op=True)
work.wait()
torch.cuda.synchronize(device)
# Verify correctness
expected_sum = sum(r for r in range(self.world_size))
self.assertEqual(test_tensor.item(), expected_sum, "Reinit collective failed")
log_test_info(self.rank, f"abort+reinit: {reinit_time:.4f}s")
dist.destroy_process_group(new_pg)
# Test 2: shrink_group with NCCL_SHRINK_ABORT
log_test_info(self.rank, "=== TEST 2: shrink_group ===")
ranks_to_exclude = [self.world_size - 1]
is_excluded = self.rank in ranks_to_exclude
log_test_info(
self.rank,
f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}",
)
device, pg1 = self._setup_shrink_test("perf_shrink_test2") # Unique suffix
shrink_time = 0
if not is_excluded:
torch.cuda.synchronize(device) # Ensure accurate timing
start_time = time.perf_counter()
shrunk_pg = c10d.shrink_group(
ranks_to_exclude, shrink_flags=NCCL_SHRINK_ABORT
)
c10d.all_reduce(torch.ones(1).cuda(device), group=shrunk_pg)
shrink_time = time.perf_counter() - start_time
# Test collective communication on shrunk group (non-blocking mode)
test_tensor = torch.full(
(1,), self.rank, device=device, dtype=torch.float32
)
work = c10d.all_reduce(test_tensor, group=shrunk_pg, async_op=True)
work.wait()
# Verify correctness
expected_sum = sum(
r for r in range(self.world_size) if r not in ranks_to_exclude
)
self.assertEqual(
test_tensor.item(),
expected_sum,
"shrink_test: collective result mismatch",
)
torch.cuda.synchronize(device) # Ensure operations complete
log_test_info(self.rank, f"shrink_group: {shrink_time:.4f}s")
dist.destroy_process_group()
else:
log_test_info(self.rank, "Excluded from shrink test - exiting immediately")
dist.destroy_process_group()
return
# Performance analysis (only for participating ranks)
if shrink_time > 0 and reinit_time > 0:
speedup = reinit_time / shrink_time
time_saved = reinit_time - shrink_time
log_test_info(self.rank, "=== PERFORMANCE RESULTS ===")
log_test_info(self.rank, f"shrink_group: {shrink_time:.4f}s")
log_test_info(self.rank, f"abort+reinit: {reinit_time:.4f}s")
log_test_info(self.rank, f"time_saved: {time_saved:+.4f}s")
log_test_info(self.rank, f"speedup: {speedup:.2f}x")
if speedup > 1.1:
log_test_success(self.rank, "shrink_group significantly faster")
elif speedup > 0.9:
log_test_info(self.rank, "≈ comparable performance")
else:
log_test_warning(self.rank, "abort+reinit faster")
log_test_info(self.rank, "Performance test completed")
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_deterministic_mode_no_break(self):

View File

@ -153,7 +153,9 @@ def _get_custom_policy(no_recompute_list=None, must_recompute_list=None):
return _custom_policy
class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase):
class ActivationCheckpointingViaTagsTests(
torch._dynamo.test_case.TestCaseWithNestedGraphBreaks
):
def _validate(
self,
fn,

View File

@ -77,7 +77,7 @@ def customized_ctx_manager_with_graph_break(mode):
torch._C._set_grad_enabled(prev)
class CtxManagerTests(torch._dynamo.test_case.TestCase):
class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
def test_no_grad(self):
def fn1(a, b):
x = a + 1
@ -1706,7 +1706,7 @@ class GraphModule(torch.nn.Module):
cnts = torch._dynamo.testing.CompileCounter()
opt_f = torch.compile(f, backend=cnts)
opt_f(torch.randn(2, 2, 2, 2).to(dtype=torch.float16))
self.assertEqual(cnts.frame_count, 3)
self.assertEqual(cnts.frame_count, 2)
# test sdpa_kernel graph break with 2 arguments
def test_sdpa_kernel_ctx_manager3(self):
@ -1836,14 +1836,18 @@ class GraphModule(torch.nn.Module):
self.assertEqual(len(counters["graph_break"]), 1)
class ContextlibContextManagerTests(torch._dynamo.test_case.TestCase):
class ContextlibContextManagerTests(
torch._dynamo.test_case.TestCaseWithNestedGraphBreaks
):
def setUp(self):
super().setUp()
self._prev = torch._dynamo.config.enable_trace_contextlib
self._u_prev = torch._dynamo.config.enable_trace_unittest
torch._dynamo.config.enable_trace_contextlib = True
torch._dynamo.config.enable_trace_unittest = True
def tearDown(self):
super().tearDown()
torch._dynamo.config.enable_trace_contextlib = self._prev
torch._dynamo.config.enable_trace_unittest = self._u_prev

View File

@ -17,7 +17,7 @@ from torch.testing._internal.common_utils import (
)
class GeneratorTestsBase(torch._dynamo.test_case.TestCase):
class GeneratorTestsBase(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
def setUp(self):
super().setUp()
self._old = torch._dynamo.config.enable_faithful_generator_behavior

View File

@ -1,86 +1,29 @@
# Owner(s): ["module: dynamo"]
import sys
import unittest
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo import config
from torch._dynamo.testing import make_test_cls_with_patches
try:
# from . import test_ctx_manager
pass
except ImportError:
# import test_aot_autograd
# import test_ctx_manager
# import test_export
# import test_functions
# import test_higher_order_ops
# import test_misc
# import test_modules
# import test_repros
# import test_sdpa
# import test_subgraphs
pass
test_classes = {}
def make_nested_cls(cls):
suffix = "_nested_graph_breaks"
cls_prefix = "NestedGraphBreaks"
test_class = make_test_cls_with_patches(
cls,
cls_prefix,
suffix,
(config, "debug_force_nested_calls", True),
(config, "debug_force_graph_break_on_leaf_return", True),
(config, "debug_disable_compile_counter", True),
xfail_prop="_expected_failure_nested_graph_breaks",
)
test_classes[test_class.__name__] = test_class
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
# globals()[test_class.__name__] = test_class
test_class.__module__ = __name__
return test_class
tests = [
# test_ctx_manager.CtxManagerTests,
# test_functions.FunctionTests,
# test_misc.MiscTests,
# test_repros.ReproTests,
# test_modules.NNModuleTests,
# test_subgraphs.SubGraphTests,
# test_higher_order_ops.HigherOrderOpTests,
# test_higher_order_ops.FuncTorchHigherOrderOpTests,
# test_aot_autograd.AotAutogradFallbackTests,
# test_sdpa.TestSDPA,
]
test = None
for test in tests:
make_nested_cls(test)
del test
# for use in test_side_effects_globals
global1, global2, global3, global4 = (torch.zeros(3),) * 4
class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
def setUp(self):
super().setUp()
torch._dynamo.config.nested_graph_breaks = True
class CustomizedCtxManager:
def __init__(self, x):
self.x = x
torch._dynamo.graph_break()
def tearDown(self):
super().tearDown()
torch._dynamo.config.nested_graph_breaks = False
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, traceback):
pass
class NestedGraphBreakTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
def test_single_graph_break(self):
# NOTE marking f1, f2, f3 as global
# prevents them from being freevars
@ -211,6 +154,31 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 14)
def test_counters(self):
global f1, f2, f3, f4
def f1(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
def f2(x):
return f1(x + 4) + 8
def f3(x):
x = x + 16
for _ in range(1):
x = f2(x)
return x + 32
@torch.compile(backend="eager")
def f4(x):
return f3(x + 64) + 128
self.assertEqual(f4(torch.zeros(3)), torch.zeros(3) + 255)
self.assertEqual(len(torch._dynamo.utils.counters["graph_break"]), 2)
breakpoint()
def test_supported_ctx_manager(self):
global check, check_disabled, f1, f2, f3
@ -612,6 +580,211 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
self.assertEqual(cnts.op_count, 4)
self.assertEqual(torch._dynamo.utils.counters["frames"]["total"], 3)
def test_generator_nested_graph_break(self):
def gen(x):
yield x + 1
torch._dynamo.graph_break()
yield x + 2
def fn(x):
x = x + 4
return list(gen(x))
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(fn)
x = torch.zeros(3)
res = fn(x)
# NOTE: if we enable nested graph breaks on inlined generators, we expect
# some sort of internal dynamo failure
ref = opt_fn(x)
self.assertEqual(ref, res)
# fn should be skipped
self.assertEqual(cnts.frame_count, 0)
def outer(x):
x = x + 8
return fn(x)[0] + 16
cnts.clear()
torch.compiler.reset()
opt_fn = torch._dynamo.optimize(backend=cnts)(outer)
x = torch.zeros(3)
res = outer(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
# only outer should be traced
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 2)
def test_return_after_graph_break_nested(self):
# With improper implementation, returning immediately after a nested graph
# break may skip the rest of the top-level frame.
def f2(inner, x):
x += 2
return inner(x)
@torch.compile(backend="eager")
def f3(inner, x):
result = f2(inner, x)
x += 4
if result is not None:
x += result
return x
# test normal graph break
x = torch.zeros(3)
def inner1(x):
x += 1
return torch._dynamo.graph_break()
ref = f3(inner1, x)
self.assertEqual(ref, torch.zeros(3) + 7)
# test step graph break
x = torch.zeros(3)
def inner2(x):
x += 1
return torch._dynamo.step_unsupported()
ref = f3(inner2, x)
self.assertEqual(ref, torch.zeros(3) + 7)
# test store attr graph break
# NOTE: we do this manual bytecode generation hack since the only RETURN_*
# instruction that can follow STORE_ATTR is RETURN_CONST, which was removed in 3.14+.
# make sure inner3's code options are compatible with the instructions below
def inner3(x):
x.attr = 1000
new_inst = torch._dynamo.bytecode_transformation.create_instruction
insts = [
new_inst("LOAD_CONST", argval=1000),
new_inst("LOAD_CONST", argval=2000),
new_inst("LOAD_FAST", argval="x"),
new_inst("STORE_ATTR", argval="attr"),
new_inst("RETURN_VALUE"),
]
if sys.version_info >= (3, 11):
insts = [new_inst("RESUME", arg=0)] + insts
code_keys = torch._dynamo.bytecode_transformation.get_code_keys()
code_options = {k: getattr(inner3.__code__, k) for k in code_keys}
_, inner3_code = (
torch._dynamo.bytecode_transformation.clean_and_assemble_instructions(
insts, code_keys, code_options
)
)
inner3.__code__ = inner3_code
x = torch.zeros(3)
ref = f3(inner3, x)
self.assertEqual(ref, torch.zeros(3) + 1006)
# dynamic branching is harder to test - the other tests should be enough cover
# test every function returning
@torch.compiler.disable
def inner5(x):
x += 8
return x
def inner4(x):
x += 1
return inner5(x)
@torch.compile(backend="eager")
def f4(x):
x += 4
return f2(inner4, x)
x = torch.zeros(3)
ref = f4(x)
self.assertEqual(ref, torch.zeros(3) + 15)
def test_return_after_graph_break_deep_nested(self):
@torch.compiler.disable
def f1(x):
return x + 1
def f2(x):
return f1(x + 2)
def f3(x):
return f2(x + 4)
def f4(x):
x = f3(x + 8)
return x + 16
def f5(x):
return f4(x + 32)
def f6(x):
return f5(x + 64)
def f7(x):
x = f6(x + 128)
return x + 256
@torch.compile(backend="eager")
def f8(x):
return f7(x + 512)
x = torch.zeros(3)
ref = f8(x)
self.assertEqual(ref, torch.zeros(3) + 1023)
# check that only 2 resume functions are created
self.assertEqual(len(torch._dynamo.utils.counters["resumes"]), 2)
for name in ("resume_in_f4", "resume_in_f7"):
self.assertTrue(
any(
name in key
for key in torch._dynamo.utils.counters["resumes"].keys()
)
)
@unittest.expectedFailure
def test_nested_decorated_function(self):
# decorator must call ContextWrappingVariable.cleanup_assert to trigger this test
def f(x):
@torch.autocast("cpu")
def inner(y):
y = y + 1
torch._dynamo.graph_break()
return y + 1
return inner(x)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f)
x = torch.zeros(3)
res = f(x)
ref = opt_fn(x)
print(ref, res)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 6)
@unittest.expectedFailure
def test_nested_graph_break_in_custom_ctx_manager_init(self):
def f(x):
with CustomizedCtxManager(x):
return x + 1
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f)
x = torch.zeros(3)
res = f(x)
ref = opt_fn(x)
print(ref, res)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 2)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -0,0 +1,136 @@
# Owner(s): ["module: dynamo"]
import unittest
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.testing import make_test_cls_with_patches
try:
from . import test_activation_checkpointing, test_ctx_manager, test_misc
except ImportError:
import test_activation_checkpointing
import test_ctx_manager
import test_misc
test_classes = {}
def make_nested_cls(cls, strong):
config = torch._dynamo.config
if strong:
# A strong nested graph break test - will graph break at every leaf function's return
test_class = make_test_cls_with_patches(
cls,
"NestedGraphBreaksStrong",
"_nested_graph_breaks_strong",
(config, "nested_graph_breaks", True),
(config, "debug_force_nested_calls", True),
(config, "debug_force_graph_break_on_leaf_return", True),
(config, "debug_disable_compile_counter", True),
xfail_prop="_expected_failure_nested_graph_breaks_strong",
)
else:
test_class = make_test_cls_with_patches(
cls,
"NestedGraphBreaks",
"_nested_graph_breaks",
(config, "nested_graph_breaks", True),
(config, "debug_force_nested_calls", True),
(config, "debug_disable_compile_counter", True),
xfail_prop="_expected_failure_nested_graph_breaks",
)
test_classes[test_class.__name__] = test_class
# REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
globals()[test_class.__name__] = test_class
test_class.__module__ = __name__
tests = [
getattr(
test_activation_checkpointing, "ActivationCheckpointingViaTagsTestsCUDA", None
),
test_ctx_manager.CtxManagerTests,
test_misc.MiscTests,
]
strong_tests = []
test = None
for test in tests:
if not test:
continue
make_nested_cls(test, False)
for test in strong_tests:
make_nested_cls(test, True)
del test
xfails = [
# multiple exit due to nested graph break in decorator
# NestedGraphBreaksStrongCtxManagerTests.test_disable_saved_tensors_hooks_nested_graph_breaks_strong, # noqa: F821
# NestedGraphBreaksStrongCtxManagerTests.test_disable_saved_tensors_hooks_prev_disabled_nested_graph_breaks_strong, # noqa: F821
# NestedGraphBreaksStrongCtxManagerTests.test_disable_saved_tensors_hooks_prev_disabled_nested_nested_graph_breaks_strong, # noqa: F821
# graph break in context manager __init__
# NestedGraphBreaksStrongCtxManagerTests.test_generic_context_manager_CustomizedCtxManager_nested_graph_breaks_strong, # noqa: F821
# NestedGraphBreaksStrongCtxManagerTests.test_generic_context_manager_customized_ctx_manager_nested_graph_breaks_strong, # noqa: F821
# NestedGraphBreaksStrongCtxManagerTests.test_generic_context_manager_with_graph_break_CustomizedCtxManager_nested_graph_breaks_strong, # noqa: F821
# NestedGraphBreaksStrongCtxManagerTests.test_generic_context_manager_with_graph_break_customized_ctx_manager_nested_graph_breaks_strong, # noqa: F821
# NestedGraphBreaksStrongCtxManagerTests.test_generic_ctx_manager_with_graph_break_CustomizedCtxManagerWithGraphBreak_nested_graph_breaks_strong, # noqa: F821
# NestedGraphBreaksStrongCtxManagerTests.test_generic_ctx_manager_with_graph_break_customized_ctx_manager_with_graph_break_nested_graph_breaks_strong, # noqa: F821
# NestedGraphBreaksStrongCtxManagerTests.test_nested_generic_context_manager_CustomizedCtxManager_nested_graph_breaks_strong, # noqa: F821
# NestedGraphBreaksStrongCtxManagerTests.test_nested_generic_context_manager_customized_ctx_manager_nested_graph_breaks_strong, # noqa: F821
# NestedGraphBreaksStrongCtxManagerTests.test_nested_generic_context_manager_with_graph_break_CustomizedCtxManager_nested_graph_breaks_strong, # noqa: F821
# NestedGraphBreaksStrongCtxManagerTests.test_nested_generic_context_manager_with_graph_break_customized_ctx_manager_nested_graph_breaks_strong, # noqa: F821
# NestedGraphBreaksStrongCtxManagerTests.test_return_context_manager_nested_graph_breaks_strong, # noqa: F821
# NestedGraphBreaksStrongCtxManagerTests.test_return_context_manager_with_graph_break_nested_graph_breaks_strong, # noqa: F821
# recursion limit exceeded
# NestedGraphBreaksStrongCtxManagerTests.test_cuda_stream_compared_with_constant_nested_graph_breaks_strong, # noqa: F821
# variable naming issues
NestedGraphBreaksMiscTests.test_flat_name_to_original_fqn_nested_graph_breaks, # noqa: F821
NestedGraphBreaksMiscTests.test_compare_shapes_with_constant_nested_graph_breaks, # noqa: F821
NestedGraphBreaksMiscTests.test_guard_failure_fn2_nested_graph_breaks, # noqa: F821
NestedGraphBreaksMiscTests.test_guard_failure_fn_shape_control_nested_graph_breaks, # noqa: F821
NestedGraphBreaksMiscTests.test_guard_failure_fn_tensor_iter_nested_graph_breaks, # noqa: F821
NestedGraphBreaksMiscTests.test_guard_filter_fn_by_name_and_value_nested_graph_breaks, # noqa: F821
NestedGraphBreaksMiscTests.test_guard_sym_node_fstring_when_used_nested_graph_breaks, # noqa: F821
NestedGraphBreaksMiscTests.test_symint_as_device_kwarg_multi_gpu_nested_graph_breaks, # noqa: F821
NestedGraphBreaksMiscTests.test_sys_modules_nested_graph_breaks, # noqa: F821
# counters["graph_breaks"] issues
NestedGraphBreaksMiscTests.test_data_ptr_graph_break_aten_nested_graph_breaks, # noqa: F821
# nested graph break removes duplicate graph break
NestedGraphBreaksMiscTests.test_duplicate_graph_break_log_nested_graph_breaks, # noqa: F821
# doesn't work due to debug_force_nested_calls wrapping the top frame
NestedGraphBreaksMiscTests.test_dynamo_cache_invalidate_nested_graph_breaks, # noqa: F821
NestedGraphBreaksMiscTests.test_dynamo_cache_move_to_front_nested_graph_breaks, # noqa: F821
NestedGraphBreaksMiscTests.test_dynamo_reset_clears_cache_nested_graph_breaks, # noqa: F821
NestedGraphBreaksMiscTests.test_fail_on_recompile_error_message_nested_graph_breaks, # noqa: F821
NestedGraphBreaksMiscTests.test_get_cache_entry_nested_graph_breaks, # noqa: F821
NestedGraphBreaksMiscTests.test_getattrvariable_as_python_constant_nested_graph_breaks, # noqa: F821
NestedGraphBreaksMiscTests.test_precompile_entries_nested_graph_breaks, # noqa: F821
NestedGraphBreaksMiscTests.test_precompile_entry_hit_nested_graph_breaks, # noqa: F821
NestedGraphBreaksMiscTests.test_precompile_fail_on_recompile_nested_graph_breaks, # noqa: F821
NestedGraphBreaksMiscTests.test_torch_guards_stack_frame_register_inlining_deep_nested_graph_breaks, # noqa: F821
NestedGraphBreaksMiscTests.test_torch_guards_stack_frame_register_inlining_nested_graph_breaks, # noqa: F821
# differing op_count
NestedGraphBreaksMiscTests.test_nested_closure_nested_graph_breaks, # noqa: F821
NestedGraphBreaksMiscTests.test_return_nested_function_nested_graph_breaks, # noqa: F821
# unknown
NestedGraphBreaksMiscTests.test_inspect_signature_bind_non_user_function_nested_graph_breaks, # noqa: F821
]
case = None
for case in xfails:
unittest.expectedFailure(case)
del case, xfails
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -1,4 +1,5 @@
# Owner(s): ["module: dynamo"]
import functools
from unittest.mock import patch
import torch
@ -548,6 +549,30 @@ class RecompileTests(torch._dynamo.test_case.TestCase):
f(x, foo1)
self.assertEqual(counter.frame_count, 2)
@dc.patch(recompile_limit=1, fail_on_recompile_limit_hit=True)
def test_wrap_inline_recompiles(self):
inp = torch.ones(3)
wrap_fns = (
torch._dynamo.external_utils.wrap_inline,
functools.partial(
torch._dynamo.external_utils.wrap_inline_with_error_on_graph_break,
error_on_graph_break=True,
),
functools.partial(
torch._dynamo.external_utils.wrap_inline_with_error_on_graph_break,
error_on_graph_break=False,
),
)
for fn in wrap_fns:
for i in range(2):
opt_fn = torch.compile(
fn(lambda x: x + i),
backend="eager",
)
self.assertEqual(inp + i, opt_fn(inp))
def test_no_recompile_over_unused_objects(self):
# This is a regression test case that imitates
# https://github.com/city96/ComfyUI-GGUF/blob/47bec6147569a138dd30ad3e14f190a36a3be456/ops.py#L169-L182

View File

@ -6,6 +6,7 @@ import builtins
import collections
import contextlib
import copy
import gc
import functools
import inspect
import io
@ -19,6 +20,7 @@ import traceback
import types
import typing
import unittest
import weakref
import warnings
from math import sqrt
from torch.multiprocessing import Process
@ -1624,6 +1626,25 @@ class TestFX(JitTestCase):
self.assertTrue(neg not in relu.users)
@skipIfTorchDynamo("Dynamo does not free right away")
def test_prepend_does_not_leak(self):
g = Graph()
x = g.placeholder("x")
relu = g.call_function(torch.relu, (x,))
neg = g.call_function(torch.neg, (x,))
relu.prepend(neg)
ref = weakref.ref(neg)
g.erase_node(neg)
del g
del x
del relu
del neg
gc.collect()
self.assertIsNone(ref())
def test_remove_uses_with_custom_filter(self):
g: torch.fx.Graph = Graph()
x: torch.fx.Node = g.placeholder("x")

View File

@ -2758,6 +2758,12 @@ class _NodeBase:
return_type: Any,
) -> None: ...
def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ...
def _prepend(self, n: FxNode) -> None: ...
def _remove_from_list(self) -> None: ...
def __lt__(self, n: Self) -> _bool: ...
def __gt__(self, n: Self) -> _bool: ...
def __le__(self, n: Self) -> _bool: ...
def __ge__(self, n: Self) -> _bool: ...
class _NodeIter(Iterator[FxNode]):
def __init__(self, root: FxNode, reversed: _bool) -> None: ...

View File

@ -154,6 +154,7 @@ def reset() -> None:
TensorifyState.clear()
torch._dynamo.utils.warn_once_cache.clear()
torch._dynamo.utils.user_obj_id_to_weakref.clear()
torch._dynamo.symbolic_convert._debug_force_graph_break_on_leaf_return_disable_codes.clear()
torch._C._autograd._saved_tensors_hooks_set_tracing(False)

View File

@ -500,7 +500,8 @@ issue_3_13_0_warning = True
# traced FX graph is empty when RETURN_* is traced.
allow_empty_graphs = False
# Used for testing - forces all top-level functions to be nested when traced with Dynamo
# Used for testing - forces all top-level functions to be nested when traced with Dynamo.
# There are slight differences between this config and wrap_top_frame.
debug_force_nested_calls = False
# Used for testing - forces a graph break when a function

View File

@ -837,8 +837,12 @@ class _TorchDynamoContext:
filename = inspect.getsourcefile(fn)
except TypeError:
filename = None
if config.debug_force_nested_calls:
if config.debug_force_nested_calls and filename not in DONT_WRAP_FILES:
fn = external_utils.wrap_inline(fn)
# Create a new code object for `fn` so that functions have different
# recompilation caches.
# Copy hack since deepcopy doesn't actually give a new code object
fn.__code__ = fn.__code__.replace(co_varnames=fn.__code__.co_varnames) # type: ignore[attr-defined]
elif config.wrap_top_frame or (
(filename is None or trace_rules.check(fn))
and (

View File

@ -21,7 +21,9 @@ Key functionality groups:
"""
import functools
import types
import warnings
import weakref
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import deprecated, ParamSpec
@ -58,15 +60,32 @@ else:
return torch.compiler.is_compiling()
def deepcopy_code(code: types.CodeType) -> types.CodeType:
# copy hack since deepcopy doesn't actually give a new code object
return code.replace(co_varnames=code.co_varnames) # type: ignore[attr-defined]
_wrap_inline_cache: weakref.WeakKeyDictionary[Callable[..., Any], types.CodeType] = (
weakref.WeakKeyDictionary()
)
def wrap_inline(fn: Callable[_P, _R]) -> Callable[_P, _R]:
"""
Create an extra frame around fn that is not in skipfiles.
This extra frame has its own per-fn code object so that we don't recompile
due to multiple calls to wrap_inline with different fn's.
"""
@functools.wraps(fn)
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
return fn(*args, **kwargs)
if fn not in _wrap_inline_cache:
_wrap_inline_cache[fn] = deepcopy_code(inner.__code__) # type: ignore[attr-defined]
inner.__code__ = _wrap_inline_cache[fn] # type: ignore[attr-defined]
return inner
@ -231,6 +250,11 @@ def call_accumulate_grad(
variable.grad = updated_grad[0]
_wrap_inline_with_error_on_graph_break_cache: weakref.WeakKeyDictionary[
Callable[..., Any], dict[bool, types.CodeType]
] = weakref.WeakKeyDictionary()
def wrap_inline_with_error_on_graph_break(
fn: Callable[_P, _R], error_on_graph_break: bool
) -> Callable[_P, _R]:
@ -249,6 +273,14 @@ def wrap_inline_with_error_on_graph_break(
with torch._dynamo.error_on_graph_break(False):
return fn(*args, **kwargs)
if fn not in _wrap_inline_with_error_on_graph_break_cache:
_wrap_inline_with_error_on_graph_break_cache[fn] = {}
cache = _wrap_inline_with_error_on_graph_break_cache[fn]
if error_on_graph_break not in cache:
cache[error_on_graph_break] = deepcopy_code(wrapper.__code__) # type: ignore[attr-defined]
wrapper.__code__ = cache[error_on_graph_break] # type: ignore[attr-defined]
return wrapper

View File

@ -47,7 +47,6 @@ from collections import deque
from traceback import StackSummary
from typing import Any, Callable, cast, NoReturn, Optional, TYPE_CHECKING, Union
from typing_extensions import TypeAlias, TypeIs
from unittest.mock import patch
import torch
import torch._logging
@ -80,7 +79,6 @@ from .bytecode_transformation import (
create_dup_top,
create_instruction,
create_jump_absolute,
create_load_const,
create_rot_n,
create_swap,
get_code_keys,
@ -677,18 +675,14 @@ def generic_jump(
)
self.pop()
if_next = self.codegen_fix_leaf_stack(
all_stack_locals_metadata[0], self.next_instruction
) + self.create_call_resume_at(
if_next = self.create_call_resume_at(
self.next_instruction,
all_stack_locals_metadata,
)
if push:
self.push(value)
assert inst.target is not None
if_jump = self.codegen_fix_leaf_stack(
all_stack_locals_metadata[0], inst.target
) + self.create_call_resume_at(
if_jump = self.create_call_resume_at(
inst.target,
all_stack_locals_metadata,
)
@ -1024,10 +1018,7 @@ def break_graph_if_unsupported(
for _ in range(push):
self.push(UnknownVariable())
self.output.add_output_instructions(
self.codegen_fix_leaf_stack(
all_stack_locals_metadata[0], self.next_instruction
)
+ self.create_call_resume_at(
self.create_call_resume_at(
self.next_instruction,
all_stack_locals_metadata,
)
@ -1163,6 +1154,11 @@ class ExceptionStack:
__repr__ = __str__
_debug_force_graph_break_on_leaf_return_disable_codes: weakref.WeakSet[
types.CodeType
] = weakref.WeakSet()
class InstructionTranslatorBase(
metaclass=BytecodeDispatchTableMeta,
):
@ -1526,27 +1522,69 @@ class InstructionTranslatorBase(
# frame 1 stack + locals,
# ], leaf_resume result
# pop frame N cells and locals
cg.extend_output(
[
*create_copy(2),
cg.create_load_const(0),
create_instruction("DELETE_SUBSCR"),
*create_copy(3),
cg.create_load_const(0),
create_instruction("DELETE_SUBSCR"),
]
)
# add the leaf_resume result to frame N-1 stack
num_stack = all_stack_locals_metadata[1].num_stack
cg.extend_output(
[
create_instruction("BUILD_LIST", arg=1),
*create_copy(2),
cg.create_load_const(1),
cg.create_load_const(0),
cg.create_binary_subscr(),
*create_binary_slice(num_stack, num_stack, True),
]
)
self.parent.push(UnknownVariable())
all_stack_locals_metadata[1].num_stack += 1
# pop frame N cells and locals
# current frame state
# cells, frame_values
# extract frame N-1 stack to stack
cg.extend_output(
[
*create_copy(1),
create_dup_top(),
cg.create_load_const(0),
create_instruction("DELETE_SUBSCR"),
cg.create_binary_subscr(),
*create_binary_slice(0, num_stack + 1),
]
)
# current frame state
# cells, frame_values, frame N-1 stack + leaf_resume result
# remove frame N-1 stack from frame_values
cg.extend_output(
# frame_values[0] = frame_values[0][num_stack + 1:]
[
*create_copy(2),
cg.create_load_const(0),
create_instruction("DELETE_SUBSCR"),
cg.create_binary_subscr(),
create_dup_top(),
*create_binary_slice(num_stack + 1, None),
*create_swap(2),
cg.create_load_const(0),
create_instruction("STORE_SUBSCR"),
]
)
# current frame state
# cells, frame_values, frame N-1 stack + leaf_resume result
# unpack the stack (need to unpack twice since UNPACK_SEQUENCE unpacks in reverse order)
cg.extend_output(
[
create_instruction("UNPACK_SEQUENCE", arg=num_stack + 1),
create_instruction("BUILD_LIST", arg=num_stack + 1),
create_instruction("UNPACK_SEQUENCE", arg=num_stack + 1),
]
)
@ -1554,12 +1592,11 @@ class InstructionTranslatorBase(
# current frame state
# [frame N-1 cells, ..., frame 1 cells],
# [
# frame N-1 stack (including leaf_resume result) + locals,
# frame N-1 locals,
# frame N-2 stack + locals,
# ...,
# frame 1 stack + locals,
# ],
self.parent.push(UnknownVariable())
all_stack_locals_metadata[1].num_stack += 1
# ], *(frame N-1 stack), leaf_resume result
self.output.add_output_instructions(
cg.get_instructions()
+ self.parent.create_call_resume_at(
@ -2641,10 +2678,7 @@ class InstructionTranslatorBase(
self.output.add_output_instructions([copy.copy(inst)])
self.popn(2)
self.output.add_output_instructions(
self.codegen_fix_leaf_stack(
all_stack_locals_metadata[0], self.next_instruction
)
+ self.create_call_resume_at(
self.create_call_resume_at(
self.next_instruction,
all_stack_locals_metadata,
)
@ -2690,47 +2724,6 @@ class InstructionTranslatorBase(
)
return insts
def codegen_fix_leaf_stack(
self, meta: StackLocalsMetadata, resume_inst: Instruction
) -> list[Instruction]:
"""
Fixes the stack values of the current/leaf frame (self).
Expects the TOS to be:
[
frame N locals,
frame N-1 stack + locals,
...,
frame 1 stack + locals
], *(frame N stack (post-unsupported instruction))
Rearranges the TOS to become:
[
frame N stack + locals,
...,
frame 1 stack + locals
]
Args:
- meta: metadata for the leaf frame returned from OutputGraph.compile_subgraph
- resume_inst: if the resume instruction is a return instruction, then don't return any instructions
"""
if resume_inst.opname in ("RETURN_VALUE", "RETURN_CONST"):
return []
# move frame N stack to the frame values list
current_num_stack = len(self.stack) - len(meta.stack_null_idxes)
meta.num_stack = current_num_stack
return [
create_instruction("BUILD_LIST", arg=current_num_stack),
*create_copy(2),
# frame_values, frame N stack, frame_values
create_load_const(0),
create_instruction("BINARY_SUBSCR"),
*create_binary_slice(0, 0, True),
# frame_values[0][0:0] = frame N stack
# frame_values left on top of stack
]
def create_resume(
self,
idx: int,
@ -2918,6 +2911,8 @@ class InstructionTranslatorBase(
new_code, self.f_globals["__name__"], package_name
)
counters["resumes"][new_code.co_name] += 1
return new_code, resume_name
def create_call_resume_at(
@ -2926,20 +2921,20 @@ class InstructionTranslatorBase(
all_stack_locals_metadata: list[StackLocalsMetadata],
) -> list[Instruction]:
"""
Codegen all resume function(s) from the frame stack starting at `self` and call them.
Codegen all resume function(s) from the frame stack starting at `self`, call them,
and return the result.
Assumes that the unsupported instruction has already been run.
Expects the stack to be in the state:
[frame N cells, ..., frame 1 cells],
Expects the TOS to be:
[
frame N stack + locals,
frame N locals,
frame N-1 stack + locals,
...,
frame 1 stack + locals
]
], *(frame N stack (post-unsupported instruction))
Pops the cells and frame values list from the stack.
Also includes a return instruction (stack expected to be empty after return).
Leaves the result of calling the resume functions on the stack and returns it
(empty stack after return).
Args:
- inst: the instruction of the current (deepest) frame to resume at
@ -2949,31 +2944,141 @@ class InstructionTranslatorBase(
self.instruction_pointer = None
current_num_stack = len(self.stack) - len(
all_stack_locals_metadata[0].stack_null_idxes
)
all_stack_locals_metadata[0].num_stack = current_num_stack
if inst.opname in ("RETURN_VALUE", "RETURN_CONST"):
return self.codegen_return_with_pops(
inst, all_stack_locals_metadata[0].num_stack
)
cg = PyCodegen(self.output.root_tx)
# NOTE: We do not need to codegen frames whose resume instruction is RETURN_VALUE
# We could also do something similar for RETURN_CONST, but a lot more code is necessary
# since we would need to track RETURN_CONST values and inject the constant in the right places.
# Filter out tx'es that are resuming on RETURN_*.
txes: list[InstructionTranslatorBase] = []
idxes: list[int] = []
resume_insts: list[Instruction] = []
cur_tx: Optional[InstructionTranslatorBase] = self
idx = 0
resume_codes: list[types.CodeType] = []
resume_names = []
while cur_tx is not None:
if cur_tx is self:
resume_inst = inst
else:
resume_inst = cur_tx.next_instruction
if (
not (
config.debug_force_graph_break_on_leaf_return
and self.current_instruction.opname == "NOP"
and self.current_instruction.argval == "GRAPH_BREAK_IF_LEAF"
and cur_tx is self
)
and resume_inst.opname != "RETURN_VALUE"
):
txes.append(cur_tx)
idxes.append(idx)
resume_insts.append(resume_inst)
cur_tx = cur_tx.parent
idx += 1
current_num_stack = len(self.stack) - len(
all_stack_locals_metadata[0].stack_null_idxes
)
# Every tx is returning - no need to call a resume function.
if not txes:
# Pop everything but TOS, then return the TOS.
# Frame N's stack must have length >= 1 since it's about to RETURN_VALUE.
# Frame N actually should have stack length == 1, because debug CPython expects
# empty stacks after return, but there is no guarantee written down anywhere.
assert current_num_stack >= 1
cg.extend_output(create_swap(current_num_stack + 2))
for _ in range(current_num_stack + 1):
cg.append_output(create_instruction("POP_TOP"))
cg.append_output(create_instruction("RETURN_VALUE"))
return cg.get_instructions()
# Let frame k be the deepest frame where the resume function is not RETURN_VALUE
# - If k == N, then the frame N stack is prepended to the frame N locals.
# - If k != N, then frame N's TOS is added to frame k's stack.
# Rearrange the TOS to be compatible with create_resume and codegen_call_resume:
# [
# frame N stack + locals,
# ...,
# frame 1 stack + locals
# ]
# create the stack values that should be moved
if txes[0] is self:
# Frame N is non-returning, pack all of frame N's stack to
# be moved to frame N's frame values
cg.append_output(create_instruction("BUILD_LIST", arg=current_num_stack))
# frame N stack is not yet on the frame N's frame values
stack_insert_idx = 0
all_stack_locals_metadata[0].num_stack = current_num_stack
else:
# Frame N is returning. Let frame k be the deepest non-returning frame.
# Add frame N's TOS to frame k's stack.
# pop frame N stack except TOS
cg.extend_output(create_swap(current_num_stack))
for _ in range(current_num_stack - 1):
cg.append_output(create_instruction("POP_TOP"))
cg.append_output(create_instruction("BUILD_LIST", arg=1))
# frame k stack is already on frame k's frame values
stack_insert_idx = all_stack_locals_metadata[idxes[0]].num_stack
all_stack_locals_metadata[idxes[0]].num_stack += 1
txes[0].push(UnknownVariable())
# move the predetermined stack value(s) to the deepest non-returning frame
cg.extend_output(
[
*create_copy(2),
# frame_values, return_const, frame_values
cg.create_load_const(idxes[0]),
cg.create_binary_subscr(),
*create_binary_slice(stack_insert_idx, stack_insert_idx, True),
# frame_values[idxes[0]][stack_insert_idx:stack_insert_idx] = frame N stack/[return_const/TOS]
# frame_values left on top of stack
]
)
# filter out frame values of skipped tx'es
filter_insts = []
for idx in idxes:
filter_insts.extend(
[
create_dup_top(),
cg.create_load_const(idx),
cg.create_binary_subscr(),
*create_swap(2),
]
)
# TOS: cells, frame_values[idxes[0]], ..., frame_values[idxes[...]], frame_values
filter_insts.extend(
[
create_instruction("POP_TOP"),
create_instruction("BUILD_LIST", arg=len(idxes)),
]
)
# TOS: cells, filtered frame_values
cg.extend_output(filter_insts)
# filter out cells of skipped tx'es using the same instructions in filter_insts,
# but with cells as TOS instead of frame values
cg.extend_output(
[
*create_swap(2),
*copy.deepcopy(filter_insts),
*create_swap(2),
]
)
# TOS: filtered cells, filtered frame_values
resume_codes: list[types.CodeType] = []
resume_names = []
for i, cur_tx in enumerate(txes):
resume_code, resume_name = cur_tx.create_resume(
idx,
resume_inst,
all_stack_locals_metadata[idx],
i,
resume_insts[i],
all_stack_locals_metadata[idxes[i]],
resume_codes,
cg,
cur_tx is self,
@ -2982,11 +3087,17 @@ class InstructionTranslatorBase(
resume_codes.append(resume_code)
resume_names.append(resume_name)
cur_tx = cur_tx.parent
idx += 1
if (
config.debug_force_graph_break_on_leaf_return
and self.current_instruction.opname == "NOP"
and self.current_instruction.argval == "GRAPH_BREAK_IF_LEAF"
):
_debug_force_graph_break_on_leaf_return_disable_codes.add(resume_codes[0])
self.codegen_call_resume(resume_codes, resume_names, cg)
return cg.get_instructions() + [create_instruction("RETURN_VALUE")]
cg.append_output(create_instruction("RETURN_VALUE"))
return cg.get_instructions()
@staticmethod
def codegen_call_resume(
@ -3119,9 +3230,14 @@ class InstructionTranslatorBase(
return (
all(b.can_restore() for b in self.block_stack)
and not self.one_graph
and not self.error_on_graph_break
and (
not self.error_on_graph_break
or config.debug_force_graph_break_on_leaf_return
)
and not self.is_tracing_resume_prologue
and not self.active_generic_context_managers
# Do not allow nested graph breaks in HOPs
and self.output.current_tracer.parent is None
)
@break_graph_if_unsupported(push=0)
@ -3361,7 +3477,10 @@ class InstructionTranslatorBase(
def NOP(self, inst: Instruction) -> None:
# Dynamo-specific testing behavior
if inst.argval == "GRAPH_BREAK_IF_LEAF":
if (
self.f_code not in _debug_force_graph_break_on_leaf_return_disable_codes
and inst.argval == "GRAPH_BREAK_IF_LEAF"
):
self.graph_break_on_leaf_function(inst)
def POP_TOP(self, inst: Instruction) -> None:
@ -4480,9 +4599,8 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
@classmethod
def inline_call(cls, parent: Any, func: Any, args: Any, kwargs: Any) -> Any:
with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
tracer = cls.build_inline_tracer(parent, func, args, kwargs)
return tracer.inline_call_()
tracer = cls.build_inline_tracer(parent, func, args, kwargs)
return tracer.inline_call_()
@staticmethod
def check_inlineable(func: Any) -> trace_rules.SkipResult:
@ -4941,6 +5059,10 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
self.generator_exhausted = False
self.is_generator_from_ctx_manager = False
def should_compile_partial_graph(self) -> bool:
# resuming on graph break on inlined generator not supported
return False
def YIELD_VALUE(self, inst: Instruction) -> None:
top = self.pop()
self.generated_items.append(top)

View File

@ -102,16 +102,36 @@ class TestCase(TorchTestCase):
torch.set_grad_enabled(self._prior_is_grad_enabled)
def assertEqual(self, x: Any, y: Any, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
if (
config.debug_disable_compile_counter
and isinstance(x, utils.CompileCounterInt)
or isinstance(y, utils.CompileCounterInt)
):
return
if config.debug_disable_compile_counter:
if isinstance(x, utils.CompileCounterInt) or isinstance(
y, utils.CompileCounterInt
):
return
# skip checks like self.assertEqual(len(counters["graph_break"]), 1)
if (
(cur_frame := inspect.currentframe())
and (upper_frame := cur_frame.f_back)
and (upper_code := inspect.getframeinfo(upper_frame).code_context)
and "counters" in upper_code[0]
):
return
return super().assertEqual(x, y, *args, **kwargs)
# assertExpectedInline might also need to be disabled for wrapped nested
# graph break tests
def assertExpectedInline(self, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
if config.debug_disable_compile_counter:
return
return super().assertExpectedInline(*args, **kwargs)
class TestCaseWithNestedGraphBreaks(TestCase):
def setUp(self) -> None:
super().setUp()
self.prev_nested_graph_breaks = torch._dynamo.config.nested_graph_breaks
torch._dynamo.config.nested_graph_breaks = True
def tearDown(self) -> None:
super().tearDown()
torch._dynamo.config.nested_graph_breaks = self.prev_nested_graph_breaks
class CPythonTestCase(TestCase):

View File

@ -33,7 +33,7 @@ from torch import fx
from torch._dynamo.backends.debugging import aot_eager
from torch._dynamo.output_graph import OutputGraph
from . import config, eval_frame, optimize_assert, reset
from . import config, eval_frame, optimize, reset
from .bytecode_transformation import (
create_instruction,
debug_checks,
@ -379,7 +379,7 @@ def standard_test(
correct1 = fn(*args1)
correct2 = fn(*args2)
reset()
opt_fn = optimize_assert(actual)(fn)
opt_fn = optimize(actual, error_on_graph_break=True)(fn)
val1a = opt_fn(*args1)
val2a = opt_fn(*args2)
val1b = opt_fn(*args1)

View File

@ -35,7 +35,6 @@ from collections.abc import Sequence
from types import FunctionType
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar
from typing_extensions import Never
from unittest.mock import patch
from weakref import WeakKeyDictionary
import torch
@ -69,7 +68,6 @@ from ..utils import (
check_constant_args,
check_unspec_or_constant_args,
cmp_name_to_op_mapping,
counters,
identity,
is_function,
is_wrapper_or_member_descriptor,
@ -712,8 +710,7 @@ class LocalGeneratorObjectVariable(VariableTracker):
# Hierarchically, tx can be seen as the parent of the inline tracer
# created on call_function. Any exception needs to be propagated to tx
# for Dynamo to behave correctly
with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
return tracer.inline_call_()
return tracer.inline_call_()
except ObservedException as e:
tracer.generator_exhausted = True
raise e
@ -723,8 +720,6 @@ class LocalGeneratorObjectVariable(VariableTracker):
except Unsupported as e:
torch._dynamo.eval_frame.skip_code(self.get_code())
raise SkipFrame from e
finally:
counters["unimplemented"] |= counters["inline_call"]
def call_obj_hasattr(self, tx, name):
if name in self.python_type().__dict__:

View File

@ -79,6 +79,23 @@ class TORCH_API Backend : public torch::CustomClassHolder {
return false;
}
virtual bool supportsShrinking() const {
return false;
}
// Shrink the backend by excluding specified ranks. Backends that support
// communicator shrinking should override this and return a new backend
// instance representing the shrunken group. Backends may use opts_override
// to supply backend-specific options for the new group.
virtual c10::intrusive_ptr<Backend> shrink(
const std::vector<int64_t>& /*ranks_to_exclude*/,
int /*shrink_flags*/ = 0,
const c10::intrusive_ptr<Options>& /*opts_override*/ = nullptr) {
TORCH_CHECK(
false,
c10::str("Backend ", getBackendName(), " does not support shrink"));
}
virtual void setTimeout(std::chrono::milliseconds timeout) {
TORCH_CHECK(
false,

View File

@ -259,6 +259,65 @@ std::shared_ptr<NCCLComm> NCCLComm::split(
}
#endif
#ifdef NCCL_HAS_COMM_SHRINK
std::shared_ptr<NCCLComm> NCCLComm::shrink(
NCCLComm* source,
std::vector<int>& ranks_to_exclude,
ncclConfig_t* config,
int shrinkFlags) {
// Preconditions are validated in ProcessGroupNCCL::shrink
LOG(INFO) << "Rank " << source->rank_ << ": shrinking comm " << source->repr()
<< " excluding " << ranks_to_exclude.size() << " ranks";
at::cuda::OptionalCUDAGuard gpuGuard(source->deviceIndex_);
auto comm = std::make_shared<NCCLComm>();
// This call will block until the source communicator is initialized
auto sourceComm = source->getNcclComm();
C10D_NCCL_CHECK_NONBLOCKING(
ncclCommShrink(
sourceComm,
ranks_to_exclude.data(),
ranks_to_exclude.size(),
reinterpret_cast<ncclComm_t*>(&(comm->ncclComm_)),
config,
shrinkFlags),
source->getNcclCommFailureReason());
// Wait for the child communicator to be ready
source->waitReady(true);
comm->initialized_ = true;
// NCCL automatically assigns rank during shrink - query it efficiently
int assigned_rank;
try {
C10D_NCCL_CHECK(
ncclCommUserRank(comm->ncclComm_, &assigned_rank), std::nullopt);
comm->rank_ = assigned_rank;
} catch (const std::exception& e) {
// Fallback: if ncclCommUserRank fails, we can't determine the rank
LOG(ERROR) << "Failed to query NCCL-assigned rank: " << e.what();
throw;
}
// Child comm should be on the same device as parent comm
comm->deviceIndex_ = source->deviceIndex_;
if (config != nullptr) {
comm->nonBlocking_ = config->blocking == 0;
} else {
// Inherit parent behavior if no config provided
comm->nonBlocking_ = source->nonBlocking_;
}
LOG(INFO) << "Rank " << source->rank_ << ": created shrunken comm "
<< comm->repr() << " with NCCL-assigned rank " << assigned_rank;
return comm;
}
#endif
void NCCLComm::finalize() {
LockType lock(mutex_);
if (aborted_) {

View File

@ -90,6 +90,10 @@ static_assert(
#define NCCL_HAS_NVLS_CTAS
#endif
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 27, 0)
#define NCCL_HAS_COMM_SHRINK
#endif
// Macro to throw on a non-successful NCCL return value.
#define C10D_NCCL_CHECK(cmd, failureReason) \
do { \
@ -294,6 +298,14 @@ class NCCLComm {
ncclConfig_t& config);
#endif // NCCL_HAS_COMM_SPLIT
#ifdef NCCL_HAS_COMM_SHRINK
static std::shared_ptr<NCCLComm> shrink(
NCCLComm* source,
std::vector<int>& ranks_to_exclude,
ncclConfig_t* config,
int shrinkFlags = 0);
#endif // NCCL_HAS_COMM_SHRINK
#if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP)
std::unordered_map<std::string, std::string> ncclCommDump();
#endif

View File

@ -165,7 +165,7 @@ ncclRedOpRAII getNcclReduceOp(
}
// Get a key string from device
inline std::string getKeyFromDevice(at::Device& device) {
inline std::string getKeyFromDevice(const at::Device& device) {
return std::to_string(device.index());
}
@ -5838,6 +5838,139 @@ at::Tensor ProcessGroupNCCL::allocateTensor(
return tensor;
}
#ifdef NCCL_HAS_COMM_SHRINK
c10::intrusive_ptr<Backend> ProcessGroupNCCL::shrink(
const std::vector<int64_t>& ranks_to_exclude,
int shrink_flags,
const c10::intrusive_ptr<Backend::Options>& opts_override) {
// Runtime version check with better error message
auto runtime_version = torch::cuda::nccl::version();
TORCH_CHECK(
runtime_version >= NCCL_VERSION(2, 27, 0),
"ProcessGroupNCCL::shrink requires NCCL version 2.27.0 or later. "
"Found version: ",
runtime_version);
// Early validation with detailed error messages
TORCH_CHECK_VALUE(
!ranks_to_exclude.empty(), "ranks_to_exclude cannot be empty");
TORCH_CHECK_VALUE(
static_cast<int>(ranks_to_exclude.size()) < size_,
"Cannot exclude all ranks (",
ranks_to_exclude.size(),
" >= ",
size_,
")");
// Validate ranks and convert to int efficiently
std::vector<int> int_ranks_to_exclude;
int_ranks_to_exclude.reserve(ranks_to_exclude.size());
for (int64_t rank : ranks_to_exclude) {
TORCH_CHECK_VALUE(
rank >= 0 && rank < size_,
"Invalid rank ",
rank,
" for group size ",
size_);
int_ranks_to_exclude.push_back(static_cast<int>(rank));
}
// Get primary communicator with better error context
auto primary_device_index = guessDeviceId();
auto primary_device = at::Device(at::kCUDA, primary_device_index);
const auto primary_key = getKeyFromDevice(primary_device);
std::shared_ptr<NCCLComm> primary_comm = getNCCLComm(primary_key);
TORCH_CHECK(
primary_comm,
"Primary NCCL communicator for device ",
primary_device,
" (key: ",
primary_key,
") is not initialized");
// Cache device index before shrink operation
at::DeviceIndex parent_device_index = primary_comm->getDeviceIndex();
ncclConfig_t* config = nullptr;
// Default to inheriting from parent options
bool high_priority_stream = options_->is_high_priority_stream;
if (opts_override) {
auto nccl_opts =
c10::static_intrusive_pointer_cast<ProcessGroupNCCL::Options>(
opts_override);
config = &nccl_opts->config;
// If user provided override options, honor is_high_priority_stream as well
high_priority_stream = nccl_opts->is_high_priority_stream;
}
std::shared_ptr<NCCLComm> shrunk_comm = NCCLComm::shrink(
primary_comm.get(),
int_ranks_to_exclude,
(config != nullptr ? config : &options_->config),
shrink_flags);
// Calculate new size and get NCCL-assigned rank
int new_size = size_ - static_cast<int>(ranks_to_exclude.size());
int new_rank = shrunk_comm->rank_;
// Create new ProcessGroupNCCL with optimized options cloning
auto new_store = store_->clone();
auto new_opts = ProcessGroupNCCL::Options::create(high_priority_stream);
new_opts->timeout = options_->timeout;
if (config != nullptr) {
new_opts->config = *config;
} else {
new_opts->config = options_->config;
}
auto new_pg = c10::make_intrusive<ProcessGroupNCCL>(
new_store, new_rank, new_size, new_opts);
// Set up the new process group with optimized device setup
new_pg->initializeDeviceStateForComm(
at::Device(at::kCUDA, parent_device_index), shrunk_comm);
return c10::static_intrusive_pointer_cast<Backend>(new_pg);
}
#else // !NCCL_HAS_COMM_SHRINK
// Backend interface override: raise consistent error when shrink is
// unsupported.
c10::intrusive_ptr<Backend> ProcessGroupNCCL::shrink(
const std::vector<int64_t>& /*ranks_to_exclude*/,
int /*shrink_flags*/,
const c10::intrusive_ptr<Backend::Options>& /*opts_override*/) {
TORCH_CHECK(
false,
"ProcessGroupNCCL::shrink requires NCCL version 2.27.0 or later, "
"but PyTorch was built with an older version or without NCCL shrink support.");
}
#endif // NCCL_HAS_COMM_SHRINK
void ProcessGroupNCCL::initializeDeviceStateForComm(
const at::Device& device,
std::shared_ptr<NCCLComm> comm) {
const auto key = getKeyFromDevice(device);
std::unique_lock<std::mutex> lock(mutex_);
at::cuda::OptionalCUDAGuard gpuGuard(device);
bool force_high = getCvarBool(TORCH_NCCL_HIGH_PRIORITY, false);
auto stream = at::cuda::getStreamFromPool(
options_->is_high_priority_stream || force_high);
devNCCLCommMap_[key] = comm;
ncclStreams_.emplace(key, stream);
ncclEvents_.emplace(key, at::cuda::CUDAEvent(cudaEventDisableTiming));
usedDeviceIdxs_.insert(device.index());
if (shouldAllCommunicatorsRegisterAllTensors()) {
std::lock_guard<std::mutex> map_lock(ncclCommMemPoolMapMutex);
ncclCommMemPoolMap.emplace(std::move(comm), MemPoolSet{});
}
}
} // namespace c10d
#endif // USE_C10D_NCCL

View File

@ -997,6 +997,21 @@ class TORCH_API ProcessGroupNCCL : public Backend {
ErrorType getError() override;
bool supportsShrinking() const override {
#ifdef NCCL_HAS_COMM_SHRINK
return true;
#else
return false;
#endif
}
// Backend-style shrink override that returns a Backend instance.
c10::intrusive_ptr<Backend> shrink(
const std::vector<int64_t>& ranks_to_exclude,
int shrink_flags = 0,
const c10::intrusive_ptr<Backend::Options>& opts_override =
nullptr) override;
std::shared_ptr<c10::Allocator> getMemAllocator() override;
// Allocate tensor from communication-optimized memory pool
@ -1065,6 +1080,12 @@ class TORCH_API ProcessGroupNCCL : public Backend {
int p2pRank = 0,
bool isSendRecvSelf = false);
// Initialize device-specific state (comm, stream, event, bookkeeping) for a
// given communicator on this process group instance.
void initializeDeviceStateForComm(
const at::Device& device,
std::shared_ptr<NCCLComm> comm);
// Wrapper method which can be overridden for tests.
virtual std::exception_ptr checkForNCCLErrors(
std::shared_ptr<NCCLComm>& ncclComm);

View File

@ -2734,12 +2734,23 @@ Arguments:
"supports_time_estimate",
&::c10d::Backend::supportsTimeEstimation,
"(test whether the backend supports collective time estimation)")
.def_property_readonly(
"supports_shrinking",
&::c10d::Backend::supportsShrinking,
"(test whether the backend supports communicator shrinking)")
.def(
"set_timeout",
&::c10d::Backend::setTimeout,
py::arg("timeout"),
py::call_guard<py::gil_scoped_release>(),
R"(Sets the default timeout for all future operations.)")
.def(
"shrink",
&::c10d::Backend::shrink,
py::arg("ranks_to_exclude"),
py::arg("shrink_flags") = 0,
py::arg("opts_override") = nullptr,
py::call_guard<py::gil_scoped_release>())
.def(
"broadcast",
&::c10d::Backend::broadcast,

View File

@ -1,11 +1,15 @@
#include <torch/csrc/fx/node.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <structmember.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/pythoncapi_compat.h>
#include <algorithm>
namespace {
using NodeSortKey = c10::SmallVector<int64_t, 4>;
struct NodeBase;
// Thrown to exit out of a C++ function and return an error to Python.
@ -163,7 +167,41 @@ struct NodeBase {
PyObject* users;
PyObject* _repr_fn;
PyObject* meta;
PyObject* _sort_key;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
alignas(NodeSortKey) char sort_key_buf[sizeof(NodeSortKey)];
inline NodeSortKey& sort_key() {
return *reinterpret_cast<NodeSortKey*>(sort_key_buf);
}
inline void set_prev(NodeBase* value) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(value);
Py_INCREF(reinterpret_cast<PyObject*>(value));
NodeBase* old = _prev;
_prev = value;
Py_DECREF(reinterpret_cast<PyObject*>(old));
}
inline void set_next(NodeBase* value) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(value);
Py_INCREF(reinterpret_cast<PyObject*>(value));
NodeBase* old = _next;
_next = value;
Py_DECREF(reinterpret_cast<PyObject*>(old));
}
// Equivalent to:
// p, n = self._prev, self._next
// p._next, n._prev = n, p
inline void remove_from_list() {
if (this->_prev == this && this->_next == this) {
return;
}
NodeBase* p = this->_prev;
NodeBase* n = this->_next;
p->set_next(n);
n->set_prev(p);
}
};
static PyObject* NodeBase_new(
@ -173,6 +211,8 @@ static PyObject* NodeBase_new(
PyObject* self = type->tp_alloc(type, 0);
if (!self)
return nullptr;
new (reinterpret_cast<NodeBase*>(self)->sort_key_buf)
NodeSortKey(); // placement new does not allocate
return self;
}
@ -201,7 +241,6 @@ static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
self->users = PyDict_New();
self->_repr_fn = Py_NewRef(Py_None);
self->meta = PyDict_New();
self->_sort_key = PyTuple_New(0);
return 0;
}
@ -221,7 +260,6 @@ static struct PyMemberDef NodeBase_members[] = {
{"users", T_OBJECT_EX, offsetof(NodeBase, users), 0, nullptr},
{"_repr_fn", T_OBJECT_EX, offsetof(NodeBase, _repr_fn), 0, nullptr},
{"meta", T_OBJECT_EX, offsetof(NodeBase, meta), 0, nullptr},
{"_sort_key", T_OBJECT_EX, offsetof(NodeBase, _sort_key), 0, nullptr},
{nullptr} /* Sentinel */
};
@ -239,7 +277,6 @@ static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
Py_VISIT(self->users);
Py_VISIT(self->_repr_fn);
Py_VISIT(self->meta);
Py_VISIT(self->_sort_key);
return 0;
}
@ -257,12 +294,12 @@ static int NodeBase_clear(NodeBase* self) {
Py_CLEAR(self->users);
Py_CLEAR(self->_repr_fn);
Py_CLEAR(self->meta);
Py_CLEAR(self->_sort_key);
return 0;
}
static void NodeBase_dealloc(PyObject* self) {
PyObject_GC_UnTrack(self);
reinterpret_cast<NodeBase*>(self)->sort_key().~NodeSortKey();
(void)NodeBase_clear((NodeBase*)self);
Py_TYPE(self)->tp_free(self);
}
@ -321,15 +358,195 @@ static PyObject* NodeBase__update_args_kwargs(
}
}
static PyObject* NodeBase__remove_from_list(
PyObject* self,
PyObject* _ignored) {
reinterpret_cast<NodeBase*>(self)->remove_from_list();
Py_RETURN_NONE;
}
static PyObject* NodeBase__prepend(PyObject* self_, PyObject* arg) {
if (self_ == arg) {
Py_RETURN_NONE;
}
if (!is_node(arg)) {
PyErr_SetString(PyExc_TypeError, "_prepend() argument must be a Node");
return nullptr;
}
NodeBase* self = reinterpret_cast<NodeBase*>(self_);
NodeBase* x = reinterpret_cast<NodeBase*>(arg);
if (self->graph != x->graph) {
PyErr_SetString(
PyExc_AssertionError,
"Attempting to move a Node into a different Graph");
return nullptr;
}
x->remove_from_list();
NodeBase* p = self->_prev;
p->set_next(x);
x->set_prev(p);
x->set_next(self);
self->set_prev(x);
// Now compute x.sort_key()
const NodeSortKey& psk = x->_prev->sort_key();
const NodeSortKey& nsk = x->_next->sort_key();
if (psk.size() > nsk.size()) {
// prefix = psk[: len(nsk)+1]
size_t slice_len = nsk.size() + 1;
NodeSortKey prefix(psk.begin(), psk.begin() + slice_len);
// last element is idx => increment by 1
prefix.back()++;
x->sort_key() = std::move(prefix);
} else if (psk.size() < nsk.size()) {
// prefix = nsk[: len(psk)+1]
size_t slice_len = psk.size() + 1;
NodeSortKey prefix(nsk.begin(), nsk.begin() + slice_len);
// last element is idx => decrement by 1
prefix.back()--;
x->sort_key() = std::move(prefix);
} else {
// same length => add a 0
x->sort_key() = psk;
x->sort_key().emplace_back(0);
}
Py_RETURN_NONE;
}
// __lt__(self, other): Return self.sort_key < other.sort_key
static PyObject* NodeBase___lt__(PyObject* self, PyObject* other) {
// METH_O => one argument: 'other'
if (!is_node(other)) {
Py_RETURN_NOTIMPLEMENTED;
}
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
bool less = std::lexicographical_compare(
lhs.begin(), lhs.end(), rhs.begin(), rhs.end());
if (less)
Py_RETURN_TRUE;
Py_RETURN_FALSE;
}
// __gt__(self, other): Return self.sort_key() > other.sort_key
static PyObject* NodeBase___gt__(PyObject* self, PyObject* other) {
if (!is_node(other)) {
Py_RETURN_NOTIMPLEMENTED;
}
const NodeSortKey& lhs = reinterpret_cast<NodeBase*>(self)->sort_key();
const NodeSortKey& rhs = reinterpret_cast<NodeBase*>(other)->sort_key();
// "a > b" is equivalent to "b < a"
bool greater = std::lexicographical_compare(
rhs.begin(), rhs.end(), lhs.begin(), lhs.end());
if (greater)
Py_RETURN_TRUE;
Py_RETURN_FALSE;
}
static PyObject* NodeBase___ge__(PyObject* self, PyObject* other) {
if (self == other) {
Py_RETURN_TRUE;
}
return NodeBase___gt__(self, other);
}
// __le__(self, other): Return not (self > other)
static PyObject* NodeBase___le__(PyObject* self, PyObject* other) {
if (self == other) {
Py_RETURN_TRUE;
}
return NodeBase___lt__(self, other);
}
// Convert the NodeBase::sort_key vector<long> into a Python tuple of ints
// Only used by pickle/__getstate__
static PyObject* NodeBase_get_sort_key(PyObject* self, void* /*closure*/) {
NodeBase* node = reinterpret_cast<NodeBase*>(self);
const NodeSortKey& vec = node->sort_key();
Py_ssize_t n = static_cast<Py_ssize_t>(vec.size());
THPObjectPtr tuple(PyTuple_New(n));
if (!tuple) {
return nullptr; // Out of memory
}
for (Py_ssize_t i = 0; i < n; i++) {
PyObject* value = PyLong_FromSsize_t(vec[i]);
if (!value) {
return nullptr;
}
PyTuple_SET_ITEM(tuple.get(), i, value);
}
return tuple.release();
}
// Setter for NodeBase::sort_key: expects a Python tuple of ints, e.g.
// node._sort_key = (1,2,3) Only used by pickle/__setstate__
static int NodeBase_set_sort_key(
PyObject* self,
PyObject* value,
void* /*closure*/) {
NodeBase* node = reinterpret_cast<NodeBase*>(self);
if (!PyTuple_Check(value)) {
PyErr_SetString(PyExc_TypeError, "_sort_key must be an tuple of ints");
return -1;
}
Py_ssize_t size = PyTuple_GET_SIZE(value);
NodeSortKey new_vec;
new_vec.reserve(size);
for (Py_ssize_t i = 0; i < size; i++) {
int64_t val = PyLong_AsSsize_t(PyTuple_GET_ITEM(value, i));
if (val == -1 && PyErr_Occurred()) {
return -1;
}
new_vec.emplace_back(val);
}
node->sort_key() = std::move(new_vec);
return 0;
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static PyMethodDef NodeBase_methods[] = {
{"_update_args_kwargs",
(PyCFunction)(void*)(NodeBase__update_args_kwargs),
METH_FASTCALL,
"Internal method: do not call directly."},
{"_remove_from_list",
(PyCFunction)(void*)(NodeBase__remove_from_list),
METH_NOARGS,
"Internal method: do not call directly."},
{"_prepend",
(PyCFunction)(void*)(NodeBase__prepend),
METH_O,
"Internal method: do not call directly."},
{"__lt__",
(PyCFunction)(void*)NodeBase___lt__,
METH_O,
"Return True if self.sort_key < other.sort_key"},
{"__gt__",
(PyCFunction)(void*)NodeBase___gt__,
METH_O,
"Return True if self.sort_key > other.sort_key"},
{"__ge__",
(PyCFunction)(void*)NodeBase___ge__,
METH_O,
"Return True if self.sort_key >= other.sort_key"},
{"__le__",
(PyCFunction)(void*)NodeBase___le__,
METH_O,
"Return True if self.sort_key <= other.sort_key"},
{nullptr, nullptr, 0, nullptr} // Sentinel
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static PyGetSetDef NodeBase_getset[] = {
{"_sort_key", // attribute name in Python
(getter)NodeBase_get_sort_key, // C getter function
(setter)NodeBase_set_sort_key, // C setter function
(char*)"The sort key as a tuple of ints", // docstring
nullptr},
{nullptr, nullptr, nullptr, nullptr, nullptr} // Sentinel
};
PyTypeObject NodeBaseType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch._C._NodeBase", /* tp_name */
@ -361,7 +578,7 @@ PyTypeObject NodeBaseType = {
nullptr, /* tp_iternext */
NodeBase_methods, /* tp_methods */
NodeBase_members, /* tp_members */
nullptr, /* tp_getset */
NodeBase_getset, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */

View File

@ -130,6 +130,7 @@ __all__ = [
"reduce_scatter_tensor",
"get_node_local_rank",
"split_group",
"shrink_group",
]
_MPI_AVAILABLE = True
@ -5713,3 +5714,521 @@ def _get_process_group_name(pg: ProcessGroup) -> str:
def _get_process_group_store(pg: ProcessGroup) -> Store:
return _world.pg_map[pg][1]
# Shrink flags for process group backends
SHRINK_DEFAULT = 0x00
SHRINK_ABORT = 0x01
@_time_logger
def shrink_group(
ranks_to_exclude: list[int],
group: Optional[ProcessGroup] = None,
shrink_flags: int = SHRINK_DEFAULT,
pg_options: Optional[Any] = None,
) -> ProcessGroup:
"""
Shrinks a process group by excluding specified ranks.
Creates and returns a new, smaller process group comprising only the ranks
from the original group that were not in the ``ranks_to_exclude`` list.
Args:
ranks_to_exclude (List[int]): A list of ranks from the original
``group`` to exclude from the new group.
group (ProcessGroup, optional): The process group to shrink. If ``None``,
the default process group is used. Defaults to ``None``.
shrink_flags (int, optional): Flags to control the shrinking behavior.
Can be ``SHRINK_DEFAULT`` (default) or ``SHRINK_ABORT``.
``SHRINK_ABORT`` will attempt to terminate ongoing operations
in the parent communicator before shrinking.
Defaults to ``SHRINK_DEFAULT``.
pg_options (ProcessGroupOptions, optional): Backend-specific options to apply
to the shrunken process group. If provided, the backend will use
these options when creating the new group. If omitted, the new group
inherits defaults from the parent.
Returns:
ProcessGroup: a new group comprised of the remaining ranks. If the
default group was shrunk, the returned group becomes the new default group.
Raises:
TypeError: if the groups backend does not support shrinking.
ValueError: if ``ranks_to_exclude`` is invalid (empty, out of bounds,
duplicates, or excludes all ranks).
RuntimeError: if an excluded rank calls this function or the backend
fails the operation.
Notes:
- Only non-excluded ranks should call this function; excluded ranks
must not participate in the shrink operation.
- Shrinking the default group destroys all other process groups since
rank reassignment makes them inconsistent.
"""
# Step 1: Validate input parameters with comprehensive error checking
_validate_shrink_inputs(ranks_to_exclude, shrink_flags)
# Step 2: Get target group and essential properties
target_group_info = _prepare_shrink_target_group(group)
# Step 3: Validate backend requirements and availability
backend_impl = _validate_shrink_backend_requirements(target_group_info)
# Step 4: Validate ranks against group and check for duplicates
excluded_ranks_set = _validate_and_process_excluded_ranks(
ranks_to_exclude, target_group_info
)
# Step 5: Execute the actual shrink operation (backend-specific)
new_backend = backend_impl.shrink(
sorted(excluded_ranks_set),
shrink_flags,
pg_options if pg_options is not None else None,
)
# Step 6: Handle cleanup and creation of new process group
target_group_info["pg_options_override"] = pg_options
return _finalize_shrunk_group(target_group_info, excluded_ranks_set, new_backend)
def _validate_shrink_inputs(ranks_to_exclude: list[int], shrink_flags: int) -> None:
"""Validate input parameters for shrink_group."""
if not isinstance(ranks_to_exclude, list):
raise TypeError(
f"ranks_to_exclude must be a list, but got {type(ranks_to_exclude).__name__}. "
f"Example: [1, 3, 5] to exclude ranks 1, 3, and 5."
)
if not ranks_to_exclude:
raise ValueError(
"ranks_to_exclude cannot be empty. To shrink a group, you must specify at least "
"one rank to exclude. Example: [failed_rank_id]"
)
# Validate shrink_flags with clear explanation of valid values
valid_flags = [SHRINK_DEFAULT, SHRINK_ABORT]
if not isinstance(shrink_flags, int) or shrink_flags not in valid_flags:
raise ValueError(
f"Invalid shrink_flags value: {shrink_flags}. Must be one of: "
f"SHRINK_DEFAULT ({SHRINK_DEFAULT}) or SHRINK_ABORT ({SHRINK_ABORT}). "
f"Use SHRINK_ABORT to abort ongoing operations before shrinking."
)
def _prepare_shrink_target_group(group: Optional[ProcessGroup]) -> dict:
"""Prepare and validate the target group for shrinking."""
target_pg = group if group is not None else _get_default_group()
# Cache frequently accessed properties to avoid repeated calls
group_size = int(target_pg.size())
group_info = {
"process_group": target_pg,
"is_default_group": (target_pg == _get_default_group()),
"group_size": group_size,
"current_rank": target_pg.rank(),
"group_name": _get_process_group_name(target_pg),
}
# Validate that we have a valid process group
if group_size <= 1:
raise ValueError(
f"Cannot shrink a process group with size {group_size}. "
f"Group must have at least 2 ranks to support shrinking."
)
return group_info
def _validate_shrink_backend_requirements(group_info: dict) -> Any:
"""Return the backend implementation for the target group or raise if unsupported."""
target_pg = group_info["process_group"]
group_name = group_info["group_name"]
# Get the group's backend directly via ProcessGroup API. Prefer a bound device if present,
# otherwise try CUDA then fall back to CPU.
try:
preferred_device = getattr(target_pg, "bound_device_id", None)
if preferred_device is not None:
backend_impl = target_pg._get_backend(preferred_device)
else:
# Try CUDA first if available, else CPU
try:
backend_impl = target_pg._get_backend(torch.device("cuda"))
except Exception:
backend_impl = target_pg._get_backend(torch.device("cpu"))
except RuntimeError as e:
raise RuntimeError(
f"Cannot access device backend for process group '{group_name}'. "
f"Ensure the process group was initialized with a compatible device backend and devices are available."
) from e
try:
supports = bool(backend_impl.supports_shrinking)
except Exception:
supports = False
if not supports:
raise TypeError(
f"Process group backend for '{group_name}' does not support shrinking operations."
)
return backend_impl
def _validate_and_process_excluded_ranks(
ranks_to_exclude: list[int], group_info: dict
) -> set:
"""Validate excluded ranks and convert to set for efficient operations."""
group_size = group_info["group_size"]
current_rank = group_info["current_rank"]
# Use set for O(1) duplicate detection and membership testing
excluded_ranks_set = set()
# Validate each rank with detailed error messages
for i, rank in enumerate(ranks_to_exclude):
if not isinstance(rank, int):
raise TypeError(
f"All elements in ranks_to_exclude must be integers. "
f"Element at index {i} is {type(rank).__name__}: {rank}"
)
if not (0 <= rank < group_size):
raise ValueError(
f"Rank {rank} at index {i} is out of bounds for group size {group_size}. "
f"Valid ranks are in range [0, {group_size - 1}]."
)
if rank in excluded_ranks_set:
raise ValueError(
f"Duplicate rank {rank} found in ranks_to_exclude at index {i}. "
f"Each rank can only be excluded once."
)
excluded_ranks_set.add(rank)
# Ensure we don't exclude all ranks
if len(excluded_ranks_set) >= group_size:
raise ValueError(
f"Cannot exclude all {group_size} ranks from process group. "
f"At least one rank must remain. Excluding {len(excluded_ranks_set)} ranks."
)
# Critical check: current rank should not be in excluded list
if current_rank in excluded_ranks_set:
raise RuntimeError(
f"Current rank {current_rank} is in the exclusion list and should not call shrink_group(). "
f"Only non-excluded ranks should participate in the shrinking operation. "
f"Excluded ranks should terminate their processes instead."
)
return excluded_ranks_set
def _finalize_shrunk_group(
group_info: dict, excluded_ranks_set: set, new_backend
) -> ProcessGroup:
"""Clean up old group and create new shrunk process group."""
target_pg = group_info["process_group"]
is_default_group = group_info["is_default_group"]
# Handle default group dependencies - destroy other groups first
if is_default_group:
_destroy_all_other_groups(exclude_group=target_pg)
# Gather original group metadata before cleanup
original_group_metadata = _extract_group_metadata(target_pg)
# Calculate remaining ranks efficiently
original_ranks = get_process_group_ranks(target_pg)
remaining_ranks = [
rank for rank in original_ranks if rank not in excluded_ranks_set
]
# Clean up the original group
_cleanup_original_group(target_pg, is_default_group)
# Create and configure the new process group
new_pg = _create_shrunk_process_group(
new_backend, remaining_ranks, original_group_metadata, is_default_group
)
# Register the new group in global state
if is_default_group:
_update_default_pg(new_pg)
# Update global state with new group information
rank_mapping = {
global_rank: group_rank
for group_rank, global_rank in enumerate(remaining_ranks)
}
_update_process_group_global_state(
pg=new_pg,
backend_name=original_group_metadata["backend_name"],
store=original_group_metadata["store"],
group_name=original_group_metadata["new_group_name"],
backend_config=original_group_metadata["backend_config"],
rank_mapping=rank_mapping,
)
return new_pg
def _extract_group_metadata(target_pg: ProcessGroup) -> dict:
"""Extract metadata from the original group before cleanup."""
original_backend_name, original_store = _world.pg_map[target_pg]
original_backend_config = _world.pg_backend_config.get(target_pg, "")
original_group_name = _get_process_group_name(target_pg)
# Extract device binding information before cleanup to avoid accessing destroyed group
bound_device_id = None
if hasattr(target_pg, "bound_device_id"):
bound_device_id = target_pg.bound_device_id
# Generate new group name for the shrunk group; hash for uniqueness across backends
remaining_ranks = list(get_process_group_ranks(target_pg))
new_group_name = _process_group_name(remaining_ranks, use_hashed_name=True)
return {
"backend_name": original_backend_name,
"store": original_store,
"backend_config": original_backend_config,
"original_group_name": original_group_name,
"new_group_name": new_group_name,
"bound_device_id": bound_device_id, # Safe to access after cleanup
}
def _cleanup_original_group(target_pg: ProcessGroup, is_default_group: bool) -> None:
"""Clean up the original process group safely."""
try:
destroy_process_group(target_pg)
except Exception as e:
group_type = "default" if is_default_group else "non-default"
logger.warning(
"Failed to destroy %s group during shrinking: %s", group_type, str(e)
)
# Ensure global state cleanup even if destroy_process_group fails
_cleanup_process_group_global_state(target_pg)
def _create_shrunk_process_group(
new_backend, remaining_ranks: list[int], metadata: dict, is_default_group: bool
) -> ProcessGroup:
"""Create and configure the new shrunk process group."""
# Create new group properties
new_group_rank = new_backend.rank()
new_group_size = new_backend.size()
group_name = metadata["new_group_name"]
# Generate descriptive group description
if is_default_group:
group_desc = "default:shrunken"
else:
group_desc = f"{metadata['original_group_name']}:shrunk"
# Create process group with new communicator (clone the parent store like split does)
prefix_store = PrefixStore(f"{group_name}/", metadata["store"].clone())
new_pg = ProcessGroup(prefix_store, new_group_rank, new_group_size)
# Configure backend using the device type of the new backend's bound device if available,
# otherwise derive from the original group's bound device or fall back to CPU.
backend_device = metadata.get("bound_device_id")
if backend_device is None:
# Default to CPU if no bound device is present
backend_device = torch.device("cpu")
# Choose backend enum based on device type
if backend_device.type == "cuda":
backend_type = ProcessGroup.BackendType.NCCL
else:
backend_type = ProcessGroup.BackendType.GLOO
new_pg._register_backend(backend_device, backend_type, new_backend)
new_pg._set_default_backend(backend_type)
# Inherit device binding from original group if it was bound
bound_device_id = metadata.get("bound_device_id")
if bound_device_id is not None:
new_pg.bound_device_id = bound_device_id
# Set group metadata
new_pg._set_group_name(group_name)
new_pg._set_group_desc(group_desc)
# Persist backend configuration overrides (if provided via shrink_group)
backend_config_override = metadata.get("backend_config")
if backend_config_override is not None:
# Store for introspection/debugging and potential backend hooks
_world.pg_backend_config[new_pg] = backend_config_override
return new_pg
def _destroy_all_other_groups(exclude_group: Optional[ProcessGroup] = None) -> None:
"""
Destroy all process groups except the excluded group and clean up all global state.
This is necessary when shrinking the default group because global ranks
are reassigned by NCCL, making all existing process groups inconsistent.
Note: Uses abort for non-collective cleanup since excluded ranks may not
participate in collective operations. Backend cleanup is handled independently per group.
Args:
exclude_group (ProcessGroup, optional): Process group to exclude from destruction.
If None, destroys all process groups.
"""
# Get list of groups to destroy (avoid modifying dict while iterating)
groups_to_destroy = []
for pg in list(_world.pg_group_ranks.keys()):
if exclude_group is not None and pg == exclude_group:
continue
groups_to_destroy.append(pg)
# Warn user about automatic destruction
if groups_to_destroy:
group_names = [_get_process_group_name(pg) for pg in groups_to_destroy]
logger.warning(
"Shrinking default group will destroy %d other process groups: %s. "
"This is necessary because shrinking the default group reassigns global ranks, "
"making existing groups inconsistent.",
len(groups_to_destroy),
", ".join(group_names),
)
# Destroy each group and clean up global state
for pg in groups_to_destroy:
try:
# First call abort_process_group which handles the C++ cleanup non-collectively
_abort_process_group(pg)
except Exception as e:
# Log but don't fail - some groups might already be destroyed
logger.warning(
"Failed to abort process group %s: %s",
_get_process_group_name(pg),
str(e),
)
# Ensure all global state is cleaned up even if _abort_process_group fails
# or doesn't clean up everything
_cleanup_process_group_global_state(pg)
def _cleanup_process_group_global_state(pg: ProcessGroup) -> None:
"""
Clean up all global state associated with a process group.
This function ensures complete cleanup of process group state from all
global dictionaries and registries, even if destroy_process_group fails
or doesn't clean up everything. This is critical when destroying multiple
groups to prevent inconsistent state.
The cleanup removes the process group from:
- _world.pg_map (backend and store mapping)
- _world.pg_names (group name mapping)
- _world.pg_group_ranks (rank mappings)
- _world.pg_backend_config (backend configuration)
- _world.tags_to_pg and _world.pg_to_tag (tag mappings)
- _world.pg_coalesce_state (coalescing state)
- C++ internal registries via _unregister_process_group
Args:
pg (ProcessGroup): The process group to clean up.
"""
try:
# Clean up main process group mappings
_world.pg_map.pop(pg, None)
_world.pg_group_ranks.pop(pg, None)
_world.pg_backend_config.pop(pg, None)
# Clean up process group name mapping
group_name = _world.pg_names.pop(pg, None)
# Clean up tag mappings
pg_tag = _world.pg_to_tag.pop(pg, None)
if pg_tag is not None and pg_tag in _world.tags_to_pg:
try:
_world.tags_to_pg[pg_tag].remove(pg)
# Remove the tag entry if list is empty
if not _world.tags_to_pg[pg_tag]:
_world.tags_to_pg.pop(pg_tag, None)
except (ValueError, KeyError):
# Process group was already removed from the list
pass
# Clean up any registered process group names using C++ unregister function
if group_name is not None:
try:
_unregister_process_group(group_name)
except Exception:
# Process group name might not be registered or already unregistered
pass
# Clean up coalesce state if present
_world.pg_coalesce_state.pop(pg, None)
except Exception as e:
# Log cleanup failures but don't propagate - we want to continue with other cleanups
logger.warning(
"Failed to fully clean up global state for process group: %s", str(e)
)
def _update_process_group_global_state(
pg: ProcessGroup,
backend_name: str,
store: Store,
group_name: str,
backend_config: str,
rank_mapping: Optional[dict[int, int]] = None,
pg_tag: Optional[str] = None,
user_tag: Optional[str] = None,
) -> None:
"""
Update all global state dictionaries for a process group.
This helper function consolidates the common pattern of updating multiple
global state dictionaries when creating or modifying process groups.
Args:
pg (ProcessGroup): The process group to update state for.
backend_name (str): Backend name for pg_map.
store (Store): Store instance for pg_map.
group_name (str): Group name for pg_names and registration.
backend_config (str): Backend configuration string.
rank_mapping (Dict[int, int], optional): Global rank to group rank mapping.
If None, skips updating pg_group_ranks.
pg_tag (str, optional): Process group tag. If None, defaults to f"ptd:{group_name}".
user_tag (str, optional): User-provided tag for special tag handling.
If provided, creates "user:{user_tag}" tag and also adds to default "".
"""
# Update main process group mappings
_world.pg_map[pg] = (backend_name, store)
_world.pg_names[pg] = group_name
_world.pg_backend_config[pg] = backend_config
# Register the process group name
_register_process_group(group_name, pg)
# Update rank mapping if provided
if rank_mapping is not None:
_world.pg_group_ranks[pg] = rank_mapping
# Handle tag management
if pg_tag is None:
pg_tag = f"ptd:{group_name}"
if user_tag is not None:
# Special handling for user-provided tags
# Add to default "" tag first
_world.tags_to_pg.setdefault("", []).append(pg)
# Then create user-specific tag
user_pg_tag = f"user:{user_tag}"
_world.tags_to_pg.setdefault(user_pg_tag, []).append(pg)
_world.pg_to_tag[pg] = user_pg_tag
else:
# Standard process group tag
_world.tags_to_pg.setdefault(pg_tag, []).append(pg)
_world.pg_to_tag[pg] = pg_tag

View File

@ -385,41 +385,7 @@ class Node(_NodeBase):
Args:
x (Node): The node to put before this node. Must be a member of the same graph.
"""
assert self.graph == x.graph, "Attempting to move a Node into a different Graph"
if self == x:
log.debug(
"Trying to prepend a node to itself. This behavior has no effect on the graph."
)
return
x._remove_from_list()
p = self._prev
p._next, x._prev = x, p
x._next, self._prev = self, x
# compute x._sort_key
psk = x._prev._sort_key
nsk = x._next._sort_key
if len(psk) > len(nsk):
idx: int
*prefix, idx = psk[: len(nsk) + 1]
x._sort_key = (*prefix, idx + 1)
elif len(psk) < len(nsk):
*prefix, idx = nsk[: len(psk) + 1]
x._sort_key = (*prefix, idx - 1)
else: # same length, increase length by 1
x._sort_key = (*psk, 0)
def __gt__(self, other: "Node") -> bool:
return self._sort_key > other._sort_key
def __lt__(self, other: "Node") -> bool:
return self._sort_key < other._sort_key
def __ge__(self, other: "Node") -> bool:
return self > other or self == other
def __le__(self, other: "Node") -> bool:
return self < other or self == other
self._prepend(x)
@compatibility(is_backward_compatible=True)
def append(self, x: "Node") -> None:
@ -430,11 +396,7 @@ class Node(_NodeBase):
Args:
x (Node): The node to put after this node. Must be a member of the same graph.
"""
self._next.prepend(x)
def _remove_from_list(self) -> None:
p, n = self._prev, self._next
p._next, n._prev = n, p
self._next._prepend(x)
@property
def args(self) -> tuple[Argument, ...]:

View File

@ -238,6 +238,47 @@ def skip_if_lt_x_gpu(x):
return decorator
def requires_world_size(n: int):
"""
Decorator to request a specific world size for a test. The test harness can
read this attribute to set the number of ranks to spawn. If there are fewer
than `n` CUDA devices available, the test should be skipped by the harness.
Usage:
@require_world_size(3)
def test_something(self):
...
"""
def decorator(func):
func._required_world_size = n
available = torch.cuda.device_count()
return unittest.skipUnless(
available >= n, f"requires {n} GPUs, found {available}"
)(func)
return decorator
def get_required_world_size(obj: Any, default: int) -> int:
"""
Returns the requested world size for the currently running unittest method on `obj`
if annotated via `@require_world_size(n)`, else returns `default`.
"""
try:
# Try MultiProcessTestCase helper first, then unittest fallback
test_name = (
obj._current_test_name() # type: ignore[attr-defined]
if hasattr(obj, "_current_test_name") and callable(obj._current_test_name)
else obj._testMethodName
)
fn = getattr(obj, test_name)
value = fn._required_world_size
return int(value)
except Exception:
return default
# This decorator helps avoiding initializing cuda while testing other backends
def nccl_skip_if_lt_x_gpu(backend, x):
def decorator(func):
@ -367,6 +408,13 @@ def requires_nccl_version(version, msg):
)
def requires_nccl_shrink():
"""
Require NCCL shrink support (NCCL available and version >= 2.27).
"""
return requires_nccl_version((2, 27), "Need NCCL 2.27+ for shrink_group")
def requires_nccl():
return skip_but_pass_in_sandcastle_if(
not c10d.is_nccl_available(),

View File

@ -1836,7 +1836,7 @@ def make_dynamo_test(
Decorator function to create a dynamo test case. A function annotate with
this decorator takes as input a unittest object.
"""
from torch._dynamo.testing import CompileCounter, reset, optimize_assert
from torch._dynamo.testing import CompileCounter, reset, optimize
if fn is None:
return lambda fn: make_dynamo_test(fn)
@ -1852,7 +1852,7 @@ def make_dynamo_test(
dummy()
reset()
opt_fn = optimize_assert(actual)(dummy)
opt_fn = optimize(actual, error_on_graph_break=True)(dummy)
opt_fn()
reset()