Introduce TEST_ACCELERATOR and TEST_MULTIACCELERATOR to simplify UT (#167196)

# Motivation
This PR aims to introduce two variables (`TEST_ACCELERATOR` and `TEST_MULTIACCELERATOR`) to simplify UT generalization. Since out-of-tree backends may be imported later, these variables are defined as lazy values.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167196
Approved by: https://github.com/albanD
This commit is contained in:
Yu, Guangye
2025-11-06 13:36:00 +00:00
committed by PyTorch MergeBot
parent 0e512ee9f0
commit 292bd62c71
4 changed files with 51 additions and 41 deletions

View File

@ -46,6 +46,7 @@ from torch.testing._internal.common_utils import (
parametrize,
run_tests,
skip_but_pass_in_sandcastle_if,
TEST_MULTIACCELERATOR,
)
@ -56,7 +57,6 @@ batch_size = 64
torch.manual_seed(0)
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
backend = dist.get_default_backend_for_device(device_type)
TEST_MULTIACCELERATOR = torch.accelerator.device_count() >= 2
@dataclass

View File

@ -24,6 +24,7 @@ from torch.testing._internal.common_utils import (
parametrize,
run_tests,
skip_but_pass_in_sandcastle_if,
TEST_MULTIACCELERATOR,
)
from torch.utils._pytree import tree_map_only
@ -34,7 +35,6 @@ chunks = 8
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
backend = dist.get_default_backend_for_device(device_type)
TEST_MULTIACCELERATOR = torch.accelerator.device_count() >= 2
torch.manual_seed(0)

View File

@ -5,17 +5,22 @@ import sys
import unittest
import torch
from torch.testing._internal.common_utils import NoTest, run_tests, TEST_MPS, TestCase
from torch.testing._internal.common_utils import (
NoTest,
run_tests,
TEST_ACCELERATOR,
TEST_MPS,
TEST_MULTIACCELERATOR,
TestCase,
)
if not torch.accelerator.is_available():
if not TEST_ACCELERATOR:
print("No available accelerator detected, skipping tests", file=sys.stderr)
TestCase = NoTest # noqa: F811
# Skip because failing when run on cuda build with no GPU, see #150059 for example
sys.exit()
TEST_MULTIACCELERATOR = torch.accelerator.device_count() > 1
class TestAccelerator(TestCase):
def test_current_accelerator(self):

View File

@ -1468,6 +1468,44 @@ def is_privateuse1_backend_available():
return (is_available := getattr(privateuse1_backend_module, "is_available", None)) and is_available()
def make_lazy_class(cls):
def lazy_init(self, cb):
self._cb = cb
self._value = None
cls.__init__ = lazy_init
for basename in [
"add", "sub", "mul", "truediv", "floordiv", "mod", "divmod", "pow",
"lshift", "rshift", "and", "or", "xor", "neg", "pos", "abs", "invert",
"eq", "ne", "lt", "le", "gt", "ge", "bool", "int", "index",
]:
name = f"__{basename}__"
def inner_wrapper(name):
use_operator = basename not in ("bool", "int")
def wrapped(self, *args, **kwargs):
if self._cb is not None:
self._value = self._cb()
self._cb = None
if not use_operator:
return getattr(self._value, name)(*args, **kwargs)
else:
return getattr(operator, name)(self._value, *args, **kwargs)
return wrapped
setattr(cls, name, inner_wrapper(name))
return cls
@make_lazy_class
class LazyVal:
pass
IS_FILESYSTEM_UTF8_ENCODING = sys.getfilesystemencoding() == 'utf-8'
TEST_NUMPY = _check_module_exists('numpy')
@ -1480,6 +1518,8 @@ MACOS_VERSION = float('.'.join(platform.mac_ver()[0].split('.')[:2]) or -1)
TEST_XPU = torch.xpu.is_available()
TEST_HPU = bool(hasattr(torch, "hpu") and torch.hpu.is_available())
TEST_CUDA = torch.cuda.is_available()
TEST_ACCELERATOR = LazyVal(lambda: torch.accelerator.is_available()) # type: ignore[call-arg]
TEST_MULTIACCELERATOR = LazyVal(lambda: torch.accelerator.device_count() > 1) # type: ignore[call-arg]
custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name(), None)
TEST_PRIVATEUSE1 = is_privateuse1_backend_available()
TEST_PRIVATEUSE1_DEVICE_TYPE = torch._C._get_privateuse1_backend_name()
@ -5601,37 +5641,7 @@ class TestGradients(TestCase):
if not op.supports_autograd and not op.supports_forward_ad:
self.skipTest("Skipped! autograd not supported.")
def make_lazy_class(cls):
def lazy_init(self, cb):
self._cb = cb
self._value = None
cls.__init__ = lazy_init
for basename in [
"add", "sub", "mul", "truediv", "floordiv", "mod", "divmod", "pow",
"lshift", "rshift", "and", "or", "xor", "neg", "pos", "abs", "invert",
"eq", "ne", "lt", "le", "gt", "ge", "bool", "int", "index",
]:
name = f"__{basename}__"
def inner_wrapper(name):
use_operator = basename not in ("bool", "int")
def wrapped(self, *args, **kwargs):
if self._cb is not None:
self._value = self._cb()
self._cb = None
if not use_operator:
return getattr(self._value, name)(*args, **kwargs)
else:
return getattr(operator, name)(self._value, *args, **kwargs)
return wrapped
setattr(cls, name, inner_wrapper(name))
return cls
# Base TestCase for NT tests; used to define common helpers, etc.
@ -5676,11 +5686,6 @@ class NestedTensorTestCase(TestCase):
nested_tensor_module._tensor_symint_registry = original_tensor_symint_registry
@make_lazy_class
class LazyVal:
pass
def munge_exc(e, *, suppress_suffix=True, suppress_prefix=True, file=None, skip=0):
from torch._dynamo.trace_rules import _as_posix_path