Add a variable for some testcases. (#124708)

Some testcases can use 'TEST_PRIVATEUSE1_DEVICE_TYPE' to make adapting these testcases on others device more convenient.

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124708
Approved by: https://github.com/albanD
This commit is contained in:
yan-yhy
2024-05-01 23:19:07 +00:00
committed by PyTorch MergeBot
parent c451d108da
commit 6cfb55dd5d
2 changed files with 5 additions and 3 deletions

View File

@ -12,7 +12,7 @@ import unittest
from torch import nan
from torch.testing import make_tensor
from torch.testing._internal.common_utils import (
TestCase, run_tests, skipIfTorchDynamo, torch_to_numpy_dtype_dict, IS_JETSON)
TestCase, run_tests, skipIfTorchDynamo, torch_to_numpy_dtype_dict, IS_JETSON, TEST_PRIVATEUSE1_DEVICE_TYPE)
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, onlyCPU, onlyCUDA, dtypes, onlyNativeDeviceTypes,
dtypesIfCUDA, largeTensorTest)
@ -631,7 +631,7 @@ class TestShapeOps(TestCase):
"scalar type Long",
lambda: torch.nonzero(tensor, out=torch.empty([], dtype=torch.float, device=device))
)
if self.device_type == 'cuda':
if self.device_type == 'cuda' or self.device_type == TEST_PRIVATEUSE1_DEVICE_TYPE:
self.assertRaisesRegex(
RuntimeError,
"on the same device",

View File

@ -1235,7 +1235,9 @@ TEST_MPS = torch.backends.mps.is_available()
TEST_XPU = torch.xpu.is_available()
TEST_CUDA = torch.cuda.is_available()
custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name(), None)
TEST_PRIVATEUSE1 = True if (hasattr(custom_device_mod, "is_available") and custom_device_mod.is_available()) else False
custom_device_is_available = hasattr(custom_device_mod, "is_available") and custom_device_mod.is_available()
TEST_PRIVATEUSE1 = True if custom_device_is_available else False
TEST_PRIVATEUSE1_DEVICE_TYPE = torch._C._get_privateuse1_backend_name()
TEST_NUMBA = _check_module_exists('numba')
TEST_TRANSFORMERS = _check_module_exists('transformers')
TEST_DILL = _check_module_exists('dill')