mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor UT][Windows][XPU] Fix Inductor UT on XPU Windows. (#146481)
This PR fixed all the inductor UT failures for XPU backend on Windows we found in local machine(Due to resource constraints, we have not yet set up a Windows CI pipeline online.) Pull Request resolved: https://github.com/pytorch/pytorch/pull/146481 Approved by: https://github.com/jansel, https://github.com/EikanWang ghstack dependencies: #147347
This commit is contained in:
committed by
PyTorch MergeBot
parent
2d433cf1ad
commit
b11d5cd584
@ -11,6 +11,7 @@ from pathlib import Path
|
|||||||
import torch
|
import torch
|
||||||
from torch._inductor import config, test_operators
|
from torch._inductor import config, test_operators
|
||||||
from torch._inductor.utils import fresh_inductor_cache
|
from torch._inductor.utils import fresh_inductor_cache
|
||||||
|
from torch.testing._internal.common_utils import skipIfWindows
|
||||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
||||||
from torch.testing._internal.logging_utils import multiple_logs_to_string
|
from torch.testing._internal.logging_utils import multiple_logs_to_string
|
||||||
|
|
||||||
@ -230,6 +231,8 @@ op2.node.kernel = extern_kernels.mm""",
|
|||||||
# intentionally only cleanup on success so debugging test is easier
|
# intentionally only cleanup on success so debugging test is easier
|
||||||
shutil.rmtree(filename)
|
shutil.rmtree(filename)
|
||||||
|
|
||||||
|
# AOT compiler have not supported windows yet.
|
||||||
|
@skipIfWindows
|
||||||
def test_debug_printer_const(self):
|
def test_debug_printer_const(self):
|
||||||
"""Test that having a const example_input does not break the debug printer."""
|
"""Test that having a const example_input does not break the debug printer."""
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ import torch
|
|||||||
import torch._dynamo
|
import torch._dynamo
|
||||||
import torch.utils.cpp_extension
|
import torch.utils.cpp_extension
|
||||||
from torch._C import FileCheck
|
from torch._C import FileCheck
|
||||||
|
from torch.testing._internal.common_utils import skipIfWindows
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -90,9 +91,9 @@ class BaseExtensionBackendTests(TestCase):
|
|||||||
|
|
||||||
torch.testing._internal.common_utils.remove_cpp_extensions_build_root()
|
torch.testing._internal.common_utils.remove_cpp_extensions_build_root()
|
||||||
|
|
||||||
|
cls.lock.release()
|
||||||
if os.path.exists(cls.lock_file):
|
if os.path.exists(cls.lock_file):
|
||||||
os.remove(cls.lock_file)
|
os.remove(cls.lock_file)
|
||||||
cls.lock.release()
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
@ -114,6 +115,7 @@ class BaseExtensionBackendTests(TestCase):
|
|||||||
|
|
||||||
@unittest.skipIf(IS_FBCODE, "cpp_extension doesn't work in fbcode right now")
|
@unittest.skipIf(IS_FBCODE, "cpp_extension doesn't work in fbcode right now")
|
||||||
class ExtensionBackendTests(BaseExtensionBackendTests):
|
class ExtensionBackendTests(BaseExtensionBackendTests):
|
||||||
|
@skipIfWindows
|
||||||
def test_open_device_registration(self):
|
def test_open_device_registration(self):
|
||||||
torch.utils.rename_privateuse1_backend("extension_device")
|
torch.utils.rename_privateuse1_backend("extension_device")
|
||||||
torch._register_device_module("extension_device", self.module)
|
torch._register_device_module("extension_device", self.module)
|
||||||
|
@ -16,6 +16,7 @@ from torch._inductor.utils import run_and_get_triton_code
|
|||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
instantiate_parametrized_tests,
|
instantiate_parametrized_tests,
|
||||||
IS_MACOS,
|
IS_MACOS,
|
||||||
|
IS_WINDOWS,
|
||||||
parametrize,
|
parametrize,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
|
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
|
||||||
@ -30,7 +31,7 @@ from torch.utils._sympy.functions import (
|
|||||||
|
|
||||||
|
|
||||||
# int64_t is long long on MacOS, but long on 64-bit Linux
|
# int64_t is long long on MacOS, but long on 64-bit Linux
|
||||||
LONG_SUFFIX = "LL" if IS_MACOS else "L"
|
LONG_SUFFIX = "LL" if IS_MACOS or IS_WINDOWS else "L"
|
||||||
DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1"
|
DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1"
|
||||||
|
|
||||||
|
|
||||||
|
@ -4136,9 +4136,9 @@ class CommonTemplate:
|
|||||||
x = torch.sum(x.view(int(x.shape[0] / 6), 6), dim=1)
|
x = torch.sum(x.view(int(x.shape[0] / 6), 6), dim=1)
|
||||||
return torch.gather(x, 0, torch.trunc(y).to(torch.int64))
|
return torch.gather(x, 0, torch.trunc(y).to(torch.int64))
|
||||||
|
|
||||||
x1 = torch.randn(30)
|
x1 = torch.randn(30, device=self.device)
|
||||||
x2 = torch.randn(36)
|
x2 = torch.randn(36, device=self.device)
|
||||||
y = torch.ones(1, dtype=torch.float64)
|
y = torch.ones(1, dtype=torch.float64, device=self.device)
|
||||||
|
|
||||||
self.assertEqual(torch.compile(fn)(x1, y), fn(x1, y))
|
self.assertEqual(torch.compile(fn)(x1, y), fn(x1, y))
|
||||||
self.assertEqual(torch.compile(fn)(x2, y), fn(x2, y))
|
self.assertEqual(torch.compile(fn)(x2, y), fn(x2, y))
|
||||||
|
@ -22,6 +22,7 @@ from torch.testing._internal import common_utils
|
|||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
parametrize,
|
parametrize,
|
||||||
skipIfRocm,
|
skipIfRocm,
|
||||||
|
skipIfWindows,
|
||||||
skipIfXpu,
|
skipIfXpu,
|
||||||
TEST_WITH_ROCM,
|
TEST_WITH_ROCM,
|
||||||
)
|
)
|
||||||
@ -3362,6 +3363,7 @@ class CustomOpTests(torch._inductor.test_case.TestCase):
|
|||||||
gm = make_fx(f, tracing_mode=tracing_mode)(x, x)
|
gm = make_fx(f, tracing_mode=tracing_mode)(x, x)
|
||||||
self.assertEqual(gm(x, x), x + x)
|
self.assertEqual(gm(x, x), x + x)
|
||||||
|
|
||||||
|
@skipIfWindows(msg="AOTI/Cpp_Wrapper have not enabled on Windows")
|
||||||
@requires_gpu
|
@requires_gpu
|
||||||
@patch.object(torch._inductor.config, "cpp_wrapper", True)
|
@patch.object(torch._inductor.config, "cpp_wrapper", True)
|
||||||
@patch.object(torch._inductor.config, "triton.autotune_at_compile_time", True)
|
@patch.object(torch._inductor.config, "triton.autotune_at_compile_time", True)
|
||||||
|
@ -2,20 +2,10 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS
|
|
||||||
|
|
||||||
|
|
||||||
if IS_WINDOWS and IS_CI:
|
|
||||||
sys.stderr.write(
|
|
||||||
"Windows CI does not have necessary dependencies for test_xpu_basic yet\n"
|
|
||||||
)
|
|
||||||
if __name__ == "__main__":
|
|
||||||
sys.exit(0)
|
|
||||||
raise unittest.SkipTest("requires sympy/functorch/filelock")
|
|
||||||
|
|
||||||
importlib.import_module("filelock")
|
importlib.import_module("filelock")
|
||||||
|
|
||||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||||
|
@ -172,7 +172,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||||||
from torch._inductor.codecache import CppWrapperCodeCache
|
from torch._inductor.codecache import CppWrapperCodeCache
|
||||||
|
|
||||||
cpp_wrapper_src = (
|
cpp_wrapper_src = (
|
||||||
'''
|
r'''
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,15 +1,17 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from torch._dynamo.utils import counters, dynamo_timed, set_feature_use
|
from torch._dynamo.utils import counters, dynamo_timed, set_feature_use
|
||||||
from torch._utils_internal import justknobs_check
|
from torch._utils_internal import justknobs_check
|
||||||
|
from torch.utils._filelock import FileLock
|
||||||
|
|
||||||
from .runtime.runtime_utils import triton_cache_dir
|
from .runtime.runtime_utils import triton_cache_dir
|
||||||
from .utils import GPU_KERNEL_BIN_EXTS
|
from .utils import _IS_WINDOWS, GPU_KERNEL_BIN_EXTS
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@ -238,7 +240,7 @@ class TritonBundler:
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
Path(directory).mkdir(parents=True, exist_ok=True)
|
Path(basedir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Random ID to avoid any collisions
|
# Random ID to avoid any collisions
|
||||||
rnd_id = str(uuid.uuid4())
|
rnd_id = str(uuid.uuid4())
|
||||||
@ -260,6 +262,14 @@ class TritonBundler:
|
|||||||
# Each kernel has bunch of files like .cubin(for cuda), spv(for xpu), .json, .ttir
|
# Each kernel has bunch of files like .cubin(for cuda), spv(for xpu), .json, .ttir
|
||||||
# Just append one of them without the extension
|
# Just append one of them without the extension
|
||||||
kernel_names.append(Path(artifact.filename).stem)
|
kernel_names.append(Path(artifact.filename).stem)
|
||||||
# Atomic on POSIX systems
|
|
||||||
os.replace(tmp_dir, directory)
|
if _IS_WINDOWS:
|
||||||
|
with FileLock(directory + ".lock"):
|
||||||
|
if os.path.exists(directory):
|
||||||
|
shutil.rmtree(directory)
|
||||||
|
os.replace(tmp_dir, directory)
|
||||||
|
else:
|
||||||
|
# Atomic on POSIX systems
|
||||||
|
os.replace(tmp_dir, directory)
|
||||||
|
|
||||||
return TritonBundlerMetadata(kernel_names)
|
return TritonBundlerMetadata(kernel_names)
|
||||||
|
Reference in New Issue
Block a user