[1/N]Enable some tests in test_ops.TestCommon on Intel GPU (#159944)

For https://github.com/pytorch/pytorch/issues/114850, we will port aten unit tests to Intel GPU. This PR will work on some test case of test/test_ops.py. We could enable Intel GPU with following methods and try the best to keep the original code styles:

1. Extended XPUTestBase.get_all_devices to support multiple devices
2. Added skipXPU decorator
3. Extended onlyOn to support device list
4. Enabled 'xpu' for some test pathes
5. Added allow_xpu=True for supported test class.
6. Replaced onlyCUDA with onlyOn(['cuda', 'xpu']) for supported tests
7. Use skipIfXpu and skipXPU to disable unsupported test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159944
Approved by: https://github.com/guangyey, https://github.com/EikanWang, https://github.com/albanD
This commit is contained in:
Deng, Daisy
2025-09-29 09:08:04 +00:00
committed by PyTorch MergeBot
parent e1e5e040cd
commit 4fd70d4e7b
5 changed files with 76 additions and 15 deletions

View File

@ -28,9 +28,11 @@ from torch.testing._internal.common_device_type import (
onlyCPU, onlyCPU,
onlyCUDA, onlyCUDA,
onlyNativeDeviceTypesAnd, onlyNativeDeviceTypesAnd,
onlyOn,
OpDTypes, OpDTypes,
ops, ops,
skipMeta, skipMeta,
skipXPU,
) )
from torch.testing._internal.common_dtype import ( from torch.testing._internal.common_dtype import (
all_types_and_complex_and, all_types_and_complex_and,
@ -221,7 +223,7 @@ class TestCommon(TestCase):
assert len(filtered_ops) == 0, err_msg assert len(filtered_ops) == 0, err_msg
# Validates that each OpInfo works correctly on different CUDA devices # Validates that each OpInfo works correctly on different CUDA devices
@onlyCUDA @onlyOn(["cuda", "xpu"])
@deviceCountAtLeast(2) @deviceCountAtLeast(2)
@ops(op_db, allowed_dtypes=(torch.float32, torch.long)) @ops(op_db, allowed_dtypes=(torch.float32, torch.long))
def test_multiple_devices(self, devices, dtype, op): def test_multiple_devices(self, devices, dtype, op):
@ -331,6 +333,8 @@ class TestCommon(TestCase):
# NumPy does computation internally using double precision for many functions # NumPy does computation internally using double precision for many functions
# resulting in possible equality check failures. # resulting in possible equality check failures.
# skip windows case on CPU due to https://github.com/pytorch/pytorch/issues/129947 # skip windows case on CPU due to https://github.com/pytorch/pytorch/issues/129947
# XPU test will be enabled step by step. Skip the tests temporarily.
@skipXPU
@onlyNativeDeviceTypesAnd(["hpu"]) @onlyNativeDeviceTypesAnd(["hpu"])
@suppress_warnings @suppress_warnings
@ops(_ref_test_ops, allowed_dtypes=(torch.float64, torch.long, torch.complex128)) @ops(_ref_test_ops, allowed_dtypes=(torch.float64, torch.long, torch.complex128))
@ -340,7 +344,7 @@ class TestCommon(TestCase):
and op.formatted_name and op.formatted_name
in ("signal_windows_exponential", "signal_windows_bartlett") in ("signal_windows_exponential", "signal_windows_bartlett")
and dtype == torch.float64 and dtype == torch.float64
and "cuda" in device and ("cuda" in device or "xpu" in device)
or "cpu" in device or "cpu" in device
): # noqa: E121 ): # noqa: E121
raise unittest.SkipTest("XXX: raises tensor-likes are not close.") raise unittest.SkipTest("XXX: raises tensor-likes are not close.")
@ -353,7 +357,7 @@ class TestCommon(TestCase):
) )
# Tests that the cpu and gpu results are consistent # Tests that the cpu and gpu results are consistent
@onlyCUDA @onlyOn(["cuda", "xpu"])
@suppress_warnings @suppress_warnings
@slowTest @slowTest
@ops(_ops_and_refs_with_no_numpy_ref, dtypes=OpDTypes.any_common_cpu_cuda_one) @ops(_ops_and_refs_with_no_numpy_ref, dtypes=OpDTypes.any_common_cpu_cuda_one)
@ -385,6 +389,7 @@ class TestCommon(TestCase):
# Tests that experimental Python References can propagate shape, dtype, # Tests that experimental Python References can propagate shape, dtype,
# and device metadata properly. # and device metadata properly.
# See https://github.com/pytorch/pytorch/issues/78050 for a discussion of stride propagation. # See https://github.com/pytorch/pytorch/issues/78050 for a discussion of stride propagation.
@skipXPU
@onlyNativeDeviceTypesAnd(["hpu"]) @onlyNativeDeviceTypesAnd(["hpu"])
@ops(python_ref_db) @ops(python_ref_db)
@skipIfTorchInductor("Takes too long for inductor") @skipIfTorchInductor("Takes too long for inductor")
@ -580,6 +585,7 @@ class TestCommon(TestCase):
# Tests that experimental Python References perform the same computation # Tests that experimental Python References perform the same computation
# as the operators they reference, when operator calls in the torch # as the operators they reference, when operator calls in the torch
# namespace are remapped to the refs namespace (torch.foo becomes refs.foo). # namespace are remapped to the refs namespace (torch.foo becomes refs.foo).
@skipXPU
@onlyNativeDeviceTypesAnd(["hpu"]) @onlyNativeDeviceTypesAnd(["hpu"])
@ops(python_ref_db) @ops(python_ref_db)
@skipIfTorchInductor("Takes too long for inductor") @skipIfTorchInductor("Takes too long for inductor")
@ -598,6 +604,7 @@ class TestCommon(TestCase):
# Tests that experimental Python References perform the same computation # Tests that experimental Python References perform the same computation
# as the operators they reference, when operator calls in the torch # as the operators they reference, when operator calls in the torch
# namespace are preserved (torch.foo remains torch.foo). # namespace are preserved (torch.foo remains torch.foo).
@skipXPU
@onlyNativeDeviceTypesAnd(["hpu"]) @onlyNativeDeviceTypesAnd(["hpu"])
@ops(python_ref_db) @ops(python_ref_db)
@skipIfTorchInductor("Takes too long for inductor") @skipIfTorchInductor("Takes too long for inductor")
@ -633,6 +640,7 @@ class TestCommon(TestCase):
op.op = partial(make_traced(op.op), executor=executor) op.op = partial(make_traced(op.op), executor=executor)
self._ref_test_helper(contextlib.nullcontext, device, dtype, op) self._ref_test_helper(contextlib.nullcontext, device, dtype, op)
@skipXPU
@skipMeta @skipMeta
@onlyNativeDeviceTypesAnd(["hpu"]) @onlyNativeDeviceTypesAnd(["hpu"])
@ops([op for op in op_db if op.error_inputs_func is not None], dtypes=OpDTypes.none) @ops([op for op in op_db if op.error_inputs_func is not None], dtypes=OpDTypes.none)
@ -644,6 +652,7 @@ class TestCommon(TestCase):
out = op(si.input, *si.args, **si.kwargs) out = op(si.input, *si.args, **si.kwargs)
self.assertFalse(isinstance(out, type(NotImplemented))) self.assertFalse(isinstance(out, type(NotImplemented)))
@skipXPU
@skipMeta @skipMeta
@onlyNativeDeviceTypesAnd(["hpu"]) @onlyNativeDeviceTypesAnd(["hpu"])
@ops( @ops(
@ -667,6 +676,7 @@ class TestCommon(TestCase):
out = op(si.input, *si.args, **si.kwargs) out = op(si.input, *si.args, **si.kwargs)
self.assertFalse(isinstance(out, type(NotImplemented))) self.assertFalse(isinstance(out, type(NotImplemented)))
@skipXPU
@skipMeta @skipMeta
@onlyNativeDeviceTypesAnd(["hpu"]) @onlyNativeDeviceTypesAnd(["hpu"])
@ops( @ops(
@ -693,6 +703,7 @@ class TestCommon(TestCase):
# Tests that the function produces the same result when called with # Tests that the function produces the same result when called with
# noncontiguous tensors. # noncontiguous tensors.
@skipXPU
@with_tf32_off @with_tf32_off
@onlyNativeDeviceTypesAnd(["hpu"]) @onlyNativeDeviceTypesAnd(["hpu"])
@suppress_warnings @suppress_warnings
@ -785,6 +796,7 @@ class TestCommon(TestCase):
# incorrectly sized out parameter warning properly yet # incorrectly sized out parameter warning properly yet
# Cases test here: # Cases test here:
# - out= with the correct dtype and device, but the wrong shape # - out= with the correct dtype and device, but the wrong shape
@skipXPU
@ops(ops_and_refs, dtypes=OpDTypes.none) @ops(ops_and_refs, dtypes=OpDTypes.none)
def test_out_warning(self, device, op): def test_out_warning(self, device, op):
if TEST_WITH_TORCHDYNAMO and op.name == "_refs.clamp": if TEST_WITH_TORCHDYNAMO and op.name == "_refs.clamp":
@ -923,6 +935,7 @@ class TestCommon(TestCase):
# Case 3 and 4 are slightly different when the op is a factory function: # Case 3 and 4 are slightly different when the op is a factory function:
# - if device, dtype are NOT passed, any combination of dtype/device should be OK for out # - if device, dtype are NOT passed, any combination of dtype/device should be OK for out
# - if device, dtype are passed, device and dtype should match # - if device, dtype are passed, device and dtype should match
@skipXPU
@ops(ops_and_refs, dtypes=OpDTypes.any_one) @ops(ops_and_refs, dtypes=OpDTypes.any_one)
def test_out(self, device, dtype, op): def test_out(self, device, dtype, op):
# Prefers running in float32 but has a fallback for the first listed supported dtype # Prefers running in float32 but has a fallback for the first listed supported dtype
@ -1126,6 +1139,7 @@ class TestCommon(TestCase):
with self.assertRaises(exc_type, msg=msg_fail): with self.assertRaises(exc_type, msg=msg_fail):
op_out(out=out) op_out(out=out)
@skipXPU
@ops( @ops(
[ [
op op
@ -1164,6 +1178,7 @@ class TestCommon(TestCase):
with self.assertRaises(RuntimeError, msg=msg), maybe_skip_size_asserts(op): with self.assertRaises(RuntimeError, msg=msg), maybe_skip_size_asserts(op):
op(sample.input, *sample.args, **sample.kwargs, out=out) op(sample.input, *sample.args, **sample.kwargs, out=out)
@skipXPU
@ops(filter(reduction_dtype_filter, ops_and_refs), dtypes=(torch.int16,)) @ops(filter(reduction_dtype_filter, ops_and_refs), dtypes=(torch.int16,))
def test_out_integral_dtype(self, device, dtype, op): def test_out_integral_dtype(self, device, dtype, op):
def helper(with_out, expectFail, op_to_test, inputs, *args, **kwargs): def helper(with_out, expectFail, op_to_test, inputs, *args, **kwargs):
@ -1207,6 +1222,7 @@ class TestCommon(TestCase):
# Tests that the forward and backward passes of operations produce the # Tests that the forward and backward passes of operations produce the
# same values for the cross-product of op variants (method, inplace) # same values for the cross-product of op variants (method, inplace)
# against eager's gold standard op function variant # against eager's gold standard op function variant
@skipXPU
@_variant_ops(op_db) @_variant_ops(op_db)
def test_variant_consistency_eager(self, device, dtype, op): def test_variant_consistency_eager(self, device, dtype, op):
# Acquires variants (method variant, inplace variant, operator variant, inplace_operator variant, aliases) # Acquires variants (method variant, inplace variant, operator variant, inplace_operator variant, aliases)
@ -1387,6 +1403,7 @@ class TestCommon(TestCase):
# Reference testing for operations in complex32 against complex64. # Reference testing for operations in complex32 against complex64.
# NOTE: We test against complex64 as NumPy doesn't have a complex32 equivalent dtype. # NOTE: We test against complex64 as NumPy doesn't have a complex32 equivalent dtype.
@skipXPU
@ops(op_db, allowed_dtypes=(torch.complex32,)) @ops(op_db, allowed_dtypes=(torch.complex32,))
def test_complex_half_reference_testing(self, device, dtype, op): def test_complex_half_reference_testing(self, device, dtype, op):
if not op.supports_dtype(torch.complex32, device): if not op.supports_dtype(torch.complex32, device):
@ -1422,6 +1439,7 @@ class TestCommon(TestCase):
# `cfloat` input -> `float` output # `cfloat` input -> `float` output
self.assertEqual(actual, expected, exact_dtype=False) self.assertEqual(actual, expected, exact_dtype=False)
@skipXPU
@ops(op_db, allowed_dtypes=(torch.bool,)) @ops(op_db, allowed_dtypes=(torch.bool,))
def test_non_standard_bool_values(self, device, dtype, op): def test_non_standard_bool_values(self, device, dtype, op):
# Test boolean values other than 0x00 and 0x01 (gh-54789) # Test boolean values other than 0x00 and 0x01 (gh-54789)
@ -1450,6 +1468,7 @@ class TestCommon(TestCase):
# Validates that each OpInfo specifies its forward and backward dtypes # Validates that each OpInfo specifies its forward and backward dtypes
# correctly for CPU and CUDA devices # correctly for CPU and CUDA devices
@skipXPU
@skipMeta @skipMeta
@onlyNativeDeviceTypesAnd(["hpu"]) @onlyNativeDeviceTypesAnd(["hpu"])
@ops(ops_and_refs, dtypes=OpDTypes.none) @ops(ops_and_refs, dtypes=OpDTypes.none)
@ -1656,6 +1675,7 @@ class TestCommon(TestCase):
self.fail(msg) self.fail(msg)
# Validates that each OpInfo that sets promotes_int_to_float=True does as it says # Validates that each OpInfo that sets promotes_int_to_float=True does as it says
@skipXPU
@skipMeta @skipMeta
@onlyNativeDeviceTypesAnd(["hpu"]) @onlyNativeDeviceTypesAnd(["hpu"])
@ops( @ops(
@ -2845,7 +2865,7 @@ class TestFakeTensor(TestCase):
self.assertEqual(strided_result.layout, torch.strided) self.assertEqual(strided_result.layout, torch.strided)
instantiate_device_type_tests(TestCommon, globals()) instantiate_device_type_tests(TestCommon, globals(), allow_xpu=True)
instantiate_device_type_tests(TestCompositeCompliance, globals()) instantiate_device_type_tests(TestCompositeCompliance, globals())
instantiate_device_type_tests(TestMathBits, globals()) instantiate_device_type_tests(TestMathBits, globals())
instantiate_device_type_tests(TestRefsOpsInfo, globals(), only_for="cpu") instantiate_device_type_tests(TestRefsOpsInfo, globals(), only_for="cpu")

View File

@ -628,8 +628,17 @@ class XPUTestBase(DeviceTypeTestBase):
@classmethod @classmethod
def get_all_devices(cls): def get_all_devices(cls):
# currently only one device is supported on MPS backend # currently only one device is supported on MPS backend
primary_device_idx = int(cls.get_primary_device().split(":")[1])
num_devices = torch.xpu.device_count()
prim_device = cls.get_primary_device() prim_device = cls.get_primary_device()
return [prim_device] xpu_str = "xpu:{0}"
non_primary_devices = [
xpu_str.format(idx)
for idx in range(num_devices)
if idx != primary_device_idx
]
return [prim_device] + non_primary_devices
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
@ -1395,13 +1404,13 @@ class expectedFailure:
class onlyOn: class onlyOn:
def __init__(self, device_type): def __init__(self, device_type: Union[str, list]):
self.device_type = device_type self.device_type = device_type
def __call__(self, fn): def __call__(self, fn):
@wraps(fn) @wraps(fn)
def only_fn(slf, *args, **kwargs): def only_fn(slf, *args, **kwargs):
if self.device_type != slf.device_type: if slf.device_type not in self.device_type:
reason = f"Only runs on {self.device_type}" reason = f"Only runs on {self.device_type}"
raise unittest.SkipTest(reason) raise unittest.SkipTest(reason)
@ -1960,6 +1969,10 @@ def skipHPU(fn):
return skipHPUIf(True, "test doesn't work on HPU backend")(fn) return skipHPUIf(True, "test doesn't work on HPU backend")(fn)
def skipXPU(fn):
return skipXPUIf(True, "test doesn't work on XPU backend")(fn)
def skipPRIVATEUSE1(fn): def skipPRIVATEUSE1(fn):
return skipPRIVATEUSE1If(True, "test doesn't work on privateuse1 backend")(fn) return skipPRIVATEUSE1If(True, "test doesn't work on privateuse1 backend")(fn)

View File

@ -28,7 +28,7 @@ from torch.testing._internal.common_device_type import \
(onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, (onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
skipCUDAIfNoCusolver, skipCPUIfNoLapack, skipCPUIfNoFFT, skipCUDAIf, precisionOverride, skipCUDAIfNoCusolver, skipCPUIfNoLapack, skipCPUIfNoFFT, skipCUDAIf, precisionOverride,
skipCPUIfNoMklSparse, skipCPUIfNoMklSparse,
toleranceOverride, tol) toleranceOverride, tol, skipXPU)
from torch.testing._internal.common_cuda import ( from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
SM53OrLater, SM80OrLater, SM89OrLater, with_tf32_off, TEST_CUDNN, _get_torch_cuda_version, SM53OrLater, SM80OrLater, SM89OrLater, with_tf32_off, TEST_CUDNN, _get_torch_cuda_version,
@ -39,7 +39,7 @@ from torch.testing._internal.common_utils import (
TEST_WITH_ROCM, IS_FBCODE, IS_WINDOWS, IS_MACOS, IS_S390X, TEST_SCIPY, TEST_WITH_ROCM, IS_FBCODE, IS_WINDOWS, IS_MACOS, IS_S390X, TEST_SCIPY,
torch_to_numpy_dtype_dict, numpy_to_torch_dtype, TEST_WITH_ASAN, torch_to_numpy_dtype_dict, numpy_to_torch_dtype, TEST_WITH_ASAN,
GRADCHECK_NONDET_TOL, slowTest, TEST_WITH_SLOW, GRADCHECK_NONDET_TOL, slowTest, TEST_WITH_SLOW,
TEST_WITH_TORCHINDUCTOR, MACOS_VERSION TEST_WITH_TORCHINDUCTOR, MACOS_VERSION,
) )
from torch.testing._utils import wrapper_set_seed from torch.testing._utils import wrapper_set_seed
@ -13646,7 +13646,8 @@ op_db: list[OpInfo] = [
skipCUDAIf(not ((_get_torch_cuda_version() >= (11, 3)) skipCUDAIf(not ((_get_torch_cuda_version() >= (11, 3))
or (_get_torch_rocm_version() >= (5, 2))), or (_get_torch_rocm_version() >= (5, 2))),
"cusparseSDDMM was added in 11.2.1"), "cusparseSDDMM was added in 11.2.1"),
skipCPUIfNoMklSparse, ], skipCPUIfNoMklSparse,
skipXPU],
skips=( skips=(
# NotImplementedError: Tensors of type SparseCsrTensorImpl do not have is_contiguous # NotImplementedError: Tensors of type SparseCsrTensorImpl do not have is_contiguous
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
@ -16373,7 +16374,7 @@ op_db: list[OpInfo] = [
supports_out=True, supports_out=True,
supports_forward_ad=False, supports_forward_ad=False,
supports_autograd=False, supports_autograd=False,
decorators=[skipCUDAIf(not SM89OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 8.9')], decorators=[skipXPU, skipCUDAIf(not SM89OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 8.9')],
skips=( skips=(
# Sample inputs isn't really parametrized on dtype # Sample inputs isn't really parametrized on dtype
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'), DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'),
@ -16500,7 +16501,8 @@ op_db: list[OpInfo] = [
# FIXME: mask_type == 2 (LowerRight) # FIXME: mask_type == 2 (LowerRight)
decorators=[ decorators=[
skipCUDAIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "This platform doesn't support efficient attention"), skipCUDAIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "This platform doesn't support efficient attention"),
skipCUDAIf(TEST_WITH_ROCM, "Efficient attention on ROCM doesn't support custom_mask_type==2")], skipCUDAIf(TEST_WITH_ROCM, "Efficient attention on ROCM doesn't support custom_mask_type==2"),
skipXPU],
skips=( skips=(
# Checking the scaler value of the philox seed and offset # Checking the scaler value of the philox seed and offset
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'), DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'),
@ -20863,6 +20865,8 @@ op_db: list[OpInfo] = [
# AssertionError: Tensor-likes are not close! # AssertionError: Tensor-likes are not close!
# Fails in cuda11.7 # Fails in cuda11.7
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu', device_type='cuda'), DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu', device_type='cuda'),
# AssertionError: Tensor-likes are not close!
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu', device_type='xpu'),
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),),), DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),),),
# In training mode, feature_alpha_dropout currently doesn't support inputs of complex dtype # In training mode, feature_alpha_dropout currently doesn't support inputs of complex dtype
# unlike when `train=False`, it supports complex inputs, hence 2 OpInfos to cover all cases # unlike when `train=False`, it supports complex inputs, hence 2 OpInfos to cover all cases

View File

@ -41,6 +41,7 @@ from torch.testing._internal.common_utils import (
skipIfSlowGradcheckEnv, skipIfSlowGradcheckEnv,
slowTest, slowTest,
TEST_WITH_ROCM, TEST_WITH_ROCM,
TEST_XPU,
) )
from torch.testing._internal.opinfo.core import ( from torch.testing._internal.opinfo.core import (
clone_sample, clone_sample,
@ -1766,7 +1767,12 @@ op_db: list[OpInfo] = [
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
skips=( skips=(
# linalg.lu_factor: LU without pivoting is not implemented on the CPU # linalg.lu_factor: LU without pivoting is not implemented on the CPU
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), DecorateInfo(
unittest.expectedFailure,
"TestCommon",
"test_compare_cpu",
active_if=(not TEST_XPU),
),
), ),
), ),
OpInfo( OpInfo(
@ -1782,7 +1788,12 @@ op_db: list[OpInfo] = [
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
skips=( skips=(
# linalg.lu_factor: LU without pivoting is not implemented on the CPU # linalg.lu_factor: LU without pivoting is not implemented on the CPU
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), DecorateInfo(
unittest.expectedFailure,
"TestCommon",
"test_compare_cpu",
active_if=(not TEST_XPU),
),
), ),
), ),
OpInfo( OpInfo(
@ -1799,7 +1810,12 @@ op_db: list[OpInfo] = [
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
skips=( skips=(
# linalg.lu_factor: LU without pivoting is not implemented on the CPU # linalg.lu_factor: LU without pivoting is not implemented on the CPU
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), DecorateInfo(
unittest.expectedFailure,
"TestCommon",
"test_compare_cpu",
active_if=(not TEST_XPU),
),
), ),
), ),
OpInfo( OpInfo(

View File

@ -448,6 +448,10 @@ op_db: list[OpInfo] = [
DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
# Greatest absolute difference: inf # Greatest absolute difference: inf
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
# Too slow
DecorateInfo(
unittest.skip, "TestCommon", "test_compare_cpu", device_type="xpu"
),
), ),
supports_one_python_scalar=True, supports_one_python_scalar=True,
supports_autograd=False, supports_autograd=False,
@ -474,6 +478,10 @@ op_db: list[OpInfo] = [
DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
# Greatest absolute difference: nan # Greatest absolute difference: nan
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
# Too slow
DecorateInfo(
unittest.skip, "TestCommon", "test_compare_cpu", device_type="xpu"
),
), ),
supports_one_python_scalar=True, supports_one_python_scalar=True,
supports_autograd=False, supports_autograd=False,