Compare commits

...

3 Commits

Author SHA1 Message Date
8dad7b7296 Add type hints and combine require_n_gpus_for_nccl_backend with nccl_skip_if_lt_x_gpu
Added type hints to several functions for better clarity and type checking.

Also make nccl_skip_if_lt_x_gpu an alias for require_n_gpus_for_nccl_backend as they have the same semantic
2025-11-10 09:40:20 +01:00
a119ca7390 Reduce usage of sys.exit to skip tests
Factor out `exit_if_lt_x_gpu`
Replace checks by `unittest.skip*` where possible
2025-11-10 09:40:20 +01:00
738c04b014 Fix MultiProcess failure on nodes with 1 GPU
The decorator(s) is written to `sys.exit` when the test function is
called which is AFTER the `setup` call which forks the processes and
uses (potentially) a GPU/NCCL based barrier which requires "n GPUs" to
be present befor checking if "n GPUs" are available.

Rewrite those decorators to use `unittest.skipIf` which will not even
enter the `setup` function.
This also exposed that `require_n_gpus_for_nccl_backend` is the same as
`nccl_skip_if_lt_x_gpu` but the former has a better name so I removed
the latter.

Fixes #89686
2025-11-10 09:40:18 +01:00
5 changed files with 55 additions and 108 deletions

View File

@ -13,6 +13,7 @@ from functorch import make_fx
from torch._inductor.utils import run_and_get_code from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck from torch.testing import FileCheck
from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_distributed import exit_if_lt_x_accelerators
from torch.testing._internal.inductor_utils import HAS_GPU from torch.testing._internal.inductor_utils import HAS_GPU
@ -24,7 +25,7 @@ from torch.testing._internal.common_distributed import (
DistributedTestBase, DistributedTestBase,
MultiThreadedTestCase, MultiThreadedTestCase,
requires_accelerator_dist_backend, requires_accelerator_dist_backend,
TEST_SKIPS, skip_if_no_gpu,
) )
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
instantiate_parametrized_tests, instantiate_parametrized_tests,
@ -479,25 +480,14 @@ elif TEST_XPU:
BACKEND = dist.Backend.XCCL BACKEND = dist.Backend.XCCL
# allows you to check for multiple accelerator irrespective of device type
# to add new device types to this check simply follow the same format
# and append an elif with the conditional and appropriate device count function for your new device
def exit_if_lt_x_accelerators(x):
if torch.accelerator.is_available():
if torch.accelerator.device_count() < x:
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
def with_comms(func=None): def with_comms(func=None):
if func is None: if func is None:
return partial(with_comms) return partial(with_comms)
@wraps(func) @wraps(func)
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
if ( if BACKEND in (dist.Backend.NCCL, dist.Backend.XCCL):
BACKEND == dist.Backend.NCCL or BACKEND == dist.Backend.XCCL exit_if_lt_x_accelerators(self.world_size)
) and torch.accelerator.device_count() < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
kwargs["device"] = DEVICE kwargs["device"] = DEVICE
self.pg = self.create_pg(device=DEVICE) self.pg = self.create_pg(device=DEVICE)
@ -510,9 +500,9 @@ def with_comms(func=None):
class TestCollectivesWithDistributedBackend(DistributedTestBase): class TestCollectivesWithDistributedBackend(DistributedTestBase):
@skip_if_no_gpu
@with_comms() @with_comms()
def test_all_gather_into_tensor_coalesced(self, device): def test_all_gather_into_tensor_coalesced(self, device):
exit_if_lt_x_accelerators(self.world_size)
tensors = [ tensors = [
torch.ones([4], device=device), torch.ones([4], device=device),
torch.ones([4], device=device) + 1, torch.ones([4], device=device) + 1,
@ -584,9 +574,8 @@ class TestCollectivesWithDistributedBackend(DistributedTestBase):
compiled_allreduce(torch.randn(8, device=device), self.pg) compiled_allreduce(torch.randn(8, device=device), self.pg)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_no_gpu
def test_tracing_with_fakepg(self, device=DEVICE): def test_tracing_with_fakepg(self, device=DEVICE):
exit_if_lt_x_accelerators(self.world_size)
def allreduce(t, pg): def allreduce(t, pg):
return ft_c.all_reduce(t, "sum", pg) return ft_c.all_reduce(t, "sum", pg)
@ -627,9 +616,9 @@ class TestDistributedBackendCollectivesWithWorldSize4(
def world_size(self): def world_size(self):
return 4 return 4
@skip_if_no_gpu
@with_comms() @with_comms()
def test_permute_tensor_with_sub_group(self, device): def test_permute_tensor_with_sub_group(self, device):
exit_if_lt_x_accelerators(self.world_size)
mesh_dim_names = ["dp", "tp"] mesh_dim_names = ["dp", "tp"]
mesh_2d = dt.init_device_mesh( mesh_2d = dt.init_device_mesh(

View File

@ -118,18 +118,30 @@ class DistTestCases:
backend_feature["xpu"] = {"xccl"} backend_feature["xpu"] = {"xccl"}
def requires_ddp_rank(device): def requires_ddp_rank(device) -> bool:
return device in DDP_RANK_DEVICES return device in DDP_RANK_DEVICES
def exit_if_lt_x_cuda_devs(x: int):
"""Exit process unless at least the given number of CUDA devices are available"""
if torch.cuda.device_count() < x:
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
# allows you to check for multiple accelerator irrespective of device type
# to add new device types to this check simply follow the same format
# and append an elif with the conditional and appropriate device count function for your new device
def exit_if_lt_x_accelerators(x: int):
if torch.accelerator.device_count() < x:
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
def skip_if_no_gpu(func): def skip_if_no_gpu(func):
"""Skips if the world size exceeds the number of GPUs, ensuring that if the """Skips if the world size exceeds the number of GPUs, ensuring that if the
test is run, each rank has its own GPU via ``torch.cuda.device(rank)``.""" test is run, each rank has its own GPU via ``torch.cuda.device(rank)``."""
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if not (TEST_CUDA or TEST_HPU or TEST_XPU):
sys.exit(TEST_SKIPS["no_cuda"].exit_code)
world_size = int(os.environ["WORLD_SIZE"]) world_size = int(os.environ["WORLD_SIZE"])
if TEST_CUDA and torch.cuda.device_count() < world_size: if TEST_CUDA and torch.cuda.device_count() < world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code) sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
@ -140,7 +152,9 @@ def skip_if_no_gpu(func):
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper return unittest.skipUnless(
TEST_CUDA or TEST_HPU or TEST_XPU, TEST_SKIPS["no_cuda"].message
)(wrapper)
# TODO (kwen2501): what is the purpose of this decorator? Tests with this # TODO (kwen2501): what is the purpose of this decorator? Tests with this
@ -171,37 +185,16 @@ def skip_if_odd_worldsize(func):
return wrapper return wrapper
def require_n_gpus_for_nccl_backend(n, backend):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if backend == "nccl" and torch.cuda.device_count() < n:
sys.exit(TEST_SKIPS[f"multi-gpu-{n}"].exit_code)
else:
return func(*args, **kwargs)
return wrapper
return decorator
def import_transformers_or_skip(): def import_transformers_or_skip():
def decorator(func): try:
@wraps(func) from transformers import AutoModelForMaskedLM, BertConfig # noqa: F401
def wrapper(*args, **kwargs):
try:
from transformers import AutoModelForMaskedLM, BertConfig # noqa: F401
return func(*args, **kwargs) return unittest.skipIf(False, "Dummy")
except ImportError: except ImportError:
sys.exit(TEST_SKIPS["importerror"].exit_code) return unittest.skip(TEST_SKIPS["importerror"].message)
return wrapper
return decorator
def at_least_x_gpu(x): def at_least_x_gpu(x: int) -> bool:
if TEST_CUDA and torch.cuda.device_count() >= x: if TEST_CUDA and torch.cuda.device_count() >= x:
return True return True
if TEST_HPU and torch.hpu.device_count() >= x: if TEST_HPU and torch.hpu.device_count() >= x:
@ -211,31 +204,8 @@ def at_least_x_gpu(x):
return False return False
def _maybe_handle_skip_if_lt_x_gpu(args, msg) -> bool: def skip_if_lt_x_gpu(x: int):
_handle_test_skip = getattr(args[0], "_handle_test_skip", None) return unittest.skipUnless(at_least_x_gpu(x), TEST_SKIPS[f"multi-gpu-{x}"].message)
if len(args) == 0 or _handle_test_skip is None:
return False
_handle_test_skip(msg)
return True
def skip_if_lt_x_gpu(x):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if torch.cuda.is_available() and torch.cuda.device_count() >= x:
return func(*args, **kwargs)
if TEST_HPU and torch.hpu.device_count() >= x:
return func(*args, **kwargs)
if TEST_XPU and torch.xpu.device_count() >= x:
return func(*args, **kwargs)
test_skip = TEST_SKIPS[f"multi-gpu-{x}"]
if not _maybe_handle_skip_if_lt_x_gpu(args, test_skip.message):
sys.exit(test_skip.exit_code)
return wrapper
return decorator
def requires_world_size(n: int): def requires_world_size(n: int):
@ -280,21 +250,14 @@ def get_required_world_size(obj: Any, default: int) -> int:
# This decorator helps avoiding initializing cuda while testing other backends # This decorator helps avoiding initializing cuda while testing other backends
def nccl_skip_if_lt_x_gpu(backend, x): def require_n_gpus_for_nccl_backend(n: int, backend: str):
def decorator(func): return unittest.skipUnless(
@wraps(func) backend != "nccl" or at_least_x_gpu(n), TEST_SKIPS[f"multi-gpu-{n}"].message
def wrapper(*args, **kwargs): )
if backend != "nccl":
return func(*args, **kwargs)
if torch.cuda.is_available() and torch.cuda.device_count() >= x:
return func(*args, **kwargs)
test_skip = TEST_SKIPS[f"multi-gpu-{x}"]
if not _maybe_handle_skip_if_lt_x_gpu(args, test_skip.message):
sys.exit(test_skip.exit_code)
return wrapper
return decorator def nccl_skip_if_lt_x_gpu(backend: str, x: int):
return require_n_gpus_for_nccl_backend(x, backend)
def verify_ddp_error_logged(model_DDP, err_substr): def verify_ddp_error_logged(model_DDP, err_substr):

View File

@ -7,8 +7,9 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import rpc from torch.distributed import rpc
from torch.testing._internal.common_distributed import ( from torch.testing._internal.common_distributed import (
exit_if_lt_x_cuda_devs,
MultiProcessTestCase, MultiProcessTestCase,
TEST_SKIPS, require_n_gpus_for_nccl_backend,
tp_transports, tp_transports,
) )
@ -94,10 +95,10 @@ def with_comms(func=None, init_rpc=True, backend="nccl"):
@wraps(func) @wraps(func)
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
if backend == "nccl" and torch.cuda.device_count() < self.world_size: if backend == "nccl":
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) exit_if_lt_x_cuda_devs(self.world_size)
self.init_comms(init_rpc=init_rpc, backend=backend) self.init_comms(init_rpc=init_rpc, backend=backend)
func(self, *args, **kwargs) func(self, *args, **kwargs)
self.destroy_comms(destroy_rpc=init_rpc) self.destroy_comms(destroy_rpc=init_rpc)
return wrapper return require_n_gpus_for_nccl_backend(1, backend)(wrapper)

View File

@ -5,7 +5,6 @@
import contextlib import contextlib
import functools import functools
import itertools import itertools
import sys
import types import types
from collections.abc import Callable, Iterator, Sequence from collections.abc import Callable, Iterator, Sequence
from dataclasses import dataclass from dataclasses import dataclass
@ -40,12 +39,12 @@ from torch.distributed.tensor.parallel import (
SequenceParallel, SequenceParallel,
) )
from torch.testing._internal.common_distributed import ( from torch.testing._internal.common_distributed import (
exit_if_lt_x_cuda_devs,
MultiProcContinuousTest, MultiProcContinuousTest,
MultiProcessTestCase, MultiProcessTestCase,
MultiThreadedTestCase, MultiThreadedTestCase,
run_subtests, run_subtests,
skip_if_lt_x_gpu, skip_if_lt_x_gpu,
TEST_SKIPS,
) )
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
TEST_CUDA, TEST_CUDA,
@ -393,8 +392,8 @@ class DTensorTestBase(MultiProcessTestCase):
return init_device_mesh(self.device_type, (self.world_size,)) return init_device_mesh(self.device_type, (self.world_size,))
def init_pg(self, eager_init, backend: Optional[str] = None) -> None: def init_pg(self, eager_init, backend: Optional[str] = None) -> None:
if "nccl" in self.backend and torch.cuda.device_count() < self.world_size: if "nccl" in self.backend:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) exit_if_lt_x_cuda_devs(self.world_size)
curr_backend = dist.get_default_backend_for_device(self.device_type) curr_backend = dist.get_default_backend_for_device(self.device_type)
@ -710,9 +709,6 @@ class LocalDTensorTestBase(DTensorTestBase):
def is_local_tensor_enabled(self) -> bool: def is_local_tensor_enabled(self) -> bool:
return True return True
def _handle_test_skip(self, msg: str) -> None:
self.skipTest(msg)
def _get_local_tensor_mode(self): def _get_local_tensor_mode(self):
return LocalTensorMode(frozenset(range(self.world_size))) return LocalTensorMode(frozenset(range(self.world_size)))

View File

@ -60,10 +60,10 @@ from torch.testing._internal.common_distributed import (
captured_output, captured_output,
cleanup_temp_dir, cleanup_temp_dir,
DistTestCases, DistTestCases,
exit_if_lt_x_cuda_devs,
init_multigpu_helper, init_multigpu_helper,
initialize_temp_directories, initialize_temp_directories,
MultiProcessTestCase, MultiProcessTestCase,
nccl_skip_if_lt_x_gpu,
require_n_gpus_for_nccl_backend, require_n_gpus_for_nccl_backend,
requires_nccl_version, requires_nccl_version,
simple_sparse_reduce_tests, simple_sparse_reduce_tests,
@ -603,10 +603,8 @@ class TestDistBackend(MultiProcessTestCase):
self.rank = rank self.rank = rank
self.file_name = file_name self.file_name = file_name
if torch.cuda.is_available() and torch.cuda.device_count() < int( if torch.cuda.is_available():
self.world_size exit_if_lt_x_cuda_devs(int(self.world_size))
):
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
try: try:
pg_timeout_seconds = CUSTOM_PG_TIMEOUT.get(test_name, default_pg_timeout) pg_timeout_seconds = CUSTOM_PG_TIMEOUT.get(test_name, default_pg_timeout)
timeout = timedelta(seconds=pg_timeout_seconds) timeout = timedelta(seconds=pg_timeout_seconds)
@ -5346,7 +5344,7 @@ class DistributedTest:
BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
"get_future is only supported on mpi, nccl and gloo", "get_future is only supported on mpi, nccl and gloo",
) )
@nccl_skip_if_lt_x_gpu(BACKEND, 2) @require_n_gpus_for_nccl_backend(2, BACKEND)
def test_accumulate_gradients_no_sync(self): def test_accumulate_gradients_no_sync(self):
""" """
Runs _test_accumulate_gradients_no_sync using default inputs Runs _test_accumulate_gradients_no_sync using default inputs
@ -5357,7 +5355,7 @@ class DistributedTest:
BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
"get_future is only supported on mpi, nccl and gloo", "get_future is only supported on mpi, nccl and gloo",
) )
@nccl_skip_if_lt_x_gpu(BACKEND, 2) @require_n_gpus_for_nccl_backend(2, BACKEND)
def test_accumulate_gradients_no_sync_grad_is_view(self): def test_accumulate_gradients_no_sync_grad_is_view(self):
""" """
Runs _test_accumulate_gradients_no_sync using default inputs Runs _test_accumulate_gradients_no_sync using default inputs
@ -5368,7 +5366,7 @@ class DistributedTest:
BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
"get_future is only supported on mpi, nccl and gloo", "get_future is only supported on mpi, nccl and gloo",
) )
@nccl_skip_if_lt_x_gpu(BACKEND, 2) @require_n_gpus_for_nccl_backend(2, BACKEND)
def test_accumulate_gradients_no_sync_allreduce_hook(self): def test_accumulate_gradients_no_sync_allreduce_hook(self):
""" """
Runs multiple iterations on _test_accumulate_gradients_no_sync Runs multiple iterations on _test_accumulate_gradients_no_sync
@ -5396,7 +5394,7 @@ class DistributedTest:
BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
"get_future is only supported on mpi, nccl and gloo", "get_future is only supported on mpi, nccl and gloo",
) )
@nccl_skip_if_lt_x_gpu(BACKEND, 2) @require_n_gpus_for_nccl_backend(2, BACKEND)
def test_accumulate_gradients_no_sync_allreduce_with_then_hook(self): def test_accumulate_gradients_no_sync_allreduce_with_then_hook(self):
""" """
Runs multiple iterations on _test_accumulate_gradients_no_sync using allreduce Runs multiple iterations on _test_accumulate_gradients_no_sync using allreduce
@ -5430,7 +5428,7 @@ class DistributedTest:
BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
"get_future is only supported on mpi, nccl and gloo", "get_future is only supported on mpi, nccl and gloo",
) )
@nccl_skip_if_lt_x_gpu(BACKEND, 2) @require_n_gpus_for_nccl_backend(2, BACKEND)
def test_get_future(self): def test_get_future(self):
def mult(fut): def mult(fut):
return [t * 3 for t in fut.wait()] return [t * 3 for t in fut.wait()]