Files
pytorch/test/inductor/test_cutedsl_template.py
drisspg 30edac5da6 Updates to CuTe DSL template renderer (#161117)
# Summary
This adds a few more render functions available to template writers, specifically get_output and modification. The reasons why are more clear in the next PR in this stack.

<img width="1645" height="364" alt="Screenshot 2025-08-21 at 1 48 50 PM" src="https://github.com/user-attachments/assets/2d508fda-4273-43ef-9edf-086e592e9249" />

Majority of the new cod is around the OpOverrides for CuTe DSL. It is alot to test and most of the actual testing I have been doing is via score_mods to the flash_attention at the next layer of this stack.

A bunch of score mods that me and Claude came up with , that exercise the actual ops.
``` Py

def causal_mask(score, b, h, q_idx, kv_idx):
    """Causal attention mask."""
    return torch.where(q_idx >= kv_idx, score, float("-inf"))

def relative_bias(score, b, h, token_q, token_kv):
    """Relative position bias."""
    return score + torch.abs(token_q - token_kv)

def relative_bias_v2(score, b, h, token_q, token_kv):
    """Relative position bias with factor of 2."""
    return score + 2 * torch.abs(token_q - token_kv)

def times_two(score, b, h, q_idx, kv_idx):
    """Simple score modification that doubles the score."""
    return score * 2

def alibi_bias(score, b, h, q_idx, kv_idx):
    """ALiBi (Attention with Linear Biases) - used in some modern models."""
    # Different slopes for different heads
    slope = 2 ** (-8 * (h + 1) / 8)  # Simplified version
    return score - slope * torch.abs(q_idx - kv_idx)

def sliding_window(score, b, h, q_idx, kv_idx, window_size=256):
    """Sliding window attention - only attend to nearby tokens."""
    return torch.where(
        torch.abs(q_idx - kv_idx) <= window_size,
        score,
        float("-inf")
    )

def block_diagonal(score, b, h, q_idx, kv_idx, block_size=64):
    """Block diagonal attention pattern."""
    q_block = q_idx // block_size
    kv_block = kv_idx // block_size
    return torch.where(q_block == kv_block, score, float("-inf"))

def additive_bias(score, b, h, q_idx, kv_idx):
    """Test simple addition with position-based bias."""
    return score + (q_idx + kv_idx) * 0.01

def multiplicative_decay(score, b, h, q_idx, kv_idx):
    """Test multiplication with distance-based decay."""
    distance = torch.abs(q_idx - kv_idx)
    return score * torch.exp(-0.1 * distance)

def sine_wave_bias(score, b, h, q_idx, kv_idx):
    """Test trigonometric functions."""
    return score + 0.1 * torch.sin(2 * math.pi * (q_idx - kv_idx) / 64)

def log_distance_penalty(score, b, h, q_idx, kv_idx):
    """Test logarithmic operations."""
    distance = torch.abs(q_idx - kv_idx).float()
    return score - torch.log(1 + distance)

def alternating_mask(score, b, h, q_idx, kv_idx):
    """Test with alternating pattern - good for branch prediction."""
    return torch.where((q_idx + kv_idx) % 2 == 0, score, float("-inf"))

def head_specific_pattern(score, b, h, q_idx, kv_idx):
    """Different behavior per attention head."""
    even_head = h % 2 == 0
    causal = q_idx >= kv_idx
    return torch.where(even_head & causal, score, float("-inf"))

def sparse_strided(score, b, h, q_idx, kv_idx, stride=4):
    """Sparse attention with strided pattern."""
    return torch.where(
        (kv_idx % stride == 0) | (q_idx == kv_idx),
        score,
        float("-inf")
    )

def causal_with_global(score, b, h, q_idx, kv_idx):
    """Causal mask but first few tokens are globally attended."""
    is_causal = q_idx >= kv_idx
    is_global = kv_idx < 4
    return torch.where(is_causal | is_global, score, float("-inf"))

def dilated_attention(score, b, h, q_idx, kv_idx, dilation_rate=2):
    """Dilated attention pattern - exponentially increasing gaps."""
    distance = torch.abs(q_idx - kv_idx)
    is_attended = (distance == 0) | ((distance > 0) & ((distance & (distance - 1)) == 0))
    return torch.where(is_attended, score, float("-inf"))

```

Example outputs:
```
[Test Suite]
Config: batch=4, heads=32, seq_q=8192, seq_kv=8192, dim=128

[Test 1: none]
[No score_mod, flash='enabled'] Found flash_attncute: True
[No score_mod, flash='disabled'] Found flash_attncute: False
✓ Outputs match between flash enabled/disabled
✓ Output matches eager SDPA (rtol=0.001, atol=0.001)

[Test 2: causal]
[With score_mod, flash='enabled'] Found flash_attncute: True
[With score_mod, flash='disabled'] Found flash_attncute: False
✗ Outputs differ between flash modes: Tensor-likes are not close!

Mismatched elements: 17879 / 134217728 (0.0%)
Greatest absolute difference: 0.0078125 at index (0, 15, 15, 60) (up to 0.001 allowed)
Greatest relative difference: 2.5 at index (3, 22, 153, 126) (up to 0.001 allowed)

[Test 3: rel_bias]
[With score_mod, flash='enabled'] Found flash_attncute: True
[With score_mod, flash='disabled'] Found flash_attncute: False
✗ Outputs differ between flash modes: Tensor-likes are not close!

Mismatched elements: 12836 / 134217728 (0.0%)
Greatest absolute difference: 0.015625 at index (0, 3, 2775, 84) (up to 0.001 allowed)
Greatest relative difference: 11.8125 at index (3, 28, 4095, 76) (up to 0.001 allowed)

[Test 4: rel_bias_v2]
```

This is bfloat16 and there are no major differences. The list of pointwise ops here isn't exhaustive but it is fairly covering

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161117
Approved by: https://github.com/mlazos
2025-08-27 23:01:31 +00:00

485 lines
16 KiB
Python

# Owner(s): ["module: inductor"]
import unittest
from unittest.mock import MagicMock, patch
from expecttest import assert_expected_inline
import torch
from torch._inductor.test_case import TestCase
from torch._inductor.virtualized import V
from torch.testing._internal.inductor_utils import MockGraphHandler
try:
import cutlass # noqa: F401
import cutlass.cute as cute # noqa: F401
HAS_CUTLASS = True
except ImportError:
HAS_CUTLASS = False
if HAS_CUTLASS:
from torch._inductor.codegen.cutedsl.cutedsl_kernel import CuteDSLTemplateKernel
from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate
from torch._inductor.select_algorithm import PartialRender
CUTEDSL_ADD_TEMPLATE = r"""
{{gen_defines()}}
@cute.kernel
def {{kernel_name}}_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
bdim, _, _ = cute.arch.block_dim()
thread_idx = bidx * bdim + tidx
m, n = gA.shape
if thread_idx < m * n:
mi = thread_idx // n
ni = thread_idx % n
if mi < m and ni < n:
gC[mi, ni] = gA[mi, ni] + gB[mi, ni]
@cute.jit
def {{kernel_name}}_jit(mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor, stream):
{{gen_defines()}}
m, n = mA.shape
total_threads = m * n
num_blocks = (total_threads + THREADS_PER_BLOCK - 1) // THREADS_PER_BLOCK
kernel = {{kernel_name}}_kernel(mA, mB, mC)
kernel.launch(
grid=[num_blocks, 1, 1],
block=[THREADS_PER_BLOCK, 1, 1],
stream=stream
)
{{def_kernel("input_a", "input_b")}}
cute_a = from_dlpack(input_a)
cute_b = from_dlpack(input_b)
cute_c = from_dlpack({{get_output()}})
{{kernel_name}}_jit(cute_a, cute_b, cute_c, cuda.CUstream(stream))
return {{get_output()}}
"""
@unittest.skipUnless(HAS_CUTLASS, "requires cutlass")
class TestCuteDSLTemplate(TestCase):
"""Test cases for CuteDSL template functionality."""
def test_gen_imports(self):
kernel = CuteDSLTemplateKernel(
kernel_name="test_kernel",
input_nodes=[],
output_node=None,
)
imports = kernel.gen_imports()
self.assertIn("import torch", imports)
self.assertIn("import cutlass", imports)
self.assertIn("import cutlass.cute as cute", imports)
self.assertIn("from cutlass.cute.runtime import from_dlpack", imports)
self.assertIsInstance(imports, str)
lines = imports.strip().split("\n")
self.assertEqual(len(lines), 7)
def test_render_includes_imports(self):
template_source = """@cute.kernel
def {{kernel_name}}_kernel():
pass
{{def_kernel("input", "output")}}
return output"""
mock_template = MagicMock()
mock_template.render = MagicMock(return_value=template_source)
kernel = CuteDSLTemplateKernel(
kernel_name="test_kernel",
input_nodes=[],
output_node=None,
)
result = kernel.render(mock_template)
self.assertIsInstance(result, PartialRender)
rendered_code = result._code
# The imports might have leading whitespace, so strip it
rendered_code_stripped = rendered_code.lstrip()
self.assertTrue(
rendered_code_stripped.startswith("import torch"),
f"Code should start with 'import torch', got: {rendered_code_stripped[:50]}",
)
self.assertIn("import cutlass", rendered_code)
self.assertIn("import cutlass.cute as cute", rendered_code)
self.assertIn("from cutlass.cute.runtime import from_dlpack", rendered_code)
self.assertIn("@cute.kernel", rendered_code)
def test_template_env_contains_hooks(self):
kernel = CuteDSLTemplateKernel(
kernel_name="test_kernel",
input_nodes=[],
output_node=None,
)
captured_env = {}
def mock_render(**kwargs):
captured_env.update(kwargs)
return "rendered"
mock_template = MagicMock()
mock_template.render = mock_render
kernel.render(mock_template)
self.assertIn("def_kernel", captured_env)
self.assertIn("kernel_name", captured_env)
self.assertTrue(callable(captured_env["def_kernel"]))
def test_multiple_templates_unique_names(self):
# Clean registry first
test_name = f"unique_test_{id(self)}"
if test_name in CuteDSLTemplate.all_templates:
del CuteDSLTemplate.all_templates[test_name]
_ = CuteDSLTemplate(
name=test_name,
source="template1",
)
with self.assertRaises(AssertionError):
_ = CuteDSLTemplate(
name=test_name,
source="template2",
)
def test_indented_buffer_usage(self):
kernel = CuteDSLTemplateKernel(
kernel_name="test_kernel",
input_nodes=[],
output_node=None,
)
imports = kernel.gen_imports()
lines = imports.strip().split("\n")
for line in lines:
if line:
self.assertFalse(
line.startswith(" "), f"Line should not be indented: '{line}'"
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_cutedsl_add_e2e(self):
"""End-to-end test with CuteDSL template including code generation verification."""
from torch._inductor.ir import TensorBox
from torch._inductor.lowering import lowerings
from torch._inductor.utils import run_and_get_code
template = CuteDSLTemplate(
name="test_add_e2e",
source=CUTEDSL_ADD_TEMPLATE,
)
def cutedsl_add_lowering(a: TensorBox, b: TensorBox) -> TensorBox:
choices = []
error = template.maybe_append_choice(
choices,
input_nodes=[a, b],
layout=a.get_layout(),
THREADS_PER_BLOCK=256,
)
if error or not choices:
default_lowering = lowerings[torch.ops.aten.add.Tensor]
return default_lowering(a, b)
# Use the single choice directly (no autotuning)
return choices[0].output_node()
with patch.dict(lowerings, {torch.ops.aten.add.Tensor: cutedsl_add_lowering}):
# Test function
def test_add(x, y):
return x + y
device = "cuda"
x = torch.randn(128, 4, device=device, dtype=torch.float32)
y = torch.randn(128, 4, device=device, dtype=torch.float32)
# Compile and get generated code
compiled_fn = torch.compile(test_add, backend="inductor")
result, (code,) = run_and_get_code(compiled_fn, x, y)
# Verify CuteDSL code is present
self.assertIn(
"cute", code.lower(), "CuteDSL code should be in generated code"
)
# Verify parameter generation worked
self.assertIn(
"THREADS_PER_BLOCK", code, "Parameter should be in generated code"
)
# Verify correctness
expected = x + y
self.assertTrue(torch.allclose(result, expected, atol=1e-5))
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_cutedsl_add_e2e_autotune(self):
"""E2E test with multiple CuteDSL template variants for autotuning."""
from torch._inductor.ir import TensorBox
from torch._inductor.lowering import lowerings
from torch._inductor.select_algorithm import autotune_select_algorithm
template = CuteDSLTemplate(
name="test_add_autotune",
source=CUTEDSL_ADD_TEMPLATE,
)
def cutedsl_add_lowering(a: TensorBox, b: TensorBox) -> TensorBox:
choices = []
# Add multiple variants with different thread counts for autotuning
thread_variants = [128, 256, 512]
for threads in thread_variants:
error = template.maybe_append_choice(
choices,
input_nodes=[a, b],
layout=a.get_layout(),
THREADS_PER_BLOCK=threads,
)
if error:
# Skip this variant if it fails
continue
if not choices:
default_lowering = lowerings[torch.ops.aten.add.Tensor]
return default_lowering(a, b)
# Use autotuning to select the best variant
return autotune_select_algorithm(
"cutedsl_add_autotune",
choices,
[a, b],
a.get_layout(),
)
with patch.dict(lowerings, {torch.ops.aten.add.Tensor: cutedsl_add_lowering}):
# Test function
def test_add(x, y):
return x + y
device = "cuda"
x = torch.randn(128, 128, device=device, dtype=torch.float32)
y = torch.randn(128, 128, device=device, dtype=torch.float32)
# Compile and run
compiled_fn = torch.compile(test_add, backend="inductor")
result = compiled_fn(x, y)
# Verify correctness
expected = x + y
self.assertTrue(torch.allclose(result, expected, atol=1e-5))
def test_gen_defines(self):
"""Test that gen_defines correctly generates CuteDSL parameter definitions."""
kernel = CuteDSLTemplateKernel(
kernel_name="test_kernel",
input_nodes=[],
output_node=None,
)
# Test integer parameters
params = kernel.gen_defines(
THREADS_PER_BLOCK=256,
BLOCK_SIZE=128,
ENABLE_FEATURE=True,
)
assert_expected_inline(
params,
"""\
THREADS_PER_BLOCK: cutlass.Constexpr = 256
BLOCK_SIZE: cutlass.Constexpr = 128
ENABLE_FEATURE: cutlass.Constexpr = True
""",
)
params_float = kernel.gen_defines(SCALE_FACTOR=1.5)
assert_expected_inline(
params_float,
"""\
SCALE_FACTOR: cutlass.Constexpr = 1.5
""",
)
def test_template_aliasing(self):
"""Test that template variables are correctly aliased to function arguments."""
from torch._inductor.ir import Buffer
mock_input1 = MagicMock(spec=Buffer)
mock_input1.get_name.return_value = "buf_input1"
mock_input2 = MagicMock(spec=Buffer)
mock_input2.get_name.return_value = "buf_input2"
mock_output = MagicMock(spec=Buffer)
mock_output.get_name.return_value = "buf_output"
mock_graph = MockGraphHandler()
with V.set_graph_handler(mock_graph):
kernel = CuteDSLTemplateKernel(
kernel_name="test_aliasing",
input_nodes=[mock_input1, mock_input2],
output_node=mock_output,
)
def_kernel_hook = kernel.def_kernel("custom_a", "custom_b")
self.assertEqual(def_kernel_hook, "<DEF_KERNEL>")
self.assertIn("<DEF_KERNEL>", kernel.render_hooks)
hook_fn = kernel.render_hooks["<DEF_KERNEL>"]
generated_code = hook_fn()
# Check that the generated code contains the expected aliasing statements
self.assertIn("custom_a = arg_custom_a", generated_code)
self.assertIn("custom_b = arg_custom_b", generated_code)
def test_get_output_hook(self):
"""Test the get_output() template hook."""
from torch._inductor.ir import Buffer
mock_output = MagicMock(spec=Buffer)
mock_output.get_name.return_value = "buf_test_output"
mock_graph = MockGraphHandler()
with V.set_graph_handler(mock_graph):
kernel = CuteDSLTemplateKernel(
kernel_name="test_output",
input_nodes=[],
output_node=mock_output,
)
with self.assertRaises(ValueError):
# error if no output buffer
result = kernel.get_output()
kernel.args.output_buffers["buf_test_output"] = "arg_buf_test_output"
result = kernel.get_output()
self.assertEqual(result, "arg_buf_test_output")
def test_modification_subgraph(self):
"""Test the modification() method and subgraph processing."""
from torch._inductor.ir import Buffer
mock_subgraph1 = MagicMock(spec=Buffer)
mock_subgraph2 = MagicMock(spec=Buffer)
subgraphs = [mock_subgraph1, mock_subgraph2]
mock_output = MagicMock(spec=Buffer)
mock_output.get_name.return_value = "buf_output"
kernel = CuteDSLTemplateKernel(
kernel_name="test_modification",
input_nodes=[],
output_node=mock_output,
subgraphs=subgraphs,
)
result = kernel._get_subgraph(0)
self.assertEqual(result, mock_subgraph1)
result = kernel._get_subgraph(1)
self.assertEqual(result, mock_subgraph2)
with self.assertRaises(AssertionError):
kernel._get_subgraph(2)
def test_cutedsl_op_overrides(self):
"""Test the new CuteDSLOpOverrides class."""
import torch
from torch._inductor.codegen.common import CSEVariable
from torch._inductor.codegen.cutedsl.cutedsl_op_overrides import (
CuteDSLOpOverrides,
)
from torch.utils._sympy.value_ranges import ValueRanges
mock_cse_a = MagicMock(spec=CSEVariable)
mock_cse_a.__str__.return_value = "tensor_a"
mock_cse_a.dtype = torch.float32
mock_cse_a.bounds = ValueRanges.unknown()
mock_cse_b = MagicMock(spec=CSEVariable)
mock_cse_b.__str__.return_value = "tensor_b"
mock_cse_b.dtype = torch.float32
mock_cse_b.bounds = ValueRanges.unknown()
mock_graph = MockGraphHandler()
with V.set_graph_handler(mock_graph):
kernel = CuteDSLTemplateKernel(
kernel_name="test_ops",
input_nodes=[],
output_node=None,
)
with V.set_kernel_handler(kernel):
result = CuteDSLOpOverrides.add(mock_cse_a, mock_cse_b)
self.assertIsInstance(result, CSEVariable)
result = CuteDSLOpOverrides.mul(mock_cse_a, mock_cse_b)
self.assertIsInstance(result, CSEVariable)
result = CuteDSLOpOverrides.truediv(mock_cse_a, mock_cse_b)
self.assertIsInstance(result, CSEVariable)
result = CuteDSLOpOverrides.exp(mock_cse_a)
self.assertIsInstance(result, CSEVariable)
result = CuteDSLOpOverrides.sqrt(mock_cse_a)
self.assertIsInstance(result, CSEVariable)
with self.assertRaises(NotImplementedError):
result = CuteDSLOpOverrides.maximum(mock_cse_a, mock_cse_b)
result = CuteDSLOpOverrides.minimum(mock_cse_a, mock_cse_b)
scalar_result = CuteDSLOpOverrides._ensure_tensor_ssa("5.0", mock_cse_a)
self.assertEqual(scalar_result, "cute.full_like(tensor_a, 5.0)")
tensor_result = CuteDSLOpOverrides._ensure_tensor_ssa(mock_cse_a, mock_cse_b)
self.assertEqual(tensor_result, "tensor_a")
def test_cse_integration(self):
"""Test CSE (Common Subexpression Elimination) integration."""
from torch._inductor.codegen.common import CSE
mock_graph = MockGraphHandler()
with V.set_graph_handler(mock_graph):
kernel = CuteDSLTemplateKernel(
kernel_name="test_cse",
input_nodes=[],
output_node=None,
)
self.assertIsInstance(kernel.cse, CSE)
self.assertEqual(kernel.cse.name_prefix, "tmp")
with V.set_kernel_handler(kernel):
test_expr = "x"
var = kernel.cse.generate(kernel.body, test_expr, dtype=None)
self.assertTrue(str(var).startswith("tmp"))
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
run_tests()