mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
## Summary Still figuring out what actually writing a template should look like, but lands alot of the base infra <img width="1267" height="262" alt="Screenshot 2025-08-16 at 10 22 12 PM" src="https://github.com/user-attachments/assets/229f8bfa-0cb4-4fb1-8530-f535e569d350" /> Test code: ```Python #!/usr/bin/env python3 """ Fixed CuteDSL template test with proper def_kernel usage. """ import torch import torch._inductor.config as config from torch._inductor.lowering import lowerings from torch._inductor.ir import TensorBox from torch._inductor.select_algorithm import autotune_select_algorithm from torch._inductor.codegen.cutedsl import CuteDSLTemplate def create_fixed_cutedsl_template(): """Create a properly structured CuteDSL template.""" def cutedsl_grid(M, N, meta): return (1,) # Part 1: Imports and kernel definition template_part1 = r""" import torch import cutlass import cutlass.cute as cute from cutlass.cute.runtime import from_dlpack @cute.kernel def {{kernel_name}}_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor): # Get thread and block indices tidx, _, _ = cute.arch.thread_idx() bidx, _, _ = cute.arch.block_idx() bdim, _, _ = cute.arch.block_dim() thread_idx = bidx * bdim + tidx m, n = gA.shape if thread_idx < m * n: mi = thread_idx // n ni = thread_idx % n if mi < m and ni < n: a_val = gA[mi, ni] b_val = gB[mi, ni] result = a_val + b_val gC[mi, ni] = a_val + b_val """ # Part 2: JIT wrapper function template_part2 = r""" @cute.jit def {{kernel_name}}_jit(mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor): m, n = mA.shape total_threads = m * n threads_per_block = 256 num_blocks = (total_threads + threads_per_block - 1) // threads_per_block kernel = {{kernel_name}}_kernel(mA, mB, mC) kernel.launch( grid=[num_blocks, 1, 1], block=[threads_per_block, 1, 1] ) """ # Part 3: Main kernel function template_part3 = r""" {{def_kernel("input_a", "input_b", "output_c")}} cute_a = from_dlpack(input_a, assumed_align=16) cute_b = from_dlpack(input_b, assumed_align=16) cute_c = from_dlpack(output_c, assumed_align=16) # Launch kernel {{kernel_name}}_jit(cute_a, cute_b, cute_c) return output_c """ # Combine all parts template = CuteDSLTemplate( name="fixed_add", grid=cutedsl_grid, source=template_part1 + template_part2 + template_part3 ) return template def fixed_cutedsl_lowering(a: TensorBox, b: TensorBox) -> TensorBox: """Fixed CuteDSL lowering.""" print(f"[FIXED] CuteDSL lowering: {a.get_size()} + {b.get_size()}") template = create_fixed_cutedsl_template() choices = [] error = template.maybe_append_choice( choices, input_nodes=[a.data, b.data], layout=a.get_layout() ) if error or not choices: print(f"[FIXED] Falling back: {error}") default_lowering = lowerings[torch.ops.aten.add.Tensor] return default_lowering(a, b) print(f"[FIXED] Using CuteDSL with {len(choices)} choices") result = autotune_select_algorithm( "fixed_cutedsl_add", choices, [a, b], a.get_layout(), ) return result def test_fixed_cutedsl(): """Test the fixed CuteDSL template.""" print("=" * 50) print("Fixed CuteDSL Template Test") print("=" * 50) original = lowerings.get(torch.ops.aten.add.Tensor, None) try: lowerings[torch.ops.aten.add.Tensor] = fixed_cutedsl_lowering def test_add(x, y): return x + y device = "cuda" if torch.cuda.is_available() else "cpu" x = torch.randn(128, 4, device=device, dtype=torch.float32) y = torch.randn(128, 4, device=device, dtype=torch.float32) print(f"[FIXED] Testing with {x.shape} tensors on {device}") compiled_fn = torch.compile(test_add, backend="inductor") result = compiled_fn(x, y) # Verify correctness expected = x + y if torch.allclose(result, expected, atol=1e-5): print("✅ [FIXED] Results match!") return True else: print("❌ [FIXED] Results don't match!") return False except Exception as e: print(f"❌ [FIXED] Failed: {e}") import traceback traceback.print_exc() return False finally: if original: lowerings[torch.ops.aten.add.Tensor] = original else: lowerings.pop(torch.ops.aten.add.Tensor, None) if __name__ == "__main__": success = test_fixed_cutedsl() print("🎉 Fixed test completed!" if success else "💥 Fixed test failed!") ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/160108 Approved by: https://github.com/mlazos
320 lines
10 KiB
Python
320 lines
10 KiB
Python
# Owner(s): ["module: inductor"]
|
|
import unittest
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import torch
|
|
from torch._inductor.test_case import TestCase
|
|
|
|
|
|
try:
|
|
import cutlass # noqa: F401
|
|
import cutlass.cute as cute # noqa: F401
|
|
|
|
HAS_CUTLASS = True
|
|
except ImportError:
|
|
HAS_CUTLASS = False
|
|
|
|
if HAS_CUTLASS:
|
|
from torch._inductor.codegen.cutedsl.cutedsl_kernel import CuteDSLTemplateKernel
|
|
from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate
|
|
from torch._inductor.select_algorithm import PartialRender
|
|
|
|
CUTEDSL_ADD_TEMPLATE = r"""
|
|
{{gen_defines()}}
|
|
|
|
@cute.kernel
|
|
def {{kernel_name}}_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor):
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
bidx, _, _ = cute.arch.block_idx()
|
|
bdim, _, _ = cute.arch.block_dim()
|
|
|
|
thread_idx = bidx * bdim + tidx
|
|
m, n = gA.shape
|
|
|
|
if thread_idx < m * n:
|
|
mi = thread_idx // n
|
|
ni = thread_idx % n
|
|
|
|
if mi < m and ni < n:
|
|
gC[mi, ni] = gA[mi, ni] + gB[mi, ni]
|
|
|
|
@cute.jit
|
|
def {{kernel_name}}_jit(mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor, stream):
|
|
{{gen_defines()}}
|
|
m, n = mA.shape
|
|
total_threads = m * n
|
|
num_blocks = (total_threads + THREADS_PER_BLOCK - 1) // THREADS_PER_BLOCK
|
|
|
|
kernel = {{kernel_name}}_kernel(mA, mB, mC)
|
|
kernel.launch(
|
|
grid=[num_blocks, 1, 1],
|
|
block=[THREADS_PER_BLOCK, 1, 1],
|
|
stream=stream
|
|
)
|
|
|
|
{{def_kernel("input_a", "input_b", "output_c")}}
|
|
cute_a = from_dlpack(input_a)
|
|
cute_b = from_dlpack(input_b)
|
|
cute_c = from_dlpack(output_c)
|
|
|
|
{{kernel_name}}_jit(cute_a, cute_b, cute_c, cuda.CUstream(stream))
|
|
return output_c
|
|
"""
|
|
|
|
|
|
@unittest.skipUnless(HAS_CUTLASS, "requires cutlass")
|
|
class TestCuteDSLTemplate(TestCase):
|
|
"""Test cases for CuteDSL template functionality."""
|
|
|
|
def test_gen_imports(self):
|
|
kernel = CuteDSLTemplateKernel(
|
|
kernel_name="test_kernel",
|
|
input_nodes=[],
|
|
output_node=None,
|
|
)
|
|
|
|
imports = kernel.gen_imports()
|
|
|
|
self.assertIn("import torch", imports)
|
|
self.assertIn("import cutlass", imports)
|
|
self.assertIn("import cutlass.cute as cute", imports)
|
|
self.assertIn("from cutlass.cute.runtime import from_dlpack", imports)
|
|
self.assertIsInstance(imports, str)
|
|
|
|
lines = imports.strip().split("\n")
|
|
self.assertEqual(len(lines), 5)
|
|
|
|
def test_render_includes_imports(self):
|
|
template_source = """@cute.kernel
|
|
def {{kernel_name}}_kernel():
|
|
pass
|
|
|
|
{{def_kernel("input", "output")}}
|
|
return output"""
|
|
|
|
mock_template = MagicMock()
|
|
mock_template.render = MagicMock(return_value=template_source)
|
|
|
|
kernel = CuteDSLTemplateKernel(
|
|
kernel_name="test_kernel",
|
|
input_nodes=[],
|
|
output_node=None,
|
|
)
|
|
|
|
result = kernel.render(mock_template)
|
|
self.assertIsInstance(result, PartialRender)
|
|
|
|
rendered_code = result._code
|
|
|
|
# The imports might have leading whitespace, so strip it
|
|
rendered_code_stripped = rendered_code.lstrip()
|
|
|
|
self.assertTrue(
|
|
rendered_code_stripped.startswith("import torch"),
|
|
f"Code should start with 'import torch', got: {rendered_code_stripped[:50]}",
|
|
)
|
|
self.assertIn("import cutlass", rendered_code)
|
|
self.assertIn("import cutlass.cute as cute", rendered_code)
|
|
self.assertIn("from cutlass.cute.runtime import from_dlpack", rendered_code)
|
|
self.assertIn("@cute.kernel", rendered_code)
|
|
|
|
def test_template_env_contains_hooks(self):
|
|
kernel = CuteDSLTemplateKernel(
|
|
kernel_name="test_kernel",
|
|
input_nodes=[],
|
|
output_node=None,
|
|
)
|
|
|
|
captured_env = {}
|
|
|
|
def mock_render(**kwargs):
|
|
captured_env.update(kwargs)
|
|
return "rendered"
|
|
|
|
mock_template = MagicMock()
|
|
mock_template.render = mock_render
|
|
|
|
kernel.render(mock_template)
|
|
|
|
self.assertIn("def_kernel", captured_env)
|
|
self.assertIn("kernel_name", captured_env)
|
|
self.assertTrue(callable(captured_env["def_kernel"]))
|
|
|
|
def test_multiple_templates_unique_names(self):
|
|
# Clean registry first
|
|
test_name = f"unique_test_{id(self)}"
|
|
if test_name in CuteDSLTemplate.all_templates:
|
|
del CuteDSLTemplate.all_templates[test_name]
|
|
|
|
_ = CuteDSLTemplate(
|
|
name=test_name,
|
|
source="template1",
|
|
)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
_ = CuteDSLTemplate(
|
|
name=test_name,
|
|
source="template2",
|
|
)
|
|
|
|
def test_indented_buffer_usage(self):
|
|
kernel = CuteDSLTemplateKernel(
|
|
kernel_name="test_kernel",
|
|
input_nodes=[],
|
|
output_node=None,
|
|
)
|
|
|
|
imports = kernel.gen_imports()
|
|
|
|
lines = imports.strip().split("\n")
|
|
for line in lines:
|
|
if line:
|
|
self.assertFalse(
|
|
line.startswith(" "), f"Line should not be indented: '{line}'"
|
|
)
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
|
def test_cutedsl_add_e2e(self):
|
|
"""End-to-end test with CuteDSL template including code generation verification."""
|
|
from torch._inductor.ir import TensorBox
|
|
from torch._inductor.lowering import lowerings
|
|
from torch._inductor.utils import run_and_get_code
|
|
|
|
template = CuteDSLTemplate(
|
|
name="test_add_e2e",
|
|
source=CUTEDSL_ADD_TEMPLATE,
|
|
)
|
|
|
|
def cutedsl_add_lowering(a: TensorBox, b: TensorBox) -> TensorBox:
|
|
choices = []
|
|
error = template.maybe_append_choice(
|
|
choices,
|
|
input_nodes=[a, b],
|
|
layout=a.get_layout(),
|
|
THREADS_PER_BLOCK=256,
|
|
)
|
|
|
|
if error or not choices:
|
|
default_lowering = lowerings[torch.ops.aten.add.Tensor]
|
|
return default_lowering(a, b)
|
|
|
|
# Use the single choice directly (no autotuning)
|
|
return choices[0].output_node()
|
|
|
|
with patch.dict(lowerings, {torch.ops.aten.add.Tensor: cutedsl_add_lowering}):
|
|
# Test function
|
|
def test_add(x, y):
|
|
return x + y
|
|
|
|
device = "cuda"
|
|
x = torch.randn(128, 4, device=device, dtype=torch.float32)
|
|
y = torch.randn(128, 4, device=device, dtype=torch.float32)
|
|
|
|
# Compile and get generated code
|
|
compiled_fn = torch.compile(test_add, backend="inductor")
|
|
result, (code,) = run_and_get_code(compiled_fn, x, y)
|
|
|
|
# Verify CuteDSL code is present
|
|
self.assertIn(
|
|
"cute", code.lower(), "CuteDSL code should be in generated code"
|
|
)
|
|
# Verify parameter generation worked
|
|
self.assertIn(
|
|
"THREADS_PER_BLOCK", code, "Parameter should be in generated code"
|
|
)
|
|
|
|
# Verify correctness
|
|
expected = x + y
|
|
self.assertTrue(torch.allclose(result, expected, atol=1e-5))
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
|
def test_cutedsl_add_e2e_autotune(self):
|
|
"""E2E test with multiple CuteDSL template variants for autotuning."""
|
|
from torch._inductor.ir import TensorBox
|
|
from torch._inductor.lowering import lowerings
|
|
from torch._inductor.select_algorithm import autotune_select_algorithm
|
|
|
|
template = CuteDSLTemplate(
|
|
name="test_add_autotune",
|
|
source=CUTEDSL_ADD_TEMPLATE,
|
|
)
|
|
|
|
def cutedsl_add_lowering(a: TensorBox, b: TensorBox) -> TensorBox:
|
|
choices = []
|
|
|
|
# Add multiple variants with different thread counts for autotuning
|
|
thread_variants = [128, 256, 512]
|
|
for threads in thread_variants:
|
|
error = template.maybe_append_choice(
|
|
choices,
|
|
input_nodes=[a, b],
|
|
layout=a.get_layout(),
|
|
THREADS_PER_BLOCK=threads,
|
|
)
|
|
if error:
|
|
# Skip this variant if it fails
|
|
continue
|
|
|
|
if not choices:
|
|
default_lowering = lowerings[torch.ops.aten.add.Tensor]
|
|
return default_lowering(a, b)
|
|
|
|
# Use autotuning to select the best variant
|
|
return autotune_select_algorithm(
|
|
"cutedsl_add_autotune",
|
|
choices,
|
|
[a, b],
|
|
a.get_layout(),
|
|
)
|
|
|
|
with patch.dict(lowerings, {torch.ops.aten.add.Tensor: cutedsl_add_lowering}):
|
|
# Test function
|
|
def test_add(x, y):
|
|
return x + y
|
|
|
|
device = "cuda"
|
|
x = torch.randn(128, 128, device=device, dtype=torch.float32)
|
|
y = torch.randn(128, 128, device=device, dtype=torch.float32)
|
|
|
|
# Compile and run
|
|
compiled_fn = torch.compile(test_add, backend="inductor")
|
|
result = compiled_fn(x, y)
|
|
|
|
# Verify correctness
|
|
expected = x + y
|
|
self.assertTrue(torch.allclose(result, expected, atol=1e-5))
|
|
|
|
def test_gen_defines(self):
|
|
"""Test that gen_defines correctly generates CuteDSL parameter definitions."""
|
|
kernel = CuteDSLTemplateKernel(
|
|
kernel_name="test_kernel",
|
|
input_nodes=[],
|
|
output_node=None,
|
|
)
|
|
|
|
# Test integer parameters
|
|
params = kernel.gen_defines(
|
|
THREADS_PER_BLOCK=256,
|
|
BLOCK_SIZE=128,
|
|
ENABLE_FEATURE=True,
|
|
)
|
|
|
|
expected_lines = [
|
|
"THREADS_PER_BLOCK: cutlass.Constexpr = 256",
|
|
"BLOCK_SIZE: cutlass.Constexpr = 128",
|
|
"ENABLE_FEATURE: cutlass.Constexpr = True",
|
|
]
|
|
|
|
for expected_line in expected_lines:
|
|
self.assertIn(expected_line, params)
|
|
|
|
# Test float parameters
|
|
params_float = kernel.gen_defines(SCALE_FACTOR=1.5)
|
|
self.assertIn("SCALE_FACTOR: cutlass.Constexpr = 1.5", params_float)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._inductor.test_case import run_tests
|
|
|
|
run_tests()
|