mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	The test checks for a substring "loadu" in generated code. On AVX systems that line is: > auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(i0)) however on non-AVX systems it is > auto tmp0 = in_ptr0[static_cast<long>(i0)]; the difference depends on `codecache.valid_vec_isa_list()` being non-empty. See torch/_inductor/codegen/cpp.py:2639 Modify the test to account for that. Pull Request resolved: https://github.com/pytorch/pytorch/pull/117272 Approved by: https://github.com/jgong5, https://github.com/jansel
		
			
				
	
	
		
			168 lines
		
	
	
		
			5.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			168 lines
		
	
	
		
			5.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Owner(s): ["module: inductor"]
 | |
| import os
 | |
| import shutil
 | |
| import sys
 | |
| import unittest
 | |
| 
 | |
| import torch
 | |
| import torch._dynamo
 | |
| import torch.utils.cpp_extension
 | |
| from torch._C import FileCheck
 | |
| 
 | |
| try:
 | |
|     from extension_backends.cpp.extension_codegen_backend import (
 | |
|         ExtensionCppWrapperCodegen,
 | |
|         ExtensionScheduling,
 | |
|         ExtensionWrapperCodegen,
 | |
|     )
 | |
| except ImportError:
 | |
|     from .extension_backends.cpp.extension_codegen_backend import (
 | |
|         ExtensionCppWrapperCodegen,
 | |
|         ExtensionScheduling,
 | |
|         ExtensionWrapperCodegen,
 | |
|     )
 | |
| 
 | |
| import torch._inductor.config as config
 | |
| from torch._inductor import codecache, metrics
 | |
| from torch._inductor.codegen import cpp
 | |
| from torch._inductor.codegen.common import (
 | |
|     get_scheduling_for_device,
 | |
|     get_wrapper_codegen_for_device,
 | |
|     register_backend_for_device,
 | |
| )
 | |
| from torch.testing._internal.common_utils import IS_FBCODE, IS_MACOS
 | |
| 
 | |
| try:
 | |
|     try:
 | |
|         from . import test_torchinductor
 | |
|     except ImportError:
 | |
|         import test_torchinductor
 | |
| except unittest.SkipTest:
 | |
|     if __name__ == "__main__":
 | |
|         sys.exit(0)
 | |
|     raise
 | |
| 
 | |
| 
 | |
| run_and_get_cpp_code = test_torchinductor.run_and_get_cpp_code
 | |
| TestCase = test_torchinductor.TestCase
 | |
| 
 | |
| 
 | |
| def remove_build_path():
 | |
|     if sys.platform == "win32":
 | |
|         # Not wiping extensions build folder because Windows
 | |
|         return
 | |
|     default_build_root = torch.utils.cpp_extension.get_default_build_root()
 | |
|     if os.path.exists(default_build_root):
 | |
|         shutil.rmtree(default_build_root, ignore_errors=True)
 | |
| 
 | |
| 
 | |
| @unittest.skipIf(IS_FBCODE, "cpp_extension doesn't work in fbcode right now")
 | |
| class ExtensionBackendTests(TestCase):
 | |
|     module = None
 | |
| 
 | |
|     @classmethod
 | |
|     def setUpClass(cls):
 | |
|         super().setUpClass()
 | |
| 
 | |
|         # Build Extension
 | |
|         remove_build_path()
 | |
|         source_file_path = os.path.dirname(os.path.abspath(__file__))
 | |
|         source_file = os.path.join(
 | |
|             source_file_path, "extension_backends/cpp/extension_device.cpp"
 | |
|         )
 | |
|         cls.module = torch.utils.cpp_extension.load(
 | |
|             name="extension_device",
 | |
|             sources=[
 | |
|                 str(source_file),
 | |
|             ],
 | |
|             extra_cflags=["-g"],
 | |
|             verbose=True,
 | |
|         )
 | |
| 
 | |
|     @classmethod
 | |
|     def tearDownClass(cls):
 | |
|         cls._stack.close()
 | |
|         super().tearDownClass()
 | |
| 
 | |
|         remove_build_path()
 | |
| 
 | |
|     def setUp(self):
 | |
|         torch._dynamo.reset()
 | |
|         super().setUp()
 | |
| 
 | |
|         # cpp extensions use relative paths. Those paths are relative to
 | |
|         # this file, so we'll change the working directory temporarily
 | |
|         self.old_working_dir = os.getcwd()
 | |
|         os.chdir(os.path.dirname(os.path.abspath(__file__)))
 | |
|         assert self.module is not None
 | |
| 
 | |
|     def tearDown(self):
 | |
|         super().tearDown()
 | |
|         torch._dynamo.reset()
 | |
| 
 | |
|         # return the working directory (see setUp)
 | |
|         os.chdir(self.old_working_dir)
 | |
| 
 | |
|     def test_open_device_registration(self):
 | |
|         torch.utils.rename_privateuse1_backend("extension_device")
 | |
|         torch._register_device_module("extension_device", self.module)
 | |
| 
 | |
|         register_backend_for_device(
 | |
|             "extension_device",
 | |
|             ExtensionScheduling,
 | |
|             ExtensionWrapperCodegen,
 | |
|             ExtensionCppWrapperCodegen,
 | |
|         )
 | |
|         self.assertTrue(
 | |
|             get_scheduling_for_device("extension_device") == ExtensionScheduling
 | |
|         )
 | |
|         self.assertTrue(
 | |
|             get_wrapper_codegen_for_device("extension_device")
 | |
|             == ExtensionWrapperCodegen
 | |
|         )
 | |
|         self.assertTrue(
 | |
|             get_wrapper_codegen_for_device("extension_device", True)
 | |
|             == ExtensionCppWrapperCodegen
 | |
|         )
 | |
| 
 | |
|         self.assertFalse(self.module.custom_op_called())
 | |
|         device = self.module.custom_device()
 | |
|         x = torch.empty(2, 16).to(device=device).fill_(1)
 | |
|         self.assertTrue(self.module.custom_op_called())
 | |
|         y = torch.empty(2, 16).to(device=device).fill_(2)
 | |
|         z = torch.empty(2, 16).to(device=device).fill_(3)
 | |
|         ref = torch.empty(2, 16).fill_(5)
 | |
| 
 | |
|         self.assertTrue(x.device == device)
 | |
|         self.assertTrue(y.device == device)
 | |
|         self.assertTrue(z.device == device)
 | |
| 
 | |
|         def fn(a, b, c):
 | |
|             return a * b + c
 | |
| 
 | |
|         cpp.DEVICE_TO_ATEN["extension_device"] = "at::kPrivateUse1"
 | |
|         for cpp_wrapper_flag in [True, False]:
 | |
|             with config.patch({"cpp_wrapper": cpp_wrapper_flag}):
 | |
|                 metrics.reset()
 | |
|                 opt_fn = torch.compile()(fn)
 | |
|                 _, code = run_and_get_cpp_code(opt_fn, x, y, z)
 | |
|                 if codecache.valid_vec_isa_list():
 | |
|                     load_expr = "loadu"
 | |
|                 else:
 | |
|                     load_expr = " = in_ptr0[static_cast<long>(i0)];"
 | |
|                 FileCheck().check("void").check(load_expr).check(
 | |
|                     "extension_device"
 | |
|                 ).run(code)
 | |
|                 opt_fn(x, y, z)
 | |
|                 res = opt_fn(x, y, z)
 | |
|                 self.assertEqual(ref, res.to(device="cpu"))
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     from torch._inductor.test_case import run_tests
 | |
|     from torch.testing._internal.inductor_utils import HAS_CPU
 | |
| 
 | |
|     # cpp_extension doesn't work in fbcode right now
 | |
|     if HAS_CPU and not IS_MACOS and not IS_FBCODE:
 | |
|         run_tests(needs="filelock")
 |