Revert "shrink_group implementation to expose ncclCommShrink API (#164518)"

This reverts commit fa0db212e717b6cb225159cb32ea3d83baa52381.

Reverted https://github.com/pytorch/pytorch/pull/164518 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/164518#issuecomment-3419893217))
This commit is contained in:
PyTorch MergeBot
2025-10-19 19:20:44 +00:00
parent fa0db212e7
commit 633a3b7f67
11 changed files with 2 additions and 1503 deletions

View File

@ -238,47 +238,6 @@ def skip_if_lt_x_gpu(x):
return decorator
def requires_world_size(n: int):
"""
Decorator to request a specific world size for a test. The test harness can
read this attribute to set the number of ranks to spawn. If there are fewer
than `n` CUDA devices available, the test should be skipped by the harness.
Usage:
@require_world_size(3)
def test_something(self):
...
"""
def decorator(func):
func._required_world_size = n
available = torch.cuda.device_count()
return unittest.skipUnless(
available >= n, f"requires {n} GPUs, found {available}"
)(func)
return decorator
def get_required_world_size(obj: Any, default: int) -> int:
"""
Returns the requested world size for the currently running unittest method on `obj`
if annotated via `@require_world_size(n)`, else returns `default`.
"""
try:
# Try MultiProcessTestCase helper first, then unittest fallback
test_name = (
obj._current_test_name() # type: ignore[attr-defined]
if hasattr(obj, "_current_test_name") and callable(obj._current_test_name)
else obj._testMethodName
)
fn = getattr(obj, test_name)
value = fn._required_world_size
return int(value)
except Exception:
return default
# This decorator helps avoiding initializing cuda while testing other backends
def nccl_skip_if_lt_x_gpu(backend, x):
def decorator(func):
@ -408,13 +367,6 @@ def requires_nccl_version(version, msg):
)
def requires_nccl_shrink():
"""
Require NCCL shrink support (NCCL available and version >= 2.27).
"""
return requires_nccl_version((2, 27), "Need NCCL 2.27+ for shrink_group")
def requires_nccl():
return skip_but_pass_in_sandcastle_if(
not c10d.is_nccl_available(),