Revert "shrink_group implementation to expose ncclCommShrink API (#164518)"

This reverts commit a032510db38e8331afa08f7635d146f9cefdd0ab.

Reverted https://github.com/pytorch/pytorch/pull/164518 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/164518#issuecomment-3416718767))
This commit is contained in:
PyTorch MergeBot
2025-10-17 18:55:53 +00:00
parent 7a65770013
commit fae74cd52f
11 changed files with 2 additions and 1503 deletions

View File

@ -228,47 +228,6 @@ def skip_if_lt_x_gpu(x):
return decorator
def requires_world_size(n: int):
"""
Decorator to request a specific world size for a test. The test harness can
read this attribute to set the number of ranks to spawn. If there are fewer
than `n` CUDA devices available, the test should be skipped by the harness.
Usage:
@require_world_size(3)
def test_something(self):
...
"""
def decorator(func):
func._required_world_size = n
available = torch.cuda.device_count()
return unittest.skipUnless(
available >= n, f"requires {n} GPUs, found {available}"
)(func)
return decorator
def get_required_world_size(obj: Any, default: int) -> int:
"""
Returns the requested world size for the currently running unittest method on `obj`
if annotated via `@require_world_size(n)`, else returns `default`.
"""
try:
# Try MultiProcessTestCase helper first, then unittest fallback
test_name = (
obj._current_test_name() # type: ignore[attr-defined]
if hasattr(obj, "_current_test_name") and callable(obj._current_test_name)
else obj._testMethodName
)
fn = getattr(obj, test_name)
value = fn._required_world_size
return int(value)
except Exception:
return default
# This decorator helps avoiding initializing cuda while testing other backends
def nccl_skip_if_lt_x_gpu(backend, x):
def decorator(func):
@ -396,13 +355,6 @@ def requires_nccl_version(version, msg):
)
def requires_nccl_shrink():
"""
Require NCCL shrink support (NCCL available and version >= 2.27).
"""
return requires_nccl_version((2, 27), "Need NCCL 2.27+ for shrink_group")
def requires_nccl():
return skip_but_pass_in_sandcastle_if(
not c10d.is_nccl_available(),