Compare commits

...

3 Commits

Author SHA1 Message Date
8dad7b7296 Add type hints and combine require_n_gpus_for_nccl_backend with nccl_skip_if_lt_x_gpu
Added type hints to several functions for better clarity and type checking.

Also make nccl_skip_if_lt_x_gpu an alias for require_n_gpus_for_nccl_backend as they have the same semantic
2025-11-10 09:40:20 +01:00
a119ca7390 Reduce usage of sys.exit to skip tests
Factor out `exit_if_lt_x_gpu`
Replace checks by `unittest.skip*` where possible
2025-11-10 09:40:20 +01:00
738c04b014 Fix MultiProcess failure on nodes with 1 GPU
The decorator(s) is written to `sys.exit` when the test function is
called which is AFTER the `setup` call which forks the processes and
uses (potentially) a GPU/NCCL based barrier which requires "n GPUs" to
be present befor checking if "n GPUs" are available.

Rewrite those decorators to use `unittest.skipIf` which will not even
enter the `setup` function.
This also exposed that `require_n_gpus_for_nccl_backend` is the same as
`nccl_skip_if_lt_x_gpu` but the former has a better name so I removed
the latter.

Fixes #89686
2025-11-10 09:40:18 +01:00
5 changed files with 55 additions and 108 deletions

View File

@ -13,6 +13,7 @@ from functorch import make_fx
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_distributed import exit_if_lt_x_accelerators
from torch.testing._internal.inductor_utils import HAS_GPU
@ -24,7 +25,7 @@ from torch.testing._internal.common_distributed import (
DistributedTestBase,
MultiThreadedTestCase,
requires_accelerator_dist_backend,
TEST_SKIPS,
skip_if_no_gpu,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
@ -479,25 +480,14 @@ elif TEST_XPU:
BACKEND = dist.Backend.XCCL
# allows you to check for multiple accelerator irrespective of device type
# to add new device types to this check simply follow the same format
# and append an elif with the conditional and appropriate device count function for your new device
def exit_if_lt_x_accelerators(x):
if torch.accelerator.is_available():
if torch.accelerator.device_count() < x:
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
def with_comms(func=None):
if func is None:
return partial(with_comms)
@wraps(func)
def wrapper(self, *args, **kwargs):
if (
BACKEND == dist.Backend.NCCL or BACKEND == dist.Backend.XCCL
) and torch.accelerator.device_count() < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
if BACKEND in (dist.Backend.NCCL, dist.Backend.XCCL):
exit_if_lt_x_accelerators(self.world_size)
kwargs["device"] = DEVICE
self.pg = self.create_pg(device=DEVICE)
@ -510,9 +500,9 @@ def with_comms(func=None):
class TestCollectivesWithDistributedBackend(DistributedTestBase):
@skip_if_no_gpu
@with_comms()
def test_all_gather_into_tensor_coalesced(self, device):
exit_if_lt_x_accelerators(self.world_size)
tensors = [
torch.ones([4], device=device),
torch.ones([4], device=device) + 1,
@ -584,9 +574,8 @@ class TestCollectivesWithDistributedBackend(DistributedTestBase):
compiled_allreduce(torch.randn(8, device=device), self.pg)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_no_gpu
def test_tracing_with_fakepg(self, device=DEVICE):
exit_if_lt_x_accelerators(self.world_size)
def allreduce(t, pg):
return ft_c.all_reduce(t, "sum", pg)
@ -627,9 +616,9 @@ class TestDistributedBackendCollectivesWithWorldSize4(
def world_size(self):
return 4
@skip_if_no_gpu
@with_comms()
def test_permute_tensor_with_sub_group(self, device):
exit_if_lt_x_accelerators(self.world_size)
mesh_dim_names = ["dp", "tp"]
mesh_2d = dt.init_device_mesh(

View File

@ -118,18 +118,30 @@ class DistTestCases:
backend_feature["xpu"] = {"xccl"}
def requires_ddp_rank(device):
def requires_ddp_rank(device) -> bool:
return device in DDP_RANK_DEVICES
def exit_if_lt_x_cuda_devs(x: int):
"""Exit process unless at least the given number of CUDA devices are available"""
if torch.cuda.device_count() < x:
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
# allows you to check for multiple accelerator irrespective of device type
# to add new device types to this check simply follow the same format
# and append an elif with the conditional and appropriate device count function for your new device
def exit_if_lt_x_accelerators(x: int):
if torch.accelerator.device_count() < x:
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
def skip_if_no_gpu(func):
"""Skips if the world size exceeds the number of GPUs, ensuring that if the
test is run, each rank has its own GPU via ``torch.cuda.device(rank)``."""
@wraps(func)
def wrapper(*args, **kwargs):
if not (TEST_CUDA or TEST_HPU or TEST_XPU):
sys.exit(TEST_SKIPS["no_cuda"].exit_code)
world_size = int(os.environ["WORLD_SIZE"])
if TEST_CUDA and torch.cuda.device_count() < world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
@ -140,7 +152,9 @@ def skip_if_no_gpu(func):
return func(*args, **kwargs)
return wrapper
return unittest.skipUnless(
TEST_CUDA or TEST_HPU or TEST_XPU, TEST_SKIPS["no_cuda"].message
)(wrapper)
# TODO (kwen2501): what is the purpose of this decorator? Tests with this
@ -171,37 +185,16 @@ def skip_if_odd_worldsize(func):
return wrapper
def require_n_gpus_for_nccl_backend(n, backend):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if backend == "nccl" and torch.cuda.device_count() < n:
sys.exit(TEST_SKIPS[f"multi-gpu-{n}"].exit_code)
else:
return func(*args, **kwargs)
return wrapper
return decorator
def import_transformers_or_skip():
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
from transformers import AutoModelForMaskedLM, BertConfig # noqa: F401
try:
from transformers import AutoModelForMaskedLM, BertConfig # noqa: F401
return func(*args, **kwargs)
except ImportError:
sys.exit(TEST_SKIPS["importerror"].exit_code)
return wrapper
return decorator
return unittest.skipIf(False, "Dummy")
except ImportError:
return unittest.skip(TEST_SKIPS["importerror"].message)
def at_least_x_gpu(x):
def at_least_x_gpu(x: int) -> bool:
if TEST_CUDA and torch.cuda.device_count() >= x:
return True
if TEST_HPU and torch.hpu.device_count() >= x:
@ -211,31 +204,8 @@ def at_least_x_gpu(x):
return False
def _maybe_handle_skip_if_lt_x_gpu(args, msg) -> bool:
_handle_test_skip = getattr(args[0], "_handle_test_skip", None)
if len(args) == 0 or _handle_test_skip is None:
return False
_handle_test_skip(msg)
return True
def skip_if_lt_x_gpu(x):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if torch.cuda.is_available() and torch.cuda.device_count() >= x:
return func(*args, **kwargs)
if TEST_HPU and torch.hpu.device_count() >= x:
return func(*args, **kwargs)
if TEST_XPU and torch.xpu.device_count() >= x:
return func(*args, **kwargs)
test_skip = TEST_SKIPS[f"multi-gpu-{x}"]
if not _maybe_handle_skip_if_lt_x_gpu(args, test_skip.message):
sys.exit(test_skip.exit_code)
return wrapper
return decorator
def skip_if_lt_x_gpu(x: int):
return unittest.skipUnless(at_least_x_gpu(x), TEST_SKIPS[f"multi-gpu-{x}"].message)
def requires_world_size(n: int):
@ -280,21 +250,14 @@ def get_required_world_size(obj: Any, default: int) -> int:
# This decorator helps avoiding initializing cuda while testing other backends
def nccl_skip_if_lt_x_gpu(backend, x):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if backend != "nccl":
return func(*args, **kwargs)
if torch.cuda.is_available() and torch.cuda.device_count() >= x:
return func(*args, **kwargs)
test_skip = TEST_SKIPS[f"multi-gpu-{x}"]
if not _maybe_handle_skip_if_lt_x_gpu(args, test_skip.message):
sys.exit(test_skip.exit_code)
def require_n_gpus_for_nccl_backend(n: int, backend: str):
return unittest.skipUnless(
backend != "nccl" or at_least_x_gpu(n), TEST_SKIPS[f"multi-gpu-{n}"].message
)
return wrapper
return decorator
def nccl_skip_if_lt_x_gpu(backend: str, x: int):
return require_n_gpus_for_nccl_backend(x, backend)
def verify_ddp_error_logged(model_DDP, err_substr):

View File

@ -7,8 +7,9 @@ import torch
import torch.distributed as dist
from torch.distributed import rpc
from torch.testing._internal.common_distributed import (
exit_if_lt_x_cuda_devs,
MultiProcessTestCase,
TEST_SKIPS,
require_n_gpus_for_nccl_backend,
tp_transports,
)
@ -94,10 +95,10 @@ def with_comms(func=None, init_rpc=True, backend="nccl"):
@wraps(func)
def wrapper(self, *args, **kwargs):
if backend == "nccl" and torch.cuda.device_count() < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
if backend == "nccl":
exit_if_lt_x_cuda_devs(self.world_size)
self.init_comms(init_rpc=init_rpc, backend=backend)
func(self, *args, **kwargs)
self.destroy_comms(destroy_rpc=init_rpc)
return wrapper
return require_n_gpus_for_nccl_backend(1, backend)(wrapper)

View File

@ -5,7 +5,6 @@
import contextlib
import functools
import itertools
import sys
import types
from collections.abc import Callable, Iterator, Sequence
from dataclasses import dataclass
@ -40,12 +39,12 @@ from torch.distributed.tensor.parallel import (
SequenceParallel,
)
from torch.testing._internal.common_distributed import (
exit_if_lt_x_cuda_devs,
MultiProcContinuousTest,
MultiProcessTestCase,
MultiThreadedTestCase,
run_subtests,
skip_if_lt_x_gpu,
TEST_SKIPS,
)
from torch.testing._internal.common_utils import (
TEST_CUDA,
@ -393,8 +392,8 @@ class DTensorTestBase(MultiProcessTestCase):
return init_device_mesh(self.device_type, (self.world_size,))
def init_pg(self, eager_init, backend: Optional[str] = None) -> None:
if "nccl" in self.backend and torch.cuda.device_count() < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
if "nccl" in self.backend:
exit_if_lt_x_cuda_devs(self.world_size)
curr_backend = dist.get_default_backend_for_device(self.device_type)
@ -710,9 +709,6 @@ class LocalDTensorTestBase(DTensorTestBase):
def is_local_tensor_enabled(self) -> bool:
return True
def _handle_test_skip(self, msg: str) -> None:
self.skipTest(msg)
def _get_local_tensor_mode(self):
return LocalTensorMode(frozenset(range(self.world_size)))

View File

@ -60,10 +60,10 @@ from torch.testing._internal.common_distributed import (
captured_output,
cleanup_temp_dir,
DistTestCases,
exit_if_lt_x_cuda_devs,
init_multigpu_helper,
initialize_temp_directories,
MultiProcessTestCase,
nccl_skip_if_lt_x_gpu,
require_n_gpus_for_nccl_backend,
requires_nccl_version,
simple_sparse_reduce_tests,
@ -603,10 +603,8 @@ class TestDistBackend(MultiProcessTestCase):
self.rank = rank
self.file_name = file_name
if torch.cuda.is_available() and torch.cuda.device_count() < int(
self.world_size
):
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
if torch.cuda.is_available():
exit_if_lt_x_cuda_devs(int(self.world_size))
try:
pg_timeout_seconds = CUSTOM_PG_TIMEOUT.get(test_name, default_pg_timeout)
timeout = timedelta(seconds=pg_timeout_seconds)
@ -5346,7 +5344,7 @@ class DistributedTest:
BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
"get_future is only supported on mpi, nccl and gloo",
)
@nccl_skip_if_lt_x_gpu(BACKEND, 2)
@require_n_gpus_for_nccl_backend(2, BACKEND)
def test_accumulate_gradients_no_sync(self):
"""
Runs _test_accumulate_gradients_no_sync using default inputs
@ -5357,7 +5355,7 @@ class DistributedTest:
BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
"get_future is only supported on mpi, nccl and gloo",
)
@nccl_skip_if_lt_x_gpu(BACKEND, 2)
@require_n_gpus_for_nccl_backend(2, BACKEND)
def test_accumulate_gradients_no_sync_grad_is_view(self):
"""
Runs _test_accumulate_gradients_no_sync using default inputs
@ -5368,7 +5366,7 @@ class DistributedTest:
BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
"get_future is only supported on mpi, nccl and gloo",
)
@nccl_skip_if_lt_x_gpu(BACKEND, 2)
@require_n_gpus_for_nccl_backend(2, BACKEND)
def test_accumulate_gradients_no_sync_allreduce_hook(self):
"""
Runs multiple iterations on _test_accumulate_gradients_no_sync
@ -5396,7 +5394,7 @@ class DistributedTest:
BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
"get_future is only supported on mpi, nccl and gloo",
)
@nccl_skip_if_lt_x_gpu(BACKEND, 2)
@require_n_gpus_for_nccl_backend(2, BACKEND)
def test_accumulate_gradients_no_sync_allreduce_with_then_hook(self):
"""
Runs multiple iterations on _test_accumulate_gradients_no_sync using allreduce
@ -5430,7 +5428,7 @@ class DistributedTest:
BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
"get_future is only supported on mpi, nccl and gloo",
)
@nccl_skip_if_lt_x_gpu(BACKEND, 2)
@require_n_gpus_for_nccl_backend(2, BACKEND)
def test_get_future(self):
def mult(fut):
return [t * 3 for t in fut.wait()]