mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add a variable for some testcases. (#124708)
Some testcases can use 'TEST_PRIVATEUSE1_DEVICE_TYPE' to make adapting these testcases on others device more convenient. Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/124708 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
c451d108da
commit
6cfb55dd5d
@ -12,7 +12,7 @@ import unittest
|
||||
from torch import nan
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, run_tests, skipIfTorchDynamo, torch_to_numpy_dtype_dict, IS_JETSON)
|
||||
TestCase, run_tests, skipIfTorchDynamo, torch_to_numpy_dtype_dict, IS_JETSON, TEST_PRIVATEUSE1_DEVICE_TYPE)
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests, onlyCPU, onlyCUDA, dtypes, onlyNativeDeviceTypes,
|
||||
dtypesIfCUDA, largeTensorTest)
|
||||
@ -631,7 +631,7 @@ class TestShapeOps(TestCase):
|
||||
"scalar type Long",
|
||||
lambda: torch.nonzero(tensor, out=torch.empty([], dtype=torch.float, device=device))
|
||||
)
|
||||
if self.device_type == 'cuda':
|
||||
if self.device_type == 'cuda' or self.device_type == TEST_PRIVATEUSE1_DEVICE_TYPE:
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"on the same device",
|
||||
|
||||
Reference in New Issue
Block a user