Files
pytorch/test/inductor/test_extension_backend.py
Jez Ng c254901bdb Have Triton custom extension test use privateuseone device (#137611)
The original PR #122396 used the CPU device since at that point in time
there was no actual Triton CPU backend. After #133408, this is no longer
the case, so we now have multiple backends getting registered for the
CPU. The test still works in OSS but fails internally due to different
test runners initializing the backends in a different order.

This PR doesn't actually end up fixing the test internally because
cpp_extension -- needed to implement the privateuseone device -- isn't
supported there, so we simply skip it instead. However, it still makes the
OSS test independent of initialization order, which is good.

Differential Revision: [D63838169](https://our.internmc.facebook.com/intern/diff/D63838169/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137611
Approved by: https://github.com/henrylhtsang
2024-10-11 21:27:29 +00:00

162 lines
5.2 KiB
Python

# Owner(s): ["module: inductor"]
import os
import sys
import unittest
import torch
import torch._dynamo
import torch.utils.cpp_extension
from torch._C import FileCheck
try:
from extension_backends.cpp.extension_codegen_backend import ( # @manual=fbcode//caffe2/test/inductor/extension_backends:extension_codegen_backend # noqa: B950
ExtensionCppWrapperCodegen,
ExtensionScheduling,
ExtensionWrapperCodegen,
)
except ImportError:
from .extension_backends.cpp.extension_codegen_backend import (
ExtensionCppWrapperCodegen,
ExtensionScheduling,
ExtensionWrapperCodegen,
)
import torch._inductor.config as config
from torch._inductor import cpu_vec_isa, metrics
from torch._inductor.codegen import cpp_utils
from torch._inductor.codegen.common import (
get_scheduling_for_device,
get_wrapper_codegen_for_device,
register_backend_for_device,
)
from torch.testing._internal.common_utils import IS_FBCODE, IS_MACOS
try:
try:
from . import test_torchinductor
except ImportError:
import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library
except unittest.SkipTest:
if __name__ == "__main__":
sys.exit(0)
raise
run_and_get_cpp_code = test_torchinductor.run_and_get_cpp_code
TestCase = test_torchinductor.TestCase
class BaseExtensionBackendTests(TestCase):
module = None
@classmethod
def setUpClass(cls):
super().setUpClass()
# Build Extension
torch.testing._internal.common_utils.remove_cpp_extensions_build_root()
source_file_path = os.path.dirname(os.path.abspath(__file__))
source_file = os.path.join(
source_file_path, "extension_backends/cpp/extension_device.cpp"
)
cls.module = torch.utils.cpp_extension.load(
name="extension_device",
sources=[
str(source_file),
],
extra_cflags=["-g"],
verbose=True,
)
@classmethod
def tearDownClass(cls):
cls._stack.close()
super().tearDownClass()
torch.testing._internal.common_utils.remove_cpp_extensions_build_root()
def setUp(self):
torch._dynamo.reset()
super().setUp()
# cpp extensions use relative paths. Those paths are relative to
# this file, so we'll change the working directory temporarily
self.old_working_dir = os.getcwd()
os.chdir(os.path.dirname(os.path.abspath(__file__)))
assert self.module is not None
def tearDown(self):
super().tearDown()
torch._dynamo.reset()
# return the working directory (see setUp)
os.chdir(self.old_working_dir)
@unittest.skipIf(IS_FBCODE, "cpp_extension doesn't work in fbcode right now")
class ExtensionBackendTests(BaseExtensionBackendTests):
def test_open_device_registration(self):
torch.utils.rename_privateuse1_backend("extension_device")
torch._register_device_module("extension_device", self.module)
register_backend_for_device(
"extension_device",
ExtensionScheduling,
ExtensionWrapperCodegen,
ExtensionCppWrapperCodegen,
)
self.assertTrue(
get_scheduling_for_device("extension_device") == ExtensionScheduling
)
self.assertTrue(
get_wrapper_codegen_for_device("extension_device")
== ExtensionWrapperCodegen
)
self.assertTrue(
get_wrapper_codegen_for_device("extension_device", True)
== ExtensionCppWrapperCodegen
)
self.assertFalse(self.module.custom_op_called())
device = self.module.custom_device()
x = torch.empty(2, 16).to(device=device).fill_(1)
self.assertTrue(self.module.custom_op_called())
y = torch.empty(2, 16).to(device=device).fill_(2)
z = torch.empty(2, 16).to(device=device).fill_(3)
ref = torch.empty(2, 16).fill_(5)
self.assertTrue(x.device == device)
self.assertTrue(y.device == device)
self.assertTrue(z.device == device)
def fn(a, b, c):
return a * b + c
cpp_utils.DEVICE_TO_ATEN["extension_device"] = "at::kPrivateUse1"
for cpp_wrapper_flag in [True, False]:
with config.patch({"cpp_wrapper": cpp_wrapper_flag}):
metrics.reset()
opt_fn = torch.compile()(fn)
_, code = run_and_get_cpp_code(opt_fn, x, y, z)
if cpu_vec_isa.valid_vec_isa_list():
load_expr = "loadu"
else:
load_expr = " = in_ptr0[static_cast<long>(i0)];"
FileCheck().check("void").check(load_expr).check(
"extension_device"
).run(code)
opt_fn(x, y, z)
res = opt_fn(x, y, z)
self.assertEqual(ref, res.to(device="cpu"))
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
from torch.testing._internal.inductor_utils import HAS_CPU
# cpp_extension doesn't work in fbcode right now
if HAS_CPU and not IS_MACOS and not IS_FBCODE:
run_tests(needs="filelock")