mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
## Summary Still figuring out what actually writing a template should look like, but lands alot of the base infra <img width="1267" height="262" alt="Screenshot 2025-08-16 at 10 22 12 PM" src="https://github.com/user-attachments/assets/229f8bfa-0cb4-4fb1-8530-f535e569d350" /> Test code: ```Python #!/usr/bin/env python3 """ Fixed CuteDSL template test with proper def_kernel usage. """ import torch import torch._inductor.config as config from torch._inductor.lowering import lowerings from torch._inductor.ir import TensorBox from torch._inductor.select_algorithm import autotune_select_algorithm from torch._inductor.codegen.cutedsl import CuteDSLTemplate def create_fixed_cutedsl_template(): """Create a properly structured CuteDSL template.""" def cutedsl_grid(M, N, meta): return (1,) # Part 1: Imports and kernel definition template_part1 = r""" import torch import cutlass import cutlass.cute as cute from cutlass.cute.runtime import from_dlpack @cute.kernel def {{kernel_name}}_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor): # Get thread and block indices tidx, _, _ = cute.arch.thread_idx() bidx, _, _ = cute.arch.block_idx() bdim, _, _ = cute.arch.block_dim() thread_idx = bidx * bdim + tidx m, n = gA.shape if thread_idx < m * n: mi = thread_idx // n ni = thread_idx % n if mi < m and ni < n: a_val = gA[mi, ni] b_val = gB[mi, ni] result = a_val + b_val gC[mi, ni] = a_val + b_val """ # Part 2: JIT wrapper function template_part2 = r""" @cute.jit def {{kernel_name}}_jit(mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor): m, n = mA.shape total_threads = m * n threads_per_block = 256 num_blocks = (total_threads + threads_per_block - 1) // threads_per_block kernel = {{kernel_name}}_kernel(mA, mB, mC) kernel.launch( grid=[num_blocks, 1, 1], block=[threads_per_block, 1, 1] ) """ # Part 3: Main kernel function template_part3 = r""" {{def_kernel("input_a", "input_b", "output_c")}} cute_a = from_dlpack(input_a, assumed_align=16) cute_b = from_dlpack(input_b, assumed_align=16) cute_c = from_dlpack(output_c, assumed_align=16) # Launch kernel {{kernel_name}}_jit(cute_a, cute_b, cute_c) return output_c """ # Combine all parts template = CuteDSLTemplate( name="fixed_add", grid=cutedsl_grid, source=template_part1 + template_part2 + template_part3 ) return template def fixed_cutedsl_lowering(a: TensorBox, b: TensorBox) -> TensorBox: """Fixed CuteDSL lowering.""" print(f"[FIXED] CuteDSL lowering: {a.get_size()} + {b.get_size()}") template = create_fixed_cutedsl_template() choices = [] error = template.maybe_append_choice( choices, input_nodes=[a.data, b.data], layout=a.get_layout() ) if error or not choices: print(f"[FIXED] Falling back: {error}") default_lowering = lowerings[torch.ops.aten.add.Tensor] return default_lowering(a, b) print(f"[FIXED] Using CuteDSL with {len(choices)} choices") result = autotune_select_algorithm( "fixed_cutedsl_add", choices, [a, b], a.get_layout(), ) return result def test_fixed_cutedsl(): """Test the fixed CuteDSL template.""" print("=" * 50) print("Fixed CuteDSL Template Test") print("=" * 50) original = lowerings.get(torch.ops.aten.add.Tensor, None) try: lowerings[torch.ops.aten.add.Tensor] = fixed_cutedsl_lowering def test_add(x, y): return x + y device = "cuda" if torch.cuda.is_available() else "cpu" x = torch.randn(128, 4, device=device, dtype=torch.float32) y = torch.randn(128, 4, device=device, dtype=torch.float32) print(f"[FIXED] Testing with {x.shape} tensors on {device}") compiled_fn = torch.compile(test_add, backend="inductor") result = compiled_fn(x, y) # Verify correctness expected = x + y if torch.allclose(result, expected, atol=1e-5): print("✅ [FIXED] Results match!") return True else: print("❌ [FIXED] Results don't match!") return False except Exception as e: print(f"❌ [FIXED] Failed: {e}") import traceback traceback.print_exc() return False finally: if original: lowerings[torch.ops.aten.add.Tensor] = original else: lowerings.pop(torch.ops.aten.add.Tensor, None) if __name__ == "__main__": success = test_fixed_cutedsl() print("🎉 Fixed test completed!" if success else "💥 Fixed test failed!") ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/160108 Approved by: https://github.com/mlazos