mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[pt2_provenance_tracking] add support for cpp kernel (#149185)
Summary: As title. Add inductor cpp kernel to post grad graph node mapping & UT. Context: Raised as a feature request for AOTI CPU case. https://fb.workplace.com/groups/1028545332188949/permalink/1169020841474730/ Differential Revision: D71181284 Pull Request resolved: https://github.com/pytorch/pytorch/pull/149185 Approved by: https://github.com/jingsh
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							7869196482
						
					
				
				
					commit
					b8f91bcb14
				
			| @ -5,12 +5,14 @@ import logging | ||||
| import re | ||||
| import shutil | ||||
| import tempfile | ||||
| import unittest | ||||
| from pathlib import Path | ||||
|  | ||||
| import torch | ||||
| from torch._inductor import config | ||||
| from torch._inductor.debug import create_node_mapping | ||||
| from torch._inductor.test_case import run_tests, TestCase | ||||
| from torch.testing._internal.inductor_utils import HAS_GPU | ||||
| from torch.testing._internal.triton_utils import requires_cuda | ||||
|  | ||||
|  | ||||
| @ -31,7 +33,6 @@ class Model(torch.nn.Module): | ||||
|         return z | ||||
|  | ||||
|  | ||||
| @requires_cuda | ||||
| @config.patch("trace.enabled", True) | ||||
| class TestProvenanceTracingArtifact(TestCase): | ||||
|     """ | ||||
| @ -39,6 +40,7 @@ class TestProvenanceTracingArtifact(TestCase): | ||||
|     corresponding "inductor triton kernel node" is expected. | ||||
|     """ | ||||
|  | ||||
|     @requires_cuda | ||||
|     def _check_provenance_tracing_artifact(self, filepath): | ||||
|         self.assertTrue(filepath.is_dir()) | ||||
|         filename = Path(filepath) / "inductor_triton_kernel_to_post_grad_nodes.json" | ||||
| @ -113,6 +115,7 @@ class TestProvenanceTracingArtifact(TestCase): | ||||
|         ] | ||||
|         self.assertEqual(sorted(actual_data.items()), sorted(expected_data)) | ||||
|  | ||||
|     @requires_cuda | ||||
|     def test_triton_kernel_to_post_grad_tracing(self): | ||||
|         a = torch.randn(10, 20, device="cuda") | ||||
|         b = torch.randn(20, 30, device="cuda") | ||||
| @ -146,6 +149,60 @@ class TestProvenanceTracingArtifact(TestCase): | ||||
|             finally: | ||||
|                 shutil.rmtree(filepath) | ||||
|  | ||||
|     @unittest.skipIf(HAS_GPU, "the test is only for cpu") | ||||
|     def test_triton_kernel_to_post_grad_tracing_cpu(self): | ||||
|         a = torch.randn(10, 20, device="cpu") | ||||
|         b = torch.randn(20, 30, device="cpu") | ||||
|         c = torch.randn(10, 30, device="cpu") | ||||
|         example_inputs = (a, b, c) | ||||
|  | ||||
|         model = Model() | ||||
|         ep = torch.export._trace._export(model, example_inputs) | ||||
|         gm = ep.module() | ||||
|         filepath = None | ||||
|  | ||||
|         for backend in ["aot_inductor", "inductor"]: | ||||
|             try: | ||||
|                 with config.patch( | ||||
|                     { | ||||
|                         "trace.debug_dir": tempfile.mkdtemp(), | ||||
|                         "force_disable_caches": True, | ||||
|                     } | ||||
|                 ): | ||||
|                     with self.assertLogs( | ||||
|                         logging.getLogger("torch._inductor.debug"), | ||||
|                         level=logging.WARNING, | ||||
|                     ) as cm: | ||||
|                         if backend == "aot_inductor": | ||||
|                             so_path = torch._inductor.aot_compile(gm, example_inputs) | ||||
|                             optimized = AOTIRunnerUtil.load("cpu", so_path) | ||||
|                             optimized(*example_inputs) | ||||
|                         else: | ||||
|                             compiled = torch.compile(gm, backend=backend) | ||||
|                             compiled(*example_inputs) | ||||
|                     self.assertEqual(len(cm.output), 1) | ||||
|                     m = re.match(r"WARNING.* debug trace: (.*)", cm.output[0]) | ||||
|                     self.assertTrue(m) | ||||
|                     filepath = Path(m.group(1)) | ||||
|                     filename = ( | ||||
|                         Path(filepath) | ||||
|                         / "inductor_triton_kernel_to_post_grad_nodes.json" | ||||
|                     ) | ||||
|                     with open(filename) as f: | ||||
|                         actual_data = json.load(f) | ||||
|                     # check the inductor kernel to post grad nodes mapping is expected for cpu | ||||
|                     expected_data = { | ||||
|                         "cpp_fused_mul_0": ["mul"], | ||||
|                         "cpp_fused_gelu_1": ["mul_3", "mul_1", "add", "erf", "mul_2"], | ||||
|                     } | ||||
|                     self.assertEqual( | ||||
|                         sorted(actual_data.items()), sorted(expected_data.items()) | ||||
|                     ) | ||||
|  | ||||
|             finally: | ||||
|                 if filepath: | ||||
|                     shutil.rmtree(filepath) | ||||
|  | ||||
|  | ||||
| class TestProvenanceTracingNodeMapping(TestCase): | ||||
|     def test_create_node_mapping(self): | ||||
|  | ||||
| @ -43,6 +43,7 @@ from ..utils import ( | ||||
|     is_welford_reduction, | ||||
|     parallel_num_threads, | ||||
|     Placeholder, | ||||
|     set_kernel_post_grad_provenance_tracing, | ||||
|     sympy_index_symbol, | ||||
|     sympy_index_symbol_with_prefix, | ||||
|     sympy_product, | ||||
| @ -5150,6 +5151,10 @@ class CppScheduling(BaseScheduling): | ||||
|             else "" | ||||
|         ) | ||||
|         kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()]) | ||||
|         # below add provenance tracing info for cpu CppKernel types | ||||
|         if config.trace.enabled: | ||||
|             set_kernel_post_grad_provenance_tracing(nodes, kernel_name) | ||||
|  | ||||
|         kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel" | ||||
|         src_code = src_code.replace(str(Placeholder.KERNEL_NAME), kernel_decl_name) | ||||
|         src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name) | ||||
|  | ||||
		Reference in New Issue
	
	Block a user