From 4fd70d4e7bd5607c218a5b8a4ac9d2c08a38efa9 Mon Sep 17 00:00:00 2001 From: "Deng, Daisy" Date: Mon, 29 Sep 2025 09:08:04 +0000 Subject: [PATCH] [1/N]Enable some tests in test_ops.TestCommon on Intel GPU (#159944) For https://github.com/pytorch/pytorch/issues/114850, we will port aten unit tests to Intel GPU. This PR will work on some test case of test/test_ops.py. We could enable Intel GPU with following methods and try the best to keep the original code styles: 1. Extended XPUTestBase.get_all_devices to support multiple devices 2. Added skipXPU decorator 3. Extended onlyOn to support device list 4. Enabled 'xpu' for some test pathes 5. Added allow_xpu=True for supported test class. 6. Replaced onlyCUDA with onlyOn(['cuda', 'xpu']) for supported tests 7. Use skipIfXpu and skipXPU to disable unsupported test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159944 Approved by: https://github.com/guangyey, https://github.com/EikanWang, https://github.com/albanD --- test/test_ops.py | 28 ++++++++++++++++--- torch/testing/_internal/common_device_type.py | 19 +++++++++++-- .../_internal/common_methods_invocations.py | 14 ++++++---- .../_internal/opinfo/definitions/linalg.py | 22 +++++++++++++-- .../_internal/opinfo/definitions/special.py | 8 ++++++ 5 files changed, 76 insertions(+), 15 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 64b657c9294b..d95f01eceaa3 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -28,9 +28,11 @@ from torch.testing._internal.common_device_type import ( onlyCPU, onlyCUDA, onlyNativeDeviceTypesAnd, + onlyOn, OpDTypes, ops, skipMeta, + skipXPU, ) from torch.testing._internal.common_dtype import ( all_types_and_complex_and, @@ -221,7 +223,7 @@ class TestCommon(TestCase): assert len(filtered_ops) == 0, err_msg # Validates that each OpInfo works correctly on different CUDA devices - @onlyCUDA + @onlyOn(["cuda", "xpu"]) @deviceCountAtLeast(2) @ops(op_db, allowed_dtypes=(torch.float32, torch.long)) def test_multiple_devices(self, devices, dtype, op): @@ -331,6 +333,8 @@ class TestCommon(TestCase): # NumPy does computation internally using double precision for many functions # resulting in possible equality check failures. # skip windows case on CPU due to https://github.com/pytorch/pytorch/issues/129947 + # XPU test will be enabled step by step. Skip the tests temporarily. + @skipXPU @onlyNativeDeviceTypesAnd(["hpu"]) @suppress_warnings @ops(_ref_test_ops, allowed_dtypes=(torch.float64, torch.long, torch.complex128)) @@ -340,7 +344,7 @@ class TestCommon(TestCase): and op.formatted_name in ("signal_windows_exponential", "signal_windows_bartlett") and dtype == torch.float64 - and "cuda" in device + and ("cuda" in device or "xpu" in device) or "cpu" in device ): # noqa: E121 raise unittest.SkipTest("XXX: raises tensor-likes are not close.") @@ -353,7 +357,7 @@ class TestCommon(TestCase): ) # Tests that the cpu and gpu results are consistent - @onlyCUDA + @onlyOn(["cuda", "xpu"]) @suppress_warnings @slowTest @ops(_ops_and_refs_with_no_numpy_ref, dtypes=OpDTypes.any_common_cpu_cuda_one) @@ -385,6 +389,7 @@ class TestCommon(TestCase): # Tests that experimental Python References can propagate shape, dtype, # and device metadata properly. # See https://github.com/pytorch/pytorch/issues/78050 for a discussion of stride propagation. + @skipXPU @onlyNativeDeviceTypesAnd(["hpu"]) @ops(python_ref_db) @skipIfTorchInductor("Takes too long for inductor") @@ -580,6 +585,7 @@ class TestCommon(TestCase): # Tests that experimental Python References perform the same computation # as the operators they reference, when operator calls in the torch # namespace are remapped to the refs namespace (torch.foo becomes refs.foo). + @skipXPU @onlyNativeDeviceTypesAnd(["hpu"]) @ops(python_ref_db) @skipIfTorchInductor("Takes too long for inductor") @@ -598,6 +604,7 @@ class TestCommon(TestCase): # Tests that experimental Python References perform the same computation # as the operators they reference, when operator calls in the torch # namespace are preserved (torch.foo remains torch.foo). + @skipXPU @onlyNativeDeviceTypesAnd(["hpu"]) @ops(python_ref_db) @skipIfTorchInductor("Takes too long for inductor") @@ -633,6 +640,7 @@ class TestCommon(TestCase): op.op = partial(make_traced(op.op), executor=executor) self._ref_test_helper(contextlib.nullcontext, device, dtype, op) + @skipXPU @skipMeta @onlyNativeDeviceTypesAnd(["hpu"]) @ops([op for op in op_db if op.error_inputs_func is not None], dtypes=OpDTypes.none) @@ -644,6 +652,7 @@ class TestCommon(TestCase): out = op(si.input, *si.args, **si.kwargs) self.assertFalse(isinstance(out, type(NotImplemented))) + @skipXPU @skipMeta @onlyNativeDeviceTypesAnd(["hpu"]) @ops( @@ -667,6 +676,7 @@ class TestCommon(TestCase): out = op(si.input, *si.args, **si.kwargs) self.assertFalse(isinstance(out, type(NotImplemented))) + @skipXPU @skipMeta @onlyNativeDeviceTypesAnd(["hpu"]) @ops( @@ -693,6 +703,7 @@ class TestCommon(TestCase): # Tests that the function produces the same result when called with # noncontiguous tensors. + @skipXPU @with_tf32_off @onlyNativeDeviceTypesAnd(["hpu"]) @suppress_warnings @@ -785,6 +796,7 @@ class TestCommon(TestCase): # incorrectly sized out parameter warning properly yet # Cases test here: # - out= with the correct dtype and device, but the wrong shape + @skipXPU @ops(ops_and_refs, dtypes=OpDTypes.none) def test_out_warning(self, device, op): if TEST_WITH_TORCHDYNAMO and op.name == "_refs.clamp": @@ -923,6 +935,7 @@ class TestCommon(TestCase): # Case 3 and 4 are slightly different when the op is a factory function: # - if device, dtype are NOT passed, any combination of dtype/device should be OK for out # - if device, dtype are passed, device and dtype should match + @skipXPU @ops(ops_and_refs, dtypes=OpDTypes.any_one) def test_out(self, device, dtype, op): # Prefers running in float32 but has a fallback for the first listed supported dtype @@ -1126,6 +1139,7 @@ class TestCommon(TestCase): with self.assertRaises(exc_type, msg=msg_fail): op_out(out=out) + @skipXPU @ops( [ op @@ -1164,6 +1178,7 @@ class TestCommon(TestCase): with self.assertRaises(RuntimeError, msg=msg), maybe_skip_size_asserts(op): op(sample.input, *sample.args, **sample.kwargs, out=out) + @skipXPU @ops(filter(reduction_dtype_filter, ops_and_refs), dtypes=(torch.int16,)) def test_out_integral_dtype(self, device, dtype, op): def helper(with_out, expectFail, op_to_test, inputs, *args, **kwargs): @@ -1207,6 +1222,7 @@ class TestCommon(TestCase): # Tests that the forward and backward passes of operations produce the # same values for the cross-product of op variants (method, inplace) # against eager's gold standard op function variant + @skipXPU @_variant_ops(op_db) def test_variant_consistency_eager(self, device, dtype, op): # Acquires variants (method variant, inplace variant, operator variant, inplace_operator variant, aliases) @@ -1387,6 +1403,7 @@ class TestCommon(TestCase): # Reference testing for operations in complex32 against complex64. # NOTE: We test against complex64 as NumPy doesn't have a complex32 equivalent dtype. + @skipXPU @ops(op_db, allowed_dtypes=(torch.complex32,)) def test_complex_half_reference_testing(self, device, dtype, op): if not op.supports_dtype(torch.complex32, device): @@ -1422,6 +1439,7 @@ class TestCommon(TestCase): # `cfloat` input -> `float` output self.assertEqual(actual, expected, exact_dtype=False) + @skipXPU @ops(op_db, allowed_dtypes=(torch.bool,)) def test_non_standard_bool_values(self, device, dtype, op): # Test boolean values other than 0x00 and 0x01 (gh-54789) @@ -1450,6 +1468,7 @@ class TestCommon(TestCase): # Validates that each OpInfo specifies its forward and backward dtypes # correctly for CPU and CUDA devices + @skipXPU @skipMeta @onlyNativeDeviceTypesAnd(["hpu"]) @ops(ops_and_refs, dtypes=OpDTypes.none) @@ -1656,6 +1675,7 @@ class TestCommon(TestCase): self.fail(msg) # Validates that each OpInfo that sets promotes_int_to_float=True does as it says + @skipXPU @skipMeta @onlyNativeDeviceTypesAnd(["hpu"]) @ops( @@ -2845,7 +2865,7 @@ class TestFakeTensor(TestCase): self.assertEqual(strided_result.layout, torch.strided) -instantiate_device_type_tests(TestCommon, globals()) +instantiate_device_type_tests(TestCommon, globals(), allow_xpu=True) instantiate_device_type_tests(TestCompositeCompliance, globals()) instantiate_device_type_tests(TestMathBits, globals()) instantiate_device_type_tests(TestRefsOpsInfo, globals(), only_for="cpu") diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 43c7741c69aa..bfe9a5fb7aee 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -628,8 +628,17 @@ class XPUTestBase(DeviceTypeTestBase): @classmethod def get_all_devices(cls): # currently only one device is supported on MPS backend + primary_device_idx = int(cls.get_primary_device().split(":")[1]) + num_devices = torch.xpu.device_count() + prim_device = cls.get_primary_device() - return [prim_device] + xpu_str = "xpu:{0}" + non_primary_devices = [ + xpu_str.format(idx) + for idx in range(num_devices) + if idx != primary_device_idx + ] + return [prim_device] + non_primary_devices @classmethod def setUpClass(cls): @@ -1395,13 +1404,13 @@ class expectedFailure: class onlyOn: - def __init__(self, device_type): + def __init__(self, device_type: Union[str, list]): self.device_type = device_type def __call__(self, fn): @wraps(fn) def only_fn(slf, *args, **kwargs): - if self.device_type != slf.device_type: + if slf.device_type not in self.device_type: reason = f"Only runs on {self.device_type}" raise unittest.SkipTest(reason) @@ -1960,6 +1969,10 @@ def skipHPU(fn): return skipHPUIf(True, "test doesn't work on HPU backend")(fn) +def skipXPU(fn): + return skipXPUIf(True, "test doesn't work on XPU backend")(fn) + + def skipPRIVATEUSE1(fn): return skipPRIVATEUSE1If(True, "test doesn't work on privateuse1 backend")(fn) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index f81104cbf4da..a30a9f1627d7 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -28,7 +28,7 @@ from torch.testing._internal.common_device_type import \ (onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfNoCusolver, skipCPUIfNoLapack, skipCPUIfNoFFT, skipCUDAIf, precisionOverride, skipCPUIfNoMklSparse, - toleranceOverride, tol) + toleranceOverride, tol, skipXPU) from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, SM53OrLater, SM80OrLater, SM89OrLater, with_tf32_off, TEST_CUDNN, _get_torch_cuda_version, @@ -39,7 +39,7 @@ from torch.testing._internal.common_utils import ( TEST_WITH_ROCM, IS_FBCODE, IS_WINDOWS, IS_MACOS, IS_S390X, TEST_SCIPY, torch_to_numpy_dtype_dict, numpy_to_torch_dtype, TEST_WITH_ASAN, GRADCHECK_NONDET_TOL, slowTest, TEST_WITH_SLOW, - TEST_WITH_TORCHINDUCTOR, MACOS_VERSION + TEST_WITH_TORCHINDUCTOR, MACOS_VERSION, ) from torch.testing._utils import wrapper_set_seed @@ -13646,7 +13646,8 @@ op_db: list[OpInfo] = [ skipCUDAIf(not ((_get_torch_cuda_version() >= (11, 3)) or (_get_torch_rocm_version() >= (5, 2))), "cusparseSDDMM was added in 11.2.1"), - skipCPUIfNoMklSparse, ], + skipCPUIfNoMklSparse, + skipXPU], skips=( # NotImplementedError: Tensors of type SparseCsrTensorImpl do not have is_contiguous DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), @@ -16373,7 +16374,7 @@ op_db: list[OpInfo] = [ supports_out=True, supports_forward_ad=False, supports_autograd=False, - decorators=[skipCUDAIf(not SM89OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 8.9')], + decorators=[skipXPU, skipCUDAIf(not SM89OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 8.9')], skips=( # Sample inputs isn't really parametrized on dtype DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'), @@ -16500,7 +16501,8 @@ op_db: list[OpInfo] = [ # FIXME: mask_type == 2 (LowerRight) decorators=[ skipCUDAIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "This platform doesn't support efficient attention"), - skipCUDAIf(TEST_WITH_ROCM, "Efficient attention on ROCM doesn't support custom_mask_type==2")], + skipCUDAIf(TEST_WITH_ROCM, "Efficient attention on ROCM doesn't support custom_mask_type==2"), + skipXPU], skips=( # Checking the scaler value of the philox seed and offset DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'), @@ -20863,6 +20865,8 @@ op_db: list[OpInfo] = [ # AssertionError: Tensor-likes are not close! # Fails in cuda11.7 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu', device_type='cuda'), + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu', device_type='xpu'), DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),),), # In training mode, feature_alpha_dropout currently doesn't support inputs of complex dtype # unlike when `train=False`, it supports complex inputs, hence 2 OpInfos to cover all cases diff --git a/torch/testing/_internal/opinfo/definitions/linalg.py b/torch/testing/_internal/opinfo/definitions/linalg.py index 9eeacf887084..23500bb3ad3a 100644 --- a/torch/testing/_internal/opinfo/definitions/linalg.py +++ b/torch/testing/_internal/opinfo/definitions/linalg.py @@ -41,6 +41,7 @@ from torch.testing._internal.common_utils import ( skipIfSlowGradcheckEnv, slowTest, TEST_WITH_ROCM, + TEST_XPU, ) from torch.testing._internal.opinfo.core import ( clone_sample, @@ -1766,7 +1767,12 @@ op_db: list[OpInfo] = [ decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], skips=( # linalg.lu_factor: LU without pivoting is not implemented on the CPU - DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), + DecorateInfo( + unittest.expectedFailure, + "TestCommon", + "test_compare_cpu", + active_if=(not TEST_XPU), + ), ), ), OpInfo( @@ -1782,7 +1788,12 @@ op_db: list[OpInfo] = [ decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], skips=( # linalg.lu_factor: LU without pivoting is not implemented on the CPU - DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), + DecorateInfo( + unittest.expectedFailure, + "TestCommon", + "test_compare_cpu", + active_if=(not TEST_XPU), + ), ), ), OpInfo( @@ -1799,7 +1810,12 @@ op_db: list[OpInfo] = [ decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], skips=( # linalg.lu_factor: LU without pivoting is not implemented on the CPU - DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), + DecorateInfo( + unittest.expectedFailure, + "TestCommon", + "test_compare_cpu", + active_if=(not TEST_XPU), + ), ), ), OpInfo( diff --git a/torch/testing/_internal/opinfo/definitions/special.py b/torch/testing/_internal/opinfo/definitions/special.py index 1418685e8832..d6dce75437d1 100644 --- a/torch/testing/_internal/opinfo/definitions/special.py +++ b/torch/testing/_internal/opinfo/definitions/special.py @@ -448,6 +448,10 @@ op_db: list[OpInfo] = [ DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), # Greatest absolute difference: inf DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), + # Too slow + DecorateInfo( + unittest.skip, "TestCommon", "test_compare_cpu", device_type="xpu" + ), ), supports_one_python_scalar=True, supports_autograd=False, @@ -474,6 +478,10 @@ op_db: list[OpInfo] = [ DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), # Greatest absolute difference: nan DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), + # Too slow + DecorateInfo( + unittest.skip, "TestCommon", "test_compare_cpu", device_type="xpu" + ), ), supports_one_python_scalar=True, supports_autograd=False,