mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[1/N]Enable some tests in test_ops.TestCommon on Intel GPU (#159944)
For https://github.com/pytorch/pytorch/issues/114850, we will port aten unit tests to Intel GPU. This PR will work on some test case of test/test_ops.py. We could enable Intel GPU with following methods and try the best to keep the original code styles: 1. Extended XPUTestBase.get_all_devices to support multiple devices 2. Added skipXPU decorator 3. Extended onlyOn to support device list 4. Enabled 'xpu' for some test pathes 5. Added allow_xpu=True for supported test class. 6. Replaced onlyCUDA with onlyOn(['cuda', 'xpu']) for supported tests 7. Use skipIfXpu and skipXPU to disable unsupported test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159944 Approved by: https://github.com/guangyey, https://github.com/EikanWang, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
e1e5e040cd
commit
4fd70d4e7b
@ -28,9 +28,11 @@ from torch.testing._internal.common_device_type import (
|
||||
onlyCPU,
|
||||
onlyCUDA,
|
||||
onlyNativeDeviceTypesAnd,
|
||||
onlyOn,
|
||||
OpDTypes,
|
||||
ops,
|
||||
skipMeta,
|
||||
skipXPU,
|
||||
)
|
||||
from torch.testing._internal.common_dtype import (
|
||||
all_types_and_complex_and,
|
||||
@ -221,7 +223,7 @@ class TestCommon(TestCase):
|
||||
assert len(filtered_ops) == 0, err_msg
|
||||
|
||||
# Validates that each OpInfo works correctly on different CUDA devices
|
||||
@onlyCUDA
|
||||
@onlyOn(["cuda", "xpu"])
|
||||
@deviceCountAtLeast(2)
|
||||
@ops(op_db, allowed_dtypes=(torch.float32, torch.long))
|
||||
def test_multiple_devices(self, devices, dtype, op):
|
||||
@ -331,6 +333,8 @@ class TestCommon(TestCase):
|
||||
# NumPy does computation internally using double precision for many functions
|
||||
# resulting in possible equality check failures.
|
||||
# skip windows case on CPU due to https://github.com/pytorch/pytorch/issues/129947
|
||||
# XPU test will be enabled step by step. Skip the tests temporarily.
|
||||
@skipXPU
|
||||
@onlyNativeDeviceTypesAnd(["hpu"])
|
||||
@suppress_warnings
|
||||
@ops(_ref_test_ops, allowed_dtypes=(torch.float64, torch.long, torch.complex128))
|
||||
@ -340,7 +344,7 @@ class TestCommon(TestCase):
|
||||
and op.formatted_name
|
||||
in ("signal_windows_exponential", "signal_windows_bartlett")
|
||||
and dtype == torch.float64
|
||||
and "cuda" in device
|
||||
and ("cuda" in device or "xpu" in device)
|
||||
or "cpu" in device
|
||||
): # noqa: E121
|
||||
raise unittest.SkipTest("XXX: raises tensor-likes are not close.")
|
||||
@ -353,7 +357,7 @@ class TestCommon(TestCase):
|
||||
)
|
||||
|
||||
# Tests that the cpu and gpu results are consistent
|
||||
@onlyCUDA
|
||||
@onlyOn(["cuda", "xpu"])
|
||||
@suppress_warnings
|
||||
@slowTest
|
||||
@ops(_ops_and_refs_with_no_numpy_ref, dtypes=OpDTypes.any_common_cpu_cuda_one)
|
||||
@ -385,6 +389,7 @@ class TestCommon(TestCase):
|
||||
# Tests that experimental Python References can propagate shape, dtype,
|
||||
# and device metadata properly.
|
||||
# See https://github.com/pytorch/pytorch/issues/78050 for a discussion of stride propagation.
|
||||
@skipXPU
|
||||
@onlyNativeDeviceTypesAnd(["hpu"])
|
||||
@ops(python_ref_db)
|
||||
@skipIfTorchInductor("Takes too long for inductor")
|
||||
@ -580,6 +585,7 @@ class TestCommon(TestCase):
|
||||
# Tests that experimental Python References perform the same computation
|
||||
# as the operators they reference, when operator calls in the torch
|
||||
# namespace are remapped to the refs namespace (torch.foo becomes refs.foo).
|
||||
@skipXPU
|
||||
@onlyNativeDeviceTypesAnd(["hpu"])
|
||||
@ops(python_ref_db)
|
||||
@skipIfTorchInductor("Takes too long for inductor")
|
||||
@ -598,6 +604,7 @@ class TestCommon(TestCase):
|
||||
# Tests that experimental Python References perform the same computation
|
||||
# as the operators they reference, when operator calls in the torch
|
||||
# namespace are preserved (torch.foo remains torch.foo).
|
||||
@skipXPU
|
||||
@onlyNativeDeviceTypesAnd(["hpu"])
|
||||
@ops(python_ref_db)
|
||||
@skipIfTorchInductor("Takes too long for inductor")
|
||||
@ -633,6 +640,7 @@ class TestCommon(TestCase):
|
||||
op.op = partial(make_traced(op.op), executor=executor)
|
||||
self._ref_test_helper(contextlib.nullcontext, device, dtype, op)
|
||||
|
||||
@skipXPU
|
||||
@skipMeta
|
||||
@onlyNativeDeviceTypesAnd(["hpu"])
|
||||
@ops([op for op in op_db if op.error_inputs_func is not None], dtypes=OpDTypes.none)
|
||||
@ -644,6 +652,7 @@ class TestCommon(TestCase):
|
||||
out = op(si.input, *si.args, **si.kwargs)
|
||||
self.assertFalse(isinstance(out, type(NotImplemented)))
|
||||
|
||||
@skipXPU
|
||||
@skipMeta
|
||||
@onlyNativeDeviceTypesAnd(["hpu"])
|
||||
@ops(
|
||||
@ -667,6 +676,7 @@ class TestCommon(TestCase):
|
||||
out = op(si.input, *si.args, **si.kwargs)
|
||||
self.assertFalse(isinstance(out, type(NotImplemented)))
|
||||
|
||||
@skipXPU
|
||||
@skipMeta
|
||||
@onlyNativeDeviceTypesAnd(["hpu"])
|
||||
@ops(
|
||||
@ -693,6 +703,7 @@ class TestCommon(TestCase):
|
||||
|
||||
# Tests that the function produces the same result when called with
|
||||
# noncontiguous tensors.
|
||||
@skipXPU
|
||||
@with_tf32_off
|
||||
@onlyNativeDeviceTypesAnd(["hpu"])
|
||||
@suppress_warnings
|
||||
@ -785,6 +796,7 @@ class TestCommon(TestCase):
|
||||
# incorrectly sized out parameter warning properly yet
|
||||
# Cases test here:
|
||||
# - out= with the correct dtype and device, but the wrong shape
|
||||
@skipXPU
|
||||
@ops(ops_and_refs, dtypes=OpDTypes.none)
|
||||
def test_out_warning(self, device, op):
|
||||
if TEST_WITH_TORCHDYNAMO and op.name == "_refs.clamp":
|
||||
@ -923,6 +935,7 @@ class TestCommon(TestCase):
|
||||
# Case 3 and 4 are slightly different when the op is a factory function:
|
||||
# - if device, dtype are NOT passed, any combination of dtype/device should be OK for out
|
||||
# - if device, dtype are passed, device and dtype should match
|
||||
@skipXPU
|
||||
@ops(ops_and_refs, dtypes=OpDTypes.any_one)
|
||||
def test_out(self, device, dtype, op):
|
||||
# Prefers running in float32 but has a fallback for the first listed supported dtype
|
||||
@ -1126,6 +1139,7 @@ class TestCommon(TestCase):
|
||||
with self.assertRaises(exc_type, msg=msg_fail):
|
||||
op_out(out=out)
|
||||
|
||||
@skipXPU
|
||||
@ops(
|
||||
[
|
||||
op
|
||||
@ -1164,6 +1178,7 @@ class TestCommon(TestCase):
|
||||
with self.assertRaises(RuntimeError, msg=msg), maybe_skip_size_asserts(op):
|
||||
op(sample.input, *sample.args, **sample.kwargs, out=out)
|
||||
|
||||
@skipXPU
|
||||
@ops(filter(reduction_dtype_filter, ops_and_refs), dtypes=(torch.int16,))
|
||||
def test_out_integral_dtype(self, device, dtype, op):
|
||||
def helper(with_out, expectFail, op_to_test, inputs, *args, **kwargs):
|
||||
@ -1207,6 +1222,7 @@ class TestCommon(TestCase):
|
||||
# Tests that the forward and backward passes of operations produce the
|
||||
# same values for the cross-product of op variants (method, inplace)
|
||||
# against eager's gold standard op function variant
|
||||
@skipXPU
|
||||
@_variant_ops(op_db)
|
||||
def test_variant_consistency_eager(self, device, dtype, op):
|
||||
# Acquires variants (method variant, inplace variant, operator variant, inplace_operator variant, aliases)
|
||||
@ -1387,6 +1403,7 @@ class TestCommon(TestCase):
|
||||
|
||||
# Reference testing for operations in complex32 against complex64.
|
||||
# NOTE: We test against complex64 as NumPy doesn't have a complex32 equivalent dtype.
|
||||
@skipXPU
|
||||
@ops(op_db, allowed_dtypes=(torch.complex32,))
|
||||
def test_complex_half_reference_testing(self, device, dtype, op):
|
||||
if not op.supports_dtype(torch.complex32, device):
|
||||
@ -1422,6 +1439,7 @@ class TestCommon(TestCase):
|
||||
# `cfloat` input -> `float` output
|
||||
self.assertEqual(actual, expected, exact_dtype=False)
|
||||
|
||||
@skipXPU
|
||||
@ops(op_db, allowed_dtypes=(torch.bool,))
|
||||
def test_non_standard_bool_values(self, device, dtype, op):
|
||||
# Test boolean values other than 0x00 and 0x01 (gh-54789)
|
||||
@ -1450,6 +1468,7 @@ class TestCommon(TestCase):
|
||||
|
||||
# Validates that each OpInfo specifies its forward and backward dtypes
|
||||
# correctly for CPU and CUDA devices
|
||||
@skipXPU
|
||||
@skipMeta
|
||||
@onlyNativeDeviceTypesAnd(["hpu"])
|
||||
@ops(ops_and_refs, dtypes=OpDTypes.none)
|
||||
@ -1656,6 +1675,7 @@ class TestCommon(TestCase):
|
||||
self.fail(msg)
|
||||
|
||||
# Validates that each OpInfo that sets promotes_int_to_float=True does as it says
|
||||
@skipXPU
|
||||
@skipMeta
|
||||
@onlyNativeDeviceTypesAnd(["hpu"])
|
||||
@ops(
|
||||
@ -2845,7 +2865,7 @@ class TestFakeTensor(TestCase):
|
||||
self.assertEqual(strided_result.layout, torch.strided)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestCommon, globals())
|
||||
instantiate_device_type_tests(TestCommon, globals(), allow_xpu=True)
|
||||
instantiate_device_type_tests(TestCompositeCompliance, globals())
|
||||
instantiate_device_type_tests(TestMathBits, globals())
|
||||
instantiate_device_type_tests(TestRefsOpsInfo, globals(), only_for="cpu")
|
||||
|
@ -628,8 +628,17 @@ class XPUTestBase(DeviceTypeTestBase):
|
||||
@classmethod
|
||||
def get_all_devices(cls):
|
||||
# currently only one device is supported on MPS backend
|
||||
primary_device_idx = int(cls.get_primary_device().split(":")[1])
|
||||
num_devices = torch.xpu.device_count()
|
||||
|
||||
prim_device = cls.get_primary_device()
|
||||
return [prim_device]
|
||||
xpu_str = "xpu:{0}"
|
||||
non_primary_devices = [
|
||||
xpu_str.format(idx)
|
||||
for idx in range(num_devices)
|
||||
if idx != primary_device_idx
|
||||
]
|
||||
return [prim_device] + non_primary_devices
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -1395,13 +1404,13 @@ class expectedFailure:
|
||||
|
||||
|
||||
class onlyOn:
|
||||
def __init__(self, device_type):
|
||||
def __init__(self, device_type: Union[str, list]):
|
||||
self.device_type = device_type
|
||||
|
||||
def __call__(self, fn):
|
||||
@wraps(fn)
|
||||
def only_fn(slf, *args, **kwargs):
|
||||
if self.device_type != slf.device_type:
|
||||
if slf.device_type not in self.device_type:
|
||||
reason = f"Only runs on {self.device_type}"
|
||||
raise unittest.SkipTest(reason)
|
||||
|
||||
@ -1960,6 +1969,10 @@ def skipHPU(fn):
|
||||
return skipHPUIf(True, "test doesn't work on HPU backend")(fn)
|
||||
|
||||
|
||||
def skipXPU(fn):
|
||||
return skipXPUIf(True, "test doesn't work on XPU backend")(fn)
|
||||
|
||||
|
||||
def skipPRIVATEUSE1(fn):
|
||||
return skipPRIVATEUSE1If(True, "test doesn't work on privateuse1 backend")(fn)
|
||||
|
||||
|
@ -28,7 +28,7 @@ from torch.testing._internal.common_device_type import \
|
||||
(onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
|
||||
skipCUDAIfNoCusolver, skipCPUIfNoLapack, skipCPUIfNoFFT, skipCUDAIf, precisionOverride,
|
||||
skipCPUIfNoMklSparse,
|
||||
toleranceOverride, tol)
|
||||
toleranceOverride, tol, skipXPU)
|
||||
from torch.testing._internal.common_cuda import (
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
|
||||
SM53OrLater, SM80OrLater, SM89OrLater, with_tf32_off, TEST_CUDNN, _get_torch_cuda_version,
|
||||
@ -39,7 +39,7 @@ from torch.testing._internal.common_utils import (
|
||||
TEST_WITH_ROCM, IS_FBCODE, IS_WINDOWS, IS_MACOS, IS_S390X, TEST_SCIPY,
|
||||
torch_to_numpy_dtype_dict, numpy_to_torch_dtype, TEST_WITH_ASAN,
|
||||
GRADCHECK_NONDET_TOL, slowTest, TEST_WITH_SLOW,
|
||||
TEST_WITH_TORCHINDUCTOR, MACOS_VERSION
|
||||
TEST_WITH_TORCHINDUCTOR, MACOS_VERSION,
|
||||
)
|
||||
from torch.testing._utils import wrapper_set_seed
|
||||
|
||||
@ -13646,7 +13646,8 @@ op_db: list[OpInfo] = [
|
||||
skipCUDAIf(not ((_get_torch_cuda_version() >= (11, 3))
|
||||
or (_get_torch_rocm_version() >= (5, 2))),
|
||||
"cusparseSDDMM was added in 11.2.1"),
|
||||
skipCPUIfNoMklSparse, ],
|
||||
skipCPUIfNoMklSparse,
|
||||
skipXPU],
|
||||
skips=(
|
||||
# NotImplementedError: Tensors of type SparseCsrTensorImpl do not have is_contiguous
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
|
||||
@ -16373,7 +16374,7 @@ op_db: list[OpInfo] = [
|
||||
supports_out=True,
|
||||
supports_forward_ad=False,
|
||||
supports_autograd=False,
|
||||
decorators=[skipCUDAIf(not SM89OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 8.9')],
|
||||
decorators=[skipXPU, skipCUDAIf(not SM89OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 8.9')],
|
||||
skips=(
|
||||
# Sample inputs isn't really parametrized on dtype
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'),
|
||||
@ -16500,7 +16501,8 @@ op_db: list[OpInfo] = [
|
||||
# FIXME: mask_type == 2 (LowerRight)
|
||||
decorators=[
|
||||
skipCUDAIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "This platform doesn't support efficient attention"),
|
||||
skipCUDAIf(TEST_WITH_ROCM, "Efficient attention on ROCM doesn't support custom_mask_type==2")],
|
||||
skipCUDAIf(TEST_WITH_ROCM, "Efficient attention on ROCM doesn't support custom_mask_type==2"),
|
||||
skipXPU],
|
||||
skips=(
|
||||
# Checking the scaler value of the philox seed and offset
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'),
|
||||
@ -20863,6 +20865,8 @@ op_db: list[OpInfo] = [
|
||||
# AssertionError: Tensor-likes are not close!
|
||||
# Fails in cuda11.7
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu', device_type='cuda'),
|
||||
# AssertionError: Tensor-likes are not close!
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu', device_type='xpu'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),),),
|
||||
# In training mode, feature_alpha_dropout currently doesn't support inputs of complex dtype
|
||||
# unlike when `train=False`, it supports complex inputs, hence 2 OpInfos to cover all cases
|
||||
|
@ -41,6 +41,7 @@ from torch.testing._internal.common_utils import (
|
||||
skipIfSlowGradcheckEnv,
|
||||
slowTest,
|
||||
TEST_WITH_ROCM,
|
||||
TEST_XPU,
|
||||
)
|
||||
from torch.testing._internal.opinfo.core import (
|
||||
clone_sample,
|
||||
@ -1766,7 +1767,12 @@ op_db: list[OpInfo] = [
|
||||
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
|
||||
skips=(
|
||||
# linalg.lu_factor: LU without pivoting is not implemented on the CPU
|
||||
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestCommon",
|
||||
"test_compare_cpu",
|
||||
active_if=(not TEST_XPU),
|
||||
),
|
||||
),
|
||||
),
|
||||
OpInfo(
|
||||
@ -1782,7 +1788,12 @@ op_db: list[OpInfo] = [
|
||||
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
|
||||
skips=(
|
||||
# linalg.lu_factor: LU without pivoting is not implemented on the CPU
|
||||
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestCommon",
|
||||
"test_compare_cpu",
|
||||
active_if=(not TEST_XPU),
|
||||
),
|
||||
),
|
||||
),
|
||||
OpInfo(
|
||||
@ -1799,7 +1810,12 @@ op_db: list[OpInfo] = [
|
||||
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
|
||||
skips=(
|
||||
# linalg.lu_factor: LU without pivoting is not implemented on the CPU
|
||||
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestCommon",
|
||||
"test_compare_cpu",
|
||||
active_if=(not TEST_XPU),
|
||||
),
|
||||
),
|
||||
),
|
||||
OpInfo(
|
||||
|
@ -448,6 +448,10 @@ op_db: list[OpInfo] = [
|
||||
DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
|
||||
# Greatest absolute difference: inf
|
||||
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
|
||||
# Too slow
|
||||
DecorateInfo(
|
||||
unittest.skip, "TestCommon", "test_compare_cpu", device_type="xpu"
|
||||
),
|
||||
),
|
||||
supports_one_python_scalar=True,
|
||||
supports_autograd=False,
|
||||
@ -474,6 +478,10 @@ op_db: list[OpInfo] = [
|
||||
DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
|
||||
# Greatest absolute difference: nan
|
||||
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
|
||||
# Too slow
|
||||
DecorateInfo(
|
||||
unittest.skip, "TestCommon", "test_compare_cpu", device_type="xpu"
|
||||
),
|
||||
),
|
||||
supports_one_python_scalar=True,
|
||||
supports_autograd=False,
|
||||
|
Reference in New Issue
Block a user