Updates to CuTe DSL template renderer (#161117)

# Summary
This adds a few more render functions available to template writers, specifically get_output and modification. The reasons why are more clear in the next PR in this stack.

<img width="1645" height="364" alt="Screenshot 2025-08-21 at 1 48 50 PM" src="https://github.com/user-attachments/assets/2d508fda-4273-43ef-9edf-086e592e9249" />

Majority of the new cod is around the OpOverrides for CuTe DSL. It is alot to test and most of the actual testing I have been doing is via score_mods to the flash_attention at the next layer of this stack.

A bunch of score mods that me and Claude came up with , that exercise the actual ops.
``` Py

def causal_mask(score, b, h, q_idx, kv_idx):
    """Causal attention mask."""
    return torch.where(q_idx >= kv_idx, score, float("-inf"))

def relative_bias(score, b, h, token_q, token_kv):
    """Relative position bias."""
    return score + torch.abs(token_q - token_kv)

def relative_bias_v2(score, b, h, token_q, token_kv):
    """Relative position bias with factor of 2."""
    return score + 2 * torch.abs(token_q - token_kv)

def times_two(score, b, h, q_idx, kv_idx):
    """Simple score modification that doubles the score."""
    return score * 2

def alibi_bias(score, b, h, q_idx, kv_idx):
    """ALiBi (Attention with Linear Biases) - used in some modern models."""
    # Different slopes for different heads
    slope = 2 ** (-8 * (h + 1) / 8)  # Simplified version
    return score - slope * torch.abs(q_idx - kv_idx)

def sliding_window(score, b, h, q_idx, kv_idx, window_size=256):
    """Sliding window attention - only attend to nearby tokens."""
    return torch.where(
        torch.abs(q_idx - kv_idx) <= window_size,
        score,
        float("-inf")
    )

def block_diagonal(score, b, h, q_idx, kv_idx, block_size=64):
    """Block diagonal attention pattern."""
    q_block = q_idx // block_size
    kv_block = kv_idx // block_size
    return torch.where(q_block == kv_block, score, float("-inf"))

def additive_bias(score, b, h, q_idx, kv_idx):
    """Test simple addition with position-based bias."""
    return score + (q_idx + kv_idx) * 0.01

def multiplicative_decay(score, b, h, q_idx, kv_idx):
    """Test multiplication with distance-based decay."""
    distance = torch.abs(q_idx - kv_idx)
    return score * torch.exp(-0.1 * distance)

def sine_wave_bias(score, b, h, q_idx, kv_idx):
    """Test trigonometric functions."""
    return score + 0.1 * torch.sin(2 * math.pi * (q_idx - kv_idx) / 64)

def log_distance_penalty(score, b, h, q_idx, kv_idx):
    """Test logarithmic operations."""
    distance = torch.abs(q_idx - kv_idx).float()
    return score - torch.log(1 + distance)

def alternating_mask(score, b, h, q_idx, kv_idx):
    """Test with alternating pattern - good for branch prediction."""
    return torch.where((q_idx + kv_idx) % 2 == 0, score, float("-inf"))

def head_specific_pattern(score, b, h, q_idx, kv_idx):
    """Different behavior per attention head."""
    even_head = h % 2 == 0
    causal = q_idx >= kv_idx
    return torch.where(even_head & causal, score, float("-inf"))

def sparse_strided(score, b, h, q_idx, kv_idx, stride=4):
    """Sparse attention with strided pattern."""
    return torch.where(
        (kv_idx % stride == 0) | (q_idx == kv_idx),
        score,
        float("-inf")
    )

def causal_with_global(score, b, h, q_idx, kv_idx):
    """Causal mask but first few tokens are globally attended."""
    is_causal = q_idx >= kv_idx
    is_global = kv_idx < 4
    return torch.where(is_causal | is_global, score, float("-inf"))

def dilated_attention(score, b, h, q_idx, kv_idx, dilation_rate=2):
    """Dilated attention pattern - exponentially increasing gaps."""
    distance = torch.abs(q_idx - kv_idx)
    is_attended = (distance == 0) | ((distance > 0) & ((distance & (distance - 1)) == 0))
    return torch.where(is_attended, score, float("-inf"))

```

Example outputs:
```
[Test Suite]
Config: batch=4, heads=32, seq_q=8192, seq_kv=8192, dim=128

[Test 1: none]
[No score_mod, flash='enabled'] Found flash_attncute: True
[No score_mod, flash='disabled'] Found flash_attncute: False
✓ Outputs match between flash enabled/disabled
✓ Output matches eager SDPA (rtol=0.001, atol=0.001)

[Test 2: causal]
[With score_mod, flash='enabled'] Found flash_attncute: True
[With score_mod, flash='disabled'] Found flash_attncute: False
✗ Outputs differ between flash modes: Tensor-likes are not close!

Mismatched elements: 17879 / 134217728 (0.0%)
Greatest absolute difference: 0.0078125 at index (0, 15, 15, 60) (up to 0.001 allowed)
Greatest relative difference: 2.5 at index (3, 22, 153, 126) (up to 0.001 allowed)

[Test 3: rel_bias]
[With score_mod, flash='enabled'] Found flash_attncute: True
[With score_mod, flash='disabled'] Found flash_attncute: False
✗ Outputs differ between flash modes: Tensor-likes are not close!

Mismatched elements: 12836 / 134217728 (0.0%)
Greatest absolute difference: 0.015625 at index (0, 3, 2775, 84) (up to 0.001 allowed)
Greatest relative difference: 11.8125 at index (3, 28, 4095, 76) (up to 0.001 allowed)

[Test 4: rel_bias_v2]
```

This is bfloat16 and there are no major differences. The list of pointwise ops here isn't exhaustive but it is fairly covering

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161117
Approved by: https://github.com/mlazos
This commit is contained in:
drisspg
2025-08-27 15:31:42 -07:00
committed by PyTorch MergeBot
parent 12c0cf3fab
commit 30edac5da6
6 changed files with 818 additions and 70 deletions

View File

@ -2,8 +2,12 @@
import unittest
from unittest.mock import MagicMock, patch
from expecttest import assert_expected_inline
import torch
from torch._inductor.test_case import TestCase
from torch._inductor.virtualized import V
from torch.testing._internal.inductor_utils import MockGraphHandler
try:
@ -19,6 +23,7 @@ if HAS_CUTLASS:
from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate
from torch._inductor.select_algorithm import PartialRender
CUTEDSL_ADD_TEMPLATE = r"""
{{gen_defines()}}
@ -52,13 +57,13 @@ def {{kernel_name}}_jit(mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor, strea
stream=stream
)
{{def_kernel("input_a", "input_b", "output_c")}}
{{def_kernel("input_a", "input_b")}}
cute_a = from_dlpack(input_a)
cute_b = from_dlpack(input_b)
cute_c = from_dlpack(output_c)
cute_c = from_dlpack({{get_output()}})
{{kernel_name}}_jit(cute_a, cute_b, cute_c, cuda.CUstream(stream))
return output_c
return {{get_output()}}
"""
@ -82,7 +87,7 @@ class TestCuteDSLTemplate(TestCase):
self.assertIsInstance(imports, str)
lines = imports.strip().split("\n")
self.assertEqual(len(lines), 5)
self.assertEqual(len(lines), 7)
def test_render_includes_imports(self):
template_source = """@cute.kernel
@ -299,18 +304,178 @@ def {{kernel_name}}_kernel():
ENABLE_FEATURE=True,
)
expected_lines = [
"THREADS_PER_BLOCK: cutlass.Constexpr = 256",
"BLOCK_SIZE: cutlass.Constexpr = 128",
"ENABLE_FEATURE: cutlass.Constexpr = True",
]
assert_expected_inline(
params,
"""\
THREADS_PER_BLOCK: cutlass.Constexpr = 256
BLOCK_SIZE: cutlass.Constexpr = 128
ENABLE_FEATURE: cutlass.Constexpr = True
""",
)
for expected_line in expected_lines:
self.assertIn(expected_line, params)
# Test float parameters
params_float = kernel.gen_defines(SCALE_FACTOR=1.5)
self.assertIn("SCALE_FACTOR: cutlass.Constexpr = 1.5", params_float)
assert_expected_inline(
params_float,
"""\
SCALE_FACTOR: cutlass.Constexpr = 1.5
""",
)
def test_template_aliasing(self):
"""Test that template variables are correctly aliased to function arguments."""
from torch._inductor.ir import Buffer
mock_input1 = MagicMock(spec=Buffer)
mock_input1.get_name.return_value = "buf_input1"
mock_input2 = MagicMock(spec=Buffer)
mock_input2.get_name.return_value = "buf_input2"
mock_output = MagicMock(spec=Buffer)
mock_output.get_name.return_value = "buf_output"
mock_graph = MockGraphHandler()
with V.set_graph_handler(mock_graph):
kernel = CuteDSLTemplateKernel(
kernel_name="test_aliasing",
input_nodes=[mock_input1, mock_input2],
output_node=mock_output,
)
def_kernel_hook = kernel.def_kernel("custom_a", "custom_b")
self.assertEqual(def_kernel_hook, "<DEF_KERNEL>")
self.assertIn("<DEF_KERNEL>", kernel.render_hooks)
hook_fn = kernel.render_hooks["<DEF_KERNEL>"]
generated_code = hook_fn()
# Check that the generated code contains the expected aliasing statements
self.assertIn("custom_a = arg_custom_a", generated_code)
self.assertIn("custom_b = arg_custom_b", generated_code)
def test_get_output_hook(self):
"""Test the get_output() template hook."""
from torch._inductor.ir import Buffer
mock_output = MagicMock(spec=Buffer)
mock_output.get_name.return_value = "buf_test_output"
mock_graph = MockGraphHandler()
with V.set_graph_handler(mock_graph):
kernel = CuteDSLTemplateKernel(
kernel_name="test_output",
input_nodes=[],
output_node=mock_output,
)
with self.assertRaises(ValueError):
# error if no output buffer
result = kernel.get_output()
kernel.args.output_buffers["buf_test_output"] = "arg_buf_test_output"
result = kernel.get_output()
self.assertEqual(result, "arg_buf_test_output")
def test_modification_subgraph(self):
"""Test the modification() method and subgraph processing."""
from torch._inductor.ir import Buffer
mock_subgraph1 = MagicMock(spec=Buffer)
mock_subgraph2 = MagicMock(spec=Buffer)
subgraphs = [mock_subgraph1, mock_subgraph2]
mock_output = MagicMock(spec=Buffer)
mock_output.get_name.return_value = "buf_output"
kernel = CuteDSLTemplateKernel(
kernel_name="test_modification",
input_nodes=[],
output_node=mock_output,
subgraphs=subgraphs,
)
result = kernel._get_subgraph(0)
self.assertEqual(result, mock_subgraph1)
result = kernel._get_subgraph(1)
self.assertEqual(result, mock_subgraph2)
with self.assertRaises(AssertionError):
kernel._get_subgraph(2)
def test_cutedsl_op_overrides(self):
"""Test the new CuteDSLOpOverrides class."""
import torch
from torch._inductor.codegen.common import CSEVariable
from torch._inductor.codegen.cutedsl.cutedsl_op_overrides import (
CuteDSLOpOverrides,
)
from torch.utils._sympy.value_ranges import ValueRanges
mock_cse_a = MagicMock(spec=CSEVariable)
mock_cse_a.__str__.return_value = "tensor_a"
mock_cse_a.dtype = torch.float32
mock_cse_a.bounds = ValueRanges.unknown()
mock_cse_b = MagicMock(spec=CSEVariable)
mock_cse_b.__str__.return_value = "tensor_b"
mock_cse_b.dtype = torch.float32
mock_cse_b.bounds = ValueRanges.unknown()
mock_graph = MockGraphHandler()
with V.set_graph_handler(mock_graph):
kernel = CuteDSLTemplateKernel(
kernel_name="test_ops",
input_nodes=[],
output_node=None,
)
with V.set_kernel_handler(kernel):
result = CuteDSLOpOverrides.add(mock_cse_a, mock_cse_b)
self.assertIsInstance(result, CSEVariable)
result = CuteDSLOpOverrides.mul(mock_cse_a, mock_cse_b)
self.assertIsInstance(result, CSEVariable)
result = CuteDSLOpOverrides.truediv(mock_cse_a, mock_cse_b)
self.assertIsInstance(result, CSEVariable)
result = CuteDSLOpOverrides.exp(mock_cse_a)
self.assertIsInstance(result, CSEVariable)
result = CuteDSLOpOverrides.sqrt(mock_cse_a)
self.assertIsInstance(result, CSEVariable)
with self.assertRaises(NotImplementedError):
result = CuteDSLOpOverrides.maximum(mock_cse_a, mock_cse_b)
result = CuteDSLOpOverrides.minimum(mock_cse_a, mock_cse_b)
scalar_result = CuteDSLOpOverrides._ensure_tensor_ssa("5.0", mock_cse_a)
self.assertEqual(scalar_result, "cute.full_like(tensor_a, 5.0)")
tensor_result = CuteDSLOpOverrides._ensure_tensor_ssa(mock_cse_a, mock_cse_b)
self.assertEqual(tensor_result, "tensor_a")
def test_cse_integration(self):
"""Test CSE (Common Subexpression Elimination) integration."""
from torch._inductor.codegen.common import CSE
mock_graph = MockGraphHandler()
with V.set_graph_handler(mock_graph):
kernel = CuteDSLTemplateKernel(
kernel_name="test_cse",
input_nodes=[],
output_node=None,
)
self.assertIsInstance(kernel.cse, CSE)
self.assertEqual(kernel.cse.name_prefix, "tmp")
with V.set_kernel_handler(kernel):
test_expr = "x"
var = kernel.cse.generate(kernel.body, test_expr, dtype=None)
self.assertTrue(str(var).startswith("tmp"))
if __name__ == "__main__":

View File

@ -10,12 +10,15 @@ from torch._inductor.codegen.cuda.cutlass_utils import (
torch_dtype_to_cutlass_type,
try_import_cutlass,
)
from torch._inductor.graph import GraphLowering
from torch._inductor.ir import ComputedBuffer, FixedLayout, PermuteView, Pointwise
from torch._inductor.scheduler import BaseSchedulerNode
from torch._inductor.utils import OrderedSet
from torch.testing._internal.common_cuda import SM90OrLater
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA_AND_TRITON
from torch.testing._internal.inductor_utils import (
HAS_CPU,
HAS_CUDA_AND_TRITON,
MockGraphHandler,
)
if try_import_cutlass():
@ -105,17 +108,6 @@ class MockComputedBuffer(ComputedBuffer):
return 1
class MockGraphHandler(GraphLowering):
def __init__(self, name_to_buffer):
import torch._inductor.sizevars
self.sizevars = torch._inductor.sizevars.SizeVarAllocator()
self.name_to_buffer = name_to_buffer
self.graph_inputs = dict()
self.mutated_buffers = OrderedSet()
self.constants = dict()
class TestCutlassEVT(TestCase):
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")

View File

@ -2,19 +2,31 @@
import contextlib
import dataclasses
import logging
import textwrap
from typing import Any, Callable, Optional
import sympy
import torch
from torch._inductor.codegen.common import IndentedBuffer, Kernel
from torch._inductor.ir import Buffer
from torch._inductor.select_algorithm import PartialRender
from torch._inductor.codegen.common import (
CSE,
CSEVariable,
IndentedBuffer,
Kernel,
ValueRanges,
)
from torch._inductor.ir import Buffer, ComputedBuffer, InputBuffer
from torch._inductor.ops_handler import StoreMode
from torch._inductor.utils import OrderedSet
from torch._inductor.virtualized import V
from .cutedsl_op_overrides import CuteDSLOpOverrides
# TODO setting the 'main' kernel w/ this suffix. We have 3 should probably just auto generate this
MAIN_SUFFIX = "main"
log = logging.getLogger(__name__)
kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code")
@ -70,14 +82,14 @@ class CuteDSLTemplateKernel(Kernel):
kernel_name: str,
input_nodes: list[Buffer],
output_node: Buffer,
subgraphs: Optional[list[Buffer]] = None,
) -> None:
# Call parent Kernel constructor
super().__init__()
self.kernel_name = kernel_name
self.input_nodes = input_nodes
self.output_node = output_node
# TODO Subgraph management for template processing
self.subgraphs = subgraphs
self.subgraph_bodies: dict[str, CuteDSLSubgraphInfo] = {}
# Template attributes
@ -97,6 +109,8 @@ class CuteDSLTemplateKernel(Kernel):
node_name = getattr(input_node, "name", f"input_{i}")
self.named_input_nodes[node_name] = input_node
self.cse = CSE(name_prefix="tmp")
def gen_imports(self) -> str:
"""Generate common imports for CuteDSL templates."""
imports = IndentedBuffer()
@ -107,6 +121,8 @@ class CuteDSLTemplateKernel(Kernel):
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
import cuda.bindings.driver as cuda
from cutlass._mlir.dialects import math as mlir_math
import operator
"""
)
return imports.getvalue()
@ -119,11 +135,15 @@ class CuteDSLTemplateKernel(Kernel):
return params.getvalue()
def render(self, template, **kwargs):
from torch._inductor.select_algorithm import PartialRender
"""Render the kernel using the template, returning PartialRender object with hooks."""
# Available {{}} hooks for jinja rendering
template_env = {
"def_kernel": self.def_kernel,
"gen_defines": lambda: self.gen_defines(**kwargs),
"get_output": self.get_output,
"modification": self.modification,
}
# Render the template with the environment and provided kwargs
@ -194,29 +214,203 @@ class CuteDSLTemplateKernel(Kernel):
def def_kernel(self, *argnames):
"""Define kernel function signature for CuteDSL templates."""
# Populate all the kernel args
renames = IndentedBuffer(initial_indent=1)
for i, input_node in enumerate(self.input_nodes):
self.args.input(input_node.get_name())
buf_name = input_node.get_name()
self.args.input(buf_name)
# Template aliasing: converts template variables (e.g., "input_a") to function args (e.g., "arg_input_a")
# and generates rename statements so template code can use the original names
if i < len(argnames):
template_name = argnames[i]
arg_name = f"arg_{template_name}"
self.args.input_buffers[buf_name] = arg_name
renames.writeline(f"{template_name} = {arg_name}")
if self.output_node:
self.args.output(self.output_node.get_name())
def hook():
# Deferred execution: arg definitions must be collected after template processing adds all args
arg_defs, *_ = self.args.python_argdefs()
code = IndentedBuffer()
code.writeline(f"# Kernel function signature: {self.kernel_name}")
params = list(argnames) + ["stream"]
params = [x.full_name() for x in arg_defs] + ["stream"]
code.writeline(
f"def {self.kernel_name}_{MAIN_SUFFIX}({', '.join(params)}):"
)
with code.indent():
code.splice(renames.getvalue())
return code.getvalue()
assert "<DEF_KERNEL>" not in self.render_hooks
# Placeholder-based rendering: hook will be called when template encounters "<DEF_KERNEL>"
self.render_hooks["<DEF_KERNEL>"] = hook
return "<DEF_KERNEL>"
def get_output(self):
"""Get the actual argument name for the output buffer."""
assert self.output_node, "Output node must exist to get output buffer name"
buf_name = self.output_node.get_name()
output = self.args.output_buffers.get(buf_name, None)
if output is None:
raise ValueError(f"Output buffer '{buf_name}' not found in args")
return output
def call_kernel(self, name: str, node=None):
"""Call the kernel function. Simplified version of TritonTemplateKernel.call_kernel."""
wrapper = V.graph.wrapper_code
_, call_args, _, arg_types = self.args.python_argdefs()
# TODO triton should really be swapped w/ `python`
wrapper.generate_kernel_call(name, call_args, triton=True, arg_types=arg_types)
def _get_subgraph(self, subgraph_number: int):
"""Get subgraph by number for modification processing."""
assert isinstance(subgraph_number, int)
assert isinstance(self.subgraphs, list)
assert subgraph_number < len(self.subgraphs), (
f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}"
)
assert self.body.getvalue() == "", (
"Body should be clear before adding a modification"
)
return self.subgraphs[subgraph_number]
def modification(
self,
subgraph_number: int,
output_name: Optional[str],
mask: Optional[str] = None,
**fixed_inputs,
) -> str:
"""Generate CuteDSL code for a subgraph modification."""
# Find unique name to avoid collisions between multiple modifications of same subgraph
num = 0
while f"mod_{subgraph_number}_{num}" in self.subgraph_bodies:
num += 1
with self.create_subgraph_body(f"mod_{subgraph_number}_{num}"):
subgraph = self._get_subgraph(subgraph_number)
modification_handler = ModificationWrapperCuteDSL(
self, subgraph_number, fixed_inputs, mask
)
with V.set_kernel_handler(self), V.set_ops_handler(modification_handler):
assert isinstance(subgraph, (ComputedBuffer, list)), (
f"Expected ComputedBuffer or List[ComputedBuffer], got {type(subgraph)}"
)
if isinstance(subgraph, list):
raise NotImplementedError(
"Scatter graphs are not supported for CuteDSL"
)
if isinstance(subgraph.data, InputBuffer):
# grad_score_mod can be InputBuffers
out = subgraph.data.make_loader()(())
else:
# Inline a pointwise lowering into the template
out = subgraph.data.inner_fn(())
if output_name is not None:
assert out is not None, (
f"Expected computation result for named output {output_name}"
)
self.body.writeline(f"{output_name} = {out.value}")
else:
# Side-effect only: no output assignment (currently only for scatter operations)
raise NotImplementedError(
"Side-effect only modifications not yet supported for CuteDSL"
)
return self.body.getvalue()
class ModificationWrapperCuteDSL(V.WrapperHandler): # type: ignore[name-defined]
"""
Wrapper handler that enables CuteDSL code generation during subgraph modifications.
This class sits between the PyTorch IR and CuteDSL code generation, providing:
1. Operation substitution: converts PyTorch ops to CuteDSL equivalents via CuteDSLOpOverrides
2. Placeholder handling: resolves fixed_inputs during template processing
3. Limited operation support: currently restricted to pointwise operations
"""
def __init__(
self,
kernel,
subgraph_number: int,
fixed_inputs: dict[str, Any],
mask: Optional[str],
):
cutedsl_ops = CuteDSLOpOverrides()
super().__init__(cutedsl_ops)
self.name = f"CuteDSLPlaceholderSubstitution_{subgraph_number}"
self.kernel = kernel
self.fixed_inputs = fixed_inputs
self.mask = mask
def _get_input_dtype(self, name: str) -> torch.dtype:
"""Get the dtype for an input from the kernel's named_input_nodes."""
if name in self.kernel.named_input_nodes:
return self.kernel.named_input_nodes[name].dtype
# TODO: Fallback for common dimension names - should be replaced with proper dtype tracking
return torch.float32 if name not in ("b", "h", "m", "n") else torch.int32
def load(self, name: str, index: sympy.Expr):
"""Handle loading from tensor or fixed(template args) input for CuteDSL."""
if name not in self.fixed_inputs:
raise NotImplementedError(
"Tensor loading not yet supported for CuteDSL - only fixed input substitution"
)
value = self.fixed_inputs[name]
dtype = self._get_input_dtype(name)
# ensure CSE wrapping
return self.kernel.cse.generate(
self.kernel.body, value, bounds=ValueRanges.unknown(), dtype=dtype
)
def indirect_indexing(self, index_var: str, size, check, wrap_neg=True):
"""Convert index variable to symbolic form."""
raise NotImplementedError("Indirect indexing not supported")
def store(
self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
) -> str:
raise NotImplementedError(
"Store operations not supported - CuteDSL limited to read-only operations"
)
def _add_kernel_input(self, name: str):
"""Add name as input to kernel and return input ref."""
return self.kernel.args.input(name)
def _process_indexing(self, index):
"""Process and rename indexing, adding symbols as kernel inputs."""
# Convert sympy expression to string representation for CuteDSL
return str(index) # Simplified for now
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
try:
return getattr(self._inner, name)(*args, **kwargs)
except NotImplementedError as e:
bar = "=" * 80
msg = textwrap.dedent(f"""
{bar}
UNSUPPORTED CUTEDSL OPERATION: '{name}'
{bar}
This operation is not yet implemented in Inductor.
Please open an issue at: https://github.com/pytorch/pytorch/issues
with the following information:
Operation: {name}
Args: {args!r}
Kwargs: {kwargs!r}
Title your issue: [CuteDSL] Missing operation: {name}
{bar}
""").strip()
raise NotImplementedError(msg) from e

View File

@ -0,0 +1,358 @@
# mypy: allow-untyped-defs
"""
CuteDSL-specific operation overrides for pointwise operations.
This module provides CuteDSL implementations of common operations used in
template kernels, particularly for flex attention modifications.
"""
import math
from typing import Optional, Union
import sympy
import torch
from torch._inductor.codegen.common import CSEVariable, OpOverrides
from torch._inductor.virtualized import OpsValue, V
from torch.utils._sympy.value_ranges import ValueRanges
CuteDSLArg = Union[CSEVariable, str]
def upcast_compute_type(dtype: torch.dtype) -> torch.dtype:
"""Maybe upcast [b]float16 to float32"""
if dtype in (torch.float16, torch.bfloat16):
return torch.float32
return dtype
class CuteDSLOpOverrides(OpOverrides):
"""
CuteDSL-specific operation overrides that generate code using CuteDSL syntax.
CuteDSL TensorSSA objects have built-in operator overloads (__add__, __mul__, etc.)
and math functions (cute.math.exp, cute.math.sqrt, etc.)
"""
TORCH_TO_CUTE_DTYPE = {
torch.float16: "cutlass.Float16",
torch.bfloat16: "cutlass.BFloat16",
torch.float32: "cutlass.Float32",
torch.float64: "cutlass.Float64",
torch.int8: "cutlass.Int8",
torch.int16: "cutlass.Int16",
torch.int32: "cutlass.Int32",
torch.int64: "cutlass.Int64",
torch.bool: "cutlass.Boolean",
torch.float8_e4m3fn: "cutlass.Float8E4M3FN",
torch.float8_e5m2: "cutlass.Float8E5M2",
}
# Math constants
LOG2_E = 1.4426950408889634 # 1/ln(2) for converting natural exp to base-2 exp
@staticmethod
def _ensure_tensor_ssa(arg: CuteDSLArg, template_tensor: CuteDSLArg) -> str:
"""
Convert scalar arguments to TensorSSA using cute.full_like if needed.
Args:
arg: The argument to check (CSEVariable for tensors, str for scalars, or OpsValue wrapper)
template_tensor: A tensor argument to use as template for full_like
Returns:
String representation suitable for CuteDSL operations
"""
if isinstance(arg, CSEVariable):
return str(arg)
if isinstance(arg, OpsValue) and isinstance(arg.value, CSEVariable):
return str(arg.value)
if isinstance(template_tensor, CSEVariable):
return f"cute.full_like({template_tensor}, {arg})"
return str(arg)
@staticmethod
def _extract_dtype_and_bounds(
*args: CuteDSLArg,
) -> tuple[Optional[torch.dtype], ValueRanges[sympy.Expr]]:
"""Extract dtype and bounds from CSEVariable arguments."""
for arg in args:
if isinstance(arg, CSEVariable):
return arg.dtype, arg.bounds
return None, ValueRanges.unknown()
@staticmethod
def _apply_binary_op(a: CuteDSLArg, b: CuteDSLArg, op_format: str) -> CuteDSLArg:
"""
Apply a binary operation with automatic scalar-to-tensor conversion.
CuteDSL requires both operands to be TensorSSA objects for tensor operations.
This helper automatically converts scalar arguments to TensorSSA using
cute.full_like when at least one argument is a tensor (CSEVariable).
Args:
a: First operand (CSEVariable for tensors, str for scalars)
b: Second operand (CSEVariable for tensors, str for scalars)
op_format: Format string with {a} and {b} placeholders for the operation
Returns:
CSEVariable if at least one operand is a CSEVariable, otherwise string
"""
tensor_arg = (
a
if isinstance(a, CSEVariable)
else b
if isinstance(b, CSEVariable)
else None
)
if tensor_arg is not None:
a_ssa = CuteDSLOpOverrides._ensure_tensor_ssa(a, tensor_arg)
b_ssa = CuteDSLOpOverrides._ensure_tensor_ssa(b, tensor_arg)
result_expr = op_format.format(a=a_ssa, b=b_ssa)
dtype, bounds = CuteDSLOpOverrides._extract_dtype_and_bounds(a, b)
# Create and return CSEVariable using CSE generation for caching
return V.kernel.cse.generate(
V.kernel.body, result_expr, bounds=bounds, dtype=dtype
)
return op_format.format(a=a, b=b)
@staticmethod
def _apply_unary_op(x: CuteDSLArg, op_format: str) -> CuteDSLArg:
"""
Apply a unary operation, returning CSEVariable if input is CSEVariable.
Args:
x: Input operand (CSEVariable for tensors, str for scalars)
op_format: Format string with {x} placeholder for the operation
Returns:
CSEVariable if input is a CSEVariable, otherwise string
"""
if isinstance(x, CSEVariable):
result_expr = op_format.format(x=str(x))
return V.kernel.cse.generate(
V.kernel.body, result_expr, bounds=x.bounds, dtype=x.dtype
)
return op_format.format(x=x)
@staticmethod
def constant(value: Union[bool, float, int], dtype: torch.dtype) -> str:
"""Generate CuteDSL constant representation."""
if value == float("-inf"):
return "float('-inf')"
elif value == float("inf"):
return "float('inf')"
elif math.isnan(value):
return "float('nan')"
return repr(value)
@staticmethod
def add(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg:
return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} + {b})")
@staticmethod
def mul(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg:
return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} * {b})")
@staticmethod
def sub(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg:
return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} - {b})")
@staticmethod
def truediv(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg:
return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} / {b})")
@staticmethod
def mod(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg:
return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} % {b})")
@staticmethod
def remainder(a, b):
return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} % {b})")
@staticmethod
def exp(x: CuteDSLArg) -> CuteDSLArg:
"""Exponential using CuteDSL cute.math.exp function."""
return CuteDSLOpOverrides._apply_unary_op(
x, f"cute.math.exp2({{x}} * {CuteDSLOpOverrides.LOG2_E})"
)
@staticmethod
def sqrt(x: CuteDSLArg) -> CuteDSLArg:
"""Square root using CuteDSL cute.math.sqrt function."""
return CuteDSLOpOverrides._apply_unary_op(x, "cute.math.sqrt({x})")
@staticmethod
def log(x: CuteDSLArg) -> CuteDSLArg:
"""Natural logarithm using CuteDSL cute.math.log function."""
return CuteDSLOpOverrides._apply_unary_op(x, "cute.math.log({x})")
@staticmethod
def cos(x: CuteDSLArg) -> CuteDSLArg:
"""Cosine using CuteDSL cute.math.cos function."""
return CuteDSLOpOverrides._apply_unary_op(x, "cute.math.cos({x})")
@staticmethod
def sin(x: CuteDSLArg) -> CuteDSLArg:
"""Sine using CuteDSL cute.math.sin function."""
return CuteDSLOpOverrides._apply_unary_op(x, "cute.math.sin({x})")
@staticmethod
def erf(x: CuteDSLArg) -> CuteDSLArg:
"""Error function using CuteDSL cute.math.erf function."""
return CuteDSLOpOverrides._apply_unary_op(x, "cute.math.erf({x})")
@staticmethod
def maximum(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg:
raise NotImplementedError("TODO: maximum is not supported yet for TensorSSA")
@staticmethod
def minimum(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg:
raise NotImplementedError("TODO: minimum is not supported yet for TensorSSA")
@staticmethod
def where(
condition: CuteDSLArg,
a: CuteDSLArg,
b: CuteDSLArg,
) -> CuteDSLArg:
"""Conditional selection - handles both CSEVariable and string inputs."""
# Find a tensor argument to use as template for full_like
# Priority: use 'a' if it's a tensor, else use 'b', else condition
tensor_arg = (
a
if isinstance(a, CSEVariable)
else (
b
if isinstance(b, CSEVariable)
else condition
if isinstance(condition, CSEVariable)
else None
)
)
if tensor_arg is not None:
a_ssa = CuteDSLOpOverrides._ensure_tensor_ssa(a, tensor_arg)
b_ssa = CuteDSLOpOverrides._ensure_tensor_ssa(b, tensor_arg)
result_expr = f"cute.where({condition}, {a_ssa}, {b_ssa})"
dtype, bounds = CuteDSLOpOverrides._extract_dtype_and_bounds(
a, b, condition
)
return V.kernel.cse.generate(
V.kernel.body, result_expr, bounds=bounds, dtype=dtype
)
return f"cute.where({condition}, {a}, {b})"
@staticmethod
def pow(a: CuteDSLArg, b: CuteDSLArg):
return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} ** {b})")
@staticmethod
def abs(x: CuteDSLArg) -> CuteDSLArg:
"""Absolute value using CuteDSL cute.math.abs function."""
if isinstance(x, CSEVariable):
x_dtype = x.dtype
elif isinstance(x, OpsValue) and isinstance(x.value, CSEVariable):
x_dtype = x.value.dtype
else:
x_dtype = torch.float32
abs_op = (
"mlir_math.absf"
if x_dtype in (torch.float16, torch.bfloat16, torch.float32)
else "mlir_math.absi"
)
return CuteDSLOpOverrides._apply_unary_op(
x, f"cute.TensorSSA({abs_op}({{x}}), {{x}}.shape, {{x}}.dtype)"
)
@staticmethod
def neg(x: CuteDSLArg) -> CuteDSLArg:
"""Negation using CuteDSL TensorSSA __neg__ operator."""
# TODO: See https://github.com/NVIDIA/cutlass/issues/2584
return CuteDSLOpOverrides._apply_unary_op(
x, "cute.TensorSSA(-{x}, {x}.shape, {x}.dtype)"
)
@staticmethod
def to_dtype(
x: CuteDSLArg, dtype: torch.dtype, src_dtype=None, use_compute_types=True
) -> CuteDSLArg:
"""Type conversion using CuteDSL TensorSSA.to(Type[Numeric]).
Maps torch dtypes to cutlass.cute.typing numeric types and emits
`{x}.to(cute.typing.<Type>)`.
Raises NotImplementedError for unsigned integer and unsupported dtypes.
"""
# Always convert up from bf16 and fp16 TODO on configuring
dtype = upcast_compute_type(dtype)
cute_type = CuteDSLOpOverrides.TORCH_TO_CUTE_DTYPE.get(dtype)
if cute_type is None:
raise NotImplementedError(
f"CuteDSL dtype cast not implemented for torch dtype: {dtype}"
)
if isinstance(x, CSEVariable):
result_expr = f"{str(x)}.to({cute_type})"
return V.kernel.cse.generate(
V.kernel.body, result_expr, bounds=x.bounds, dtype=dtype
)
return f"{x}.to({cute_type})"
@staticmethod
def tanh(x0: CuteDSLArg) -> CuteDSLArg:
"""Hyperbolic tangent using CuteDSL cute.math.tanh function."""
return CuteDSLOpOverrides._apply_unary_op(x0, "cute.math.tanh({x})")
# Logical operations
@staticmethod
def logical_and(x0: CuteDSLArg, x1: CuteDSLArg) -> CuteDSLArg:
return CuteDSLOpOverrides._apply_binary_op(x0, x1, "({a} and {b})")
@staticmethod
def logical_or(x0: CuteDSLArg, x1: CuteDSLArg) -> CuteDSLArg:
return CuteDSLOpOverrides._apply_binary_op(x0, x1, "({a} or {b})")
@staticmethod
def logical_not(a):
"""Logical NOT."""
return CuteDSLOpOverrides._apply_unary_op(a, "({x} == 0)")
# Comparison operations
@staticmethod
def eq(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg:
return CuteDSLOpOverrides._apply_binary_op(a, b, "operator.eq({a}, {b})")
@staticmethod
def ne(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg:
return CuteDSLOpOverrides._apply_binary_op(a, b, "operator.ne({a}, {b})")
@staticmethod
def lt(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg:
return CuteDSLOpOverrides._apply_binary_op(a, b, "operator.lt({a}, {b})")
@staticmethod
def le(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg:
return CuteDSLOpOverrides._apply_binary_op(a, b, "operator.le({a}, {b})")
@staticmethod
def gt(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg:
return CuteDSLOpOverrides._apply_binary_op(a, b, "operator.gt({a}, {b})")
@staticmethod
def ge(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg:
return CuteDSLOpOverrides._apply_binary_op(a, b, "operator.ge({a}, {b})")

View File

@ -1,14 +1,17 @@
# mypy: allow-untyped-defs
import functools
import itertools
from collections.abc import Iterable
from typing import Any, Optional, Union
from unittest.mock import patch
from torch._inductor.ir import ShapeAsConstantBuffer
from torch._inductor.utils import Placeholder
from torch._inductor.virtualized import V
from torch._logging import getArtifactLogger
from ...autotune_process import CuteDSLBenchmarkRequest, TensorMeta
from ...ir import Buffer, ChoiceCaller, CuteDSLTemplateBuffer, Layout, TensorBox
from ...ir import Buffer, ChoiceCaller, CuteDSLTemplateBuffer, IRNode, Layout, TensorBox
from ..common import KernelTemplate
from .cutedsl_kernel import CuteDSLTemplateKernel
@ -64,6 +67,8 @@ class CuteDSLTemplate(KernelTemplate):
"""Generate the CuteDSL kernel caller."""
input_nodes = kwargs.pop("input_nodes")
layout = kwargs.pop("layout")
mutated_inputs = kwargs.pop("mutated_inputs", None)
subgraphs = kwargs.pop("subgraphs", None)
kernel_name = f"cutedsl_{self.name}_{next(self.index_counter)}"
@ -71,45 +76,57 @@ class CuteDSLTemplate(KernelTemplate):
raise RuntimeError("Template compilation failed (Jinja2 required)")
self.output_node: Buffer = Buffer(name="buf_out", layout=layout)
kernel = self.kernel_type(
kernel_name=kernel_name,
input_nodes=input_nodes,
output_node=self.output_node,
)
code = kernel.render(self.template, **kwargs)
log.debug("Generated CuteDSL Code:\n%s", code)
bmreq = CuteDSLBenchmarkRequest(
kernel_name=kernel_name,
input_tensor_meta=TensorMeta.from_irnodes(input_nodes),
output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
extra_args=tuple(),
source_code=code,
)
def make_kernel_render(out_node, hint_override: Optional[int] = None):
render_kernel = self.kernel_type(
kernel_name=str(Placeholder.KERNEL_NAME),
# Patch V.graph.get_dtype to handle the fake buf_out buffer
with patch.object(
V.graph, "get_dtype", KernelTemplate._fake_get_dtype(self.output_node)
):
kernel = self.kernel_type(
kernel_name=kernel_name,
input_nodes=input_nodes,
output_node=out_node,
output_node=self.output_node,
subgraphs=subgraphs,
)
code = kernel.render(self.template, **kwargs)
log.debug("Generated CuteDSL Code:\n%s", code)
bmreq = CuteDSLBenchmarkRequest(
kernel_name=kernel_name,
input_tensor_meta=TensorMeta.from_irnodes(input_nodes),
output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
extra_args=tuple(),
source_code=code,
)
def render():
return render_kernel.render(self.template, **kwargs)
def make_kernel_render(out_node, hint_override: Optional[int] = None):
"""
Factory function that creates a kernel renderer for the final output.
return render_kernel, render
This closure captures the current template and parameters, but allows
the output node to be specified later. This is used during the final
kernel selection phase when the actual output buffer is available.
"""
render_kernel = self.kernel_type(
kernel_name=str(Placeholder.KERNEL_NAME),
input_nodes=input_nodes,
output_node=out_node,
subgraphs=subgraphs,
)
return CuteDSLTemplateCaller(
name=kernel_name,
input_nodes=input_nodes,
layout=layout,
make_kernel_render=make_kernel_render,
bmreq=bmreq,
template=self,
)
def render():
return render_kernel.render(self.template, **kwargs)
return render_kernel, render
return CuteDSLTemplateCaller(
name=kernel_name,
input_nodes=input_nodes,
layout=layout,
make_kernel_render=make_kernel_render,
bmreq=bmreq,
template=self,
mutated_inputs=mutated_inputs,
)
class CuteDSLTemplateCaller(ChoiceCaller):
@ -123,6 +140,7 @@ class CuteDSLTemplateCaller(ChoiceCaller):
make_kernel_render: Any,
bmreq: CuteDSLBenchmarkRequest,
template: "CuteDSLTemplate",
mutated_inputs: Optional[Iterable[IRNode]] = None,
):
super().__init__(
name=name,
@ -133,6 +151,7 @@ class CuteDSLTemplateCaller(ChoiceCaller):
self.make_kernel_render = make_kernel_render
self.bmreq = bmreq
self.template = template
self.mutated_inputs = mutated_inputs
def __str__(self) -> str:
return f"CuteDSLTemplateCaller({self.name})"
@ -149,6 +168,7 @@ class CuteDSLTemplateCaller(ChoiceCaller):
inputs=self.input_nodes,
make_kernel_render=self.make_kernel_render,
template=self.template,
mutated_inputs=self.mutated_inputs,
)
)

View File

@ -13,6 +13,7 @@ import torch._inductor.async_compile # noqa: F401 required to warm up AsyncComp
from torch.fx.experimental.proxy_tensor import make_fx
from torch._inductor.graph import GraphLowering
from torch._inductor.compile_fx import shape_env_from_inputs
from torch._inductor.utils import OrderedSet
from torch._inductor.codecache import CppCodeCache
from torch._inductor.custom_graph_pass import CustomGraphModulePass
from torch._inductor.codegen.common import (
@ -306,6 +307,24 @@ def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype):
inverse_scale = scale.reciprocal()
return x_fp8, inverse_scale
class MockGraphHandler(GraphLowering):
"""Minimal mock graph handler for testing virtualized context."""
def __init__(self, name_to_buffer=None):
import torch._inductor.sizevars
self.sizevars = torch._inductor.sizevars.SizeVarAllocator()
self.name_to_buffer = name_to_buffer or {}
self.graph_inputs = {}
self.mutated_buffers = OrderedSet()
self.removed_buffers = OrderedSet()
self.constants = {}
self.scheduler = None
def get_dtype(self, buffer_name: str) -> torch.dtype: # noqa: ARG002
"""Return default dtype for any buffer (for testing)."""
return torch.float32
@contextlib.contextmanager
def patch_inductor_backend(
device: str,