drisspg
3c6efd1380
Add cutedsl template support to compile (#160108)
## Summary
Still figuring out what actually writing a template should look like, but lands alot of the base infra
<img width="1267" height="262" alt="Screenshot 2025-08-16 at 10 22 12 PM" src="https://github.com/user-attachments/assets/229f8bfa-0cb4-4fb1-8530-f535e569d350" />
Test code:
```Python
#!/usr/bin/env python3
"""
Fixed CuteDSL template test with proper def_kernel usage.
"""
import torch
import torch._inductor.config as config
from torch._inductor.lowering import lowerings
from torch._inductor.ir import TensorBox
from torch._inductor.select_algorithm import autotune_select_algorithm
from torch._inductor.codegen.cutedsl import CuteDSLTemplate
def create_fixed_cutedsl_template():
"""Create a properly structured CuteDSL template."""
def cutedsl_grid(M, N, meta):
return (1,)
# Part 1: Imports and kernel definition
template_part1 = r"""
import torch
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
@cute.kernel
def {{kernel_name}}_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor):
# Get thread and block indices
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
bdim, _, _ = cute.arch.block_dim()
thread_idx = bidx * bdim + tidx
m, n = gA.shape
if thread_idx < m * n:
mi = thread_idx // n
ni = thread_idx % n
if mi < m and ni < n:
a_val = gA[mi, ni]
b_val = gB[mi, ni]
result = a_val + b_val
gC[mi, ni] = a_val + b_val
"""
# Part 2: JIT wrapper function
template_part2 = r"""
@cute.jit
def {{kernel_name}}_jit(mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor):
m, n = mA.shape
total_threads = m * n
threads_per_block = 256
num_blocks = (total_threads + threads_per_block - 1) // threads_per_block
kernel = {{kernel_name}}_kernel(mA, mB, mC)
kernel.launch(
grid=[num_blocks, 1, 1],
block=[threads_per_block, 1, 1]
)
"""
# Part 3: Main kernel function
template_part3 = r"""
{{def_kernel("input_a", "input_b", "output_c")}}
cute_a = from_dlpack(input_a, assumed_align=16)
cute_b = from_dlpack(input_b, assumed_align=16)
cute_c = from_dlpack(output_c, assumed_align=16)
# Launch kernel
{{kernel_name}}_jit(cute_a, cute_b, cute_c)
return output_c
"""
# Combine all parts
template = CuteDSLTemplate(
name="fixed_add",
grid=cutedsl_grid,
source=template_part1 + template_part2 + template_part3
)
return template
def fixed_cutedsl_lowering(a: TensorBox, b: TensorBox) -> TensorBox:
"""Fixed CuteDSL lowering."""
print(f"[FIXED] CuteDSL lowering: {a.get_size()} + {b.get_size()}")
template = create_fixed_cutedsl_template()
choices = []
error = template.maybe_append_choice(
choices,
input_nodes=[a.data, b.data],
layout=a.get_layout()
)
if error or not choices:
print(f"[FIXED] Falling back: {error}")
default_lowering = lowerings[torch.ops.aten.add.Tensor]
return default_lowering(a, b)
print(f"[FIXED] Using CuteDSL with {len(choices)} choices")
result = autotune_select_algorithm(
"fixed_cutedsl_add",
choices,
[a, b],
a.get_layout(),
)
return result
def test_fixed_cutedsl():
"""Test the fixed CuteDSL template."""
print("=" * 50)
print("Fixed CuteDSL Template Test")
print("=" * 50)
original = lowerings.get(torch.ops.aten.add.Tensor, None)
try:
lowerings[torch.ops.aten.add.Tensor] = fixed_cutedsl_lowering
def test_add(x, y):
return x + y
device = "cuda" if torch.cuda.is_available() else "cpu"
x = torch.randn(128, 4, device=device, dtype=torch.float32)
y = torch.randn(128, 4, device=device, dtype=torch.float32)
print(f"[FIXED] Testing with {x.shape} tensors on {device}")
compiled_fn = torch.compile(test_add, backend="inductor")
result = compiled_fn(x, y)
# Verify correctness
expected = x + y
if torch.allclose(result, expected, atol=1e-5):
print("✅ [FIXED] Results match!")
return True
else:
print("❌ [FIXED] Results don't match!")
return False
except Exception as e:
print(f"❌ [FIXED] Failed: {e}")
import traceback
traceback.print_exc()
return False
finally:
if original:
lowerings[torch.ops.aten.add.Tensor] = original
else:
lowerings.pop(torch.ops.aten.add.Tensor, None)
if __name__ == "__main__":
success = test_fixed_cutedsl()
print("🎉 Fixed test completed!" if success else "💥 Fixed test failed!")
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160108
Approved by: https://github.com/mlazos
2025-08-18 04:37:15 +00:00
..
2025-07-17 12:08:33 +00:00
2024-12-13 22:13:12 +00:00
2025-07-17 12:08:33 +00:00
2024-12-13 22:13:12 +00:00
2025-06-21 18:33:38 +00:00
2025-08-18 04:25:05 +00:00
2025-06-24 04:53:54 +00:00
2025-08-13 23:42:24 +00:00
2025-03-29 01:39:13 +00:00
2025-07-25 02:37:30 +00:00
2025-08-15 18:29:50 +00:00
2025-07-25 02:56:34 +00:00
2025-08-18 01:41:08 +00:00
2025-08-16 20:44:40 +00:00
2025-08-14 20:55:59 +00:00
2025-08-08 22:22:48 +00:00
2025-08-15 07:26:28 +00:00
2025-07-22 22:25:44 +00:00
2025-08-16 09:15:58 +00:00
2025-08-15 04:59:35 +00:00
2025-08-11 22:48:10 +00:00
2025-08-18 04:37:15 +00:00
2025-04-02 20:56:43 +00:00
2025-01-27 18:12:39 +00:00
2025-08-12 21:59:04 +00:00
2025-03-29 01:39:13 +00:00
2025-07-29 03:26:09 +00:00
2025-07-29 03:26:09 +00:00
2025-08-12 18:07:41 +00:00
2025-08-16 04:48:58 +00:00
2025-06-14 11:27:04 +00:00
2025-08-06 02:26:10 +00:00
2025-08-07 01:17:55 +00:00
2025-08-07 13:09:33 +00:00
2025-07-29 03:26:09 +00:00
2025-02-22 03:44:53 +00:00
2025-07-29 03:26:09 +00:00
2025-07-05 17:48:27 +00:00
2025-08-02 05:16:01 +00:00
2025-07-26 01:22:17 +00:00
2025-08-03 20:53:58 +00:00
2025-06-17 17:51:40 +00:00
2024-11-04 18:30:29 +00:00
2025-08-07 15:23:06 +00:00
2025-05-12 18:30:52 +00:00
2025-08-11 12:00:13 +00:00
2025-08-08 17:41:22 +00:00
2025-06-10 18:33:09 +00:00
2025-04-26 18:10:58 +00:00
2025-02-07 06:06:18 +00:00
2025-07-30 19:30:55 +00:00
2025-08-12 20:14:18 +00:00
2025-07-09 11:02:23 +00:00
2025-06-04 14:38:13 +00:00
2025-01-04 10:47:51 +00:00
2024-11-22 20:54:55 +00:00
2025-02-26 23:57:59 +00:00
2025-08-04 20:37:39 +00:00
2025-07-09 11:24:27 +00:00
2025-07-19 06:51:57 +00:00
2025-04-10 21:02:14 +00:00
2025-04-25 20:15:04 +00:00
2025-01-04 14:17:20 +00:00
2025-07-09 11:02:23 +00:00
2025-07-09 11:02:23 +00:00
2025-07-09 11:02:23 +00:00
2025-01-04 10:47:51 +00:00
2024-12-18 23:02:30 +00:00
2025-08-08 17:41:22 +00:00
2025-08-13 21:00:59 +00:00
2025-08-05 03:44:01 +00:00
2025-07-09 11:02:23 +00:00
2025-07-25 20:21:36 +00:00
2025-06-14 03:37:38 +00:00
2025-08-05 18:57:35 +00:00
2025-08-13 12:28:29 +00:00
2025-08-04 20:37:39 +00:00
2025-01-23 00:31:39 +00:00
2025-08-07 22:37:15 +00:00
2024-12-18 23:02:30 +00:00
2025-05-25 17:36:14 +00:00
2025-08-10 07:05:52 +00:00
2024-12-18 23:02:30 +00:00
2025-02-08 00:55:20 +00:00
2025-01-04 10:47:51 +00:00
2025-07-09 11:02:23 +00:00
2025-07-09 11:02:23 +00:00
2025-08-13 05:50:15 +00:00
2025-07-09 11:02:23 +00:00
2025-07-10 06:34:46 +00:00
2025-08-07 23:43:53 +00:00
2025-02-28 00:47:03 +00:00
2025-06-04 14:38:13 +00:00
2025-08-05 18:57:35 +00:00
2025-08-04 20:37:39 +00:00
2025-08-04 20:37:39 +00:00
2025-08-09 02:21:22 +00:00
2025-08-04 20:37:39 +00:00
2025-08-04 20:37:39 +00:00
2024-12-18 23:02:30 +00:00
2025-01-22 04:48:28 +00:00
2025-08-04 20:37:39 +00:00
2024-12-18 23:02:30 +00:00
2025-07-09 11:02:23 +00:00
2025-07-09 11:02:23 +00:00
2025-08-07 13:09:33 +00:00
2025-07-09 11:02:23 +00:00
2025-07-21 21:44:49 +00:00
2025-07-31 17:58:02 +00:00
2025-07-09 11:02:23 +00:00
2024-12-18 23:02:30 +00:00
2025-01-04 10:47:51 +00:00
2025-07-17 08:57:34 +00:00
2024-12-18 23:02:30 +00:00
2025-04-03 23:50:13 +00:00
2024-12-18 23:02:30 +00:00
2025-01-25 00:58:03 +00:00
2024-12-18 23:02:30 +00:00
2025-08-15 16:19:25 +00:00
2025-08-15 00:11:55 +00:00
2025-06-12 14:42:32 +00:00
2024-12-18 23:02:30 +00:00
2025-01-04 10:47:51 +00:00
2024-12-06 21:45:18 +00:00
2025-08-08 22:22:48 +00:00
2025-08-15 16:19:25 +00:00
2025-01-04 10:47:51 +00:00
2025-08-15 16:52:43 +00:00
2025-04-27 09:56:42 +00:00
2025-07-25 23:49:46 +00:00
2024-12-18 23:02:30 +00:00
2025-07-25 02:39:41 +00:00
2025-03-18 16:09:39 +00:00
2025-07-09 11:02:23 +00:00
2025-08-07 02:38:45 +00:00
2025-05-07 22:46:05 +00:00
2024-12-27 07:58:44 +00:00
2025-07-09 11:02:23 +00:00
2025-07-09 11:02:23 +00:00
2024-12-18 23:02:30 +00:00
2025-07-02 23:12:29 +00:00
2025-07-09 11:02:23 +00:00
2025-08-16 09:15:58 +00:00
2025-06-08 17:30:31 +00:00
2025-07-11 03:21:47 +00:00
2025-07-09 11:02:23 +00:00
2025-07-17 01:27:44 +00:00
2025-07-09 11:02:23 +00:00
2025-08-10 18:35:42 +00:00
2025-07-09 11:02:23 +00:00
2025-08-12 20:52:25 +00:00
2025-02-25 03:47:40 +00:00
2025-08-14 17:06:27 +00:00
2025-07-29 17:40:49 +00:00
2025-07-25 02:56:34 +00:00
2025-08-14 15:09:16 +00:00
2025-06-04 01:58:52 +00:00
2025-07-09 11:02:23 +00:00
2025-01-04 10:47:51 +00:00
2024-12-18 23:02:30 +00:00
2025-07-09 11:02:23 +00:00
2025-07-09 11:02:23 +00:00
2025-05-30 19:18:43 +00:00
2025-07-09 11:02:23 +00:00
2025-07-13 09:30:57 +00:00
2025-01-26 03:37:20 +00:00
2025-08-10 18:35:42 +00:00
2025-07-15 08:10:05 +00:00
2025-08-14 08:55:31 +00:00
2025-01-04 14:17:20 +00:00
2025-07-20 23:49:18 +00:00
2025-07-09 11:02:23 +00:00
2025-08-10 18:35:42 +00:00
2025-02-05 19:40:10 +00:00
2024-12-12 01:18:34 +00:00
2025-06-25 18:09:04 +00:00
2025-08-16 00:54:32 +00:00
2024-12-18 23:02:30 +00:00
2025-07-09 11:02:23 +00:00
2025-02-04 19:07:04 +00:00
2025-08-14 02:22:39 +00:00