mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 12:15:03 +08:00 
			
		
		
		
	Compare commits
	
		
			2 Commits
		
	
	
		
			ciflow/tru
			...
			flex_flash
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| d9007ea76c | |||
| 4bc383b405 | 
							
								
								
									
										318
									
								
								test/inductor/test_cutedsl_template.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										318
									
								
								test/inductor/test_cutedsl_template.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,318 @@ | ||||
| # Owner(s): ["module: inductor"] | ||||
| import unittest | ||||
| from unittest.mock import MagicMock, patch | ||||
|  | ||||
| import torch | ||||
| from torch._inductor.test_case import TestCase | ||||
|  | ||||
|  | ||||
| try: | ||||
|     import cutlass  # noqa: F401 | ||||
|     import cutlass.cute as cute  # noqa: F401 | ||||
|  | ||||
|     HAS_CUTLASS = True | ||||
| except ImportError: | ||||
|     HAS_CUTLASS = False | ||||
|  | ||||
| if HAS_CUTLASS: | ||||
|     from torch._inductor.codegen.cutedsl.cutedsl_kernel import CuteDSLTemplateKernel | ||||
|     from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate | ||||
|     from torch._inductor.select_algorithm import PartialRender | ||||
|  | ||||
| CUTEDSL_ADD_TEMPLATE = r""" | ||||
| {{gen_defines()}} | ||||
|  | ||||
| @cute.kernel | ||||
| def {{kernel_name}}_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor): | ||||
|     tidx, _, _ = cute.arch.thread_idx() | ||||
|     bidx, _, _ = cute.arch.block_idx() | ||||
|     bdim, _, _ = cute.arch.block_dim() | ||||
|  | ||||
|     thread_idx = bidx * bdim + tidx | ||||
|     m, n = gA.shape | ||||
|  | ||||
|     if thread_idx < m * n: | ||||
|         mi = thread_idx // n | ||||
|         ni = thread_idx % n | ||||
|  | ||||
|         if mi < m and ni < n: | ||||
|             gC[mi, ni] = gA[mi, ni] + gB[mi, ni] | ||||
|  | ||||
| @cute.jit | ||||
| def {{kernel_name}}_jit(mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor): | ||||
|     {{gen_defines()}} | ||||
|     m, n = mA.shape | ||||
|     total_threads = m * n | ||||
|     num_blocks = (total_threads + THREADS_PER_BLOCK - 1) // THREADS_PER_BLOCK | ||||
|  | ||||
|     kernel = {{kernel_name}}_kernel(mA, mB, mC) | ||||
|     kernel.launch( | ||||
|         grid=[num_blocks, 1, 1], | ||||
|         block=[THREADS_PER_BLOCK, 1, 1] | ||||
|     ) | ||||
|  | ||||
| {{def_kernel("input_a", "input_b", "output_c")}} | ||||
|     cute_a = from_dlpack(input_a) | ||||
|     cute_b = from_dlpack(input_b) | ||||
|     cute_c = from_dlpack(output_c) | ||||
|  | ||||
|     {{kernel_name}}_jit(cute_a, cute_b, cute_c) | ||||
|     return output_c | ||||
| """ | ||||
|  | ||||
|  | ||||
| @unittest.skipUnless(HAS_CUTLASS, "requires cutlass") | ||||
| class TestCuteDSLTemplate(TestCase): | ||||
|     """Test cases for CuteDSL template functionality.""" | ||||
|  | ||||
|     def test_gen_imports(self): | ||||
|         kernel = CuteDSLTemplateKernel( | ||||
|             kernel_name="test_kernel", | ||||
|             input_nodes=[], | ||||
|             output_node=None, | ||||
|         ) | ||||
|  | ||||
|         imports = kernel.gen_imports() | ||||
|  | ||||
|         self.assertIn("import torch", imports) | ||||
|         self.assertIn("import cutlass", imports) | ||||
|         self.assertIn("import cutlass.cute as cute", imports) | ||||
|         self.assertIn("from cutlass.cute.runtime import from_dlpack", imports) | ||||
|         self.assertIsInstance(imports, str) | ||||
|  | ||||
|         lines = imports.strip().split("\n") | ||||
|         self.assertEqual(len(lines), 4) | ||||
|  | ||||
|     def test_render_includes_imports(self): | ||||
|         template_source = """@cute.kernel | ||||
| def {{kernel_name}}_kernel(): | ||||
|     pass | ||||
|  | ||||
| {{def_kernel("input", "output")}} | ||||
|     return output""" | ||||
|  | ||||
|         mock_template = MagicMock() | ||||
|         mock_template.render = MagicMock(return_value=template_source) | ||||
|  | ||||
|         kernel = CuteDSLTemplateKernel( | ||||
|             kernel_name="test_kernel", | ||||
|             input_nodes=[], | ||||
|             output_node=None, | ||||
|         ) | ||||
|  | ||||
|         result = kernel.render(mock_template) | ||||
|         self.assertIsInstance(result, PartialRender) | ||||
|  | ||||
|         rendered_code = result._code | ||||
|  | ||||
|         # The imports might have leading whitespace, so strip it | ||||
|         rendered_code_stripped = rendered_code.lstrip() | ||||
|  | ||||
|         self.assertTrue( | ||||
|             rendered_code_stripped.startswith("import torch"), | ||||
|             f"Code should start with 'import torch', got: {rendered_code_stripped[:50]}", | ||||
|         ) | ||||
|         self.assertIn("import cutlass", rendered_code) | ||||
|         self.assertIn("import cutlass.cute as cute", rendered_code) | ||||
|         self.assertIn("from cutlass.cute.runtime import from_dlpack", rendered_code) | ||||
|         self.assertIn("@cute.kernel", rendered_code) | ||||
|  | ||||
|     def test_template_env_contains_hooks(self): | ||||
|         kernel = CuteDSLTemplateKernel( | ||||
|             kernel_name="test_kernel", | ||||
|             input_nodes=[], | ||||
|             output_node=None, | ||||
|         ) | ||||
|  | ||||
|         captured_env = {} | ||||
|  | ||||
|         def mock_render(**kwargs): | ||||
|             captured_env.update(kwargs) | ||||
|             return "rendered" | ||||
|  | ||||
|         mock_template = MagicMock() | ||||
|         mock_template.render = mock_render | ||||
|  | ||||
|         kernel.render(mock_template) | ||||
|  | ||||
|         self.assertIn("def_kernel", captured_env) | ||||
|         self.assertIn("kernel_name", captured_env) | ||||
|         self.assertTrue(callable(captured_env["def_kernel"])) | ||||
|  | ||||
|     def test_multiple_templates_unique_names(self): | ||||
|         # Clean registry first | ||||
|         test_name = f"unique_test_{id(self)}" | ||||
|         if test_name in CuteDSLTemplate.all_templates: | ||||
|             del CuteDSLTemplate.all_templates[test_name] | ||||
|  | ||||
|         _ = CuteDSLTemplate( | ||||
|             name=test_name, | ||||
|             source="template1", | ||||
|         ) | ||||
|  | ||||
|         with self.assertRaises(AssertionError): | ||||
|             _ = CuteDSLTemplate( | ||||
|                 name=test_name, | ||||
|                 source="template2", | ||||
|             ) | ||||
|  | ||||
|     def test_indented_buffer_usage(self): | ||||
|         kernel = CuteDSLTemplateKernel( | ||||
|             kernel_name="test_kernel", | ||||
|             input_nodes=[], | ||||
|             output_node=None, | ||||
|         ) | ||||
|  | ||||
|         imports = kernel.gen_imports() | ||||
|  | ||||
|         lines = imports.strip().split("\n") | ||||
|         for line in lines: | ||||
|             if line: | ||||
|                 self.assertFalse( | ||||
|                     line.startswith(" "), f"Line should not be indented: '{line}'" | ||||
|                 ) | ||||
|  | ||||
|     @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") | ||||
|     def test_cutedsl_add_e2e(self): | ||||
|         """End-to-end test with CuteDSL template including code generation verification.""" | ||||
|         from torch._inductor.ir import TensorBox | ||||
|         from torch._inductor.lowering import lowerings | ||||
|         from torch._inductor.utils import run_and_get_code | ||||
|  | ||||
|         template = CuteDSLTemplate( | ||||
|             name="test_add_e2e", | ||||
|             source=CUTEDSL_ADD_TEMPLATE, | ||||
|         ) | ||||
|  | ||||
|         def cutedsl_add_lowering(a: TensorBox, b: TensorBox) -> TensorBox: | ||||
|             choices = [] | ||||
|             error = template.maybe_append_choice( | ||||
|                 choices, | ||||
|                 input_nodes=[a, b], | ||||
|                 layout=a.get_layout(), | ||||
|                 THREADS_PER_BLOCK=256, | ||||
|             ) | ||||
|  | ||||
|             if error or not choices: | ||||
|                 default_lowering = lowerings[torch.ops.aten.add.Tensor] | ||||
|                 return default_lowering(a, b) | ||||
|  | ||||
|             # Use the single choice directly (no autotuning) | ||||
|             return choices[0].output_node() | ||||
|  | ||||
|         with patch.dict(lowerings, {torch.ops.aten.add.Tensor: cutedsl_add_lowering}): | ||||
|             # Test function | ||||
|             def test_add(x, y): | ||||
|                 return x + y | ||||
|  | ||||
|             device = "cuda" | ||||
|             x = torch.randn(128, 4, device=device, dtype=torch.float32) | ||||
|             y = torch.randn(128, 4, device=device, dtype=torch.float32) | ||||
|  | ||||
|             # Compile and get generated code | ||||
|             compiled_fn = torch.compile(test_add, backend="inductor") | ||||
|             result, (code,) = run_and_get_code(compiled_fn, x, y) | ||||
|  | ||||
|             # Verify CuteDSL code is present | ||||
|             self.assertIn( | ||||
|                 "cute", code.lower(), "CuteDSL code should be in generated code" | ||||
|             ) | ||||
|             # Verify parameter generation worked | ||||
|             self.assertIn( | ||||
|                 "THREADS_PER_BLOCK", code, "Parameter should be in generated code" | ||||
|             ) | ||||
|  | ||||
|             # Verify correctness | ||||
|             expected = x + y | ||||
|             self.assertTrue(torch.allclose(result, expected, atol=1e-5)) | ||||
|  | ||||
|     @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") | ||||
|     def test_cutedsl_add_e2e_autotune(self): | ||||
|         """E2E test with multiple CuteDSL template variants for autotuning.""" | ||||
|         from torch._inductor.ir import TensorBox | ||||
|         from torch._inductor.lowering import lowerings | ||||
|         from torch._inductor.select_algorithm import autotune_select_algorithm | ||||
|  | ||||
|         template = CuteDSLTemplate( | ||||
|             name="test_add_autotune", | ||||
|             source=CUTEDSL_ADD_TEMPLATE, | ||||
|         ) | ||||
|  | ||||
|         def cutedsl_add_lowering(a: TensorBox, b: TensorBox) -> TensorBox: | ||||
|             choices = [] | ||||
|  | ||||
|             # Add multiple variants with different thread counts for autotuning | ||||
|             thread_variants = [128, 256, 512] | ||||
|             for threads in thread_variants: | ||||
|                 error = template.maybe_append_choice( | ||||
|                     choices, | ||||
|                     input_nodes=[a, b], | ||||
|                     layout=a.get_layout(), | ||||
|                     THREADS_PER_BLOCK=threads, | ||||
|                 ) | ||||
|                 if error: | ||||
|                     # Skip this variant if it fails | ||||
|                     continue | ||||
|  | ||||
|             if not choices: | ||||
|                 default_lowering = lowerings[torch.ops.aten.add.Tensor] | ||||
|                 return default_lowering(a, b) | ||||
|  | ||||
|             # Use autotuning to select the best variant | ||||
|             return autotune_select_algorithm( | ||||
|                 "cutedsl_add_autotune", | ||||
|                 choices, | ||||
|                 [a, b], | ||||
|                 a.get_layout(), | ||||
|             ) | ||||
|  | ||||
|         with patch.dict(lowerings, {torch.ops.aten.add.Tensor: cutedsl_add_lowering}): | ||||
|             # Test function | ||||
|             def test_add(x, y): | ||||
|                 return x + y | ||||
|  | ||||
|             device = "cuda" | ||||
|             x = torch.randn(128, 128, device=device, dtype=torch.float32) | ||||
|             y = torch.randn(128, 128, device=device, dtype=torch.float32) | ||||
|  | ||||
|             # Compile and run | ||||
|             compiled_fn = torch.compile(test_add, backend="inductor") | ||||
|             result = compiled_fn(x, y) | ||||
|  | ||||
|             # Verify correctness | ||||
|             expected = x + y | ||||
|             self.assertTrue(torch.allclose(result, expected, atol=1e-5)) | ||||
|  | ||||
|     def test_gen_defines(self): | ||||
|         """Test that gen_defines correctly generates CuteDSL parameter definitions.""" | ||||
|         kernel = CuteDSLTemplateKernel( | ||||
|             kernel_name="test_kernel", | ||||
|             input_nodes=[], | ||||
|             output_node=None, | ||||
|         ) | ||||
|  | ||||
|         # Test integer parameters | ||||
|         params = kernel.gen_defines( | ||||
|             THREADS_PER_BLOCK=256, | ||||
|             BLOCK_SIZE=128, | ||||
|             ENABLE_FEATURE=True, | ||||
|         ) | ||||
|  | ||||
|         expected_lines = [ | ||||
|             "THREADS_PER_BLOCK: cutlass.Constexpr = 256", | ||||
|             "BLOCK_SIZE: cutlass.Constexpr = 128", | ||||
|             "ENABLE_FEATURE: cutlass.Constexpr = True", | ||||
|         ] | ||||
|  | ||||
|         for expected_line in expected_lines: | ||||
|             self.assertIn(expected_line, params) | ||||
|  | ||||
|         # Test float parameters | ||||
|         params_float = kernel.gen_defines(SCALE_FACTOR=1.5) | ||||
|         self.assertIn("SCALE_FACTOR: cutlass.Constexpr = 1.5", params_float) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     from torch._inductor.test_case import run_tests | ||||
|  | ||||
|     run_tests() | ||||
| @ -569,6 +569,45 @@ class AsyncCompile: | ||||
|             ) | ||||
|             return LambdaFuture(get_result) | ||||
|  | ||||
|     def cutedsl(self, kernel_name: str, source_code: str): | ||||
|         """ | ||||
|         Compile CuteDSL (CUTLASS Python DSL) kernels. | ||||
|  | ||||
|         Args: | ||||
|             kernel_name: Name of the kernel to be defined | ||||
|             source_code: Source code of the CuteDSL kernel, as a string | ||||
|  | ||||
|         Note: | ||||
|             CuteDSL currently requires source files to do its compilation, there we | ||||
|             use the PyCodeCache to write the source code to a file and load it. | ||||
|         """ | ||||
|         from torch._inductor.codegen.cutedsl.cutedsl_kernel import ( | ||||
|             CuteDSLKernelWrapper, | ||||
|             MAIN_SUFFIX, | ||||
|         ) | ||||
|  | ||||
|         kernel_code_log.info("CuteDSL Kernel:\n%s", source_code) | ||||
|  | ||||
|         def task(): | ||||
|             key, path = torch._inductor.codecache.PyCodeCache.write(source_code) | ||||
|             mod = torch._inductor.codecache.PyCodeCache.load_by_key_path(key, path) | ||||
|  | ||||
|             # Find our special entry point named function | ||||
|             main_func_name = f"{kernel_name}_{MAIN_SUFFIX}" | ||||
|             if not hasattr(mod, main_func_name): | ||||
|                 available = [name for name in dir(mod) if callable(getattr(mod, name))] | ||||
|                 raise RuntimeError( | ||||
|                     f"Could not find CuteDSL main kernel function '{main_func_name}'. Available callables: {available}" | ||||
|                 ) | ||||
|  | ||||
|             return CuteDSLKernelWrapper(getattr(mod, main_func_name), kernel_path=path) | ||||
|  | ||||
|         if get_compile_threads() <= 1: | ||||
|             return task() | ||||
|         else: | ||||
|             future = self.submit(task) | ||||
|             return LambdaFuture(lambda: future.result()) | ||||
|  | ||||
|     def wait(self, scope: dict[str, Any]) -> None: | ||||
|         if get_compile_threads() > 1: | ||||
|             with dynamo_timed( | ||||
|  | ||||
| @ -11,6 +11,7 @@ from ..scheduler import ( | ||||
|     SchedulerNode, | ||||
| ) | ||||
| from .cuda.cuda_cpp_scheduling import CUDACPPScheduling | ||||
| from .cutedsl.cutedsl_scheduling import CuteDSLScheduling | ||||
| from .rocm.rocm_cpp_scheduling import ROCmCPPScheduling | ||||
| from .triton import TritonScheduling | ||||
|  | ||||
| @ -44,6 +45,7 @@ class CUDACombinedScheduling(BaseScheduling): | ||||
|         self._triton_scheduling = TritonScheduling(scheduler) | ||||
|         self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler) | ||||
|         self._rocm_cpp_scheduling = ROCmCPPScheduling(scheduler) | ||||
|         self._cutedsl_scheduling = CuteDSLScheduling(scheduler) | ||||
|  | ||||
|     def get_backend_features(self, device: torch.device) -> OrderedSet[BackendFeature]: | ||||
|         return self._triton_scheduling.get_backend_features(device) | ||||
| @ -53,6 +55,8 @@ class CUDACombinedScheduling(BaseScheduling): | ||||
|             return self._cuda_cpp_scheduling | ||||
|         if self._rocm_cpp_scheduling.is_rocm_cpp_template(node): | ||||
|             return self._rocm_cpp_scheduling | ||||
|         if self._cutedsl_scheduling.is_cutedsl_template(node): | ||||
|             return self._cutedsl_scheduling | ||||
|         return self._triton_scheduling | ||||
|  | ||||
|     def can_fuse_vertical( | ||||
| @ -64,6 +68,11 @@ class CUDACombinedScheduling(BaseScheduling): | ||||
|             node1 | ||||
|         ) or self._cuda_cpp_scheduling.is_cuda_cpp_template(node2): | ||||
|             return False | ||||
|         # CuteDSL doesn't support vertical fusion currently | ||||
|         elif self._cutedsl_scheduling.is_cutedsl_template( | ||||
|             node1 | ||||
|         ) or self._cutedsl_scheduling.is_cutedsl_template(node2): | ||||
|             return False | ||||
|         return self._triton_scheduling.can_fuse_vertical(node1, node2) | ||||
|  | ||||
|     def can_fuse_horizontal( | ||||
| @ -74,6 +83,10 @@ class CUDACombinedScheduling(BaseScheduling): | ||||
|                 return self._cuda_cpp_scheduling.can_fuse_horizontal( | ||||
|                     node1, node2 | ||||
|                 )  # always False at the moment | ||||
|             if self._cutedsl_scheduling.is_cutedsl_template(node): | ||||
|                 return self._cutedsl_scheduling.can_fuse_horizontal( | ||||
|                     node1, node2 | ||||
|                 )  # always False at the moment | ||||
|         return self._triton_scheduling.can_fuse_horizontal(node1, node2) | ||||
|  | ||||
|     def group_fn( | ||||
| @ -98,6 +111,13 @@ class CUDACombinedScheduling(BaseScheduling): | ||||
|             return self._rocm_cpp_scheduling.codegen_template( | ||||
|                 template_node, epilogue_nodes, prologue_nodes | ||||
|             ) | ||||
|         elif self._cutedsl_scheduling.is_cutedsl_template(template_node): | ||||
|             # TODO remove this when we add epilogue support | ||||
|             assert not epilogue_nodes | ||||
|             assert not prologue_nodes | ||||
|             return self._cutedsl_scheduling.codegen_template( | ||||
|                 template_node, epilogue_nodes, prologue_nodes | ||||
|             ) | ||||
|         else: | ||||
|             return self._triton_scheduling.codegen_template( | ||||
|                 template_node, epilogue_nodes, prologue_nodes | ||||
|  | ||||
							
								
								
									
										101
									
								
								torch/_inductor/codegen/cutedsl/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										101
									
								
								torch/_inductor/codegen/cutedsl/README.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,101 @@ | ||||
| # CuteDSL Template System | ||||
|  | ||||
| ## Quick Start | ||||
|  | ||||
| Writing a CuteDSL template: | ||||
|  | ||||
| ```python | ||||
| from torch._inductor.codegen.cutedsl import CuteDSLTemplate | ||||
|  | ||||
| template_source = """ | ||||
| @cute.kernel | ||||
| def {{kernel_name}}_kernel(A, B, C): | ||||
|     # Your CUTLASS kernel logic here | ||||
|     pass | ||||
|  | ||||
| {{def_kernel("A", "B", "C")}} | ||||
|     # Call the kernel | ||||
|     {{kernel_name}}_kernel(A, B, C) | ||||
|     return C | ||||
| """ | ||||
|  | ||||
| my_template = CuteDSLTemplate( | ||||
|     name="my_gemm", | ||||
|     source=template_source, | ||||
| ) | ||||
| ``` | ||||
|  | ||||
| ## Architecture | ||||
|  | ||||
| - **[CuteDSLTemplate](cutedsl_template.py#L39)**: Template definition and registration. Generates ChoiceCallers for autotuning. | ||||
| - **[CuteDSLTemplateKernel](cutedsl_kernel.py#L61)**: Handles code generation, provides template hooks (`def_kernel`), manages args. | ||||
| - **[CuteDSLScheduling](cutedsl_scheduling.py#L28)**: Integrates with Inductor's scheduler, handles kernel compilation via [`async_compile.cutedsl()`](../../async_compile.py#L756). | ||||
| - **[CuteDSLTemplateBuffer](../../ir.py)**: IR node representing a CuteDSL template operation in the graph. | ||||
|  | ||||
| ### Compilation Process | ||||
|  | ||||
| CuteDSL requires source files for compilation (cannot compile from strings directly). The process: | ||||
|  | ||||
| 1. **[CuteDSLScheduling](cutedsl_scheduling.py#L59)** generates the kernel code string and calls [`async_compile.cutedsl()`](../../async_compile.py#L756) | ||||
| 2. **[async_compile.cutedsl()](../../async_compile.py#L756)** uses [`PyCodeCache.write()`](../../codecache.py) to write source to a temporary `.py` file | ||||
| 3. **[PyCodeCache](../../codecache.py)** loads the module from disk, enabling CUTLASS compilation | ||||
| 4. The compiled kernel is wrapped in **[CuteDSLKernelWrapper](cutedsl_kernel.py#L22)** to provide a `.run()` interface | ||||
| 5. The generated Python file is cached via PyCodeCache, but CUTLASS compilation runs every time (no kernel-level caching yet) | ||||
|  | ||||
| **Debug tip**: Use `TORCH_LOGS="kernel_code"` to see the generated kernel source and file path during compilation. | ||||
|  | ||||
| ## Writing Templates | ||||
|  | ||||
| Templates use Jinja2 syntax with these available hooks: | ||||
|  | ||||
| - `{{kernel_name}}` - Unique kernel identifier | ||||
| - `{{def_kernel(args...)}}` - Generates kernel function signature and argument handling | ||||
| - `{{input_nodes}}` - List of input buffers | ||||
| - `{{output_node}}` - Output buffer | ||||
| - `{{gen_defines()}}` - Generates autotunable parameter definitions with proper CuteDSL typing | ||||
|  | ||||
| ## Autotunable Parameters | ||||
|  | ||||
| CuteDSL templates support autotunable parameters similar to Triton's `tl.constexpr` system: | ||||
|  | ||||
| ```python | ||||
| template_source = r""" | ||||
| {{gen_defines()}} | ||||
|  | ||||
| @cute.kernel | ||||
| def {{kernel_name}}_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor): | ||||
|     threads_per_block = THREADS_PER_BLOCK  # Uses autotuned value | ||||
|     block_size = BLOCK_SIZE | ||||
|     # ... kernel implementation | ||||
| """ | ||||
|  | ||||
| # Pass parameters when generating template choices | ||||
| template.maybe_append_choice( | ||||
|     choices, | ||||
|     input_nodes=[a, b], | ||||
|     layout=layout, | ||||
|     THREADS_PER_BLOCK=256,    # cutlass.Constexpr = 256 | ||||
|     BLOCK_SIZE=128,           # cutlass.Constexpr = 128 | ||||
|     SCALE_FACTOR=1.5,         # cutlass.Constexpr = 1.5 | ||||
| ) | ||||
| ``` | ||||
|  | ||||
| Templates must: | ||||
| 1. Define a `@cute.kernel` decorated function | ||||
| 2. Use `{{def_kernel()}}` to create the entry point | ||||
| 3. Return the output tensor | ||||
| 4. Use `{{gen_defines()}}` for autotunable parameters | ||||
|  | ||||
| See [test_cutedsl_template.py](../../../../test/inductor/test_cutedsl_template.py) for complete examples. | ||||
|  | ||||
| ## Current Limitations / TODOs | ||||
|  | ||||
| - **No fusion support**: `can_fuse_vertical` and `can_fuse_horizontal` return False | ||||
| - **Subgraph management**: Bodies and masks not fully implemented | ||||
| - **File-based compilation**: Requires writing to disk (uses PyCodeCache) | ||||
| - **Missing epilogue/prologue**: No support for fused operations yet | ||||
| - **Fixed kernel suffix**: Uses hardcoded "_main" suffix | ||||
| - **No CUTLASS kernel caching**: Only PyCodeCache works; CUTLASS compilation runs every time (major perf issue) | ||||
|  | ||||
|  | ||||
| Note: Requires CUTLASS Python package (`pip install nvidia-cutlass`) | ||||
							
								
								
									
										8
									
								
								torch/_inductor/codegen/cutedsl/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								torch/_inductor/codegen/cutedsl/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,8 @@ | ||||
| # mypy: allow-untyped-defs | ||||
| from .cutedsl_template import CuteDSLTemplate, CuteDSLTemplateCaller | ||||
|  | ||||
|  | ||||
| __all__ = [ | ||||
|     "CuteDSLTemplate", | ||||
|     "CuteDSLTemplateCaller", | ||||
| ] | ||||
							
								
								
									
										228
									
								
								torch/_inductor/codegen/cutedsl/cutedsl_kernel.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										228
									
								
								torch/_inductor/codegen/cutedsl/cutedsl_kernel.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,228 @@ | ||||
| # mypy: allow-untyped-defs | ||||
| import contextlib | ||||
| import dataclasses | ||||
| import logging | ||||
| from typing import Any, Callable, Optional | ||||
|  | ||||
| import torch | ||||
| from torch._inductor.codegen.common import IndentedBuffer, Kernel | ||||
| from torch._inductor.ir import Buffer | ||||
| from torch._inductor.select_algorithm import PartialRender | ||||
| from torch._inductor.utils import OrderedSet | ||||
| from torch._inductor.virtualized import V | ||||
|  | ||||
|  | ||||
| # TODO setting the 'main' kernel w/ this suffix. We have 3 should probably just auto generate this | ||||
| MAIN_SUFFIX = "main" | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
| kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code") | ||||
|  | ||||
|  | ||||
| class CuteDSLKernelWrapper: | ||||
|     """Wrapper to provide .run() interface for CuteDSL kernels""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, kernel_fn: Callable[..., Any], kernel_path: Optional[str] = None | ||||
|     ): | ||||
|         self.kernel_fn = kernel_fn | ||||
|         self.kernel_path = kernel_path | ||||
|         kernel_code_log.info("CuteDSL kernel path: %s", kernel_path) | ||||
|  | ||||
|     def run(self, *args, stream=None, **kwargs): | ||||
|         """ | ||||
|         Execute the CuteDSL kernel. | ||||
|  | ||||
|         Args: | ||||
|             *args: Arguments to pass to the kernel function | ||||
|             stream: TODO: CUDA stream (handled internally by CuteDSL, so ignored) | ||||
|             **kwargs: Additional keyword arguments for the kernel | ||||
|  | ||||
|         Returns: | ||||
|             Result of the kernel execution | ||||
|         """ | ||||
|         return self.kernel_fn(*args, **kwargs) | ||||
|  | ||||
|  | ||||
| @dataclasses.dataclass | ||||
| class CuteDSLSubgraphInfo: | ||||
|     """Minimal subgraph info for CuteDSL kernels.""" | ||||
|  | ||||
|     body: IndentedBuffer | ||||
|     template_mask: Optional[str] = None | ||||
|     template_out: Optional[str] = None | ||||
|  | ||||
|     def to_dict(self): | ||||
|         return { | ||||
|             field.name: getattr(self, field.name) for field in dataclasses.fields(self) | ||||
|         } | ||||
|  | ||||
|  | ||||
| class CuteDSLTemplateKernel(Kernel): | ||||
|     """ | ||||
|     Template kernel implementation for CuteDSL (CUTLASS Python DSL). | ||||
|     Handles code generation and argument management for CuteDSL CUDA kernels. | ||||
|     Provides CuteDSL-specific functionality for tensor conversion and kernel configuration. | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         kernel_name: str, | ||||
|         input_nodes: list[Buffer], | ||||
|         output_node: Buffer, | ||||
|     ) -> None: | ||||
|         # Call parent Kernel constructor | ||||
|         super().__init__() | ||||
|         self.kernel_name = kernel_name | ||||
|         self.input_nodes = input_nodes | ||||
|         self.output_node = output_node | ||||
|  | ||||
|         # TODO Subgraph management for template processing | ||||
|         self.subgraph_bodies: dict[str, CuteDSLSubgraphInfo] = {} | ||||
|  | ||||
|         # Template attributes | ||||
|         self.body: IndentedBuffer = IndentedBuffer() | ||||
|         self.template_mask: Optional[str] = None | ||||
|         self.template_out: Optional[str] = None | ||||
|         self.template_indices: Optional[list[Any]] = None | ||||
|         self.render_hooks: dict[str, Any] = {} | ||||
|  | ||||
|         # TODO Additional attributes needed by template system | ||||
|         self.prologue_fused_inputs: OrderedSet[str] = OrderedSet() | ||||
|         self.prologue_fused_inputs_preserve_zero: OrderedSet[str] = OrderedSet() | ||||
|         self.named_input_nodes: dict[str, Buffer] = {} | ||||
|  | ||||
|         # Create named input nodes mapping | ||||
|         for i, input_node in enumerate(input_nodes): | ||||
|             node_name = getattr(input_node, "name", f"input_{i}") | ||||
|             self.named_input_nodes[node_name] = input_node | ||||
|  | ||||
|     def gen_imports(self) -> str: | ||||
|         """Generate common imports for CuteDSL templates.""" | ||||
|         imports = IndentedBuffer() | ||||
|         imports.splice( | ||||
|             """ | ||||
|             import torch | ||||
|             import cutlass | ||||
|             import cutlass.cute as cute | ||||
|             from cutlass.cute.runtime import from_dlpack | ||||
|             """ | ||||
|         ) | ||||
|         return imports.getvalue() | ||||
|  | ||||
|     def gen_defines(self, **kwargs) -> str: | ||||
|         """Generate CuteDSL parameter definitions from kwargs, similar to Triton's gen_defines.""" | ||||
|         params = IndentedBuffer() | ||||
|         for name, val in kwargs.items(): | ||||
|             params.writeline(f"{name}: cutlass.Constexpr = {val}") | ||||
|         return params.getvalue() | ||||
|  | ||||
|     def render(self, template, **kwargs): | ||||
|         """Render the kernel using the template, returning PartialRender object with hooks.""" | ||||
|         # Available {{}} hooks for jinja rendering | ||||
|         template_env = { | ||||
|             "def_kernel": self.def_kernel, | ||||
|             "gen_defines": lambda: self.gen_defines(**kwargs), | ||||
|         } | ||||
|  | ||||
|         # Render the template with the environment and provided kwargs | ||||
|         rendered_code = template.render( | ||||
|             kernel_name=self.kernel_name, | ||||
|             input_nodes=self.input_nodes, | ||||
|             output_node=self.output_node, | ||||
|             **template_env, | ||||
|             **kwargs, | ||||
|         ) | ||||
|  | ||||
|         # Always prepend the common imports | ||||
|         imports = self.gen_imports() | ||||
|         full_code = imports + rendered_code | ||||
|  | ||||
|         return PartialRender(full_code, self.render_hooks) | ||||
|  | ||||
|     def __enter__(self): | ||||
|         """TODO: Context manager entry - doesn't set anything yet""" | ||||
|         return self | ||||
|  | ||||
|     def __exit__(self, exc_type, exc_val, exc_tb): | ||||
|         """TODO: Context manager exit - doesn't set anything yet""" | ||||
|  | ||||
|     @contextlib.contextmanager | ||||
|     def set_subgraph_body(self, body_name: str): | ||||
|         """Set the active subgraph body for template processing.""" | ||||
|         assert all( | ||||
|             hasattr(self, field.name) | ||||
|             for field in dataclasses.fields(CuteDSLSubgraphInfo) | ||||
|         ) | ||||
|         old_state = { | ||||
|             key.name: getattr(self, key.name) | ||||
|             for key in dataclasses.fields(CuteDSLSubgraphInfo) | ||||
|         } | ||||
|  | ||||
|         # Auto-create subgraph if it doesn't exist (for kernels without epilogue fusion) | ||||
|         if body_name not in self.subgraph_bodies: | ||||
|             self.subgraph_bodies[body_name] = CuteDSLSubgraphInfo( | ||||
|                 body=IndentedBuffer(), | ||||
|                 template_mask=None, | ||||
|                 template_out=None, | ||||
|             ) | ||||
|  | ||||
|         subgraph = self.subgraph_bodies[body_name] | ||||
|         for key, value in subgraph.to_dict().items(): | ||||
|             setattr(self, key, value) | ||||
|  | ||||
|         try: | ||||
|             yield | ||||
|         finally: | ||||
|             # Save current state back to subgraph | ||||
|             self.subgraph_bodies[body_name] = CuteDSLSubgraphInfo( | ||||
|                 **{ | ||||
|                     key.name: getattr(self, key.name) | ||||
|                     for key in dataclasses.fields(CuteDSLSubgraphInfo) | ||||
|                 } | ||||
|             ) | ||||
|             # Restore old state | ||||
|             for key, value in old_state.items(): | ||||
|                 setattr(self, key, value) | ||||
|  | ||||
|     @contextlib.contextmanager | ||||
|     def create_subgraph_body(self, body_name: str): | ||||
|         """Create a new subgraph body for template processing.""" | ||||
|         assert body_name not in self.subgraph_bodies, ( | ||||
|             f"Subgraph body '{body_name}' already exists" | ||||
|         ) | ||||
|         self.subgraph_bodies[body_name] = CuteDSLSubgraphInfo( | ||||
|             body=IndentedBuffer(), | ||||
|             template_mask=None, | ||||
|             template_out=None, | ||||
|         ) | ||||
|         with self.set_subgraph_body(body_name): | ||||
|             yield | ||||
|  | ||||
|     def def_kernel(self, *argnames): | ||||
|         """Define kernel function signature for CuteDSL templates.""" | ||||
|         # Populate all the kernel args | ||||
|         for i, input_node in enumerate(self.input_nodes): | ||||
|             self.args.input(input_node.get_name()) | ||||
|  | ||||
|         if self.output_node: | ||||
|             self.args.output(self.output_node.get_name()) | ||||
|  | ||||
|         def hook(): | ||||
|             code = IndentedBuffer() | ||||
|             code.writeline(f"# Kernel function signature: {self.kernel_name}") | ||||
|             code.writeline( | ||||
|                 f"def {self.kernel_name}_{MAIN_SUFFIX}({', '.join(argnames)}):" | ||||
|             ) | ||||
|             return code.getvalue() | ||||
|  | ||||
|         assert "<DEF_KERNEL>" not in self.render_hooks | ||||
|         self.render_hooks["<DEF_KERNEL>"] = hook | ||||
|         return "<DEF_KERNEL>" | ||||
|  | ||||
|     def call_kernel(self, name: str, node=None): | ||||
|         """Call the kernel function. Simplified version of TritonTemplateKernel.call_kernel.""" | ||||
|         wrapper = V.graph.wrapper_code | ||||
|         _, call_args, _, arg_types = self.args.python_argdefs() | ||||
|         # TODO triton should really be swapped w/ `python` | ||||
|         wrapper.generate_kernel_call(name, call_args, triton=True, arg_types=arg_types) | ||||
							
								
								
									
										141
									
								
								torch/_inductor/codegen/cutedsl/cutedsl_scheduling.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										141
									
								
								torch/_inductor/codegen/cutedsl/cutedsl_scheduling.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,141 @@ | ||||
| # mypy: allow-untyped-defs | ||||
| import hashlib | ||||
| import logging | ||||
| from collections.abc import Sequence | ||||
| from typing import cast | ||||
|  | ||||
| from torch._inductor.utils import Placeholder | ||||
| from torch.utils._ordered_set import OrderedSet | ||||
|  | ||||
| from ... import config | ||||
| from ...codecache import code_hash, get_path | ||||
| from ...ir import CuteDSLTemplateBuffer | ||||
| from ...scheduler import ( | ||||
|     BaseSchedulerNode, | ||||
|     BaseScheduling, | ||||
|     FusedSchedulerNode, | ||||
|     SchedulerNode, | ||||
| ) | ||||
| from ...select_algorithm import PartialRender | ||||
| from ...utils import get_fused_kernel_name, get_kernel_metadata | ||||
| from ...virtualized import V | ||||
| from ..common import BackendFeature, IndentedBuffer | ||||
|  | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| class CuteDSLScheduling(BaseScheduling): | ||||
|     """ | ||||
|     Scheduling implementation for CuteDSL (CUTLASS Python DSL) kernels. | ||||
|     This class is intended to be used in combination with other schedulers, | ||||
|     and delegated to by CUDACombinedScheduling. | ||||
|     """ | ||||
|  | ||||
|     @classmethod | ||||
|     def get_backend_features(cls, device) -> OrderedSet[BackendFeature]: | ||||
|         return OrderedSet() | ||||
|  | ||||
|     @staticmethod | ||||
|     def is_cutedsl_template(node: BaseSchedulerNode) -> bool: | ||||
|         """Check if a node is a CuteDSL template.""" | ||||
|         return isinstance(node, SchedulerNode) and isinstance( | ||||
|             node.node, CuteDSLTemplateBuffer | ||||
|         ) | ||||
|  | ||||
|     def is_cutedsl_fused_template(self, node: BaseSchedulerNode) -> bool: | ||||
|         """Check if a node is a fused CuteDSL template.""" | ||||
|         return isinstance(node, FusedSchedulerNode) and self.is_cutedsl_template(node) | ||||
|  | ||||
|     def can_fuse_vertical( | ||||
|         self, node1: BaseSchedulerNode, node2: BaseSchedulerNode | ||||
|     ) -> bool: | ||||
|         """ | ||||
|         TODO CuteDSL doesn't support vertical fusion yet. | ||||
|         This could be extended in the future for epilogue fusion. | ||||
|         """ | ||||
|         return False | ||||
|  | ||||
|     def define_kernel(self, src_code_str: str, node_schedule) -> str: | ||||
|         """Produce the kernel string | ||||
|         Args: | ||||
|             src_code_str: The finalized kernel code string | ||||
|             node_schedule: List of nodes in the schedule | ||||
|  | ||||
|         Note: | ||||
|             This is a little weird since async_compile.cutedsl() has to write the string to | ||||
|             a file in order to cute compile it. Feels bad to have two... | ||||
|         """ | ||||
|         wrapper = V.graph.wrapper_code | ||||
|  | ||||
|         # Use the string as the key for caching | ||||
|         if src_code_str in wrapper.src_to_kernel: | ||||
|             kernel_name = wrapper.src_to_kernel[src_code_str] | ||||
|         else: | ||||
|             fused_name = ( | ||||
|                 get_fused_kernel_name(node_schedule, config.triton.descriptive_names) | ||||
|                 if config.triton.descriptive_names | ||||
|                 else "" | ||||
|             ) | ||||
|  | ||||
|             kernel_hash = hashlib.sha256(src_code_str.encode("utf-8")).hexdigest()[:8] | ||||
|             if fused_name == "fused": | ||||
|                 kernel_name = f"cutedsl_{kernel_hash}" | ||||
|             else: | ||||
|                 kernel_name = f"cutedsl_{fused_name}_{kernel_hash}" | ||||
|             wrapper.src_to_kernel[src_code_str] = kernel_name | ||||
|             src_code_str = src_code_str.replace( | ||||
|                 str(Placeholder.KERNEL_NAME), kernel_name | ||||
|             ) | ||||
|  | ||||
|             _, _, kernel_path = get_path(code_hash(src_code_str), "py") | ||||
|  | ||||
|             compile_wrapper = IndentedBuffer() | ||||
|             compile_wrapper.writeline(f"async_compile.cutedsl({kernel_name!r}, r'''") | ||||
|             compile_wrapper.splice(src_code_str, strip=True) | ||||
|             compile_wrapper.writeline("''')") | ||||
|  | ||||
|             metadata_comment = f"# kernel path: {kernel_path}" | ||||
|             origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) | ||||
|             metadata_comment += "\n" + origins + "\n" + detailed_origins | ||||
|             wrapper.define_kernel( | ||||
|                 kernel_name, compile_wrapper.getvalue(), metadata_comment | ||||
|             ) | ||||
|         return kernel_name | ||||
|  | ||||
|     def codegen_template( | ||||
|         self, | ||||
|         template_node: BaseSchedulerNode, | ||||
|         epilogue_nodes: Sequence[BaseSchedulerNode], | ||||
|         prologue_nodes: Sequence[BaseSchedulerNode], | ||||
|     ): | ||||
|         """ | ||||
|         Codegen a CuteDSL template. Currently doesn't support fusion. | ||||
|         """ | ||||
|         assert self.is_cutedsl_template(template_node), ( | ||||
|             "Template node passed to CuteDSLScheduling.codegen_template must be a " | ||||
|             "SchedulerNode that wraps a CuteDSLTemplateBuffer" | ||||
|         ) | ||||
|         # TODO remove when supported | ||||
|         assert not epilogue_nodes, "CuteDSL doesn't support epilogue fusion yet" | ||||
|         assert not prologue_nodes, "CuteDSL doesn't support prologue fusion yet" | ||||
|  | ||||
|         template_node = cast(SchedulerNode, template_node) | ||||
|         ctb: CuteDSLTemplateBuffer = cast(CuteDSLTemplateBuffer, template_node.node) | ||||
|  | ||||
|         kernel, render = ctb.make_kernel_render(ctb)  # type: ignore[misc] | ||||
|         with kernel: | ||||
|             template_node.mark_run() | ||||
|             src_code = render() | ||||
|             # Finalize PartialRender if needed | ||||
|             if isinstance(src_code, PartialRender): | ||||
|                 src_code_str = src_code.finalize_all() | ||||
|             else: | ||||
|                 src_code_str = src_code | ||||
|  | ||||
|         with V.set_kernel_handler(kernel): | ||||
|             node_schedule = [template_node] | ||||
|             kernel_name = self.define_kernel(src_code_str, node_schedule) | ||||
|         kernel.call_kernel(kernel_name, ctb) | ||||
|         V.graph.removed_buffers |= kernel.removed_buffers | ||||
|         self.free_buffers_in_scheduler() | ||||
							
								
								
									
										228
									
								
								torch/_inductor/codegen/cutedsl/cutedsl_template.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										228
									
								
								torch/_inductor/codegen/cutedsl/cutedsl_template.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,228 @@ | ||||
| # mypy: allow-untyped-defs | ||||
| import functools | ||||
| import itertools | ||||
| from typing import Any, Callable, Optional, Union | ||||
|  | ||||
| import torch | ||||
| from torch._inductor.codecache import PyCodeCache | ||||
| from torch._inductor.ir import ShapeAsConstantBuffer | ||||
| from torch._inductor.select_algorithm import PartialRender | ||||
| from torch._inductor.utils import Placeholder | ||||
| from torch._logging import getArtifactLogger | ||||
|  | ||||
| from ...autotune_process import BenchmarkRequest, GPUDeviceBenchmarkMixin, TensorMeta | ||||
| from ...ir import Buffer, ChoiceCaller, CuteDSLTemplateBuffer, Layout, TensorBox | ||||
| from ..common import KernelTemplate | ||||
| from .cutedsl_kernel import CuteDSLTemplateKernel | ||||
|  | ||||
|  | ||||
| log = getArtifactLogger(__name__, "output_code") | ||||
|  | ||||
|  | ||||
| class CuteDSLBenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest): | ||||
|     """Benchmark request for CuteDSL (CUTLASS Python DSL) kernels.""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         kernel_name: str, | ||||
|         input_tensor_meta: Union[TensorMeta, list[TensorMeta]], | ||||
|         output_tensor_meta: Union[TensorMeta, list[TensorMeta]], | ||||
|         extra_args: tuple[Any, ...], | ||||
|         source_code: PartialRender, | ||||
|     ) -> None: | ||||
|         super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) | ||||
|  | ||||
|         finalized_code = source_code.finalize_all() | ||||
|         self.module_cache_key, self.module_path = PyCodeCache.write(finalized_code) | ||||
|  | ||||
|     def make_run_fn( | ||||
|         self, *input_tensors: torch.Tensor, out: torch.Tensor | ||||
|     ) -> Callable[[], None]: | ||||
|         """ | ||||
|         Create a function to run the CuteDSL kernel with the given input and output tensors. | ||||
|         Similar to TritonBenchmarkRequest.make_run_fn but for CuteDSL kernels. | ||||
|         """ | ||||
|         mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path) | ||||
|  | ||||
|         # Logic replicated async_compile | ||||
|         from .cutedsl_kernel import MAIN_SUFFIX | ||||
|  | ||||
|         main_func_name = f"{self.kernel_name}_{MAIN_SUFFIX}" | ||||
|  | ||||
|         if not hasattr(mod, main_func_name): | ||||
|             available = [name for name in dir(mod) if callable(getattr(mod, name))] | ||||
|             raise RuntimeError( | ||||
|                 f"Could not find CuteDSL main kernel function '{main_func_name}'. Available callables: {available}" | ||||
|             ) | ||||
|  | ||||
|         kernel_func = getattr(mod, main_func_name) | ||||
|  | ||||
|         def run_kernel(): | ||||
|             return kernel_func(*input_tensors, out) | ||||
|  | ||||
|         return run_kernel | ||||
|  | ||||
|     def cleanup_run_fn(self) -> None: | ||||
|         """Clean up any resources used by the kernel.""" | ||||
|  | ||||
|  | ||||
| class CuteDSLTemplate(KernelTemplate): | ||||
|     """Template for generating CuteDSL (CUTLASS Python DSL) kernels.""" | ||||
|  | ||||
|     kernel_type: type[Any] = CuteDSLTemplateKernel | ||||
|     index_counter = itertools.count() | ||||
|     all_templates: dict[str, "CuteDSLTemplate"] = {} | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         name: str, | ||||
|         source: str, | ||||
|         subgraph_fn: Optional[Any] = None, | ||||
|         mask_fn: Optional[Any] = None, | ||||
|     ) -> None: | ||||
|         super().__init__(name) | ||||
|         self.source = source | ||||
|         self.subgraph_fn = subgraph_fn | ||||
|         self.mask_fn = mask_fn | ||||
|         self.template = CuteDSLTemplate._template_from_string(source) | ||||
|         assert name not in self.all_templates, f"duplicate template name, {name}" | ||||
|         CuteDSLTemplate.all_templates[name] = self | ||||
|  | ||||
|     @staticmethod | ||||
|     @functools.lru_cache(None) | ||||
|     def _template_from_string(source: str) -> Any: | ||||
|         return KernelTemplate._template_from_string(source) | ||||
|  | ||||
|     def maybe_append_choice( | ||||
|         self, choices: list[Any], **kwargs: Any | ||||
|     ) -> Optional[NotImplementedError]: | ||||
|         """ | ||||
|         Maybe generates a new ChoiceCaller and appends it into existing choices. | ||||
|         Returns None if success, otherwise returns the error. | ||||
|         """ | ||||
|         try: | ||||
|             choices.append(self.generate(**kwargs)) | ||||
|             return None | ||||
|         except NotImplementedError as e: | ||||
|             log.debug("CuteDSL template choice generation failed: %s", e) | ||||
|             return e | ||||
|         except Exception as e: | ||||
|             log.debug("CuteDSL template choice generation error: %s", e) | ||||
|             return NotImplementedError(f"CuteDSL template failed: {e}") | ||||
|  | ||||
|     def generate(self, **kwargs: Any) -> ChoiceCaller: | ||||
|         """Generate the CuteDSL kernel caller.""" | ||||
|         input_nodes = kwargs.pop("input_nodes") | ||||
|         layout = kwargs.pop("layout") | ||||
|  | ||||
|         kernel_name = f"cutedsl_{self.name}_{next(self.index_counter)}" | ||||
|  | ||||
|         if self.template is None: | ||||
|             raise RuntimeError("Template compilation failed (Jinja2 required)") | ||||
|  | ||||
|         self.output_node: Buffer = Buffer(name="buf_out", layout=layout) | ||||
|  | ||||
|         kernel = self.kernel_type( | ||||
|             kernel_name=kernel_name, | ||||
|             input_nodes=input_nodes, | ||||
|             output_node=self.output_node, | ||||
|         ) | ||||
|  | ||||
|         code = kernel.render(self.template, **kwargs) | ||||
|  | ||||
|         log.debug("Generated CuteDSL Code:\n%s", code) | ||||
|  | ||||
|         bmreq = CuteDSLBenchmarkRequest( | ||||
|             kernel_name=kernel_name, | ||||
|             input_tensor_meta=TensorMeta.from_irnodes(input_nodes), | ||||
|             output_tensor_meta=TensorMeta.from_irnodes(self.output_node), | ||||
|             extra_args=tuple(), | ||||
|             source_code=code, | ||||
|         ) | ||||
|  | ||||
|         def make_kernel_render(out_node, hint_override: Optional[int] = None): | ||||
|             render_kernel = self.kernel_type( | ||||
|                 kernel_name=str(Placeholder.KERNEL_NAME), | ||||
|                 input_nodes=input_nodes, | ||||
|                 output_node=out_node, | ||||
|             ) | ||||
|  | ||||
|             def render(): | ||||
|                 return render_kernel.render(self.template, **kwargs) | ||||
|  | ||||
|             return render_kernel, render | ||||
|  | ||||
|         return CuteDSLTemplateCaller( | ||||
|             name=kernel_name, | ||||
|             input_nodes=input_nodes, | ||||
|             layout=layout, | ||||
|             make_kernel_render=make_kernel_render, | ||||
|             bmreq=bmreq, | ||||
|             template=self, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class CuteDSLTemplateCaller(ChoiceCaller): | ||||
|     """Caller for CuteDSL templates that integrates with the autotuning system.""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         name: str, | ||||
|         input_nodes: list[Buffer], | ||||
|         layout: Layout, | ||||
|         make_kernel_render: Any, | ||||
|         bmreq: CuteDSLBenchmarkRequest, | ||||
|         template: "CuteDSLTemplate", | ||||
|     ): | ||||
|         super().__init__( | ||||
|             name=name, | ||||
|             input_nodes=input_nodes, | ||||
|             layout=layout, | ||||
|             description=f"CuteDSL template {name}", | ||||
|         ) | ||||
|         self.make_kernel_render = make_kernel_render | ||||
|         self.bmreq = bmreq | ||||
|         self.template = template | ||||
|  | ||||
|     def __str__(self) -> str: | ||||
|         return f"CuteDSLTemplateCaller({self.name})" | ||||
|  | ||||
|     def benchmark(self, *args, out) -> float: | ||||
|         """Benchmark the kernel execution.""" | ||||
|         return self.bmreq.benchmark(*args, out=out) | ||||
|  | ||||
|     def output_node(self) -> Union[TensorBox, ShapeAsConstantBuffer]: | ||||
|         """Create the output node for this template choice.""" | ||||
|         return TensorBox.create( | ||||
|             CuteDSLTemplateBuffer( | ||||
|                 layout=self.layout, | ||||
|                 inputs=self.input_nodes, | ||||
|                 make_kernel_render=self.make_kernel_render, | ||||
|                 template=self.template, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     def call_name(self) -> str: | ||||
|         """Return the kernel call name.""" | ||||
|         return self.name | ||||
|  | ||||
|     def to_callable(self) -> Any: | ||||
|         """Return callable that can execute this kernel.""" | ||||
|         return self.make_kernel_render | ||||
|  | ||||
|     def hash_key(self) -> str: | ||||
|         """Return unique hash key for this choice.""" | ||||
|         return "-".join( | ||||
|             [ | ||||
|                 self.name.rsplit("_", 1)[0], | ||||
|                 self.bmreq.module_cache_key, | ||||
|             ] | ||||
|         ) | ||||
|  | ||||
|     def info_dict(self) -> dict[str, Any]: | ||||
|         """Return information about this kernel.""" | ||||
|         return { | ||||
|             "name": self.name, | ||||
|             "backend": "CuteDSL", | ||||
|             "template": self.template.name, | ||||
|         } | ||||
| @ -5094,6 +5094,23 @@ class CppTemplateBuffer(TemplateBuffer): | ||||
|             return super().get_layout() | ||||
|  | ||||
|  | ||||
| class CuteDSLTemplateBuffer(TemplateBuffer): | ||||
|     """ | ||||
|     Buffer for CuteDSL (CUTLASS Python DSL) template kernels. | ||||
|     Similar to other template buffers but specialized for CuteDSL operations. | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         layout: Layout, | ||||
|         inputs: Sequence[IRNode], | ||||
|         make_kernel_render: Callable[_P, _T], | ||||
|         template: Any, | ||||
|     ) -> None: | ||||
|         super().__init__(layout, inputs, make_kernel_render) | ||||
|         self.template = template | ||||
|  | ||||
|  | ||||
| def is_node_sequence( | ||||
|     nodes: Sequence[Union[IRNode, Sequence[IRNode]]], | ||||
| ) -> TypeIs[Sequence[IRNode]]: | ||||
|  | ||||
| @ -39,6 +39,10 @@ from .common import ( | ||||
| ) | ||||
| from .flex_cpu import lower_cpu | ||||
| from .flex_decoding import _use_flex_decoding, create_flex_decoding_kernel | ||||
| from .flex_flash_attention import ( | ||||
|     _use_flex_flash_attention, | ||||
|     create_flex_flash_attention_kernel, | ||||
| ) | ||||
|  | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
| @ -437,6 +441,19 @@ def flex_attention( | ||||
|             score_mod_other_buffers, | ||||
|             mask_mod_other_buffers, | ||||
|         ) | ||||
|     if _use_flex_flash_attention(subgraph, mask_graph, kernel_options): | ||||
|         return create_flex_flash_attention_kernel( | ||||
|             query, | ||||
|             key, | ||||
|             value, | ||||
|             block_mask, | ||||
|             scale, | ||||
|             kernel_options, | ||||
|             subgraph_buffer, | ||||
|             mask_graph_buffer, | ||||
|             score_mod_other_buffers, | ||||
|             mask_mod_other_buffers, | ||||
|         ) | ||||
|  | ||||
|     ( | ||||
|         query, | ||||
|  | ||||
							
								
								
									
										126
									
								
								torch/_inductor/kernel/flex/flex_flash_attention.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										126
									
								
								torch/_inductor/kernel/flex/flex_flash_attention.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,126 @@ | ||||
| # mypy: allow-untyped-defs | ||||
| """Call into flash-attention 4 for flexattention""" | ||||
|  | ||||
| from typing import Any | ||||
|  | ||||
| import torch | ||||
| from torch.fx import GraphModule | ||||
|  | ||||
| from ...ir import FallbackKernel, ShapeAsConstantBuffer, Subgraph, TensorBox | ||||
| from .common import SubgraphResults | ||||
|  | ||||
|  | ||||
| aten = torch.ops.aten | ||||
| prims = torch.ops.prims | ||||
|  | ||||
| try: | ||||
|     from flash_attn.cute import flash_attn_func  # type: ignore[import-not-found] | ||||
|  | ||||
|     CUTE_AVAILABLE = True | ||||
| except ImportError: | ||||
|     flash_attn_func = None | ||||
|     CUTE_AVAILABLE = False | ||||
|  | ||||
|  | ||||
| def is_trivial_graph(graph_module: GraphModule, is_score_graph: bool): | ||||
|     """Check if the flex graphs are trivial""" | ||||
|     graph = graph_module.graph | ||||
|     nodes = list(graph.nodes) | ||||
|     # Check if it's just placeholder -> output | ||||
|     placeholders = [n for n in nodes if n.op == "placeholder"] | ||||
|     output = [n for n in nodes if n.op == "output"] | ||||
|     assert len(output) == 1, "Got graph w/ multiple outputs" | ||||
|     output_val = output[0].args[0] | ||||
|     if is_score_graph: | ||||
|         return len(placeholders) == 5 and output_val == placeholders[0] | ||||
|     # mask mod graph is empty if we have 4 inputs and full_default output | ||||
|     return len(placeholders) == 4 and output_val.target == torch.ops.aten.full.default | ||||
|  | ||||
|  | ||||
| def _use_flex_flash_attention( | ||||
|     subgraph: Subgraph, mask_graph: Subgraph, kernel_options: dict[str, Any] | ||||
| ) -> bool: | ||||
|     """Determine if we can use flex flash attention for the given inputs.""" | ||||
|     if not CUTE_AVAILABLE: | ||||
|         return False | ||||
|     if kernel_options.get("disable_flash", False): | ||||
|         return False | ||||
|     if is_trivial_graph(subgraph.graph_module, True) and is_trivial_graph( | ||||
|         mask_graph.graph_module, False | ||||
|     ): | ||||
|         return True | ||||
|  | ||||
|     return False | ||||
|  | ||||
|  | ||||
| @torch.library.custom_op("flex_flash_attn::flash_attn_fwd", mutates_args=()) | ||||
| def flash_attention_forward_kernel( | ||||
|     query: torch.Tensor, | ||||
|     key: torch.Tensor, | ||||
|     value: torch.Tensor, | ||||
|     scale: float, | ||||
|     causal: bool = False, | ||||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||||
|     """Minimal flash attention forward kernel using CUTE implementation.""" | ||||
|     if not CUTE_AVAILABLE: | ||||
|         raise RuntimeError("CUTE flash attention not available") | ||||
|     assert flash_attn_func is not None | ||||
|  | ||||
|     q_transposed = query.transpose(1, 2) | ||||
|     k_transposed = key.transpose(1, 2) | ||||
|     v_transposed = value.transpose(1, 2) | ||||
|  | ||||
|     output, lse = flash_attn_func( | ||||
|         q_transposed, | ||||
|         k_transposed, | ||||
|         v_transposed, | ||||
|         softmax_scale=scale, | ||||
|         causal=causal, | ||||
|     ) | ||||
|  | ||||
|     return output.transpose(1, 2), lse | ||||
|  | ||||
|  | ||||
| @torch.library.register_fake("flex_flash_attn::flash_attn_fwd")  # type: ignore[misc] | ||||
| def flex_flash_attn_fwd_fake( | ||||
|     query: torch.Tensor, | ||||
|     key: torch.Tensor, | ||||
|     value: torch.Tensor, | ||||
|     scale: float, | ||||
|     causal: bool = False, | ||||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||||
|     """Fake implementation for the custom op.""" | ||||
|     batch_size, num_heads, seqlen_q, head_dim = query.shape | ||||
|  | ||||
|     out = query.new_empty(batch_size, seqlen_q, num_heads, head_dim).transpose(1, 2) | ||||
|     lse = query.new_empty(batch_size, num_heads, seqlen_q, dtype=torch.float32) | ||||
|  | ||||
|     return out, lse | ||||
|  | ||||
|  | ||||
| def create_flex_flash_attention_kernel( | ||||
|     query: TensorBox, | ||||
|     key: TensorBox, | ||||
|     value: TensorBox, | ||||
|     block_mask: tuple[Any, ...], | ||||
|     scale: float, | ||||
|     kernel_options: dict[str, Any], | ||||
|     subgraph_buffer: SubgraphResults, | ||||
|     mask_graph_buffer: SubgraphResults, | ||||
|     score_mod_other_buffers: list[TensorBox], | ||||
|     mask_mod_other_buffers: list[TensorBox], | ||||
| ) -> tuple[TensorBox | ShapeAsConstantBuffer, TensorBox | ShapeAsConstantBuffer]: | ||||
|     """Create a flex flash attention kernel.""" | ||||
|     if not CUTE_AVAILABLE: | ||||
|         raise RuntimeError("CUTE flash attention not available") | ||||
|  | ||||
|     outputs = FallbackKernel.create( | ||||
|         torch.ops.flex_flash_attn.flash_attn_fwd.default, | ||||
|         query, | ||||
|         key, | ||||
|         value, | ||||
|         scale=scale, | ||||
|         causal=False, | ||||
|     ) | ||||
|     assert isinstance(outputs, (tuple, list)) | ||||
|     return TensorBox.create(outputs[0]), TensorBox.create(outputs[1]) | ||||
| @ -198,6 +198,9 @@ class FlexKernelOptions(TypedDict, total=False): | ||||
|     waves_per_eu: NotRequired[int] | ||||
|     """ROCm-specific waves per execution unit.""" | ||||
|  | ||||
|     disable_flash: NotRequired[bool] | ||||
|     """ If True, we will not attempt to run the cute-dsl flash attention kernel""" | ||||
|  | ||||
|  | ||||
| class _ModificationType(Enum): | ||||
|     """Enum for the type of modification function. | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	