mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
Introduce TEST_ACCELERATOR and TEST_MULTIACCELERATOR to simplify UT (#167196)
# Motivation This PR aims to introduce two variables (`TEST_ACCELERATOR` and `TEST_MULTIACCELERATOR`) to simplify UT generalization. Since out-of-tree backends may be imported later, these variables are defined as lazy values. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167196 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
0e512ee9f0
commit
292bd62c71
@ -46,6 +46,7 @@ from torch.testing._internal.common_utils import (
|
||||
parametrize,
|
||||
run_tests,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
TEST_MULTIACCELERATOR,
|
||||
)
|
||||
|
||||
|
||||
@ -56,7 +57,6 @@ batch_size = 64
|
||||
torch.manual_seed(0)
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
backend = dist.get_default_backend_for_device(device_type)
|
||||
TEST_MULTIACCELERATOR = torch.accelerator.device_count() >= 2
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -24,6 +24,7 @@ from torch.testing._internal.common_utils import (
|
||||
parametrize,
|
||||
run_tests,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
TEST_MULTIACCELERATOR,
|
||||
)
|
||||
from torch.utils._pytree import tree_map_only
|
||||
|
||||
@ -34,7 +35,6 @@ chunks = 8
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
backend = dist.get_default_backend_for_device(device_type)
|
||||
TEST_MULTIACCELERATOR = torch.accelerator.device_count() >= 2
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
@ -5,17 +5,22 @@ import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import NoTest, run_tests, TEST_MPS, TestCase
|
||||
from torch.testing._internal.common_utils import (
|
||||
NoTest,
|
||||
run_tests,
|
||||
TEST_ACCELERATOR,
|
||||
TEST_MPS,
|
||||
TEST_MULTIACCELERATOR,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
if not torch.accelerator.is_available():
|
||||
if not TEST_ACCELERATOR:
|
||||
print("No available accelerator detected, skipping tests", file=sys.stderr)
|
||||
TestCase = NoTest # noqa: F811
|
||||
# Skip because failing when run on cuda build with no GPU, see #150059 for example
|
||||
sys.exit()
|
||||
|
||||
TEST_MULTIACCELERATOR = torch.accelerator.device_count() > 1
|
||||
|
||||
|
||||
class TestAccelerator(TestCase):
|
||||
def test_current_accelerator(self):
|
||||
|
||||
@ -1468,6 +1468,44 @@ def is_privateuse1_backend_available():
|
||||
return (is_available := getattr(privateuse1_backend_module, "is_available", None)) and is_available()
|
||||
|
||||
|
||||
def make_lazy_class(cls):
|
||||
|
||||
def lazy_init(self, cb):
|
||||
self._cb = cb
|
||||
self._value = None
|
||||
|
||||
cls.__init__ = lazy_init
|
||||
|
||||
for basename in [
|
||||
"add", "sub", "mul", "truediv", "floordiv", "mod", "divmod", "pow",
|
||||
"lshift", "rshift", "and", "or", "xor", "neg", "pos", "abs", "invert",
|
||||
"eq", "ne", "lt", "le", "gt", "ge", "bool", "int", "index",
|
||||
]:
|
||||
name = f"__{basename}__"
|
||||
|
||||
def inner_wrapper(name):
|
||||
use_operator = basename not in ("bool", "int")
|
||||
|
||||
def wrapped(self, *args, **kwargs):
|
||||
if self._cb is not None:
|
||||
self._value = self._cb()
|
||||
self._cb = None
|
||||
if not use_operator:
|
||||
return getattr(self._value, name)(*args, **kwargs)
|
||||
else:
|
||||
return getattr(operator, name)(self._value, *args, **kwargs)
|
||||
return wrapped
|
||||
|
||||
setattr(cls, name, inner_wrapper(name))
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
@make_lazy_class
|
||||
class LazyVal:
|
||||
pass
|
||||
|
||||
|
||||
IS_FILESYSTEM_UTF8_ENCODING = sys.getfilesystemencoding() == 'utf-8'
|
||||
|
||||
TEST_NUMPY = _check_module_exists('numpy')
|
||||
@ -1480,6 +1518,8 @@ MACOS_VERSION = float('.'.join(platform.mac_ver()[0].split('.')[:2]) or -1)
|
||||
TEST_XPU = torch.xpu.is_available()
|
||||
TEST_HPU = bool(hasattr(torch, "hpu") and torch.hpu.is_available())
|
||||
TEST_CUDA = torch.cuda.is_available()
|
||||
TEST_ACCELERATOR = LazyVal(lambda: torch.accelerator.is_available()) # type: ignore[call-arg]
|
||||
TEST_MULTIACCELERATOR = LazyVal(lambda: torch.accelerator.device_count() > 1) # type: ignore[call-arg]
|
||||
custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name(), None)
|
||||
TEST_PRIVATEUSE1 = is_privateuse1_backend_available()
|
||||
TEST_PRIVATEUSE1_DEVICE_TYPE = torch._C._get_privateuse1_backend_name()
|
||||
@ -5601,37 +5641,7 @@ class TestGradients(TestCase):
|
||||
if not op.supports_autograd and not op.supports_forward_ad:
|
||||
self.skipTest("Skipped! autograd not supported.")
|
||||
|
||||
def make_lazy_class(cls):
|
||||
|
||||
def lazy_init(self, cb):
|
||||
self._cb = cb
|
||||
self._value = None
|
||||
|
||||
cls.__init__ = lazy_init
|
||||
|
||||
for basename in [
|
||||
"add", "sub", "mul", "truediv", "floordiv", "mod", "divmod", "pow",
|
||||
"lshift", "rshift", "and", "or", "xor", "neg", "pos", "abs", "invert",
|
||||
"eq", "ne", "lt", "le", "gt", "ge", "bool", "int", "index",
|
||||
]:
|
||||
name = f"__{basename}__"
|
||||
|
||||
def inner_wrapper(name):
|
||||
use_operator = basename not in ("bool", "int")
|
||||
|
||||
def wrapped(self, *args, **kwargs):
|
||||
if self._cb is not None:
|
||||
self._value = self._cb()
|
||||
self._cb = None
|
||||
if not use_operator:
|
||||
return getattr(self._value, name)(*args, **kwargs)
|
||||
else:
|
||||
return getattr(operator, name)(self._value, *args, **kwargs)
|
||||
return wrapped
|
||||
|
||||
setattr(cls, name, inner_wrapper(name))
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
# Base TestCase for NT tests; used to define common helpers, etc.
|
||||
@ -5676,11 +5686,6 @@ class NestedTensorTestCase(TestCase):
|
||||
nested_tensor_module._tensor_symint_registry = original_tensor_symint_registry
|
||||
|
||||
|
||||
@make_lazy_class
|
||||
class LazyVal:
|
||||
pass
|
||||
|
||||
|
||||
def munge_exc(e, *, suppress_suffix=True, suppress_prefix=True, file=None, skip=0):
|
||||
from torch._dynamo.trace_rules import _as_posix_path
|
||||
|
||||
|
||||
Reference in New Issue
Block a user