mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Updates to CuTe DSL template renderer (#161117)
# Summary This adds a few more render functions available to template writers, specifically get_output and modification. The reasons why are more clear in the next PR in this stack. <img width="1645" height="364" alt="Screenshot 2025-08-21 at 1 48 50 PM" src="https://github.com/user-attachments/assets/2d508fda-4273-43ef-9edf-086e592e9249" /> Majority of the new cod is around the OpOverrides for CuTe DSL. It is alot to test and most of the actual testing I have been doing is via score_mods to the flash_attention at the next layer of this stack. A bunch of score mods that me and Claude came up with , that exercise the actual ops. ``` Py def causal_mask(score, b, h, q_idx, kv_idx): """Causal attention mask.""" return torch.where(q_idx >= kv_idx, score, float("-inf")) def relative_bias(score, b, h, token_q, token_kv): """Relative position bias.""" return score + torch.abs(token_q - token_kv) def relative_bias_v2(score, b, h, token_q, token_kv): """Relative position bias with factor of 2.""" return score + 2 * torch.abs(token_q - token_kv) def times_two(score, b, h, q_idx, kv_idx): """Simple score modification that doubles the score.""" return score * 2 def alibi_bias(score, b, h, q_idx, kv_idx): """ALiBi (Attention with Linear Biases) - used in some modern models.""" # Different slopes for different heads slope = 2 ** (-8 * (h + 1) / 8) # Simplified version return score - slope * torch.abs(q_idx - kv_idx) def sliding_window(score, b, h, q_idx, kv_idx, window_size=256): """Sliding window attention - only attend to nearby tokens.""" return torch.where( torch.abs(q_idx - kv_idx) <= window_size, score, float("-inf") ) def block_diagonal(score, b, h, q_idx, kv_idx, block_size=64): """Block diagonal attention pattern.""" q_block = q_idx // block_size kv_block = kv_idx // block_size return torch.where(q_block == kv_block, score, float("-inf")) def additive_bias(score, b, h, q_idx, kv_idx): """Test simple addition with position-based bias.""" return score + (q_idx + kv_idx) * 0.01 def multiplicative_decay(score, b, h, q_idx, kv_idx): """Test multiplication with distance-based decay.""" distance = torch.abs(q_idx - kv_idx) return score * torch.exp(-0.1 * distance) def sine_wave_bias(score, b, h, q_idx, kv_idx): """Test trigonometric functions.""" return score + 0.1 * torch.sin(2 * math.pi * (q_idx - kv_idx) / 64) def log_distance_penalty(score, b, h, q_idx, kv_idx): """Test logarithmic operations.""" distance = torch.abs(q_idx - kv_idx).float() return score - torch.log(1 + distance) def alternating_mask(score, b, h, q_idx, kv_idx): """Test with alternating pattern - good for branch prediction.""" return torch.where((q_idx + kv_idx) % 2 == 0, score, float("-inf")) def head_specific_pattern(score, b, h, q_idx, kv_idx): """Different behavior per attention head.""" even_head = h % 2 == 0 causal = q_idx >= kv_idx return torch.where(even_head & causal, score, float("-inf")) def sparse_strided(score, b, h, q_idx, kv_idx, stride=4): """Sparse attention with strided pattern.""" return torch.where( (kv_idx % stride == 0) | (q_idx == kv_idx), score, float("-inf") ) def causal_with_global(score, b, h, q_idx, kv_idx): """Causal mask but first few tokens are globally attended.""" is_causal = q_idx >= kv_idx is_global = kv_idx < 4 return torch.where(is_causal | is_global, score, float("-inf")) def dilated_attention(score, b, h, q_idx, kv_idx, dilation_rate=2): """Dilated attention pattern - exponentially increasing gaps.""" distance = torch.abs(q_idx - kv_idx) is_attended = (distance == 0) | ((distance > 0) & ((distance & (distance - 1)) == 0)) return torch.where(is_attended, score, float("-inf")) ``` Example outputs: ``` [Test Suite] Config: batch=4, heads=32, seq_q=8192, seq_kv=8192, dim=128 [Test 1: none] [No score_mod, flash='enabled'] Found flash_attncute: True [No score_mod, flash='disabled'] Found flash_attncute: False ✓ Outputs match between flash enabled/disabled ✓ Output matches eager SDPA (rtol=0.001, atol=0.001) [Test 2: causal] [With score_mod, flash='enabled'] Found flash_attncute: True [With score_mod, flash='disabled'] Found flash_attncute: False ✗ Outputs differ between flash modes: Tensor-likes are not close! Mismatched elements: 17879 / 134217728 (0.0%) Greatest absolute difference: 0.0078125 at index (0, 15, 15, 60) (up to 0.001 allowed) Greatest relative difference: 2.5 at index (3, 22, 153, 126) (up to 0.001 allowed) [Test 3: rel_bias] [With score_mod, flash='enabled'] Found flash_attncute: True [With score_mod, flash='disabled'] Found flash_attncute: False ✗ Outputs differ between flash modes: Tensor-likes are not close! Mismatched elements: 12836 / 134217728 (0.0%) Greatest absolute difference: 0.015625 at index (0, 3, 2775, 84) (up to 0.001 allowed) Greatest relative difference: 11.8125 at index (3, 28, 4095, 76) (up to 0.001 allowed) [Test 4: rel_bias_v2] ``` This is bfloat16 and there are no major differences. The list of pointwise ops here isn't exhaustive but it is fairly covering Pull Request resolved: https://github.com/pytorch/pytorch/pull/161117 Approved by: https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
12c0cf3fab
commit
30edac5da6
@ -2,8 +2,12 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from expecttest import assert_expected_inline
|
||||
|
||||
import torch
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.testing._internal.inductor_utils import MockGraphHandler
|
||||
|
||||
|
||||
try:
|
||||
@ -19,6 +23,7 @@ if HAS_CUTLASS:
|
||||
from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate
|
||||
from torch._inductor.select_algorithm import PartialRender
|
||||
|
||||
|
||||
CUTEDSL_ADD_TEMPLATE = r"""
|
||||
{{gen_defines()}}
|
||||
|
||||
@ -52,13 +57,13 @@ def {{kernel_name}}_jit(mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor, strea
|
||||
stream=stream
|
||||
)
|
||||
|
||||
{{def_kernel("input_a", "input_b", "output_c")}}
|
||||
{{def_kernel("input_a", "input_b")}}
|
||||
cute_a = from_dlpack(input_a)
|
||||
cute_b = from_dlpack(input_b)
|
||||
cute_c = from_dlpack(output_c)
|
||||
cute_c = from_dlpack({{get_output()}})
|
||||
|
||||
{{kernel_name}}_jit(cute_a, cute_b, cute_c, cuda.CUstream(stream))
|
||||
return output_c
|
||||
return {{get_output()}}
|
||||
"""
|
||||
|
||||
|
||||
@ -82,7 +87,7 @@ class TestCuteDSLTemplate(TestCase):
|
||||
self.assertIsInstance(imports, str)
|
||||
|
||||
lines = imports.strip().split("\n")
|
||||
self.assertEqual(len(lines), 5)
|
||||
self.assertEqual(len(lines), 7)
|
||||
|
||||
def test_render_includes_imports(self):
|
||||
template_source = """@cute.kernel
|
||||
@ -299,18 +304,178 @@ def {{kernel_name}}_kernel():
|
||||
ENABLE_FEATURE=True,
|
||||
)
|
||||
|
||||
expected_lines = [
|
||||
"THREADS_PER_BLOCK: cutlass.Constexpr = 256",
|
||||
"BLOCK_SIZE: cutlass.Constexpr = 128",
|
||||
"ENABLE_FEATURE: cutlass.Constexpr = True",
|
||||
]
|
||||
assert_expected_inline(
|
||||
params,
|
||||
"""\
|
||||
THREADS_PER_BLOCK: cutlass.Constexpr = 256
|
||||
BLOCK_SIZE: cutlass.Constexpr = 128
|
||||
ENABLE_FEATURE: cutlass.Constexpr = True
|
||||
""",
|
||||
)
|
||||
|
||||
for expected_line in expected_lines:
|
||||
self.assertIn(expected_line, params)
|
||||
|
||||
# Test float parameters
|
||||
params_float = kernel.gen_defines(SCALE_FACTOR=1.5)
|
||||
self.assertIn("SCALE_FACTOR: cutlass.Constexpr = 1.5", params_float)
|
||||
assert_expected_inline(
|
||||
params_float,
|
||||
"""\
|
||||
SCALE_FACTOR: cutlass.Constexpr = 1.5
|
||||
""",
|
||||
)
|
||||
|
||||
def test_template_aliasing(self):
|
||||
"""Test that template variables are correctly aliased to function arguments."""
|
||||
from torch._inductor.ir import Buffer
|
||||
|
||||
mock_input1 = MagicMock(spec=Buffer)
|
||||
mock_input1.get_name.return_value = "buf_input1"
|
||||
|
||||
mock_input2 = MagicMock(spec=Buffer)
|
||||
mock_input2.get_name.return_value = "buf_input2"
|
||||
|
||||
mock_output = MagicMock(spec=Buffer)
|
||||
mock_output.get_name.return_value = "buf_output"
|
||||
|
||||
mock_graph = MockGraphHandler()
|
||||
with V.set_graph_handler(mock_graph):
|
||||
kernel = CuteDSLTemplateKernel(
|
||||
kernel_name="test_aliasing",
|
||||
input_nodes=[mock_input1, mock_input2],
|
||||
output_node=mock_output,
|
||||
)
|
||||
|
||||
def_kernel_hook = kernel.def_kernel("custom_a", "custom_b")
|
||||
self.assertEqual(def_kernel_hook, "<DEF_KERNEL>")
|
||||
|
||||
self.assertIn("<DEF_KERNEL>", kernel.render_hooks)
|
||||
|
||||
hook_fn = kernel.render_hooks["<DEF_KERNEL>"]
|
||||
generated_code = hook_fn()
|
||||
|
||||
# Check that the generated code contains the expected aliasing statements
|
||||
self.assertIn("custom_a = arg_custom_a", generated_code)
|
||||
self.assertIn("custom_b = arg_custom_b", generated_code)
|
||||
|
||||
def test_get_output_hook(self):
|
||||
"""Test the get_output() template hook."""
|
||||
from torch._inductor.ir import Buffer
|
||||
|
||||
mock_output = MagicMock(spec=Buffer)
|
||||
mock_output.get_name.return_value = "buf_test_output"
|
||||
|
||||
mock_graph = MockGraphHandler()
|
||||
with V.set_graph_handler(mock_graph):
|
||||
kernel = CuteDSLTemplateKernel(
|
||||
kernel_name="test_output",
|
||||
input_nodes=[],
|
||||
output_node=mock_output,
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# error if no output buffer
|
||||
result = kernel.get_output()
|
||||
|
||||
kernel.args.output_buffers["buf_test_output"] = "arg_buf_test_output"
|
||||
result = kernel.get_output()
|
||||
self.assertEqual(result, "arg_buf_test_output")
|
||||
|
||||
def test_modification_subgraph(self):
|
||||
"""Test the modification() method and subgraph processing."""
|
||||
|
||||
from torch._inductor.ir import Buffer
|
||||
|
||||
mock_subgraph1 = MagicMock(spec=Buffer)
|
||||
mock_subgraph2 = MagicMock(spec=Buffer)
|
||||
subgraphs = [mock_subgraph1, mock_subgraph2]
|
||||
|
||||
mock_output = MagicMock(spec=Buffer)
|
||||
mock_output.get_name.return_value = "buf_output"
|
||||
|
||||
kernel = CuteDSLTemplateKernel(
|
||||
kernel_name="test_modification",
|
||||
input_nodes=[],
|
||||
output_node=mock_output,
|
||||
subgraphs=subgraphs,
|
||||
)
|
||||
|
||||
result = kernel._get_subgraph(0)
|
||||
self.assertEqual(result, mock_subgraph1)
|
||||
|
||||
result = kernel._get_subgraph(1)
|
||||
self.assertEqual(result, mock_subgraph2)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
kernel._get_subgraph(2)
|
||||
|
||||
def test_cutedsl_op_overrides(self):
|
||||
"""Test the new CuteDSLOpOverrides class."""
|
||||
import torch
|
||||
from torch._inductor.codegen.common import CSEVariable
|
||||
from torch._inductor.codegen.cutedsl.cutedsl_op_overrides import (
|
||||
CuteDSLOpOverrides,
|
||||
)
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
mock_cse_a = MagicMock(spec=CSEVariable)
|
||||
mock_cse_a.__str__.return_value = "tensor_a"
|
||||
mock_cse_a.dtype = torch.float32
|
||||
mock_cse_a.bounds = ValueRanges.unknown()
|
||||
|
||||
mock_cse_b = MagicMock(spec=CSEVariable)
|
||||
mock_cse_b.__str__.return_value = "tensor_b"
|
||||
mock_cse_b.dtype = torch.float32
|
||||
mock_cse_b.bounds = ValueRanges.unknown()
|
||||
|
||||
mock_graph = MockGraphHandler()
|
||||
with V.set_graph_handler(mock_graph):
|
||||
kernel = CuteDSLTemplateKernel(
|
||||
kernel_name="test_ops",
|
||||
input_nodes=[],
|
||||
output_node=None,
|
||||
)
|
||||
with V.set_kernel_handler(kernel):
|
||||
result = CuteDSLOpOverrides.add(mock_cse_a, mock_cse_b)
|
||||
self.assertIsInstance(result, CSEVariable)
|
||||
|
||||
result = CuteDSLOpOverrides.mul(mock_cse_a, mock_cse_b)
|
||||
self.assertIsInstance(result, CSEVariable)
|
||||
|
||||
result = CuteDSLOpOverrides.truediv(mock_cse_a, mock_cse_b)
|
||||
self.assertIsInstance(result, CSEVariable)
|
||||
|
||||
result = CuteDSLOpOverrides.exp(mock_cse_a)
|
||||
self.assertIsInstance(result, CSEVariable)
|
||||
|
||||
result = CuteDSLOpOverrides.sqrt(mock_cse_a)
|
||||
self.assertIsInstance(result, CSEVariable)
|
||||
|
||||
with self.assertRaises(NotImplementedError):
|
||||
result = CuteDSLOpOverrides.maximum(mock_cse_a, mock_cse_b)
|
||||
result = CuteDSLOpOverrides.minimum(mock_cse_a, mock_cse_b)
|
||||
|
||||
scalar_result = CuteDSLOpOverrides._ensure_tensor_ssa("5.0", mock_cse_a)
|
||||
self.assertEqual(scalar_result, "cute.full_like(tensor_a, 5.0)")
|
||||
|
||||
tensor_result = CuteDSLOpOverrides._ensure_tensor_ssa(mock_cse_a, mock_cse_b)
|
||||
self.assertEqual(tensor_result, "tensor_a")
|
||||
|
||||
def test_cse_integration(self):
|
||||
"""Test CSE (Common Subexpression Elimination) integration."""
|
||||
from torch._inductor.codegen.common import CSE
|
||||
|
||||
mock_graph = MockGraphHandler()
|
||||
with V.set_graph_handler(mock_graph):
|
||||
kernel = CuteDSLTemplateKernel(
|
||||
kernel_name="test_cse",
|
||||
input_nodes=[],
|
||||
output_node=None,
|
||||
)
|
||||
|
||||
self.assertIsInstance(kernel.cse, CSE)
|
||||
self.assertEqual(kernel.cse.name_prefix, "tmp")
|
||||
|
||||
with V.set_kernel_handler(kernel):
|
||||
test_expr = "x"
|
||||
var = kernel.cse.generate(kernel.body, test_expr, dtype=None)
|
||||
self.assertTrue(str(var).startswith("tmp"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -10,12 +10,15 @@ from torch._inductor.codegen.cuda.cutlass_utils import (
|
||||
torch_dtype_to_cutlass_type,
|
||||
try_import_cutlass,
|
||||
)
|
||||
from torch._inductor.graph import GraphLowering
|
||||
from torch._inductor.ir import ComputedBuffer, FixedLayout, PermuteView, Pointwise
|
||||
from torch._inductor.scheduler import BaseSchedulerNode
|
||||
from torch._inductor.utils import OrderedSet
|
||||
from torch.testing._internal.common_cuda import SM90OrLater
|
||||
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA_AND_TRITON
|
||||
from torch.testing._internal.inductor_utils import (
|
||||
HAS_CPU,
|
||||
HAS_CUDA_AND_TRITON,
|
||||
MockGraphHandler,
|
||||
)
|
||||
|
||||
|
||||
if try_import_cutlass():
|
||||
@ -105,17 +108,6 @@ class MockComputedBuffer(ComputedBuffer):
|
||||
return 1
|
||||
|
||||
|
||||
class MockGraphHandler(GraphLowering):
|
||||
def __init__(self, name_to_buffer):
|
||||
import torch._inductor.sizevars
|
||||
|
||||
self.sizevars = torch._inductor.sizevars.SizeVarAllocator()
|
||||
self.name_to_buffer = name_to_buffer
|
||||
self.graph_inputs = dict()
|
||||
self.mutated_buffers = OrderedSet()
|
||||
self.constants = dict()
|
||||
|
||||
|
||||
class TestCutlassEVT(TestCase):
|
||||
@unittest.skipIf(not SM90OrLater, "need sm_90")
|
||||
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")
|
||||
|
@ -2,19 +2,31 @@
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import logging
|
||||
import textwrap
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch._inductor.codegen.common import IndentedBuffer, Kernel
|
||||
from torch._inductor.ir import Buffer
|
||||
from torch._inductor.select_algorithm import PartialRender
|
||||
from torch._inductor.codegen.common import (
|
||||
CSE,
|
||||
CSEVariable,
|
||||
IndentedBuffer,
|
||||
Kernel,
|
||||
ValueRanges,
|
||||
)
|
||||
from torch._inductor.ir import Buffer, ComputedBuffer, InputBuffer
|
||||
from torch._inductor.ops_handler import StoreMode
|
||||
from torch._inductor.utils import OrderedSet
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
from .cutedsl_op_overrides import CuteDSLOpOverrides
|
||||
|
||||
|
||||
# TODO setting the 'main' kernel w/ this suffix. We have 3 should probably just auto generate this
|
||||
MAIN_SUFFIX = "main"
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code")
|
||||
|
||||
@ -70,14 +82,14 @@ class CuteDSLTemplateKernel(Kernel):
|
||||
kernel_name: str,
|
||||
input_nodes: list[Buffer],
|
||||
output_node: Buffer,
|
||||
subgraphs: Optional[list[Buffer]] = None,
|
||||
) -> None:
|
||||
# Call parent Kernel constructor
|
||||
super().__init__()
|
||||
self.kernel_name = kernel_name
|
||||
self.input_nodes = input_nodes
|
||||
self.output_node = output_node
|
||||
|
||||
# TODO Subgraph management for template processing
|
||||
self.subgraphs = subgraphs
|
||||
self.subgraph_bodies: dict[str, CuteDSLSubgraphInfo] = {}
|
||||
|
||||
# Template attributes
|
||||
@ -97,6 +109,8 @@ class CuteDSLTemplateKernel(Kernel):
|
||||
node_name = getattr(input_node, "name", f"input_{i}")
|
||||
self.named_input_nodes[node_name] = input_node
|
||||
|
||||
self.cse = CSE(name_prefix="tmp")
|
||||
|
||||
def gen_imports(self) -> str:
|
||||
"""Generate common imports for CuteDSL templates."""
|
||||
imports = IndentedBuffer()
|
||||
@ -107,6 +121,8 @@ class CuteDSLTemplateKernel(Kernel):
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
import cuda.bindings.driver as cuda
|
||||
from cutlass._mlir.dialects import math as mlir_math
|
||||
import operator
|
||||
"""
|
||||
)
|
||||
return imports.getvalue()
|
||||
@ -119,11 +135,15 @@ class CuteDSLTemplateKernel(Kernel):
|
||||
return params.getvalue()
|
||||
|
||||
def render(self, template, **kwargs):
|
||||
from torch._inductor.select_algorithm import PartialRender
|
||||
|
||||
"""Render the kernel using the template, returning PartialRender object with hooks."""
|
||||
# Available {{}} hooks for jinja rendering
|
||||
template_env = {
|
||||
"def_kernel": self.def_kernel,
|
||||
"gen_defines": lambda: self.gen_defines(**kwargs),
|
||||
"get_output": self.get_output,
|
||||
"modification": self.modification,
|
||||
}
|
||||
|
||||
# Render the template with the environment and provided kwargs
|
||||
@ -194,29 +214,203 @@ class CuteDSLTemplateKernel(Kernel):
|
||||
|
||||
def def_kernel(self, *argnames):
|
||||
"""Define kernel function signature for CuteDSL templates."""
|
||||
# Populate all the kernel args
|
||||
renames = IndentedBuffer(initial_indent=1)
|
||||
|
||||
for i, input_node in enumerate(self.input_nodes):
|
||||
self.args.input(input_node.get_name())
|
||||
buf_name = input_node.get_name()
|
||||
self.args.input(buf_name)
|
||||
|
||||
# Template aliasing: converts template variables (e.g., "input_a") to function args (e.g., "arg_input_a")
|
||||
# and generates rename statements so template code can use the original names
|
||||
if i < len(argnames):
|
||||
template_name = argnames[i]
|
||||
arg_name = f"arg_{template_name}"
|
||||
self.args.input_buffers[buf_name] = arg_name
|
||||
renames.writeline(f"{template_name} = {arg_name}")
|
||||
|
||||
if self.output_node:
|
||||
self.args.output(self.output_node.get_name())
|
||||
|
||||
def hook():
|
||||
# Deferred execution: arg definitions must be collected after template processing adds all args
|
||||
arg_defs, *_ = self.args.python_argdefs()
|
||||
code = IndentedBuffer()
|
||||
code.writeline(f"# Kernel function signature: {self.kernel_name}")
|
||||
params = list(argnames) + ["stream"]
|
||||
params = [x.full_name() for x in arg_defs] + ["stream"]
|
||||
code.writeline(
|
||||
f"def {self.kernel_name}_{MAIN_SUFFIX}({', '.join(params)}):"
|
||||
)
|
||||
with code.indent():
|
||||
code.splice(renames.getvalue())
|
||||
return code.getvalue()
|
||||
|
||||
assert "<DEF_KERNEL>" not in self.render_hooks
|
||||
# Placeholder-based rendering: hook will be called when template encounters "<DEF_KERNEL>"
|
||||
self.render_hooks["<DEF_KERNEL>"] = hook
|
||||
return "<DEF_KERNEL>"
|
||||
|
||||
def get_output(self):
|
||||
"""Get the actual argument name for the output buffer."""
|
||||
assert self.output_node, "Output node must exist to get output buffer name"
|
||||
buf_name = self.output_node.get_name()
|
||||
output = self.args.output_buffers.get(buf_name, None)
|
||||
if output is None:
|
||||
raise ValueError(f"Output buffer '{buf_name}' not found in args")
|
||||
return output
|
||||
|
||||
def call_kernel(self, name: str, node=None):
|
||||
"""Call the kernel function. Simplified version of TritonTemplateKernel.call_kernel."""
|
||||
wrapper = V.graph.wrapper_code
|
||||
_, call_args, _, arg_types = self.args.python_argdefs()
|
||||
# TODO triton should really be swapped w/ `python`
|
||||
wrapper.generate_kernel_call(name, call_args, triton=True, arg_types=arg_types)
|
||||
|
||||
def _get_subgraph(self, subgraph_number: int):
|
||||
"""Get subgraph by number for modification processing."""
|
||||
assert isinstance(subgraph_number, int)
|
||||
assert isinstance(self.subgraphs, list)
|
||||
assert subgraph_number < len(self.subgraphs), (
|
||||
f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}"
|
||||
)
|
||||
assert self.body.getvalue() == "", (
|
||||
"Body should be clear before adding a modification"
|
||||
)
|
||||
return self.subgraphs[subgraph_number]
|
||||
|
||||
def modification(
|
||||
self,
|
||||
subgraph_number: int,
|
||||
output_name: Optional[str],
|
||||
mask: Optional[str] = None,
|
||||
**fixed_inputs,
|
||||
) -> str:
|
||||
"""Generate CuteDSL code for a subgraph modification."""
|
||||
# Find unique name to avoid collisions between multiple modifications of same subgraph
|
||||
num = 0
|
||||
while f"mod_{subgraph_number}_{num}" in self.subgraph_bodies:
|
||||
num += 1
|
||||
|
||||
with self.create_subgraph_body(f"mod_{subgraph_number}_{num}"):
|
||||
subgraph = self._get_subgraph(subgraph_number)
|
||||
modification_handler = ModificationWrapperCuteDSL(
|
||||
self, subgraph_number, fixed_inputs, mask
|
||||
)
|
||||
with V.set_kernel_handler(self), V.set_ops_handler(modification_handler):
|
||||
assert isinstance(subgraph, (ComputedBuffer, list)), (
|
||||
f"Expected ComputedBuffer or List[ComputedBuffer], got {type(subgraph)}"
|
||||
)
|
||||
|
||||
if isinstance(subgraph, list):
|
||||
raise NotImplementedError(
|
||||
"Scatter graphs are not supported for CuteDSL"
|
||||
)
|
||||
|
||||
if isinstance(subgraph.data, InputBuffer):
|
||||
# grad_score_mod can be InputBuffers
|
||||
out = subgraph.data.make_loader()(())
|
||||
else:
|
||||
# Inline a pointwise lowering into the template
|
||||
out = subgraph.data.inner_fn(())
|
||||
|
||||
if output_name is not None:
|
||||
assert out is not None, (
|
||||
f"Expected computation result for named output {output_name}"
|
||||
)
|
||||
self.body.writeline(f"{output_name} = {out.value}")
|
||||
else:
|
||||
# Side-effect only: no output assignment (currently only for scatter operations)
|
||||
raise NotImplementedError(
|
||||
"Side-effect only modifications not yet supported for CuteDSL"
|
||||
)
|
||||
|
||||
return self.body.getvalue()
|
||||
|
||||
|
||||
class ModificationWrapperCuteDSL(V.WrapperHandler): # type: ignore[name-defined]
|
||||
"""
|
||||
Wrapper handler that enables CuteDSL code generation during subgraph modifications.
|
||||
|
||||
This class sits between the PyTorch IR and CuteDSL code generation, providing:
|
||||
1. Operation substitution: converts PyTorch ops to CuteDSL equivalents via CuteDSLOpOverrides
|
||||
2. Placeholder handling: resolves fixed_inputs during template processing
|
||||
3. Limited operation support: currently restricted to pointwise operations
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel,
|
||||
subgraph_number: int,
|
||||
fixed_inputs: dict[str, Any],
|
||||
mask: Optional[str],
|
||||
):
|
||||
cutedsl_ops = CuteDSLOpOverrides()
|
||||
super().__init__(cutedsl_ops)
|
||||
self.name = f"CuteDSLPlaceholderSubstitution_{subgraph_number}"
|
||||
self.kernel = kernel
|
||||
self.fixed_inputs = fixed_inputs
|
||||
self.mask = mask
|
||||
|
||||
def _get_input_dtype(self, name: str) -> torch.dtype:
|
||||
"""Get the dtype for an input from the kernel's named_input_nodes."""
|
||||
if name in self.kernel.named_input_nodes:
|
||||
return self.kernel.named_input_nodes[name].dtype
|
||||
# TODO: Fallback for common dimension names - should be replaced with proper dtype tracking
|
||||
return torch.float32 if name not in ("b", "h", "m", "n") else torch.int32
|
||||
|
||||
def load(self, name: str, index: sympy.Expr):
|
||||
"""Handle loading from tensor or fixed(template args) input for CuteDSL."""
|
||||
if name not in self.fixed_inputs:
|
||||
raise NotImplementedError(
|
||||
"Tensor loading not yet supported for CuteDSL - only fixed input substitution"
|
||||
)
|
||||
value = self.fixed_inputs[name]
|
||||
dtype = self._get_input_dtype(name)
|
||||
|
||||
# ensure CSE wrapping
|
||||
return self.kernel.cse.generate(
|
||||
self.kernel.body, value, bounds=ValueRanges.unknown(), dtype=dtype
|
||||
)
|
||||
|
||||
def indirect_indexing(self, index_var: str, size, check, wrap_neg=True):
|
||||
"""Convert index variable to symbolic form."""
|
||||
raise NotImplementedError("Indirect indexing not supported")
|
||||
|
||||
def store(
|
||||
self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
|
||||
) -> str:
|
||||
raise NotImplementedError(
|
||||
"Store operations not supported - CuteDSL limited to read-only operations"
|
||||
)
|
||||
|
||||
def _add_kernel_input(self, name: str):
|
||||
"""Add name as input to kernel and return input ref."""
|
||||
return self.kernel.args.input(name)
|
||||
|
||||
def _process_indexing(self, index):
|
||||
"""Process and rename indexing, adding symbols as kernel inputs."""
|
||||
# Convert sympy expression to string representation for CuteDSL
|
||||
return str(index) # Simplified for now
|
||||
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
try:
|
||||
return getattr(self._inner, name)(*args, **kwargs)
|
||||
except NotImplementedError as e:
|
||||
bar = "=" * 80
|
||||
msg = textwrap.dedent(f"""
|
||||
{bar}
|
||||
UNSUPPORTED CUTEDSL OPERATION: '{name}'
|
||||
{bar}
|
||||
This operation is not yet implemented in Inductor.
|
||||
|
||||
Please open an issue at: https://github.com/pytorch/pytorch/issues
|
||||
with the following information:
|
||||
|
||||
Operation: {name}
|
||||
Args: {args!r}
|
||||
Kwargs: {kwargs!r}
|
||||
|
||||
Title your issue: [CuteDSL] Missing operation: {name}
|
||||
{bar}
|
||||
""").strip()
|
||||
raise NotImplementedError(msg) from e
|
||||
|
358
torch/_inductor/codegen/cutedsl/cutedsl_op_overrides.py
Normal file
358
torch/_inductor/codegen/cutedsl/cutedsl_op_overrides.py
Normal file
@ -0,0 +1,358 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""
|
||||
CuteDSL-specific operation overrides for pointwise operations.
|
||||
|
||||
This module provides CuteDSL implementations of common operations used in
|
||||
template kernels, particularly for flex attention modifications.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional, Union
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch._inductor.codegen.common import CSEVariable, OpOverrides
|
||||
from torch._inductor.virtualized import OpsValue, V
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
|
||||
CuteDSLArg = Union[CSEVariable, str]
|
||||
|
||||
|
||||
def upcast_compute_type(dtype: torch.dtype) -> torch.dtype:
|
||||
"""Maybe upcast [b]float16 to float32"""
|
||||
if dtype in (torch.float16, torch.bfloat16):
|
||||
return torch.float32
|
||||
return dtype
|
||||
|
||||
|
||||
class CuteDSLOpOverrides(OpOverrides):
|
||||
"""
|
||||
CuteDSL-specific operation overrides that generate code using CuteDSL syntax.
|
||||
|
||||
CuteDSL TensorSSA objects have built-in operator overloads (__add__, __mul__, etc.)
|
||||
and math functions (cute.math.exp, cute.math.sqrt, etc.)
|
||||
"""
|
||||
|
||||
TORCH_TO_CUTE_DTYPE = {
|
||||
torch.float16: "cutlass.Float16",
|
||||
torch.bfloat16: "cutlass.BFloat16",
|
||||
torch.float32: "cutlass.Float32",
|
||||
torch.float64: "cutlass.Float64",
|
||||
torch.int8: "cutlass.Int8",
|
||||
torch.int16: "cutlass.Int16",
|
||||
torch.int32: "cutlass.Int32",
|
||||
torch.int64: "cutlass.Int64",
|
||||
torch.bool: "cutlass.Boolean",
|
||||
torch.float8_e4m3fn: "cutlass.Float8E4M3FN",
|
||||
torch.float8_e5m2: "cutlass.Float8E5M2",
|
||||
}
|
||||
|
||||
# Math constants
|
||||
LOG2_E = 1.4426950408889634 # 1/ln(2) for converting natural exp to base-2 exp
|
||||
|
||||
@staticmethod
|
||||
def _ensure_tensor_ssa(arg: CuteDSLArg, template_tensor: CuteDSLArg) -> str:
|
||||
"""
|
||||
Convert scalar arguments to TensorSSA using cute.full_like if needed.
|
||||
|
||||
Args:
|
||||
arg: The argument to check (CSEVariable for tensors, str for scalars, or OpsValue wrapper)
|
||||
template_tensor: A tensor argument to use as template for full_like
|
||||
|
||||
Returns:
|
||||
String representation suitable for CuteDSL operations
|
||||
"""
|
||||
if isinstance(arg, CSEVariable):
|
||||
return str(arg)
|
||||
|
||||
if isinstance(arg, OpsValue) and isinstance(arg.value, CSEVariable):
|
||||
return str(arg.value)
|
||||
|
||||
if isinstance(template_tensor, CSEVariable):
|
||||
return f"cute.full_like({template_tensor}, {arg})"
|
||||
|
||||
return str(arg)
|
||||
|
||||
@staticmethod
|
||||
def _extract_dtype_and_bounds(
|
||||
*args: CuteDSLArg,
|
||||
) -> tuple[Optional[torch.dtype], ValueRanges[sympy.Expr]]:
|
||||
"""Extract dtype and bounds from CSEVariable arguments."""
|
||||
for arg in args:
|
||||
if isinstance(arg, CSEVariable):
|
||||
return arg.dtype, arg.bounds
|
||||
return None, ValueRanges.unknown()
|
||||
|
||||
@staticmethod
|
||||
def _apply_binary_op(a: CuteDSLArg, b: CuteDSLArg, op_format: str) -> CuteDSLArg:
|
||||
"""
|
||||
Apply a binary operation with automatic scalar-to-tensor conversion.
|
||||
|
||||
CuteDSL requires both operands to be TensorSSA objects for tensor operations.
|
||||
This helper automatically converts scalar arguments to TensorSSA using
|
||||
cute.full_like when at least one argument is a tensor (CSEVariable).
|
||||
|
||||
Args:
|
||||
a: First operand (CSEVariable for tensors, str for scalars)
|
||||
b: Second operand (CSEVariable for tensors, str for scalars)
|
||||
op_format: Format string with {a} and {b} placeholders for the operation
|
||||
|
||||
Returns:
|
||||
CSEVariable if at least one operand is a CSEVariable, otherwise string
|
||||
"""
|
||||
tensor_arg = (
|
||||
a
|
||||
if isinstance(a, CSEVariable)
|
||||
else b
|
||||
if isinstance(b, CSEVariable)
|
||||
else None
|
||||
)
|
||||
if tensor_arg is not None:
|
||||
a_ssa = CuteDSLOpOverrides._ensure_tensor_ssa(a, tensor_arg)
|
||||
b_ssa = CuteDSLOpOverrides._ensure_tensor_ssa(b, tensor_arg)
|
||||
result_expr = op_format.format(a=a_ssa, b=b_ssa)
|
||||
|
||||
dtype, bounds = CuteDSLOpOverrides._extract_dtype_and_bounds(a, b)
|
||||
|
||||
# Create and return CSEVariable using CSE generation for caching
|
||||
return V.kernel.cse.generate(
|
||||
V.kernel.body, result_expr, bounds=bounds, dtype=dtype
|
||||
)
|
||||
|
||||
return op_format.format(a=a, b=b)
|
||||
|
||||
@staticmethod
|
||||
def _apply_unary_op(x: CuteDSLArg, op_format: str) -> CuteDSLArg:
|
||||
"""
|
||||
Apply a unary operation, returning CSEVariable if input is CSEVariable.
|
||||
|
||||
Args:
|
||||
x: Input operand (CSEVariable for tensors, str for scalars)
|
||||
op_format: Format string with {x} placeholder for the operation
|
||||
|
||||
Returns:
|
||||
CSEVariable if input is a CSEVariable, otherwise string
|
||||
"""
|
||||
if isinstance(x, CSEVariable):
|
||||
result_expr = op_format.format(x=str(x))
|
||||
return V.kernel.cse.generate(
|
||||
V.kernel.body, result_expr, bounds=x.bounds, dtype=x.dtype
|
||||
)
|
||||
|
||||
return op_format.format(x=x)
|
||||
|
||||
@staticmethod
|
||||
def constant(value: Union[bool, float, int], dtype: torch.dtype) -> str:
|
||||
"""Generate CuteDSL constant representation."""
|
||||
if value == float("-inf"):
|
||||
return "float('-inf')"
|
||||
elif value == float("inf"):
|
||||
return "float('inf')"
|
||||
elif math.isnan(value):
|
||||
return "float('nan')"
|
||||
return repr(value)
|
||||
|
||||
@staticmethod
|
||||
def add(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg:
|
||||
return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} + {b})")
|
||||
|
||||
@staticmethod
|
||||
def mul(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg:
|
||||
return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} * {b})")
|
||||
|
||||
@staticmethod
|
||||
def sub(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg:
|
||||
return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} - {b})")
|
||||
|
||||
@staticmethod
|
||||
def truediv(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg:
|
||||
return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} / {b})")
|
||||
|
||||
@staticmethod
|
||||
def mod(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg:
|
||||
return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} % {b})")
|
||||
|
||||
@staticmethod
|
||||
def remainder(a, b):
|
||||
return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} % {b})")
|
||||
|
||||
@staticmethod
|
||||
def exp(x: CuteDSLArg) -> CuteDSLArg:
|
||||
"""Exponential using CuteDSL cute.math.exp function."""
|
||||
return CuteDSLOpOverrides._apply_unary_op(
|
||||
x, f"cute.math.exp2({{x}} * {CuteDSLOpOverrides.LOG2_E})"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def sqrt(x: CuteDSLArg) -> CuteDSLArg:
|
||||
"""Square root using CuteDSL cute.math.sqrt function."""
|
||||
return CuteDSLOpOverrides._apply_unary_op(x, "cute.math.sqrt({x})")
|
||||
|
||||
@staticmethod
|
||||
def log(x: CuteDSLArg) -> CuteDSLArg:
|
||||
"""Natural logarithm using CuteDSL cute.math.log function."""
|
||||
return CuteDSLOpOverrides._apply_unary_op(x, "cute.math.log({x})")
|
||||
|
||||
@staticmethod
|
||||
def cos(x: CuteDSLArg) -> CuteDSLArg:
|
||||
"""Cosine using CuteDSL cute.math.cos function."""
|
||||
return CuteDSLOpOverrides._apply_unary_op(x, "cute.math.cos({x})")
|
||||
|
||||
@staticmethod
|
||||
def sin(x: CuteDSLArg) -> CuteDSLArg:
|
||||
"""Sine using CuteDSL cute.math.sin function."""
|
||||
return CuteDSLOpOverrides._apply_unary_op(x, "cute.math.sin({x})")
|
||||
|
||||
@staticmethod
|
||||
def erf(x: CuteDSLArg) -> CuteDSLArg:
|
||||
"""Error function using CuteDSL cute.math.erf function."""
|
||||
return CuteDSLOpOverrides._apply_unary_op(x, "cute.math.erf({x})")
|
||||
|
||||
@staticmethod
|
||||
def maximum(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg:
|
||||
raise NotImplementedError("TODO: maximum is not supported yet for TensorSSA")
|
||||
|
||||
@staticmethod
|
||||
def minimum(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg:
|
||||
raise NotImplementedError("TODO: minimum is not supported yet for TensorSSA")
|
||||
|
||||
@staticmethod
|
||||
def where(
|
||||
condition: CuteDSLArg,
|
||||
a: CuteDSLArg,
|
||||
b: CuteDSLArg,
|
||||
) -> CuteDSLArg:
|
||||
"""Conditional selection - handles both CSEVariable and string inputs."""
|
||||
# Find a tensor argument to use as template for full_like
|
||||
# Priority: use 'a' if it's a tensor, else use 'b', else condition
|
||||
tensor_arg = (
|
||||
a
|
||||
if isinstance(a, CSEVariable)
|
||||
else (
|
||||
b
|
||||
if isinstance(b, CSEVariable)
|
||||
else condition
|
||||
if isinstance(condition, CSEVariable)
|
||||
else None
|
||||
)
|
||||
)
|
||||
|
||||
if tensor_arg is not None:
|
||||
a_ssa = CuteDSLOpOverrides._ensure_tensor_ssa(a, tensor_arg)
|
||||
b_ssa = CuteDSLOpOverrides._ensure_tensor_ssa(b, tensor_arg)
|
||||
result_expr = f"cute.where({condition}, {a_ssa}, {b_ssa})"
|
||||
|
||||
dtype, bounds = CuteDSLOpOverrides._extract_dtype_and_bounds(
|
||||
a, b, condition
|
||||
)
|
||||
|
||||
return V.kernel.cse.generate(
|
||||
V.kernel.body, result_expr, bounds=bounds, dtype=dtype
|
||||
)
|
||||
|
||||
return f"cute.where({condition}, {a}, {b})"
|
||||
|
||||
@staticmethod
|
||||
def pow(a: CuteDSLArg, b: CuteDSLArg):
|
||||
return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} ** {b})")
|
||||
|
||||
@staticmethod
|
||||
def abs(x: CuteDSLArg) -> CuteDSLArg:
|
||||
"""Absolute value using CuteDSL cute.math.abs function."""
|
||||
if isinstance(x, CSEVariable):
|
||||
x_dtype = x.dtype
|
||||
elif isinstance(x, OpsValue) and isinstance(x.value, CSEVariable):
|
||||
x_dtype = x.value.dtype
|
||||
else:
|
||||
x_dtype = torch.float32
|
||||
|
||||
abs_op = (
|
||||
"mlir_math.absf"
|
||||
if x_dtype in (torch.float16, torch.bfloat16, torch.float32)
|
||||
else "mlir_math.absi"
|
||||
)
|
||||
return CuteDSLOpOverrides._apply_unary_op(
|
||||
x, f"cute.TensorSSA({abs_op}({{x}}), {{x}}.shape, {{x}}.dtype)"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def neg(x: CuteDSLArg) -> CuteDSLArg:
|
||||
"""Negation using CuteDSL TensorSSA __neg__ operator."""
|
||||
# TODO: See https://github.com/NVIDIA/cutlass/issues/2584
|
||||
return CuteDSLOpOverrides._apply_unary_op(
|
||||
x, "cute.TensorSSA(-{x}, {x}.shape, {x}.dtype)"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def to_dtype(
|
||||
x: CuteDSLArg, dtype: torch.dtype, src_dtype=None, use_compute_types=True
|
||||
) -> CuteDSLArg:
|
||||
"""Type conversion using CuteDSL TensorSSA.to(Type[Numeric]).
|
||||
|
||||
Maps torch dtypes to cutlass.cute.typing numeric types and emits
|
||||
`{x}.to(cute.typing.<Type>)`.
|
||||
|
||||
Raises NotImplementedError for unsigned integer and unsupported dtypes.
|
||||
"""
|
||||
# Always convert up from bf16 and fp16 TODO on configuring
|
||||
dtype = upcast_compute_type(dtype)
|
||||
|
||||
cute_type = CuteDSLOpOverrides.TORCH_TO_CUTE_DTYPE.get(dtype)
|
||||
if cute_type is None:
|
||||
raise NotImplementedError(
|
||||
f"CuteDSL dtype cast not implemented for torch dtype: {dtype}"
|
||||
)
|
||||
|
||||
if isinstance(x, CSEVariable):
|
||||
result_expr = f"{str(x)}.to({cute_type})"
|
||||
return V.kernel.cse.generate(
|
||||
V.kernel.body, result_expr, bounds=x.bounds, dtype=dtype
|
||||
)
|
||||
|
||||
return f"{x}.to({cute_type})"
|
||||
|
||||
@staticmethod
|
||||
def tanh(x0: CuteDSLArg) -> CuteDSLArg:
|
||||
"""Hyperbolic tangent using CuteDSL cute.math.tanh function."""
|
||||
return CuteDSLOpOverrides._apply_unary_op(x0, "cute.math.tanh({x})")
|
||||
|
||||
# Logical operations
|
||||
@staticmethod
|
||||
def logical_and(x0: CuteDSLArg, x1: CuteDSLArg) -> CuteDSLArg:
|
||||
return CuteDSLOpOverrides._apply_binary_op(x0, x1, "({a} and {b})")
|
||||
|
||||
@staticmethod
|
||||
def logical_or(x0: CuteDSLArg, x1: CuteDSLArg) -> CuteDSLArg:
|
||||
return CuteDSLOpOverrides._apply_binary_op(x0, x1, "({a} or {b})")
|
||||
|
||||
@staticmethod
|
||||
def logical_not(a):
|
||||
"""Logical NOT."""
|
||||
return CuteDSLOpOverrides._apply_unary_op(a, "({x} == 0)")
|
||||
|
||||
# Comparison operations
|
||||
@staticmethod
|
||||
def eq(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg:
|
||||
return CuteDSLOpOverrides._apply_binary_op(a, b, "operator.eq({a}, {b})")
|
||||
|
||||
@staticmethod
|
||||
def ne(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg:
|
||||
return CuteDSLOpOverrides._apply_binary_op(a, b, "operator.ne({a}, {b})")
|
||||
|
||||
@staticmethod
|
||||
def lt(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg:
|
||||
return CuteDSLOpOverrides._apply_binary_op(a, b, "operator.lt({a}, {b})")
|
||||
|
||||
@staticmethod
|
||||
def le(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg:
|
||||
return CuteDSLOpOverrides._apply_binary_op(a, b, "operator.le({a}, {b})")
|
||||
|
||||
@staticmethod
|
||||
def gt(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg:
|
||||
return CuteDSLOpOverrides._apply_binary_op(a, b, "operator.gt({a}, {b})")
|
||||
|
||||
@staticmethod
|
||||
def ge(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg:
|
||||
return CuteDSLOpOverrides._apply_binary_op(a, b, "operator.ge({a}, {b})")
|
@ -1,14 +1,17 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
import itertools
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Optional, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
from torch._inductor.ir import ShapeAsConstantBuffer
|
||||
from torch._inductor.utils import Placeholder
|
||||
from torch._inductor.virtualized import V
|
||||
from torch._logging import getArtifactLogger
|
||||
|
||||
from ...autotune_process import CuteDSLBenchmarkRequest, TensorMeta
|
||||
from ...ir import Buffer, ChoiceCaller, CuteDSLTemplateBuffer, Layout, TensorBox
|
||||
from ...ir import Buffer, ChoiceCaller, CuteDSLTemplateBuffer, IRNode, Layout, TensorBox
|
||||
from ..common import KernelTemplate
|
||||
from .cutedsl_kernel import CuteDSLTemplateKernel
|
||||
|
||||
@ -64,6 +67,8 @@ class CuteDSLTemplate(KernelTemplate):
|
||||
"""Generate the CuteDSL kernel caller."""
|
||||
input_nodes = kwargs.pop("input_nodes")
|
||||
layout = kwargs.pop("layout")
|
||||
mutated_inputs = kwargs.pop("mutated_inputs", None)
|
||||
subgraphs = kwargs.pop("subgraphs", None)
|
||||
|
||||
kernel_name = f"cutedsl_{self.name}_{next(self.index_counter)}"
|
||||
|
||||
@ -71,45 +76,57 @@ class CuteDSLTemplate(KernelTemplate):
|
||||
raise RuntimeError("Template compilation failed (Jinja2 required)")
|
||||
|
||||
self.output_node: Buffer = Buffer(name="buf_out", layout=layout)
|
||||
|
||||
kernel = self.kernel_type(
|
||||
kernel_name=kernel_name,
|
||||
input_nodes=input_nodes,
|
||||
output_node=self.output_node,
|
||||
)
|
||||
|
||||
code = kernel.render(self.template, **kwargs)
|
||||
|
||||
log.debug("Generated CuteDSL Code:\n%s", code)
|
||||
|
||||
bmreq = CuteDSLBenchmarkRequest(
|
||||
kernel_name=kernel_name,
|
||||
input_tensor_meta=TensorMeta.from_irnodes(input_nodes),
|
||||
output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
|
||||
extra_args=tuple(),
|
||||
source_code=code,
|
||||
)
|
||||
|
||||
def make_kernel_render(out_node, hint_override: Optional[int] = None):
|
||||
render_kernel = self.kernel_type(
|
||||
kernel_name=str(Placeholder.KERNEL_NAME),
|
||||
# Patch V.graph.get_dtype to handle the fake buf_out buffer
|
||||
with patch.object(
|
||||
V.graph, "get_dtype", KernelTemplate._fake_get_dtype(self.output_node)
|
||||
):
|
||||
kernel = self.kernel_type(
|
||||
kernel_name=kernel_name,
|
||||
input_nodes=input_nodes,
|
||||
output_node=out_node,
|
||||
output_node=self.output_node,
|
||||
subgraphs=subgraphs,
|
||||
)
|
||||
code = kernel.render(self.template, **kwargs)
|
||||
|
||||
log.debug("Generated CuteDSL Code:\n%s", code)
|
||||
|
||||
bmreq = CuteDSLBenchmarkRequest(
|
||||
kernel_name=kernel_name,
|
||||
input_tensor_meta=TensorMeta.from_irnodes(input_nodes),
|
||||
output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
|
||||
extra_args=tuple(),
|
||||
source_code=code,
|
||||
)
|
||||
|
||||
def render():
|
||||
return render_kernel.render(self.template, **kwargs)
|
||||
def make_kernel_render(out_node, hint_override: Optional[int] = None):
|
||||
"""
|
||||
Factory function that creates a kernel renderer for the final output.
|
||||
|
||||
return render_kernel, render
|
||||
This closure captures the current template and parameters, but allows
|
||||
the output node to be specified later. This is used during the final
|
||||
kernel selection phase when the actual output buffer is available.
|
||||
"""
|
||||
render_kernel = self.kernel_type(
|
||||
kernel_name=str(Placeholder.KERNEL_NAME),
|
||||
input_nodes=input_nodes,
|
||||
output_node=out_node,
|
||||
subgraphs=subgraphs,
|
||||
)
|
||||
|
||||
return CuteDSLTemplateCaller(
|
||||
name=kernel_name,
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
make_kernel_render=make_kernel_render,
|
||||
bmreq=bmreq,
|
||||
template=self,
|
||||
)
|
||||
def render():
|
||||
return render_kernel.render(self.template, **kwargs)
|
||||
|
||||
return render_kernel, render
|
||||
|
||||
return CuteDSLTemplateCaller(
|
||||
name=kernel_name,
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
make_kernel_render=make_kernel_render,
|
||||
bmreq=bmreq,
|
||||
template=self,
|
||||
mutated_inputs=mutated_inputs,
|
||||
)
|
||||
|
||||
|
||||
class CuteDSLTemplateCaller(ChoiceCaller):
|
||||
@ -123,6 +140,7 @@ class CuteDSLTemplateCaller(ChoiceCaller):
|
||||
make_kernel_render: Any,
|
||||
bmreq: CuteDSLBenchmarkRequest,
|
||||
template: "CuteDSLTemplate",
|
||||
mutated_inputs: Optional[Iterable[IRNode]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
name=name,
|
||||
@ -133,6 +151,7 @@ class CuteDSLTemplateCaller(ChoiceCaller):
|
||||
self.make_kernel_render = make_kernel_render
|
||||
self.bmreq = bmreq
|
||||
self.template = template
|
||||
self.mutated_inputs = mutated_inputs
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"CuteDSLTemplateCaller({self.name})"
|
||||
@ -149,6 +168,7 @@ class CuteDSLTemplateCaller(ChoiceCaller):
|
||||
inputs=self.input_nodes,
|
||||
make_kernel_render=self.make_kernel_render,
|
||||
template=self.template,
|
||||
mutated_inputs=self.mutated_inputs,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -13,6 +13,7 @@ import torch._inductor.async_compile # noqa: F401 required to warm up AsyncComp
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._inductor.graph import GraphLowering
|
||||
from torch._inductor.compile_fx import shape_env_from_inputs
|
||||
from torch._inductor.utils import OrderedSet
|
||||
from torch._inductor.codecache import CppCodeCache
|
||||
from torch._inductor.custom_graph_pass import CustomGraphModulePass
|
||||
from torch._inductor.codegen.common import (
|
||||
@ -306,6 +307,24 @@ def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype):
|
||||
inverse_scale = scale.reciprocal()
|
||||
return x_fp8, inverse_scale
|
||||
|
||||
class MockGraphHandler(GraphLowering):
|
||||
"""Minimal mock graph handler for testing virtualized context."""
|
||||
|
||||
def __init__(self, name_to_buffer=None):
|
||||
import torch._inductor.sizevars
|
||||
|
||||
self.sizevars = torch._inductor.sizevars.SizeVarAllocator()
|
||||
self.name_to_buffer = name_to_buffer or {}
|
||||
self.graph_inputs = {}
|
||||
self.mutated_buffers = OrderedSet()
|
||||
self.removed_buffers = OrderedSet()
|
||||
self.constants = {}
|
||||
self.scheduler = None
|
||||
|
||||
def get_dtype(self, buffer_name: str) -> torch.dtype: # noqa: ARG002
|
||||
"""Return default dtype for any buffer (for testing)."""
|
||||
return torch.float32
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch_inductor_backend(
|
||||
device: str,
|
||||
|
Reference in New Issue
Block a user