mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
## Summary Still figuring out what actually writing a template should look like, but lands alot of the base infra <img width="1267" height="262" alt="Screenshot 2025-08-16 at 10 22 12 PM" src="https://github.com/user-attachments/assets/229f8bfa-0cb4-4fb1-8530-f535e569d350" /> Test code: ```Python #!/usr/bin/env python3 """ Fixed CuteDSL template test with proper def_kernel usage. """ import torch import torch._inductor.config as config from torch._inductor.lowering import lowerings from torch._inductor.ir import TensorBox from torch._inductor.select_algorithm import autotune_select_algorithm from torch._inductor.codegen.cutedsl import CuteDSLTemplate def create_fixed_cutedsl_template(): """Create a properly structured CuteDSL template.""" def cutedsl_grid(M, N, meta): return (1,) # Part 1: Imports and kernel definition template_part1 = r""" import torch import cutlass import cutlass.cute as cute from cutlass.cute.runtime import from_dlpack @cute.kernel def {{kernel_name}}_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor): # Get thread and block indices tidx, _, _ = cute.arch.thread_idx() bidx, _, _ = cute.arch.block_idx() bdim, _, _ = cute.arch.block_dim() thread_idx = bidx * bdim + tidx m, n = gA.shape if thread_idx < m * n: mi = thread_idx // n ni = thread_idx % n if mi < m and ni < n: a_val = gA[mi, ni] b_val = gB[mi, ni] result = a_val + b_val gC[mi, ni] = a_val + b_val """ # Part 2: JIT wrapper function template_part2 = r""" @cute.jit def {{kernel_name}}_jit(mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor): m, n = mA.shape total_threads = m * n threads_per_block = 256 num_blocks = (total_threads + threads_per_block - 1) // threads_per_block kernel = {{kernel_name}}_kernel(mA, mB, mC) kernel.launch( grid=[num_blocks, 1, 1], block=[threads_per_block, 1, 1] ) """ # Part 3: Main kernel function template_part3 = r""" {{def_kernel("input_a", "input_b", "output_c")}} cute_a = from_dlpack(input_a, assumed_align=16) cute_b = from_dlpack(input_b, assumed_align=16) cute_c = from_dlpack(output_c, assumed_align=16) # Launch kernel {{kernel_name}}_jit(cute_a, cute_b, cute_c) return output_c """ # Combine all parts template = CuteDSLTemplate( name="fixed_add", grid=cutedsl_grid, source=template_part1 + template_part2 + template_part3 ) return template def fixed_cutedsl_lowering(a: TensorBox, b: TensorBox) -> TensorBox: """Fixed CuteDSL lowering.""" print(f"[FIXED] CuteDSL lowering: {a.get_size()} + {b.get_size()}") template = create_fixed_cutedsl_template() choices = [] error = template.maybe_append_choice( choices, input_nodes=[a.data, b.data], layout=a.get_layout() ) if error or not choices: print(f"[FIXED] Falling back: {error}") default_lowering = lowerings[torch.ops.aten.add.Tensor] return default_lowering(a, b) print(f"[FIXED] Using CuteDSL with {len(choices)} choices") result = autotune_select_algorithm( "fixed_cutedsl_add", choices, [a, b], a.get_layout(), ) return result def test_fixed_cutedsl(): """Test the fixed CuteDSL template.""" print("=" * 50) print("Fixed CuteDSL Template Test") print("=" * 50) original = lowerings.get(torch.ops.aten.add.Tensor, None) try: lowerings[torch.ops.aten.add.Tensor] = fixed_cutedsl_lowering def test_add(x, y): return x + y device = "cuda" if torch.cuda.is_available() else "cpu" x = torch.randn(128, 4, device=device, dtype=torch.float32) y = torch.randn(128, 4, device=device, dtype=torch.float32) print(f"[FIXED] Testing with {x.shape} tensors on {device}") compiled_fn = torch.compile(test_add, backend="inductor") result = compiled_fn(x, y) # Verify correctness expected = x + y if torch.allclose(result, expected, atol=1e-5): print("✅ [FIXED] Results match!") return True else: print("❌ [FIXED] Results don't match!") return False except Exception as e: print(f"❌ [FIXED] Failed: {e}") import traceback traceback.print_exc() return False finally: if original: lowerings[torch.ops.aten.add.Tensor] = original else: lowerings.pop(torch.ops.aten.add.Tensor, None) if __name__ == "__main__": success = test_fixed_cutedsl() print("🎉 Fixed test completed!" if success else "💥 Fixed test failed!") ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/160108 Approved by: https://github.com/mlazos
3.8 KiB
3.8 KiB
CuteDSL Template System
Quick Start
Writing a CuteDSL template:
from torch._inductor.codegen.cutedsl import CuteDSLTemplate
template_source = """
@cute.kernel
def {{kernel_name}}_kernel(A, B, C):
# Your CUTLASS kernel logic here
pass
{{def_kernel("A", "B", "C")}}
# Call the kernel
{{kernel_name}}_kernel(A, B, C)
return C
"""
my_template = CuteDSLTemplate(
name="my_gemm",
source=template_source,
)
Architecture
- CuteDSLTemplate: Template definition and registration. Generates ChoiceCallers for autotuning.
- CuteDSLTemplateKernel: Handles code generation, provides template hooks (
def_kernel
), manages args. - CuteDSLScheduling: Integrates with Inductor's scheduler, handles kernel compilation via
async_compile.cutedsl()
. - CuteDSLTemplateBuffer: IR node representing a CuteDSL template operation in the graph.
Compilation Process
CuteDSL requires source files for compilation (cannot compile from strings directly). The process:
- CuteDSLScheduling generates the kernel code string and calls
async_compile.cutedsl()
- async_compile.cutedsl() uses
PyCodeCache.write()
to write source to a temporary.py
file - PyCodeCache loads the module from disk, enabling CUTLASS compilation
- The compiled kernel is wrapped in CuteDSLKernelWrapper to provide a
.run()
interface - The generated Python file is cached via PyCodeCache, but CUTLASS compilation runs every time (no kernel-level caching yet)
Debug tip: Use TORCH_LOGS="kernel_code"
to see the generated kernel source and file path during compilation.
Writing Templates
Templates use Jinja2 syntax with these available hooks:
{{kernel_name}}
- Unique kernel identifier{{def_kernel(args...)}}
- Generates kernel function signature and argument handling{{input_nodes}}
- List of input buffers{{output_node}}
- Output buffer{{gen_defines()}}
- Generates autotunable parameter definitions with proper CuteDSL typing
Autotunable Parameters
CuteDSL templates support autotunable parameters similar to Triton's tl.constexpr
system:
template_source = r"""
{{gen_defines()}}
@cute.kernel
def {{kernel_name}}_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor):
threads_per_block = THREADS_PER_BLOCK # Uses autotuned value
block_size = BLOCK_SIZE
# ... kernel implementation
"""
# Pass parameters when generating template choices
template.maybe_append_choice(
choices,
input_nodes=[a, b],
layout=layout,
THREADS_PER_BLOCK=256, # cutlass.Constexpr = 256
BLOCK_SIZE=128, # cutlass.Constexpr = 128
SCALE_FACTOR=1.5, # cutlass.Constexpr = 1.5
)
Templates must:
- Define a
@cute.kernel
decorated function - Use
{{def_kernel()}}
to create the entry point - Return the output tensor
- Use
{{gen_defines()}}
for autotunable parameters
See test_cutedsl_template.py for complete examples.
Current Limitations / TODOs
- No fusion support:
can_fuse_vertical
andcan_fuse_horizontal
return False - Subgraph management: Bodies and masks not fully implemented
- File-based compilation: Requires writing to disk (uses PyCodeCache)
- Missing epilogue/prologue: No support for fused operations yet
- Fixed kernel suffix: Uses hardcoded "_main" suffix
- No CUTLASS kernel caching: Only PyCodeCache works; CUTLASS compilation runs every time (major perf issue)
Note: Requires CUTLASS Python package (pip install nvidia-cutlass
)