shrink_group implementation to expose ncclCommShrink API (#164518)

Closes #164529

To expose the new [ncclCommShrink](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommshrink) API to PyTorch.

This is useful when you need to exclude certain GPUs or nodes from a collective operation, for example in fault tolerance scenarios or when dynamically adjusting resource utilization.

For more info:  [Shrinking a communicator](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#shrinking-a-communicator)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164518
Approved by: https://github.com/Skylion007, https://github.com/syed-ahmed, https://github.com/kwen2501
This commit is contained in:
Bruce Chang
2025-10-17 17:55:00 +00:00
committed by PyTorch MergeBot
parent 39e0a832c9
commit a032510db3
11 changed files with 1503 additions and 2 deletions

View File

@ -228,6 +228,47 @@ def skip_if_lt_x_gpu(x):
return decorator
def requires_world_size(n: int):
"""
Decorator to request a specific world size for a test. The test harness can
read this attribute to set the number of ranks to spawn. If there are fewer
than `n` CUDA devices available, the test should be skipped by the harness.
Usage:
@require_world_size(3)
def test_something(self):
...
"""
def decorator(func):
func._required_world_size = n
available = torch.cuda.device_count()
return unittest.skipUnless(
available >= n, f"requires {n} GPUs, found {available}"
)(func)
return decorator
def get_required_world_size(obj: Any, default: int) -> int:
"""
Returns the requested world size for the currently running unittest method on `obj`
if annotated via `@require_world_size(n)`, else returns `default`.
"""
try:
# Try MultiProcessTestCase helper first, then unittest fallback
test_name = (
obj._current_test_name() # type: ignore[attr-defined]
if hasattr(obj, "_current_test_name") and callable(obj._current_test_name)
else obj._testMethodName
)
fn = getattr(obj, test_name)
value = fn._required_world_size
return int(value)
except Exception:
return default
# This decorator helps avoiding initializing cuda while testing other backends
def nccl_skip_if_lt_x_gpu(backend, x):
def decorator(func):
@ -355,6 +396,13 @@ def requires_nccl_version(version, msg):
)
def requires_nccl_shrink():
"""
Require NCCL shrink support (NCCL available and version >= 2.27).
"""
return requires_nccl_version((2, 27), "Need NCCL 2.27+ for shrink_group")
def requires_nccl():
return skip_but_pass_in_sandcastle_if(
not c10d.is_nccl_available(),