From 3dabc351bb5581f69825eee6b24fbac9f9260241 Mon Sep 17 00:00:00 2001 From: "xinan.lin" Date: Mon, 16 Jun 2025 14:30:20 -0700 Subject: [PATCH] [Break XPU] Fix XPU UT failures introduced by community. (#156091) Fixes #15089, Fixes #156063, Fixes #155689, Fixes #155692, Fixes #156146 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156091 Approved by: https://github.com/jansel --- test/dynamo/test_precompile_context.py | 6 +++--- test/inductor/test_torchinductor_codegen_dynamic_shapes.py | 1 - test/inductor/test_torchinductor_dynamic_shapes.py | 1 - test/test_openreg.py | 2 ++ 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/dynamo/test_precompile_context.py b/test/dynamo/test_precompile_context.py index 34742259113d..cc826f0add37 100644 --- a/test/dynamo/test_precompile_context.py +++ b/test/dynamo/test_precompile_context.py @@ -11,7 +11,7 @@ from torch._functorch._aot_autograd.autograd_cache import ( BundledAOTAutogradCacheEntry, ) from torch._inductor.test_case import TestCase as InductorTestCase -from torch.testing._internal.inductor_utils import requires_triton +from torch.testing._internal.inductor_utils import GPU_TYPE, requires_triton @functorch_config.patch({"enable_autograd_cache": True}) @@ -39,7 +39,7 @@ class PrecompileContextTests(InductorTestCase): compiled_fn = torch.compile(simple_function) # Run the compiled function - x = torch.randn(10, device="cuda", requires_grad=True) + x = torch.randn(10, device=GPU_TYPE, requires_grad=True) result = compiled_fn(x) result.sum().backward() # Check that PrecompileContext._new_cache_artifacts_by_key has length 1 @@ -80,7 +80,7 @@ class PrecompileContextTests(InductorTestCase): compiled_fn = torch.compile(simple_function) # Run the compiled function - x = torch.randn(10, device="cuda", requires_grad=True) + x = torch.randn(10, device=GPU_TYPE, requires_grad=True) result = compiled_fn(x) result.sum().backward() # Check that PrecompileContext._new_cache_artifacts_by_key has length 1 diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index 6e9dce76a2b8..b11183ed83ed 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -242,7 +242,6 @@ test_failures = { "test_pointwise_laguerre_polynomial_l_dynamic_shapes": TestFailure(("cuda", "xpu")), "test_pointwise_legendre_polynomial_p_dynamic_shapes": TestFailure(("cuda", "xpu")), "test_polar_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu"), is_skip=True), - "test_randint_distribution_dynamic_shapes": TestFailure(("cuda", "xpu")), "test_randn_generator_dynamic_shapes": TestFailure(("cpu",)), "test_randn_like_empty_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_single_elem_dynamic_shapes": TestFailure(("cpu",)), diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index f6e985596463..93e0026d2399 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -61,7 +61,6 @@ test_failures = { "test_AllenaiLongformerBase_repro_dynamic_shapes": TestFailure( ("cpu", "cuda", "xpu") ), - "test_randint_distribution_dynamic_shapes": TestFailure(("xpu",)), } if not torch._inductor.config.cpp_wrapper: test_failures["test_conv_inference_heuristics_dynamic_shapes"] = TestFailure( diff --git a/test/test_openreg.py b/test/test_openreg.py index 85948fd85b39..b8dbb18bd527 100644 --- a/test/test_openreg.py +++ b/test/test_openreg.py @@ -13,6 +13,7 @@ from torch.testing._internal.common_utils import ( IS_LINUX, run_tests, skipIfTorchDynamo, + skipIfXpu, TestCase, ) @@ -365,6 +366,7 @@ class TestOpenReg(TestCase): self.assertEqual(y.to(device="cpu"), torch.tensor([[1, 1], [2, 2], [3, 3]])) self.assertEqual(x.data_ptr(), y.data_ptr()) + @skipIfXpu(msg="missing kernel for openreg") def test_quantize(self): x = torch.randn(3, 4, 5, dtype=torch.float32, device="openreg") quantized_tensor = torch.quantize_per_tensor(x, 0.1, 10, torch.qint8)