mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add a variable for some testcases. (#124708)
Some testcases can use 'TEST_PRIVATEUSE1_DEVICE_TYPE' to make adapting these testcases on others device more convenient. Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/124708 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
c451d108da
commit
6cfb55dd5d
@ -12,7 +12,7 @@ import unittest
|
||||
from torch import nan
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, run_tests, skipIfTorchDynamo, torch_to_numpy_dtype_dict, IS_JETSON)
|
||||
TestCase, run_tests, skipIfTorchDynamo, torch_to_numpy_dtype_dict, IS_JETSON, TEST_PRIVATEUSE1_DEVICE_TYPE)
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests, onlyCPU, onlyCUDA, dtypes, onlyNativeDeviceTypes,
|
||||
dtypesIfCUDA, largeTensorTest)
|
||||
@ -631,7 +631,7 @@ class TestShapeOps(TestCase):
|
||||
"scalar type Long",
|
||||
lambda: torch.nonzero(tensor, out=torch.empty([], dtype=torch.float, device=device))
|
||||
)
|
||||
if self.device_type == 'cuda':
|
||||
if self.device_type == 'cuda' or self.device_type == TEST_PRIVATEUSE1_DEVICE_TYPE:
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"on the same device",
|
||||
|
@ -1235,7 +1235,9 @@ TEST_MPS = torch.backends.mps.is_available()
|
||||
TEST_XPU = torch.xpu.is_available()
|
||||
TEST_CUDA = torch.cuda.is_available()
|
||||
custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name(), None)
|
||||
TEST_PRIVATEUSE1 = True if (hasattr(custom_device_mod, "is_available") and custom_device_mod.is_available()) else False
|
||||
custom_device_is_available = hasattr(custom_device_mod, "is_available") and custom_device_mod.is_available()
|
||||
TEST_PRIVATEUSE1 = True if custom_device_is_available else False
|
||||
TEST_PRIVATEUSE1_DEVICE_TYPE = torch._C._get_privateuse1_backend_name()
|
||||
TEST_NUMBA = _check_module_exists('numba')
|
||||
TEST_TRANSFORMERS = _check_module_exists('transformers')
|
||||
TEST_DILL = _check_module_exists('dill')
|
||||
|
Reference in New Issue
Block a user