[Inductor UT][Windows][XPU] Fix Inductor UT on XPU Windows. (#146481)

This PR fixed all the inductor UT failures for XPU backend on Windows we found in local machine(Due to resource constraints, we have not yet set up a Windows CI pipeline online.)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146481
Approved by: https://github.com/jansel, https://github.com/EikanWang
ghstack dependencies: #147347
This commit is contained in:
xinan.lin
2025-02-21 10:02:52 +08:00
committed by PyTorch MergeBot
parent 2d433cf1ad
commit b11d5cd584
8 changed files with 28 additions and 20 deletions

View File

@ -11,6 +11,7 @@ from pathlib import Path
import torch import torch
from torch._inductor import config, test_operators from torch._inductor import config, test_operators
from torch._inductor.utils import fresh_inductor_cache from torch._inductor.utils import fresh_inductor_cache
from torch.testing._internal.common_utils import skipIfWindows
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
from torch.testing._internal.logging_utils import multiple_logs_to_string from torch.testing._internal.logging_utils import multiple_logs_to_string
@ -230,6 +231,8 @@ op2.node.kernel = extern_kernels.mm""",
# intentionally only cleanup on success so debugging test is easier # intentionally only cleanup on success so debugging test is easier
shutil.rmtree(filename) shutil.rmtree(filename)
# AOT compiler have not supported windows yet.
@skipIfWindows
def test_debug_printer_const(self): def test_debug_printer_const(self):
"""Test that having a const example_input does not break the debug printer.""" """Test that having a const example_input does not break the debug printer."""

View File

@ -7,6 +7,7 @@ import torch
import torch._dynamo import torch._dynamo
import torch.utils.cpp_extension import torch.utils.cpp_extension
from torch._C import FileCheck from torch._C import FileCheck
from torch.testing._internal.common_utils import skipIfWindows
try: try:
@ -90,9 +91,9 @@ class BaseExtensionBackendTests(TestCase):
torch.testing._internal.common_utils.remove_cpp_extensions_build_root() torch.testing._internal.common_utils.remove_cpp_extensions_build_root()
cls.lock.release()
if os.path.exists(cls.lock_file): if os.path.exists(cls.lock_file):
os.remove(cls.lock_file) os.remove(cls.lock_file)
cls.lock.release()
def setUp(self): def setUp(self):
torch._dynamo.reset() torch._dynamo.reset()
@ -114,6 +115,7 @@ class BaseExtensionBackendTests(TestCase):
@unittest.skipIf(IS_FBCODE, "cpp_extension doesn't work in fbcode right now") @unittest.skipIf(IS_FBCODE, "cpp_extension doesn't work in fbcode right now")
class ExtensionBackendTests(BaseExtensionBackendTests): class ExtensionBackendTests(BaseExtensionBackendTests):
@skipIfWindows
def test_open_device_registration(self): def test_open_device_registration(self):
torch.utils.rename_privateuse1_backend("extension_device") torch.utils.rename_privateuse1_backend("extension_device")
torch._register_device_module("extension_device", self.module) torch._register_device_module("extension_device", self.module)

View File

@ -16,6 +16,7 @@ from torch._inductor.utils import run_and_get_triton_code
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
instantiate_parametrized_tests, instantiate_parametrized_tests,
IS_MACOS, IS_MACOS,
IS_WINDOWS,
parametrize, parametrize,
) )
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
@ -30,7 +31,7 @@ from torch.utils._sympy.functions import (
# int64_t is long long on MacOS, but long on 64-bit Linux # int64_t is long long on MacOS, but long on 64-bit Linux
LONG_SUFFIX = "LL" if IS_MACOS else "L" LONG_SUFFIX = "LL" if IS_MACOS or IS_WINDOWS else "L"
DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1" DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1"

View File

@ -4136,9 +4136,9 @@ class CommonTemplate:
x = torch.sum(x.view(int(x.shape[0] / 6), 6), dim=1) x = torch.sum(x.view(int(x.shape[0] / 6), 6), dim=1)
return torch.gather(x, 0, torch.trunc(y).to(torch.int64)) return torch.gather(x, 0, torch.trunc(y).to(torch.int64))
x1 = torch.randn(30) x1 = torch.randn(30, device=self.device)
x2 = torch.randn(36) x2 = torch.randn(36, device=self.device)
y = torch.ones(1, dtype=torch.float64) y = torch.ones(1, dtype=torch.float64, device=self.device)
self.assertEqual(torch.compile(fn)(x1, y), fn(x1, y)) self.assertEqual(torch.compile(fn)(x1, y), fn(x1, y))
self.assertEqual(torch.compile(fn)(x2, y), fn(x2, y)) self.assertEqual(torch.compile(fn)(x2, y), fn(x2, y))

View File

@ -22,6 +22,7 @@ from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
parametrize, parametrize,
skipIfRocm, skipIfRocm,
skipIfWindows,
skipIfXpu, skipIfXpu,
TEST_WITH_ROCM, TEST_WITH_ROCM,
) )
@ -3362,6 +3363,7 @@ class CustomOpTests(torch._inductor.test_case.TestCase):
gm = make_fx(f, tracing_mode=tracing_mode)(x, x) gm = make_fx(f, tracing_mode=tracing_mode)(x, x)
self.assertEqual(gm(x, x), x + x) self.assertEqual(gm(x, x), x + x)
@skipIfWindows(msg="AOTI/Cpp_Wrapper have not enabled on Windows")
@requires_gpu @requires_gpu
@patch.object(torch._inductor.config, "cpp_wrapper", True) @patch.object(torch._inductor.config, "cpp_wrapper", True)
@patch.object(torch._inductor.config, "triton.autotune_at_compile_time", True) @patch.object(torch._inductor.config, "triton.autotune_at_compile_time", True)

View File

@ -2,20 +2,10 @@
import importlib import importlib
import os import os
import sys import sys
import unittest
import torch import torch
from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS
if IS_WINDOWS and IS_CI:
sys.stderr.write(
"Windows CI does not have necessary dependencies for test_xpu_basic yet\n"
)
if __name__ == "__main__":
sys.exit(0)
raise unittest.SkipTest("requires sympy/functorch/filelock")
importlib.import_module("filelock") importlib.import_module("filelock")
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))

View File

@ -172,7 +172,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
from torch._inductor.codecache import CppWrapperCodeCache from torch._inductor.codecache import CppWrapperCodeCache
cpp_wrapper_src = ( cpp_wrapper_src = (
''' r'''
""" """
) )

View File

@ -1,15 +1,17 @@
import dataclasses import dataclasses
import logging import logging
import os import os
import shutil
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from torch._dynamo.utils import counters, dynamo_timed, set_feature_use from torch._dynamo.utils import counters, dynamo_timed, set_feature_use
from torch._utils_internal import justknobs_check from torch._utils_internal import justknobs_check
from torch.utils._filelock import FileLock
from .runtime.runtime_utils import triton_cache_dir from .runtime.runtime_utils import triton_cache_dir
from .utils import GPU_KERNEL_BIN_EXTS from .utils import _IS_WINDOWS, GPU_KERNEL_BIN_EXTS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -238,7 +240,7 @@ class TritonBundler:
) )
continue continue
Path(directory).mkdir(parents=True, exist_ok=True) Path(basedir).mkdir(parents=True, exist_ok=True)
# Random ID to avoid any collisions # Random ID to avoid any collisions
rnd_id = str(uuid.uuid4()) rnd_id = str(uuid.uuid4())
@ -260,6 +262,14 @@ class TritonBundler:
# Each kernel has bunch of files like .cubin(for cuda), spv(for xpu), .json, .ttir # Each kernel has bunch of files like .cubin(for cuda), spv(for xpu), .json, .ttir
# Just append one of them without the extension # Just append one of them without the extension
kernel_names.append(Path(artifact.filename).stem) kernel_names.append(Path(artifact.filename).stem)
# Atomic on POSIX systems
os.replace(tmp_dir, directory) if _IS_WINDOWS:
with FileLock(directory + ".lock"):
if os.path.exists(directory):
shutil.rmtree(directory)
os.replace(tmp_dir, directory)
else:
# Atomic on POSIX systems
os.replace(tmp_dir, directory)
return TritonBundlerMetadata(kernel_names) return TritonBundlerMetadata(kernel_names)