mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Fix XPU CI][Inductor UT] Fix test cases broken by community. (#162933)
Fixes #162937 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162933 Approved by: https://github.com/EikanWang, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
f1eb99e2e4
commit
39450e7b00
@ -18,7 +18,7 @@ from torch.nativert.backends._lower_utils import (
|
||||
package_nativert_with_aoti_delegate,
|
||||
)
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
||||
@ -243,7 +243,7 @@ class TestNativeRT(TestCase):
|
||||
|
||||
parameters = []
|
||||
for device in ["cpu", "cuda"]:
|
||||
if device == "cuda" and not HAS_GPU:
|
||||
if device == "cuda" and not HAS_CUDA_AND_TRITON:
|
||||
continue
|
||||
for module, sample_inputs in [
|
||||
(get_module.__func__().to(device), (torch.randn(4, 4).to(device),)),
|
||||
@ -353,4 +353,6 @@ del test
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
# nativert has not been supported on XPU yet.
|
||||
if not torch.xpu.is_available():
|
||||
run_tests()
|
||||
|
||||
@ -7124,6 +7124,8 @@ class AOTInductorTestsTemplate:
|
||||
"libtorch_cuda.so",
|
||||
"libc10_cuda.so",
|
||||
"libtorch_cpu.so",
|
||||
"libtorch_xpu.so",
|
||||
"libc10_xpu.so",
|
||||
}
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
|
||||
@ -1603,15 +1603,16 @@ def get_cpp_torch_device_options(
|
||||
raise OSError(xpu_error_string)
|
||||
include_dirs += [os.path.join(ze_root, "include")]
|
||||
libraries_dirs += [os.path.join(ze_root, "lib")]
|
||||
libraries += ["c10_xpu", "sycl", "ze_loader", "torch_xpu"]
|
||||
else:
|
||||
# Suppress multi-line comment warnings in sycl headers
|
||||
cflags += ["Wno-comment"]
|
||||
libraries += ["c10_xpu", "sycl", "ze_loader", "torch_xpu"]
|
||||
|
||||
if not find_library("ze_loader"):
|
||||
raise OSError(xpu_error_string)
|
||||
|
||||
libraries += ["ze_loader", "sycl"]
|
||||
if link_libtorch:
|
||||
libraries += ["c10_xpu", "torch_xpu"]
|
||||
|
||||
if device_type == "mps":
|
||||
definitions.append(" USE_MPS")
|
||||
|
||||
|
||||
@ -1306,7 +1306,6 @@ class XPUConfigHeuristic(BaseConfigHeuristic):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.xpu_default_flex_config = {
|
||||
(torch.float32, 64): FlexConfig(128, 32, 1, 16),
|
||||
(torch.float32, 128): FlexConfig(128, 32, 1, 16),
|
||||
@ -2209,6 +2208,12 @@ class CPUMMPlusMMTemplateConfigHeuristic(
|
||||
class XPUMMTemplateConfigHeuristic(MMTemplateConfigMixin, XPUConfigHeuristic):
|
||||
"""Standard MM template heuristic for XPU"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
# TODO(etaf): Design proper exhaustive search space for XPU.
|
||||
self.exhaustive_configs = self.mm_configs
|
||||
|
||||
|
||||
@register_template_heuristic(mm_template.uid, "xpu", op_name="addmm")
|
||||
@register_template_heuristic(bmm_template.uid, "xpu", op_name="baddbmm")
|
||||
|
||||
Reference in New Issue
Block a user