[Inductor UT][Windows][XPU] Fix Inductor UT on XPU Windows. (#146481)

This PR fixed all the inductor UT failures for XPU backend on Windows we found in local machine(Due to resource constraints, we have not yet set up a Windows CI pipeline online.)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146481
Approved by: https://github.com/jansel, https://github.com/EikanWang
ghstack dependencies: #147347
This commit is contained in:
xinan.lin
2025-02-21 10:02:52 +08:00
committed by PyTorch MergeBot
parent 2d433cf1ad
commit b11d5cd584
8 changed files with 28 additions and 20 deletions

View File

@ -11,6 +11,7 @@ from pathlib import Path
import torch
from torch._inductor import config, test_operators
from torch._inductor.utils import fresh_inductor_cache
from torch.testing._internal.common_utils import skipIfWindows
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
from torch.testing._internal.logging_utils import multiple_logs_to_string
@ -230,6 +231,8 @@ op2.node.kernel = extern_kernels.mm""",
# intentionally only cleanup on success so debugging test is easier
shutil.rmtree(filename)
# AOT compiler have not supported windows yet.
@skipIfWindows
def test_debug_printer_const(self):
"""Test that having a const example_input does not break the debug printer."""

View File

@ -7,6 +7,7 @@ import torch
import torch._dynamo
import torch.utils.cpp_extension
from torch._C import FileCheck
from torch.testing._internal.common_utils import skipIfWindows
try:
@ -90,9 +91,9 @@ class BaseExtensionBackendTests(TestCase):
torch.testing._internal.common_utils.remove_cpp_extensions_build_root()
cls.lock.release()
if os.path.exists(cls.lock_file):
os.remove(cls.lock_file)
cls.lock.release()
def setUp(self):
torch._dynamo.reset()
@ -114,6 +115,7 @@ class BaseExtensionBackendTests(TestCase):
@unittest.skipIf(IS_FBCODE, "cpp_extension doesn't work in fbcode right now")
class ExtensionBackendTests(BaseExtensionBackendTests):
@skipIfWindows
def test_open_device_registration(self):
torch.utils.rename_privateuse1_backend("extension_device")
torch._register_device_module("extension_device", self.module)

View File

@ -16,6 +16,7 @@ from torch._inductor.utils import run_and_get_triton_code
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
IS_MACOS,
IS_WINDOWS,
parametrize,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
@ -30,7 +31,7 @@ from torch.utils._sympy.functions import (
# int64_t is long long on MacOS, but long on 64-bit Linux
LONG_SUFFIX = "LL" if IS_MACOS else "L"
LONG_SUFFIX = "LL" if IS_MACOS or IS_WINDOWS else "L"
DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1"

View File

@ -4136,9 +4136,9 @@ class CommonTemplate:
x = torch.sum(x.view(int(x.shape[0] / 6), 6), dim=1)
return torch.gather(x, 0, torch.trunc(y).to(torch.int64))
x1 = torch.randn(30)
x2 = torch.randn(36)
y = torch.ones(1, dtype=torch.float64)
x1 = torch.randn(30, device=self.device)
x2 = torch.randn(36, device=self.device)
y = torch.ones(1, dtype=torch.float64, device=self.device)
self.assertEqual(torch.compile(fn)(x1, y), fn(x1, y))
self.assertEqual(torch.compile(fn)(x2, y), fn(x2, y))

View File

@ -22,6 +22,7 @@ from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import (
parametrize,
skipIfRocm,
skipIfWindows,
skipIfXpu,
TEST_WITH_ROCM,
)
@ -3362,6 +3363,7 @@ class CustomOpTests(torch._inductor.test_case.TestCase):
gm = make_fx(f, tracing_mode=tracing_mode)(x, x)
self.assertEqual(gm(x, x), x + x)
@skipIfWindows(msg="AOTI/Cpp_Wrapper have not enabled on Windows")
@requires_gpu
@patch.object(torch._inductor.config, "cpp_wrapper", True)
@patch.object(torch._inductor.config, "triton.autotune_at_compile_time", True)

View File

@ -2,20 +2,10 @@
import importlib
import os
import sys
import unittest
import torch
from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS
if IS_WINDOWS and IS_CI:
sys.stderr.write(
"Windows CI does not have necessary dependencies for test_xpu_basic yet\n"
)
if __name__ == "__main__":
sys.exit(0)
raise unittest.SkipTest("requires sympy/functorch/filelock")
importlib.import_module("filelock")
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))

View File

@ -172,7 +172,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
from torch._inductor.codecache import CppWrapperCodeCache
cpp_wrapper_src = (
'''
r'''
"""
)

View File

@ -1,15 +1,17 @@
import dataclasses
import logging
import os
import shutil
import uuid
from pathlib import Path
from typing import Optional
from torch._dynamo.utils import counters, dynamo_timed, set_feature_use
from torch._utils_internal import justknobs_check
from torch.utils._filelock import FileLock
from .runtime.runtime_utils import triton_cache_dir
from .utils import GPU_KERNEL_BIN_EXTS
from .utils import _IS_WINDOWS, GPU_KERNEL_BIN_EXTS
log = logging.getLogger(__name__)
@ -238,7 +240,7 @@ class TritonBundler:
)
continue
Path(directory).mkdir(parents=True, exist_ok=True)
Path(basedir).mkdir(parents=True, exist_ok=True)
# Random ID to avoid any collisions
rnd_id = str(uuid.uuid4())
@ -260,6 +262,14 @@ class TritonBundler:
# Each kernel has bunch of files like .cubin(for cuda), spv(for xpu), .json, .ttir
# Just append one of them without the extension
kernel_names.append(Path(artifact.filename).stem)
# Atomic on POSIX systems
os.replace(tmp_dir, directory)
if _IS_WINDOWS:
with FileLock(directory + ".lock"):
if os.path.exists(directory):
shutil.rmtree(directory)
os.replace(tmp_dir, directory)
else:
# Atomic on POSIX systems
os.replace(tmp_dir, directory)
return TritonBundlerMetadata(kernel_names)