Files
pytorch/test/inductor/test_extension_backend.py
Alexander Grund 9dded148d0 Fix test_extension_backend on non-AVX systems (#117272)
The test checks for a substring "loadu" in generated code. On AVX systems that line is:
> auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(i0))
however on non-AVX systems it is
> auto tmp0 = in_ptr0[static_cast<long>(i0)];

the difference depends on `codecache.valid_vec_isa_list()` being non-empty. See torch/_inductor/codegen/cpp.py:2639

Modify the test to account for that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117272
Approved by: https://github.com/jgong5, https://github.com/jansel
2024-04-24 02:55:12 +00:00

168 lines
5.1 KiB
Python

# Owner(s): ["module: inductor"]
import os
import shutil
import sys
import unittest
import torch
import torch._dynamo
import torch.utils.cpp_extension
from torch._C import FileCheck
try:
from extension_backends.cpp.extension_codegen_backend import (
ExtensionCppWrapperCodegen,
ExtensionScheduling,
ExtensionWrapperCodegen,
)
except ImportError:
from .extension_backends.cpp.extension_codegen_backend import (
ExtensionCppWrapperCodegen,
ExtensionScheduling,
ExtensionWrapperCodegen,
)
import torch._inductor.config as config
from torch._inductor import codecache, metrics
from torch._inductor.codegen import cpp
from torch._inductor.codegen.common import (
get_scheduling_for_device,
get_wrapper_codegen_for_device,
register_backend_for_device,
)
from torch.testing._internal.common_utils import IS_FBCODE, IS_MACOS
try:
try:
from . import test_torchinductor
except ImportError:
import test_torchinductor
except unittest.SkipTest:
if __name__ == "__main__":
sys.exit(0)
raise
run_and_get_cpp_code = test_torchinductor.run_and_get_cpp_code
TestCase = test_torchinductor.TestCase
def remove_build_path():
if sys.platform == "win32":
# Not wiping extensions build folder because Windows
return
default_build_root = torch.utils.cpp_extension.get_default_build_root()
if os.path.exists(default_build_root):
shutil.rmtree(default_build_root, ignore_errors=True)
@unittest.skipIf(IS_FBCODE, "cpp_extension doesn't work in fbcode right now")
class ExtensionBackendTests(TestCase):
module = None
@classmethod
def setUpClass(cls):
super().setUpClass()
# Build Extension
remove_build_path()
source_file_path = os.path.dirname(os.path.abspath(__file__))
source_file = os.path.join(
source_file_path, "extension_backends/cpp/extension_device.cpp"
)
cls.module = torch.utils.cpp_extension.load(
name="extension_device",
sources=[
str(source_file),
],
extra_cflags=["-g"],
verbose=True,
)
@classmethod
def tearDownClass(cls):
cls._stack.close()
super().tearDownClass()
remove_build_path()
def setUp(self):
torch._dynamo.reset()
super().setUp()
# cpp extensions use relative paths. Those paths are relative to
# this file, so we'll change the working directory temporarily
self.old_working_dir = os.getcwd()
os.chdir(os.path.dirname(os.path.abspath(__file__)))
assert self.module is not None
def tearDown(self):
super().tearDown()
torch._dynamo.reset()
# return the working directory (see setUp)
os.chdir(self.old_working_dir)
def test_open_device_registration(self):
torch.utils.rename_privateuse1_backend("extension_device")
torch._register_device_module("extension_device", self.module)
register_backend_for_device(
"extension_device",
ExtensionScheduling,
ExtensionWrapperCodegen,
ExtensionCppWrapperCodegen,
)
self.assertTrue(
get_scheduling_for_device("extension_device") == ExtensionScheduling
)
self.assertTrue(
get_wrapper_codegen_for_device("extension_device")
== ExtensionWrapperCodegen
)
self.assertTrue(
get_wrapper_codegen_for_device("extension_device", True)
== ExtensionCppWrapperCodegen
)
self.assertFalse(self.module.custom_op_called())
device = self.module.custom_device()
x = torch.empty(2, 16).to(device=device).fill_(1)
self.assertTrue(self.module.custom_op_called())
y = torch.empty(2, 16).to(device=device).fill_(2)
z = torch.empty(2, 16).to(device=device).fill_(3)
ref = torch.empty(2, 16).fill_(5)
self.assertTrue(x.device == device)
self.assertTrue(y.device == device)
self.assertTrue(z.device == device)
def fn(a, b, c):
return a * b + c
cpp.DEVICE_TO_ATEN["extension_device"] = "at::kPrivateUse1"
for cpp_wrapper_flag in [True, False]:
with config.patch({"cpp_wrapper": cpp_wrapper_flag}):
metrics.reset()
opt_fn = torch.compile()(fn)
_, code = run_and_get_cpp_code(opt_fn, x, y, z)
if codecache.valid_vec_isa_list():
load_expr = "loadu"
else:
load_expr = " = in_ptr0[static_cast<long>(i0)];"
FileCheck().check("void").check(load_expr).check(
"extension_device"
).run(code)
opt_fn(x, y, z)
res = opt_fn(x, y, z)
self.assertEqual(ref, res.to(device="cpu"))
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
from torch.testing._internal.inductor_utils import HAS_CPU
# cpp_extension doesn't work in fbcode right now
if HAS_CPU and not IS_MACOS and not IS_FBCODE:
run_tests(needs="filelock")