From 40ea6e418a324ef8ca34e85176dec1a496621f11 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 10 Sep 2025 20:51:30 +0000 Subject: [PATCH] Revert "Fix decorators skipping NCCL tests (#158846)" This reverts commit c2388201fc85b0748173212de5a17514c7a71f21. Reverted https://github.com/pytorch/pytorch/pull/158846 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing some inductor tests ([comment](https://github.com/pytorch/pytorch/pull/158846#issuecomment-3276471387)) --- .../fsdp/test_fully_shard_logging.py | 6 +- test/distributed/test_functional_api.py | 25 ++++-- torch/testing/_internal/common_distributed.py | 89 +++++++++++++------ .../_shard/sharded_tensor/__init__.py | 9 +- .../distributed/_tensor/common_dtensor.py | 7 +- .../_internal/distributed/distributed_test.py | 18 ++-- 6 files changed, 103 insertions(+), 51 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_logging.py b/test/distributed/_composable/fsdp/test_fully_shard_logging.py index 9b666eb55ba0..c9450a2b8f47 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_logging.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_logging.py @@ -1,7 +1,7 @@ # Owner(s): ["module: fsdp"] import functools import os -import unittest +import unittest.mock import torch.distributed as dist from torch._dynamo.test_case import run_tests @@ -37,9 +37,9 @@ import torch import torch.distributed as dist import torch.nn as nn from torch.distributed.fsdp import fully_shard -logger = logging.getLogger("torch.distributed.fsdp.fully_shard") +logger = logging.getLogger("torch.distributed._composable.fsdp") logger.setLevel(logging.DEBUG) -device = '{device_type.type}' +device = {device_type.type} torch.manual_seed(0) model = nn.Sequential(*[nn.Linear(4, 4, device=device, bias=False) for _ in range(2)]) for layer in model: diff --git a/test/distributed/test_functional_api.py b/test/distributed/test_functional_api.py index a21eb0dbf444..b5522fe2bef0 100644 --- a/test/distributed/test_functional_api.py +++ b/test/distributed/test_functional_api.py @@ -13,7 +13,6 @@ from functorch import make_fx from torch._inductor.utils import run_and_get_code from torch.testing import FileCheck from torch.testing._internal.common_device_type import instantiate_device_type_tests -from torch.testing._internal.common_distributed import exit_if_lt_x_accelerators from torch.testing._internal.inductor_utils import HAS_GPU @@ -25,7 +24,7 @@ from torch.testing._internal.common_distributed import ( DistributedTestBase, MultiThreadedTestCase, requires_accelerator_dist_backend, - skip_if_no_gpu, + TEST_SKIPS, ) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -480,14 +479,25 @@ elif TEST_XPU: BACKEND = dist.Backend.XCCL +# allows you to check for multiple accelerator irrespective of device type +# to add new device types to this check simply follow the same format +# and append an elif with the conditional and appropriate device count function for your new device +def exit_if_lt_x_accelerators(x): + if torch.accelerator.is_available(): + if torch.accelerator.device_count() < x: + sys.exit(TEST_SKIPS[f"multi-accelerator-{x}"].exit_code) + + def with_comms(func=None): if func is None: return partial(with_comms) @wraps(func) def wrapper(self, *args, **kwargs): - if BACKEND in (dist.Backend.NCCL, dist.Backend.XCCL): - exit_if_lt_x_accelerators(self.world_size) + if ( + BACKEND == dist.Backend.NCCL or BACKEND == dist.Backend.XCCL + ) and torch.accelerator.device_count() < self.world_size: + sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) kwargs["device"] = DEVICE self.pg = self.create_pg(device=DEVICE) @@ -500,9 +510,9 @@ def with_comms(func=None): class TestCollectivesWithDistributedBackend(DistributedTestBase): - @skip_if_no_gpu @with_comms() def test_all_gather_into_tensor_coalesced(self, device): + exit_if_lt_x_accelerators(self.world_size) tensors = [ torch.ones([4], device=device), torch.ones([4], device=device) + 1, @@ -574,8 +584,9 @@ class TestCollectivesWithDistributedBackend(DistributedTestBase): compiled_allreduce(torch.randn(8, device=device), self.pg) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @skip_if_no_gpu def test_tracing_with_fakepg(self, device=DEVICE): + exit_if_lt_x_accelerators(self.world_size) + def allreduce(t, pg): return ft_c.all_reduce(t, "sum", pg) @@ -616,9 +627,9 @@ class TestDistributedBackendCollectivesWithWorldSize4( def world_size(self): return 4 - @skip_if_no_gpu @with_comms() def test_permute_tensor_with_sub_group(self, device): + exit_if_lt_x_accelerators(self.world_size) mesh_dim_names = ["dp", "tp"] mesh_2d = dt.init_device_mesh( diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index d9d07dddea3d..c1f75697fe88 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -118,26 +118,14 @@ def requires_ddp_rank(device): return device in DDP_RANK_DEVICES -def exit_if_lt_x_cuda_devs(x): - """Exit process unless at least the given number of CUDA devices are available""" - if torch.cuda.device_count() < x: - sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) - - -# allows you to check for multiple accelerator irrespective of device type -# to add new device types to this check simply follow the same format -# and append an elif with the conditional and appropriate device count function for your new device -def exit_if_lt_x_accelerators(x): - if torch.accelerator.device_count() < x: - sys.exit(TEST_SKIPS[f"multi-accelerator-{x}"].exit_code) - - def skip_if_no_gpu(func): """Skips if the world size exceeds the number of GPUs, ensuring that if the test is run, each rank has its own GPU via ``torch.cuda.device(rank)``.""" @wraps(func) def wrapper(*args, **kwargs): + if not (TEST_CUDA or TEST_HPU or TEST_XPU): + sys.exit(TEST_SKIPS["no_cuda"].exit_code) world_size = int(os.environ["WORLD_SIZE"]) if TEST_CUDA and torch.cuda.device_count() < world_size: sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code) @@ -148,9 +136,7 @@ def skip_if_no_gpu(func): return func(*args, **kwargs) - return unittest.skipUnless( - TEST_CUDA or TEST_HPU or TEST_XPU, TEST_SKIPS["no_cuda"].message - )(wrapper) + return wrapper # TODO (kwen2501): what is the purpose of this decorator? Tests with this @@ -182,16 +168,33 @@ def skip_if_odd_worldsize(func): def require_n_gpus_for_nccl_backend(n, backend): - return skip_if_lt_x_gpu(n) if backend == "nccl" else unittest.skipIf(False, None) + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + if backend == "nccl" and torch.cuda.device_count() < n: + sys.exit(TEST_SKIPS[f"multi-gpu-{n}"].exit_code) + else: + return func(*args, **kwargs) + + return wrapper + + return decorator def import_transformers_or_skip(): - try: - from transformers import AutoModelForMaskedLM, BertConfig # noqa: F401 + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + from transformers import AutoModelForMaskedLM, BertConfig # noqa: F401 - return unittest.skipIf(False) - except ImportError: - return unittest.skip(TEST_SKIPS["importerror"].message) + return func(*args, **kwargs) + except ImportError: + sys.exit(TEST_SKIPS["importerror"].exit_code) + + return wrapper + + return decorator def at_least_x_gpu(x): @@ -205,7 +208,36 @@ def at_least_x_gpu(x): def skip_if_lt_x_gpu(x): - return unittest.skipUnless(at_least_x_gpu(x), TEST_SKIPS[f"multi-gpu-{x}"].message) + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + if torch.cuda.is_available() and torch.cuda.device_count() >= x: + return func(*args, **kwargs) + if TEST_HPU and torch.hpu.device_count() >= x: + return func(*args, **kwargs) + if TEST_XPU and torch.xpu.device_count() >= x: + return func(*args, **kwargs) + sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) + + return wrapper + + return decorator + + +# This decorator helps avoiding initializing cuda while testing other backends +def nccl_skip_if_lt_x_gpu(backend, x): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + if backend != "nccl": + return func(*args, **kwargs) + if torch.cuda.is_available() and torch.cuda.device_count() >= x: + return func(*args, **kwargs) + sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) + + return wrapper + + return decorator def verify_ddp_error_logged(model_DDP, err_substr): @@ -392,7 +424,14 @@ def requires_multicast_support(): def skip_if_rocm_multiprocess(func): """Skips a test for ROCm""" func.skip_if_rocm_multiprocess = True - return unittest.skipUnless(TEST_WITH_ROCM, TEST_SKIPS["skipIfRocm"].message)(func) + + @wraps(func) + def wrapper(*args, **kwargs): + if not TEST_WITH_ROCM: + return func(*args, **kwargs) + sys.exit(TEST_SKIPS["skipIfRocm"].exit_code) + + return wrapper def skip_if_win32(): diff --git a/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py b/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py index a0a38837c14b..60c744ac1a84 100644 --- a/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py +++ b/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py @@ -7,9 +7,8 @@ import torch import torch.distributed as dist from torch.distributed import rpc from torch.testing._internal.common_distributed import ( - exit_if_lt_x_cuda_devs, MultiProcessTestCase, - require_n_gpus_for_nccl_backend, + TEST_SKIPS, tp_transports, ) @@ -95,10 +94,10 @@ def with_comms(func=None, init_rpc=True, backend="nccl"): @wraps(func) def wrapper(self, *args, **kwargs): - if backend == "nccl": - exit_if_lt_x_cuda_devs(self.world_size) + if backend == "nccl" and torch.cuda.device_count() < self.world_size: + sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) self.init_comms(init_rpc=init_rpc, backend=backend) func(self, *args, **kwargs) self.destroy_comms(destroy_rpc=init_rpc) - return require_n_gpus_for_nccl_backend(1, backend)(wrapper) + return wrapper diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index 9758fa5d1e7d..e25e08fbf509 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -3,6 +3,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import itertools +import sys from collections.abc import Iterator, Sequence from dataclasses import dataclass from functools import partial, wraps @@ -30,12 +31,12 @@ from torch.distributed.tensor.parallel import ( SequenceParallel, ) from torch.testing._internal.common_distributed import ( - exit_if_lt_x_cuda_devs, MultiProcContinuousTest, MultiProcessTestCase, MultiThreadedTestCase, run_subtests, skip_if_lt_x_gpu, + TEST_SKIPS, ) from torch.testing._internal.common_utils import TEST_CUDA, TEST_HPU, TEST_XPU from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec @@ -373,8 +374,8 @@ class DTensorTestBase(MultiProcessTestCase): return init_device_mesh(self.device_type, (self.world_size,)) def init_pg(self, eager_init, backend: Optional[str] = None) -> None: - if "nccl" in self.backend: - exit_if_lt_x_cuda_devs(self.world_size) + if "nccl" in self.backend and torch.cuda.device_count() < self.world_size: + sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) if backend is None: backend = self.backend diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 21d51b66ad03..024fd47285ae 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -59,10 +59,10 @@ from torch.testing._internal.common_distributed import ( captured_output, cleanup_temp_dir, DistTestCases, - exit_if_lt_x_cuda_devs, init_multigpu_helper, initialize_temp_directories, MultiProcessTestCase, + nccl_skip_if_lt_x_gpu, require_n_gpus_for_nccl_backend, requires_nccl_version, simple_sparse_reduce_tests, @@ -609,8 +609,10 @@ class TestDistBackend(MultiProcessTestCase): self.rank = rank self.file_name = file_name - if torch.cuda.is_available(): - exit_if_lt_x_cuda_devs(int(self.world_size)) + if torch.cuda.is_available() and torch.cuda.device_count() < int( + self.world_size + ): + sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) try: pg_timeout_seconds = CUSTOM_PG_TIMEOUT.get(test_name, default_pg_timeout) timeout = timedelta(seconds=pg_timeout_seconds) @@ -5342,7 +5344,7 @@ class DistributedTest: BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", "get_future is only supported on mpi, nccl and gloo", ) - @require_n_gpus_for_nccl_backend(2, BACKEND) + @nccl_skip_if_lt_x_gpu(BACKEND, 2) def test_accumulate_gradients_no_sync(self): """ Runs _test_accumulate_gradients_no_sync using default inputs @@ -5353,7 +5355,7 @@ class DistributedTest: BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", "get_future is only supported on mpi, nccl and gloo", ) - @require_n_gpus_for_nccl_backend(2, BACKEND) + @nccl_skip_if_lt_x_gpu(BACKEND, 2) def test_accumulate_gradients_no_sync_grad_is_view(self): """ Runs _test_accumulate_gradients_no_sync using default inputs @@ -5364,7 +5366,7 @@ class DistributedTest: BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", "get_future is only supported on mpi, nccl and gloo", ) - @require_n_gpus_for_nccl_backend(2, BACKEND) + @nccl_skip_if_lt_x_gpu(BACKEND, 2) def test_accumulate_gradients_no_sync_allreduce_hook(self): """ Runs multiple iterations on _test_accumulate_gradients_no_sync @@ -5392,7 +5394,7 @@ class DistributedTest: BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", "get_future is only supported on mpi, nccl and gloo", ) - @require_n_gpus_for_nccl_backend(2, BACKEND) + @nccl_skip_if_lt_x_gpu(BACKEND, 2) def test_accumulate_gradients_no_sync_allreduce_with_then_hook(self): """ Runs multiple iterations on _test_accumulate_gradients_no_sync using allreduce @@ -5426,7 +5428,7 @@ class DistributedTest: BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", "get_future is only supported on mpi, nccl and gloo", ) - @require_n_gpus_for_nccl_backend(2, BACKEND) + @nccl_skip_if_lt_x_gpu(BACKEND, 2) def test_get_future(self): def mult(fut): return [t * 3 for t in fut.wait()]