Add cutedsl template support to compile (#160108)

## Summary
Still figuring out what actually writing a template should look like, but lands alot of the base infra

<img width="1267" height="262" alt="Screenshot 2025-08-16 at 10 22 12 PM" src="https://github.com/user-attachments/assets/229f8bfa-0cb4-4fb1-8530-f535e569d350" />

Test code:

```Python
#!/usr/bin/env python3
"""
Fixed CuteDSL template test with proper def_kernel usage.
"""

import torch
import torch._inductor.config as config
from torch._inductor.lowering import lowerings
from torch._inductor.ir import TensorBox
from torch._inductor.select_algorithm import autotune_select_algorithm
from torch._inductor.codegen.cutedsl import CuteDSLTemplate

def create_fixed_cutedsl_template():
    """Create a properly structured CuteDSL template."""

    def cutedsl_grid(M, N, meta):
        return (1,)

    # Part 1: Imports and kernel definition
    template_part1 = r"""
import torch
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack

@cute.kernel
def {{kernel_name}}_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor):
    # Get thread and block indices
    tidx, _, _ = cute.arch.thread_idx()
    bidx, _, _ = cute.arch.block_idx()
    bdim, _, _ = cute.arch.block_dim()

    thread_idx = bidx * bdim + tidx
    m, n = gA.shape

    if thread_idx < m * n:
        mi = thread_idx // n
        ni = thread_idx % n

        if mi < m and ni < n:
            a_val = gA[mi, ni]
            b_val = gB[mi, ni]
            result = a_val + b_val
            gC[mi, ni] = a_val + b_val
"""

    # Part 2: JIT wrapper function
    template_part2 = r"""
@cute.jit
def {{kernel_name}}_jit(mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor):
    m, n = mA.shape
    total_threads = m * n
    threads_per_block = 256
    num_blocks = (total_threads + threads_per_block - 1) // threads_per_block

    kernel = {{kernel_name}}_kernel(mA, mB, mC)
    kernel.launch(
        grid=[num_blocks, 1, 1],
        block=[threads_per_block, 1, 1]
    )
"""

    # Part 3: Main kernel function
    template_part3 = r"""
{{def_kernel("input_a", "input_b", "output_c")}}
    cute_a = from_dlpack(input_a, assumed_align=16)
    cute_b = from_dlpack(input_b, assumed_align=16)
    cute_c = from_dlpack(output_c, assumed_align=16)

    # Launch kernel
    {{kernel_name}}_jit(cute_a, cute_b, cute_c)

    return output_c
"""

    # Combine all parts
    template = CuteDSLTemplate(
        name="fixed_add",
        grid=cutedsl_grid,
        source=template_part1 + template_part2 + template_part3
    )

    return template

def fixed_cutedsl_lowering(a: TensorBox, b: TensorBox) -> TensorBox:
    """Fixed CuteDSL lowering."""
    print(f"[FIXED] CuteDSL lowering: {a.get_size()} + {b.get_size()}")

    template = create_fixed_cutedsl_template()
    choices = []

    error = template.maybe_append_choice(
        choices,
        input_nodes=[a.data, b.data],
        layout=a.get_layout()
    )

    if error or not choices:
        print(f"[FIXED] Falling back: {error}")
        default_lowering = lowerings[torch.ops.aten.add.Tensor]
        return default_lowering(a, b)

    print(f"[FIXED] Using CuteDSL with {len(choices)} choices")

    result = autotune_select_algorithm(
        "fixed_cutedsl_add",
        choices,
        [a, b],
        a.get_layout(),
    )

    return result

def test_fixed_cutedsl():
    """Test the fixed CuteDSL template."""
    print("=" * 50)
    print("Fixed CuteDSL Template Test")
    print("=" * 50)

    original = lowerings.get(torch.ops.aten.add.Tensor, None)

    try:
        lowerings[torch.ops.aten.add.Tensor] = fixed_cutedsl_lowering

        def test_add(x, y):
            return x + y

        device = "cuda" if torch.cuda.is_available() else "cpu"
        x = torch.randn(128, 4, device=device, dtype=torch.float32)
        y = torch.randn(128, 4, device=device, dtype=torch.float32)

        print(f"[FIXED] Testing with {x.shape} tensors on {device}")

        compiled_fn = torch.compile(test_add, backend="inductor")
        result = compiled_fn(x, y)

        # Verify correctness
        expected = x + y
        if torch.allclose(result, expected, atol=1e-5):
            print(" [FIXED] Results match!")
            return True
        else:
            print(" [FIXED] Results don't match!")
            return False

    except Exception as e:
        print(f" [FIXED] Failed: {e}")
        import traceback
        traceback.print_exc()
        return False

    finally:
        if original:
            lowerings[torch.ops.aten.add.Tensor] = original
        else:
            lowerings.pop(torch.ops.aten.add.Tensor, None)

if __name__ == "__main__":
    success = test_fixed_cutedsl()
    print("🎉 Fixed test completed!" if success else "💥 Fixed test failed!")

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160108
Approved by: https://github.com/mlazos
This commit is contained in:
drisspg
2025-08-16 22:20:45 -07:00
committed by PyTorch MergeBot
parent d18007a1d0
commit 3c6efd1380
10 changed files with 1108 additions and 1 deletions

View File

@ -0,0 +1,319 @@
# Owner(s): ["module: inductor"]
import unittest
from unittest.mock import MagicMock, patch
import torch
from torch._inductor.test_case import TestCase
try:
import cutlass # noqa: F401
import cutlass.cute as cute # noqa: F401
HAS_CUTLASS = True
except ImportError:
HAS_CUTLASS = False
if HAS_CUTLASS:
from torch._inductor.codegen.cutedsl.cutedsl_kernel import CuteDSLTemplateKernel
from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate
from torch._inductor.select_algorithm import PartialRender
CUTEDSL_ADD_TEMPLATE = r"""
{{gen_defines()}}
@cute.kernel
def {{kernel_name}}_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
bdim, _, _ = cute.arch.block_dim()
thread_idx = bidx * bdim + tidx
m, n = gA.shape
if thread_idx < m * n:
mi = thread_idx // n
ni = thread_idx % n
if mi < m and ni < n:
gC[mi, ni] = gA[mi, ni] + gB[mi, ni]
@cute.jit
def {{kernel_name}}_jit(mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor, stream):
{{gen_defines()}}
m, n = mA.shape
total_threads = m * n
num_blocks = (total_threads + THREADS_PER_BLOCK - 1) // THREADS_PER_BLOCK
kernel = {{kernel_name}}_kernel(mA, mB, mC)
kernel.launch(
grid=[num_blocks, 1, 1],
block=[THREADS_PER_BLOCK, 1, 1],
stream=stream
)
{{def_kernel("input_a", "input_b", "output_c")}}
cute_a = from_dlpack(input_a)
cute_b = from_dlpack(input_b)
cute_c = from_dlpack(output_c)
{{kernel_name}}_jit(cute_a, cute_b, cute_c, cuda.CUstream(stream))
return output_c
"""
@unittest.skipUnless(HAS_CUTLASS, "requires cutlass")
class TestCuteDSLTemplate(TestCase):
"""Test cases for CuteDSL template functionality."""
def test_gen_imports(self):
kernel = CuteDSLTemplateKernel(
kernel_name="test_kernel",
input_nodes=[],
output_node=None,
)
imports = kernel.gen_imports()
self.assertIn("import torch", imports)
self.assertIn("import cutlass", imports)
self.assertIn("import cutlass.cute as cute", imports)
self.assertIn("from cutlass.cute.runtime import from_dlpack", imports)
self.assertIsInstance(imports, str)
lines = imports.strip().split("\n")
self.assertEqual(len(lines), 5)
def test_render_includes_imports(self):
template_source = """@cute.kernel
def {{kernel_name}}_kernel():
pass
{{def_kernel("input", "output")}}
return output"""
mock_template = MagicMock()
mock_template.render = MagicMock(return_value=template_source)
kernel = CuteDSLTemplateKernel(
kernel_name="test_kernel",
input_nodes=[],
output_node=None,
)
result = kernel.render(mock_template)
self.assertIsInstance(result, PartialRender)
rendered_code = result._code
# The imports might have leading whitespace, so strip it
rendered_code_stripped = rendered_code.lstrip()
self.assertTrue(
rendered_code_stripped.startswith("import torch"),
f"Code should start with 'import torch', got: {rendered_code_stripped[:50]}",
)
self.assertIn("import cutlass", rendered_code)
self.assertIn("import cutlass.cute as cute", rendered_code)
self.assertIn("from cutlass.cute.runtime import from_dlpack", rendered_code)
self.assertIn("@cute.kernel", rendered_code)
def test_template_env_contains_hooks(self):
kernel = CuteDSLTemplateKernel(
kernel_name="test_kernel",
input_nodes=[],
output_node=None,
)
captured_env = {}
def mock_render(**kwargs):
captured_env.update(kwargs)
return "rendered"
mock_template = MagicMock()
mock_template.render = mock_render
kernel.render(mock_template)
self.assertIn("def_kernel", captured_env)
self.assertIn("kernel_name", captured_env)
self.assertTrue(callable(captured_env["def_kernel"]))
def test_multiple_templates_unique_names(self):
# Clean registry first
test_name = f"unique_test_{id(self)}"
if test_name in CuteDSLTemplate.all_templates:
del CuteDSLTemplate.all_templates[test_name]
_ = CuteDSLTemplate(
name=test_name,
source="template1",
)
with self.assertRaises(AssertionError):
_ = CuteDSLTemplate(
name=test_name,
source="template2",
)
def test_indented_buffer_usage(self):
kernel = CuteDSLTemplateKernel(
kernel_name="test_kernel",
input_nodes=[],
output_node=None,
)
imports = kernel.gen_imports()
lines = imports.strip().split("\n")
for line in lines:
if line:
self.assertFalse(
line.startswith(" "), f"Line should not be indented: '{line}'"
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_cutedsl_add_e2e(self):
"""End-to-end test with CuteDSL template including code generation verification."""
from torch._inductor.ir import TensorBox
from torch._inductor.lowering import lowerings
from torch._inductor.utils import run_and_get_code
template = CuteDSLTemplate(
name="test_add_e2e",
source=CUTEDSL_ADD_TEMPLATE,
)
def cutedsl_add_lowering(a: TensorBox, b: TensorBox) -> TensorBox:
choices = []
error = template.maybe_append_choice(
choices,
input_nodes=[a, b],
layout=a.get_layout(),
THREADS_PER_BLOCK=256,
)
if error or not choices:
default_lowering = lowerings[torch.ops.aten.add.Tensor]
return default_lowering(a, b)
# Use the single choice directly (no autotuning)
return choices[0].output_node()
with patch.dict(lowerings, {torch.ops.aten.add.Tensor: cutedsl_add_lowering}):
# Test function
def test_add(x, y):
return x + y
device = "cuda"
x = torch.randn(128, 4, device=device, dtype=torch.float32)
y = torch.randn(128, 4, device=device, dtype=torch.float32)
# Compile and get generated code
compiled_fn = torch.compile(test_add, backend="inductor")
result, (code,) = run_and_get_code(compiled_fn, x, y)
# Verify CuteDSL code is present
self.assertIn(
"cute", code.lower(), "CuteDSL code should be in generated code"
)
# Verify parameter generation worked
self.assertIn(
"THREADS_PER_BLOCK", code, "Parameter should be in generated code"
)
# Verify correctness
expected = x + y
self.assertTrue(torch.allclose(result, expected, atol=1e-5))
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_cutedsl_add_e2e_autotune(self):
"""E2E test with multiple CuteDSL template variants for autotuning."""
from torch._inductor.ir import TensorBox
from torch._inductor.lowering import lowerings
from torch._inductor.select_algorithm import autotune_select_algorithm
template = CuteDSLTemplate(
name="test_add_autotune",
source=CUTEDSL_ADD_TEMPLATE,
)
def cutedsl_add_lowering(a: TensorBox, b: TensorBox) -> TensorBox:
choices = []
# Add multiple variants with different thread counts for autotuning
thread_variants = [128, 256, 512]
for threads in thread_variants:
error = template.maybe_append_choice(
choices,
input_nodes=[a, b],
layout=a.get_layout(),
THREADS_PER_BLOCK=threads,
)
if error:
# Skip this variant if it fails
continue
if not choices:
default_lowering = lowerings[torch.ops.aten.add.Tensor]
return default_lowering(a, b)
# Use autotuning to select the best variant
return autotune_select_algorithm(
"cutedsl_add_autotune",
choices,
[a, b],
a.get_layout(),
)
with patch.dict(lowerings, {torch.ops.aten.add.Tensor: cutedsl_add_lowering}):
# Test function
def test_add(x, y):
return x + y
device = "cuda"
x = torch.randn(128, 128, device=device, dtype=torch.float32)
y = torch.randn(128, 128, device=device, dtype=torch.float32)
# Compile and run
compiled_fn = torch.compile(test_add, backend="inductor")
result = compiled_fn(x, y)
# Verify correctness
expected = x + y
self.assertTrue(torch.allclose(result, expected, atol=1e-5))
def test_gen_defines(self):
"""Test that gen_defines correctly generates CuteDSL parameter definitions."""
kernel = CuteDSLTemplateKernel(
kernel_name="test_kernel",
input_nodes=[],
output_node=None,
)
# Test integer parameters
params = kernel.gen_defines(
THREADS_PER_BLOCK=256,
BLOCK_SIZE=128,
ENABLE_FEATURE=True,
)
expected_lines = [
"THREADS_PER_BLOCK: cutlass.Constexpr = 256",
"BLOCK_SIZE: cutlass.Constexpr = 128",
"ENABLE_FEATURE: cutlass.Constexpr = True",
]
for expected_line in expected_lines:
self.assertIn(expected_line, params)
# Test float parameters
params_float = kernel.gen_defines(SCALE_FACTOR=1.5)
self.assertIn("SCALE_FACTOR: cutlass.Constexpr = 1.5", params_float)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
run_tests()

View File

@ -569,6 +569,45 @@ class AsyncCompile:
)
return LambdaFuture(get_result)
def cutedsl(self, kernel_name: str, source_code: str):
"""
Compile CuteDSL (CUTLASS Python DSL) kernels.
Args:
kernel_name: Name of the kernel to be defined
source_code: Source code of the CuteDSL kernel, as a string
Note:
CuteDSL currently requires source files to do its compilation, there we
use the PyCodeCache to write the source code to a file and load it.
"""
from torch._inductor.codegen.cutedsl.cutedsl_kernel import (
CuteDSLKernelWrapper,
MAIN_SUFFIX,
)
kernel_code_log.info("CuteDSL Kernel:\n%s", source_code)
def task():
key, path = torch._inductor.codecache.PyCodeCache.write(source_code)
mod = torch._inductor.codecache.PyCodeCache.load_by_key_path(key, path)
# Find our special entry point named function
main_func_name = f"{kernel_name}_{MAIN_SUFFIX}"
if not hasattr(mod, main_func_name):
available = [name for name in dir(mod) if callable(getattr(mod, name))]
raise RuntimeError(
f"Could not find CuteDSL main kernel function '{main_func_name}'. Available callables: {available}"
)
return CuteDSLKernelWrapper(getattr(mod, main_func_name), kernel_path=path)
if get_compile_threads() <= 1:
return task()
else:
future = self.submit(task)
return LambdaFuture(lambda: future.result())
def wait(self, scope: dict[str, Any]) -> None:
if get_compile_threads() > 1:
with dynamo_timed(

View File

@ -44,7 +44,7 @@ from torch.utils._ordered_set import OrderedSet
if TYPE_CHECKING:
from types import ModuleType
from torch._inductor.select_algorithm import TritonTemplateCaller
from torch._inductor.select_algorithm import PartialRender, TritonTemplateCaller
from . import config
from .runtime.benchmarking import benchmarker
@ -876,6 +876,55 @@ class CppBenchmarkRequest(CPUDeviceBenchmarkMixin, BenchmarkRequest):
return f"{self.kernel_name=}"
class CuteDSLBenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
"""Benchmark request for CuteDSL (CUTLASS Python DSL) kernels."""
def __init__(
self,
kernel_name: str,
input_tensor_meta: Union[TensorMeta, list[TensorMeta]],
output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
extra_args: tuple[Any, ...],
source_code: PartialRender,
) -> None:
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
finalized_code = source_code.finalize_all()
self.module_cache_key, self.module_path = PyCodeCache.write(finalized_code)
def make_run_fn(
self, *input_tensors: torch.Tensor, out: torch.Tensor
) -> Callable[[], None]:
"""
Create a function to run the CuteDSL kernel with the given input and output tensors.
Similar to TritonBenchmarkRequest.make_run_fn but for CuteDSL kernels.
"""
mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path)
# Logic replicated async_compile
from .codegen.cutedsl.cutedsl_kernel import MAIN_SUFFIX
main_func_name = f"{self.kernel_name}_{MAIN_SUFFIX}"
if not hasattr(mod, main_func_name):
available = [name for name in dir(mod) if callable(getattr(mod, name))]
raise RuntimeError(
f"Could not find CuteDSL main kernel function '{main_func_name}'. Available callables: {available}"
)
kernel_func = getattr(mod, main_func_name)
def run_kernel():
device_interface = get_interface_for_device("cuda")
stream = device_interface.get_raw_stream(out.device.index)
return kernel_func(*input_tensors, out, stream=stream)
return run_kernel
def cleanup_run_fn(self) -> None:
"""Clean up any resources used by the kernel."""
@functools.cache
def get_tuning_process_pool() -> TuningProcessPool:
pool = TuningProcessPool()

View File

@ -11,6 +11,7 @@ from ..scheduler import (
SchedulerNode,
)
from .cuda.cuda_cpp_scheduling import CUDACPPScheduling
from .cutedsl.cutedsl_scheduling import CuteDSLScheduling
from .rocm.rocm_cpp_scheduling import ROCmCPPScheduling
from .triton import TritonScheduling
@ -44,6 +45,7 @@ class CUDACombinedScheduling(BaseScheduling):
self._triton_scheduling = TritonScheduling(scheduler)
self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler)
self._rocm_cpp_scheduling = ROCmCPPScheduling(scheduler)
self._cutedsl_scheduling = CuteDSLScheduling(scheduler)
def get_backend_features(self, device: torch.device) -> OrderedSet[BackendFeature]:
return self._triton_scheduling.get_backend_features(device)
@ -53,6 +55,8 @@ class CUDACombinedScheduling(BaseScheduling):
return self._cuda_cpp_scheduling
if self._rocm_cpp_scheduling.is_rocm_cpp_template(node):
return self._rocm_cpp_scheduling
if self._cutedsl_scheduling.is_cutedsl_template(node):
return self._cutedsl_scheduling
return self._triton_scheduling
def can_fuse_vertical(
@ -64,6 +68,11 @@ class CUDACombinedScheduling(BaseScheduling):
node1
) or self._cuda_cpp_scheduling.is_cuda_cpp_template(node2):
return False
# CuteDSL doesn't support vertical fusion currently
elif self._cutedsl_scheduling.is_cutedsl_template(
node1
) or self._cutedsl_scheduling.is_cutedsl_template(node2):
return False
return self._triton_scheduling.can_fuse_vertical(node1, node2)
def can_fuse_horizontal(
@ -74,6 +83,10 @@ class CUDACombinedScheduling(BaseScheduling):
return self._cuda_cpp_scheduling.can_fuse_horizontal(
node1, node2
) # always False at the moment
if self._cutedsl_scheduling.is_cutedsl_template(node):
return self._cutedsl_scheduling.can_fuse_horizontal(
node1, node2
) # always False at the moment
return self._triton_scheduling.can_fuse_horizontal(node1, node2)
def group_fn(
@ -98,6 +111,13 @@ class CUDACombinedScheduling(BaseScheduling):
return self._rocm_cpp_scheduling.codegen_template(
template_node, epilogue_nodes, prologue_nodes
)
elif self._cutedsl_scheduling.is_cutedsl_template(template_node):
# TODO remove this when we add epilogue support
assert not epilogue_nodes
assert not prologue_nodes
return self._cutedsl_scheduling.codegen_template(
template_node, epilogue_nodes, prologue_nodes
)
else:
return self._triton_scheduling.codegen_template(
template_node, epilogue_nodes, prologue_nodes

View File

@ -0,0 +1,101 @@
# CuteDSL Template System
## Quick Start
Writing a CuteDSL template:
```python
from torch._inductor.codegen.cutedsl import CuteDSLTemplate
template_source = """
@cute.kernel
def {{kernel_name}}_kernel(A, B, C):
# Your CUTLASS kernel logic here
pass
{{def_kernel("A", "B", "C")}}
# Call the kernel
{{kernel_name}}_kernel(A, B, C)
return C
"""
my_template = CuteDSLTemplate(
name="my_gemm",
source=template_source,
)
```
## Architecture
- **[CuteDSLTemplate](cutedsl_template.py#L39)**: Template definition and registration. Generates ChoiceCallers for autotuning.
- **[CuteDSLTemplateKernel](cutedsl_kernel.py#L61)**: Handles code generation, provides template hooks (`def_kernel`), manages args.
- **[CuteDSLScheduling](cutedsl_scheduling.py#L28)**: Integrates with Inductor's scheduler, handles kernel compilation via [`async_compile.cutedsl()`](../../async_compile.py#L756).
- **[CuteDSLTemplateBuffer](../../ir.py)**: IR node representing a CuteDSL template operation in the graph.
### Compilation Process
CuteDSL requires source files for compilation (cannot compile from strings directly). The process:
1. **[CuteDSLScheduling](cutedsl_scheduling.py#L59)** generates the kernel code string and calls [`async_compile.cutedsl()`](../../async_compile.py#L756)
2. **[async_compile.cutedsl()](../../async_compile.py#L756)** uses [`PyCodeCache.write()`](../../codecache.py) to write source to a temporary `.py` file
3. **[PyCodeCache](../../codecache.py)** loads the module from disk, enabling CUTLASS compilation
4. The compiled kernel is wrapped in **[CuteDSLKernelWrapper](cutedsl_kernel.py#L22)** to provide a `.run()` interface
5. The generated Python file is cached via PyCodeCache, but CUTLASS compilation runs every time (no kernel-level caching yet)
**Debug tip**: Use `TORCH_LOGS="kernel_code"` to see the generated kernel source and file path during compilation.
## Writing Templates
Templates use Jinja2 syntax with these available hooks:
- `{{kernel_name}}` - Unique kernel identifier
- `{{def_kernel(args...)}}` - Generates kernel function signature and argument handling
- `{{input_nodes}}` - List of input buffers
- `{{output_node}}` - Output buffer
- `{{gen_defines()}}` - Generates autotunable parameter definitions with proper CuteDSL typing
## Autotunable Parameters
CuteDSL templates support autotunable parameters similar to Triton's `tl.constexpr` system:
```python
template_source = r"""
{{gen_defines()}}
@cute.kernel
def {{kernel_name}}_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor):
threads_per_block = THREADS_PER_BLOCK # Uses autotuned value
block_size = BLOCK_SIZE
# ... kernel implementation
"""
# Pass parameters when generating template choices
template.maybe_append_choice(
choices,
input_nodes=[a, b],
layout=layout,
THREADS_PER_BLOCK=256, # cutlass.Constexpr = 256
BLOCK_SIZE=128, # cutlass.Constexpr = 128
SCALE_FACTOR=1.5, # cutlass.Constexpr = 1.5
)
```
Templates must:
1. Define a `@cute.kernel` decorated function
2. Use `{{def_kernel()}}` to create the entry point
3. Return the output tensor
4. Use `{{gen_defines()}}` for autotunable parameters
See [test_cutedsl_template.py](../../../../test/inductor/test_cutedsl_template.py) for complete examples.
## Current Limitations / TODOs
- **No fusion support**: `can_fuse_vertical` and `can_fuse_horizontal` return False
- **Subgraph management**: Bodies and masks not fully implemented
- **File-based compilation**: Requires writing to disk (uses PyCodeCache)
- **Missing epilogue/prologue**: No support for fused operations yet
- **Fixed kernel suffix**: Uses hardcoded "_main" suffix
- **No CUTLASS kernel caching**: Only PyCodeCache works; CUTLASS compilation runs every time (major perf issue)
Note: Requires CUTLASS Python package (`pip install nvidia-cutlass`)

View File

@ -0,0 +1,8 @@
# mypy: allow-untyped-defs
from .cutedsl_template import CuteDSLTemplate, CuteDSLTemplateCaller
__all__ = [
"CuteDSLTemplate",
"CuteDSLTemplateCaller",
]

View File

@ -0,0 +1,222 @@
# mypy: allow-untyped-defs
import contextlib
import dataclasses
import logging
from typing import Any, Callable, Optional
import torch
from torch._inductor.codegen.common import IndentedBuffer, Kernel
from torch._inductor.ir import Buffer
from torch._inductor.select_algorithm import PartialRender
from torch._inductor.utils import OrderedSet
from torch._inductor.virtualized import V
# TODO setting the 'main' kernel w/ this suffix. We have 3 should probably just auto generate this
MAIN_SUFFIX = "main"
log = logging.getLogger(__name__)
kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code")
class CuteDSLKernelWrapper:
"""Wrapper to provide .run() interface for CuteDSL kernels"""
def __init__(
self, kernel_fn: Callable[..., Any], kernel_path: Optional[str] = None
):
self.kernel_fn = kernel_fn
self.kernel_path = kernel_path
kernel_code_log.info("CuteDSL kernel path: %s", kernel_path)
def run(self, *args, stream=None, **kwargs):
"""
Execute the CuteDSL kernel.
Args:
*args: Arguments to pass to the kernel function
stream: CUDA stream to pass to the kernel function
**kwargs: Additional keyword arguments for the kernel
Returns:
Result of the kernel execution
"""
return self.kernel_fn(*args, stream=stream, **kwargs)
@dataclasses.dataclass
class CuteDSLSubgraphInfo:
"""Minimal subgraph info for CuteDSL kernels."""
body: IndentedBuffer
template_mask: Optional[str] = None
template_out: Optional[str] = None
def to_dict(self):
return {
field.name: getattr(self, field.name) for field in dataclasses.fields(self)
}
class CuteDSLTemplateKernel(Kernel):
"""
Template kernel implementation for CuteDSL (CUTLASS Python DSL).
Handles code generation and argument management for CuteDSL CUDA kernels.
Provides CuteDSL-specific functionality for tensor conversion and kernel configuration.
"""
def __init__(
self,
kernel_name: str,
input_nodes: list[Buffer],
output_node: Buffer,
) -> None:
# Call parent Kernel constructor
super().__init__()
self.kernel_name = kernel_name
self.input_nodes = input_nodes
self.output_node = output_node
# TODO Subgraph management for template processing
self.subgraph_bodies: dict[str, CuteDSLSubgraphInfo] = {}
# Template attributes
self.body: IndentedBuffer = IndentedBuffer()
self.template_mask: Optional[str] = None
self.template_out: Optional[str] = None
self.template_indices: Optional[list[Any]] = None
self.render_hooks: dict[str, Any] = {}
# TODO Additional attributes needed by template system
self.prologue_fused_inputs: OrderedSet[str] = OrderedSet()
self.prologue_fused_inputs_preserve_zero: OrderedSet[str] = OrderedSet()
self.named_input_nodes: dict[str, Buffer] = {}
# Create named input nodes mapping
for i, input_node in enumerate(input_nodes):
node_name = getattr(input_node, "name", f"input_{i}")
self.named_input_nodes[node_name] = input_node
def gen_imports(self) -> str:
"""Generate common imports for CuteDSL templates."""
imports = IndentedBuffer()
imports.splice(
"""
import torch
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
import cuda.bindings.driver as cuda
"""
)
return imports.getvalue()
def gen_defines(self, **kwargs) -> str:
"""Generate CuteDSL parameter definitions from kwargs, similar to Triton's gen_defines."""
params = IndentedBuffer()
for name, val in kwargs.items():
params.writeline(f"{name}: cutlass.Constexpr = {val}")
return params.getvalue()
def render(self, template, **kwargs):
"""Render the kernel using the template, returning PartialRender object with hooks."""
# Available {{}} hooks for jinja rendering
template_env = {
"def_kernel": self.def_kernel,
"gen_defines": lambda: self.gen_defines(**kwargs),
}
# Render the template with the environment and provided kwargs
rendered_code = template.render(
kernel_name=self.kernel_name,
input_nodes=self.input_nodes,
output_node=self.output_node,
**template_env,
**kwargs,
)
# Always prepend the common imports
imports = self.gen_imports()
full_code = imports + rendered_code
return PartialRender(full_code, self.render_hooks)
@contextlib.contextmanager
def set_subgraph_body(self, body_name: str):
"""Set the active subgraph body for template processing."""
assert all(
hasattr(self, field.name)
for field in dataclasses.fields(CuteDSLSubgraphInfo)
)
old_state = {
key.name: getattr(self, key.name)
for key in dataclasses.fields(CuteDSLSubgraphInfo)
}
if body_name not in self.subgraph_bodies:
self.subgraph_bodies[body_name] = CuteDSLSubgraphInfo(
body=IndentedBuffer(),
template_mask=None,
template_out=None,
)
subgraph = self.subgraph_bodies[body_name]
for key, value in subgraph.to_dict().items():
setattr(self, key, value)
try:
yield
finally:
# Save current state back to subgraph
self.subgraph_bodies[body_name] = CuteDSLSubgraphInfo(
**{
key.name: getattr(self, key.name)
for key in dataclasses.fields(CuteDSLSubgraphInfo)
}
)
# Restore old state
for key, value in old_state.items():
setattr(self, key, value)
@contextlib.contextmanager
def create_subgraph_body(self, body_name: str):
"""Create a new subgraph body for template processing."""
assert body_name not in self.subgraph_bodies, (
f"Subgraph body '{body_name}' already exists"
)
self.subgraph_bodies[body_name] = CuteDSLSubgraphInfo(
body=IndentedBuffer(),
template_mask=None,
template_out=None,
)
with self.set_subgraph_body(body_name):
yield
def def_kernel(self, *argnames):
"""Define kernel function signature for CuteDSL templates."""
# Populate all the kernel args
for i, input_node in enumerate(self.input_nodes):
self.args.input(input_node.get_name())
if self.output_node:
self.args.output(self.output_node.get_name())
def hook():
code = IndentedBuffer()
code.writeline(f"# Kernel function signature: {self.kernel_name}")
params = list(argnames) + ["stream"]
code.writeline(
f"def {self.kernel_name}_{MAIN_SUFFIX}({', '.join(params)}):"
)
return code.getvalue()
assert "<DEF_KERNEL>" not in self.render_hooks
self.render_hooks["<DEF_KERNEL>"] = hook
return "<DEF_KERNEL>"
def call_kernel(self, name: str, node=None):
"""Call the kernel function. Simplified version of TritonTemplateKernel.call_kernel."""
wrapper = V.graph.wrapper_code
_, call_args, _, arg_types = self.args.python_argdefs()
# TODO triton should really be swapped w/ `python`
wrapper.generate_kernel_call(name, call_args, triton=True, arg_types=arg_types)

View File

@ -0,0 +1,140 @@
# mypy: allow-untyped-defs
import hashlib
import logging
from collections.abc import Sequence
from typing import cast
from torch._inductor.utils import Placeholder
from torch.utils._ordered_set import OrderedSet
from ... import config
from ...codecache import code_hash, get_path
from ...ir import CuteDSLTemplateBuffer
from ...scheduler import (
BaseSchedulerNode,
BaseScheduling,
FusedSchedulerNode,
SchedulerNode,
)
from ...select_algorithm import PartialRender
from ...utils import get_fused_kernel_name, get_kernel_metadata
from ...virtualized import V
from ..common import BackendFeature, IndentedBuffer
log = logging.getLogger(__name__)
class CuteDSLScheduling(BaseScheduling):
"""
Scheduling implementation for CuteDSL (CUTLASS Python DSL) kernels.
This class is intended to be used in combination with other schedulers,
and delegated to by CUDACombinedScheduling.
"""
@classmethod
def get_backend_features(cls, device) -> OrderedSet[BackendFeature]:
return OrderedSet()
@staticmethod
def is_cutedsl_template(node: BaseSchedulerNode) -> bool:
"""Check if a node is a CuteDSL template."""
return isinstance(node, SchedulerNode) and isinstance(
node.node, CuteDSLTemplateBuffer
)
def is_cutedsl_fused_template(self, node: BaseSchedulerNode) -> bool:
"""Check if a node is a fused CuteDSL template."""
return isinstance(node, FusedSchedulerNode) and self.is_cutedsl_template(node)
def can_fuse_vertical(
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
) -> bool:
"""
TODO CuteDSL doesn't support vertical fusion yet.
This could be extended in the future for epilogue fusion.
"""
return False
def define_kernel(self, src_code_str: str, node_schedule) -> str:
"""Produce the kernel string
Args:
src_code_str: The finalized kernel code string
node_schedule: List of nodes in the schedule
Note:
This is a little weird since async_compile.cutedsl() has to write the string to
a file in order to cute compile it. Feels bad to have two...
"""
wrapper = V.graph.wrapper_code
# Use the string as the key for caching
if src_code_str in wrapper.src_to_kernel:
kernel_name = wrapper.src_to_kernel[src_code_str]
else:
fused_name = (
get_fused_kernel_name(node_schedule, config.triton.descriptive_names)
if config.triton.descriptive_names
else ""
)
kernel_hash = hashlib.sha256(src_code_str.encode("utf-8")).hexdigest()[:8]
if fused_name == "fused":
kernel_name = f"cutedsl_{kernel_hash}"
else:
kernel_name = f"cutedsl_{fused_name}_{kernel_hash}"
wrapper.src_to_kernel[src_code_str] = kernel_name
src_code_str = src_code_str.replace(
str(Placeholder.KERNEL_NAME), kernel_name
)
_, _, kernel_path = get_path(code_hash(src_code_str), "py")
compile_wrapper = IndentedBuffer()
compile_wrapper.writeline(f"async_compile.cutedsl({kernel_name!r}, r'''")
compile_wrapper.splice(src_code_str, strip=True)
compile_wrapper.writeline("''')")
metadata_comment = f"# kernel path: {kernel_path}"
origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
metadata_comment += "\n" + origins + "\n" + detailed_origins
wrapper.define_kernel(
kernel_name, compile_wrapper.getvalue(), metadata_comment
)
return kernel_name
def codegen_template(
self,
template_node: BaseSchedulerNode,
epilogue_nodes: Sequence[BaseSchedulerNode],
prologue_nodes: Sequence[BaseSchedulerNode],
):
"""
Codegen a CuteDSL template. Currently doesn't support fusion.
"""
assert self.is_cutedsl_template(template_node), (
"Template node passed to CuteDSLScheduling.codegen_template must be a "
"SchedulerNode that wraps a CuteDSLTemplateBuffer"
)
# TODO remove when supported
assert not epilogue_nodes, "CuteDSL doesn't support epilogue fusion yet"
assert not prologue_nodes, "CuteDSL doesn't support prologue fusion yet"
template_node = cast(SchedulerNode, template_node)
ctb: CuteDSLTemplateBuffer = cast(CuteDSLTemplateBuffer, template_node.node)
kernel, render = ctb.make_kernel_render(ctb) # type: ignore[misc]
template_node.mark_run()
src_code = render()
# Finalize PartialRender if needed
if isinstance(src_code, PartialRender):
src_code_str = src_code.finalize_all()
else:
src_code_str = src_code
with V.set_kernel_handler(kernel):
node_schedule = [template_node]
kernel_name = self.define_kernel(src_code_str, node_schedule)
kernel.call_kernel(kernel_name, ctb)
V.graph.removed_buffers |= kernel.removed_buffers
self.free_buffers_in_scheduler()

View File

@ -0,0 +1,178 @@
# mypy: allow-untyped-defs
import functools
import itertools
from typing import Any, Optional, Union
from torch._inductor.ir import ShapeAsConstantBuffer
from torch._inductor.utils import Placeholder
from torch._logging import getArtifactLogger
from ...autotune_process import CuteDSLBenchmarkRequest, TensorMeta
from ...ir import Buffer, ChoiceCaller, CuteDSLTemplateBuffer, Layout, TensorBox
from ..common import KernelTemplate
from .cutedsl_kernel import CuteDSLTemplateKernel
log = getArtifactLogger(__name__, "output_code")
class CuteDSLTemplate(KernelTemplate):
"""Template for generating CuteDSL (CUTLASS Python DSL) kernels."""
kernel_type: type[Any] = CuteDSLTemplateKernel
index_counter = itertools.count()
all_templates: dict[str, "CuteDSLTemplate"] = {}
def __init__(
self,
name: str,
source: str,
subgraph_fn: Optional[Any] = None,
mask_fn: Optional[Any] = None,
) -> None:
super().__init__(name)
self.source = source
self.subgraph_fn = subgraph_fn
self.mask_fn = mask_fn
self.template = CuteDSLTemplate._template_from_string(source)
assert name not in self.all_templates, f"duplicate template name, {name}"
CuteDSLTemplate.all_templates[name] = self
@staticmethod
@functools.lru_cache(None)
def _template_from_string(source: str) -> Any:
return KernelTemplate._template_from_string(source)
def maybe_append_choice(
self, choices: list[Any], **kwargs: Any
) -> Optional[NotImplementedError]:
"""
Maybe generates a new ChoiceCaller and appends it into existing choices.
Returns None if success, otherwise returns the error.
"""
try:
choices.append(self.generate(**kwargs))
return None
except NotImplementedError as e:
log.debug("CuteDSL template choice generation failed: %s", e)
return e
except Exception as e:
log.debug("CuteDSL template choice generation error: %s", e)
return NotImplementedError(f"CuteDSL template failed: {e}")
def generate(self, **kwargs: Any) -> ChoiceCaller:
"""Generate the CuteDSL kernel caller."""
input_nodes = kwargs.pop("input_nodes")
layout = kwargs.pop("layout")
kernel_name = f"cutedsl_{self.name}_{next(self.index_counter)}"
if self.template is None:
raise RuntimeError("Template compilation failed (Jinja2 required)")
self.output_node: Buffer = Buffer(name="buf_out", layout=layout)
kernel = self.kernel_type(
kernel_name=kernel_name,
input_nodes=input_nodes,
output_node=self.output_node,
)
code = kernel.render(self.template, **kwargs)
log.debug("Generated CuteDSL Code:\n%s", code)
bmreq = CuteDSLBenchmarkRequest(
kernel_name=kernel_name,
input_tensor_meta=TensorMeta.from_irnodes(input_nodes),
output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
extra_args=tuple(),
source_code=code,
)
def make_kernel_render(out_node, hint_override: Optional[int] = None):
render_kernel = self.kernel_type(
kernel_name=str(Placeholder.KERNEL_NAME),
input_nodes=input_nodes,
output_node=out_node,
)
def render():
return render_kernel.render(self.template, **kwargs)
return render_kernel, render
return CuteDSLTemplateCaller(
name=kernel_name,
input_nodes=input_nodes,
layout=layout,
make_kernel_render=make_kernel_render,
bmreq=bmreq,
template=self,
)
class CuteDSLTemplateCaller(ChoiceCaller):
"""Caller for CuteDSL templates that integrates with the autotuning system."""
def __init__(
self,
name: str,
input_nodes: list[Buffer],
layout: Layout,
make_kernel_render: Any,
bmreq: CuteDSLBenchmarkRequest,
template: "CuteDSLTemplate",
):
super().__init__(
name=name,
input_nodes=input_nodes,
layout=layout,
description=f"CuteDSL template {name}",
)
self.make_kernel_render = make_kernel_render
self.bmreq = bmreq
self.template = template
def __str__(self) -> str:
return f"CuteDSLTemplateCaller({self.name})"
def benchmark(self, *args, out) -> float:
"""Benchmark the kernel execution."""
return self.bmreq.benchmark(*args, out=out)
def output_node(self) -> Union[TensorBox, ShapeAsConstantBuffer]:
"""Create the output node for this template choice."""
return TensorBox.create(
CuteDSLTemplateBuffer(
layout=self.layout,
inputs=self.input_nodes,
make_kernel_render=self.make_kernel_render,
template=self.template,
)
)
def call_name(self) -> str:
"""Return the kernel call name."""
return self.name
def to_callable(self) -> Any:
"""Return callable that can execute this kernel."""
return self.make_kernel_render
def hash_key(self) -> str:
"""Return unique hash key for this choice."""
return "-".join(
[
self.name.rsplit("_", 1)[0],
self.bmreq.module_cache_key,
]
)
def info_dict(self) -> dict[str, Any]:
"""Return information about this kernel."""
return {
"name": self.name,
"backend": "CuteDSL",
"template": self.template.name,
}

View File

@ -5132,6 +5132,37 @@ class CppTemplateBuffer(TemplateBuffer):
return super().get_layout()
class CuteDSLTemplateBuffer(TemplateBuffer):
"""
Buffer for CuteDSL (CUTLASS Python DSL) template kernels.
Similar to other template buffers but specialized for CuteDSL operations.
"""
def __init__(
self,
layout: Layout,
inputs: Sequence[IRNode],
make_kernel_render: Callable[_P, _T],
template: Any,
mutated_inputs: Optional[Iterable[IRNode]] = None,
) -> None:
super().__init__(layout, inputs, make_kernel_render)
self.template = template
self.mutated_inputs = mutated_inputs
self.outputs: list[Buffer] = [self]
if mutated_inputs is not None:
assert isinstance(self.inputs[0], IRNode), type(self.inputs[0])
device = self.inputs[0].get_device()
self.outputs += [
MutationOutput(NoneLayout(device=device), buf, self)
for buf in mutated_inputs
]
def get_outputs(self) -> list[Buffer]:
return self.outputs
def is_node_sequence(
nodes: Sequence[Union[IRNode, Sequence[IRNode]]],
) -> TypeIs[Sequence[IRNode]]: