[Fix XPU CI][Inductor UT] Fix test cases broken by community. (#162933)

Fixes #162937

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162933
Approved by: https://github.com/EikanWang, https://github.com/jansel
This commit is contained in:
xinan.lin
2025-09-17 05:35:06 +00:00
committed by PyTorch MergeBot
parent f1eb99e2e4
commit 39450e7b00
4 changed files with 17 additions and 7 deletions

View File

@ -18,7 +18,7 @@ from torch.nativert.backends._lower_utils import (
package_nativert_with_aoti_delegate,
)
from torch.testing._internal.common_utils import IS_WINDOWS
from torch.testing._internal.inductor_utils import HAS_GPU
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
from torch.utils import _pytree as pytree
@ -243,7 +243,7 @@ class TestNativeRT(TestCase):
parameters = []
for device in ["cpu", "cuda"]:
if device == "cuda" and not HAS_GPU:
if device == "cuda" and not HAS_CUDA_AND_TRITON:
continue
for module, sample_inputs in [
(get_module.__func__().to(device), (torch.randn(4, 4).to(device),)),
@ -353,4 +353,6 @@ del test
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()
# nativert has not been supported on XPU yet.
if not torch.xpu.is_available():
run_tests()

View File

@ -7124,6 +7124,8 @@ class AOTInductorTestsTemplate:
"libtorch_cuda.so",
"libc10_cuda.so",
"libtorch_cpu.so",
"libtorch_xpu.so",
"libc10_xpu.so",
}
with tempfile.TemporaryDirectory() as tmpdir:

View File

@ -1603,15 +1603,16 @@ def get_cpp_torch_device_options(
raise OSError(xpu_error_string)
include_dirs += [os.path.join(ze_root, "include")]
libraries_dirs += [os.path.join(ze_root, "lib")]
libraries += ["c10_xpu", "sycl", "ze_loader", "torch_xpu"]
else:
# Suppress multi-line comment warnings in sycl headers
cflags += ["Wno-comment"]
libraries += ["c10_xpu", "sycl", "ze_loader", "torch_xpu"]
if not find_library("ze_loader"):
raise OSError(xpu_error_string)
libraries += ["ze_loader", "sycl"]
if link_libtorch:
libraries += ["c10_xpu", "torch_xpu"]
if device_type == "mps":
definitions.append(" USE_MPS")

View File

@ -1306,7 +1306,6 @@ class XPUConfigHeuristic(BaseConfigHeuristic):
def __init__(self) -> None:
super().__init__()
self.xpu_default_flex_config = {
(torch.float32, 64): FlexConfig(128, 32, 1, 16),
(torch.float32, 128): FlexConfig(128, 32, 1, 16),
@ -2209,6 +2208,12 @@ class CPUMMPlusMMTemplateConfigHeuristic(
class XPUMMTemplateConfigHeuristic(MMTemplateConfigMixin, XPUConfigHeuristic):
"""Standard MM template heuristic for XPU"""
def __init__(self) -> None:
super().__init__()
# TODO(etaf): Design proper exhaustive search space for XPU.
self.exhaustive_configs = self.mm_configs
@register_template_heuristic(mm_template.uid, "xpu", op_name="addmm")
@register_template_heuristic(bmm_template.uid, "xpu", op_name="baddbmm")