mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
shrink_group implementation to expose ncclCommShrink API (#164518)
Closes #164529 To expose the new [ncclCommShrink](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommshrink) API to PyTorch. This is useful when you need to exclude certain GPUs or nodes from a collective operation, for example in fault tolerance scenarios or when dynamically adjusting resource utilization. For more info: [Shrinking a communicator](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#shrinking-a-communicator) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164518 Approved by: https://github.com/Skylion007, https://github.com/syed-ahmed, https://github.com/kwen2501
This commit is contained in:
committed by
PyTorch MergeBot
parent
39e0a832c9
commit
a032510db3
@ -228,6 +228,47 @@ def skip_if_lt_x_gpu(x):
|
||||
return decorator
|
||||
|
||||
|
||||
def requires_world_size(n: int):
|
||||
"""
|
||||
Decorator to request a specific world size for a test. The test harness can
|
||||
read this attribute to set the number of ranks to spawn. If there are fewer
|
||||
than `n` CUDA devices available, the test should be skipped by the harness.
|
||||
|
||||
Usage:
|
||||
@require_world_size(3)
|
||||
def test_something(self):
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
func._required_world_size = n
|
||||
available = torch.cuda.device_count()
|
||||
return unittest.skipUnless(
|
||||
available >= n, f"requires {n} GPUs, found {available}"
|
||||
)(func)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_required_world_size(obj: Any, default: int) -> int:
|
||||
"""
|
||||
Returns the requested world size for the currently running unittest method on `obj`
|
||||
if annotated via `@require_world_size(n)`, else returns `default`.
|
||||
"""
|
||||
try:
|
||||
# Try MultiProcessTestCase helper first, then unittest fallback
|
||||
test_name = (
|
||||
obj._current_test_name() # type: ignore[attr-defined]
|
||||
if hasattr(obj, "_current_test_name") and callable(obj._current_test_name)
|
||||
else obj._testMethodName
|
||||
)
|
||||
fn = getattr(obj, test_name)
|
||||
value = fn._required_world_size
|
||||
return int(value)
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
|
||||
# This decorator helps avoiding initializing cuda while testing other backends
|
||||
def nccl_skip_if_lt_x_gpu(backend, x):
|
||||
def decorator(func):
|
||||
@ -355,6 +396,13 @@ def requires_nccl_version(version, msg):
|
||||
)
|
||||
|
||||
|
||||
def requires_nccl_shrink():
|
||||
"""
|
||||
Require NCCL shrink support (NCCL available and version >= 2.27).
|
||||
"""
|
||||
return requires_nccl_version((2, 27), "Need NCCL 2.27+ for shrink_group")
|
||||
|
||||
|
||||
def requires_nccl():
|
||||
return skip_but_pass_in_sandcastle_if(
|
||||
not c10d.is_nccl_available(),
|
||||
|
Reference in New Issue
Block a user