mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-27 09:04:53 +08:00
Compare commits
2 Commits
ciflow/tru
...
flex_flash
| Author | SHA1 | Date | |
|---|---|---|---|
| d9007ea76c | |||
| 4bc383b405 |
318
test/inductor/test_cutedsl_template.py
Normal file
318
test/inductor/test_cutedsl_template.py
Normal file
@ -0,0 +1,318 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from torch._inductor.test_case import TestCase
|
||||
|
||||
|
||||
try:
|
||||
import cutlass # noqa: F401
|
||||
import cutlass.cute as cute # noqa: F401
|
||||
|
||||
HAS_CUTLASS = True
|
||||
except ImportError:
|
||||
HAS_CUTLASS = False
|
||||
|
||||
if HAS_CUTLASS:
|
||||
from torch._inductor.codegen.cutedsl.cutedsl_kernel import CuteDSLTemplateKernel
|
||||
from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate
|
||||
from torch._inductor.select_algorithm import PartialRender
|
||||
|
||||
CUTEDSL_ADD_TEMPLATE = r"""
|
||||
{{gen_defines()}}
|
||||
|
||||
@cute.kernel
|
||||
def {{kernel_name}}_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
bdim, _, _ = cute.arch.block_dim()
|
||||
|
||||
thread_idx = bidx * bdim + tidx
|
||||
m, n = gA.shape
|
||||
|
||||
if thread_idx < m * n:
|
||||
mi = thread_idx // n
|
||||
ni = thread_idx % n
|
||||
|
||||
if mi < m and ni < n:
|
||||
gC[mi, ni] = gA[mi, ni] + gB[mi, ni]
|
||||
|
||||
@cute.jit
|
||||
def {{kernel_name}}_jit(mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor):
|
||||
{{gen_defines()}}
|
||||
m, n = mA.shape
|
||||
total_threads = m * n
|
||||
num_blocks = (total_threads + THREADS_PER_BLOCK - 1) // THREADS_PER_BLOCK
|
||||
|
||||
kernel = {{kernel_name}}_kernel(mA, mB, mC)
|
||||
kernel.launch(
|
||||
grid=[num_blocks, 1, 1],
|
||||
block=[THREADS_PER_BLOCK, 1, 1]
|
||||
)
|
||||
|
||||
{{def_kernel("input_a", "input_b", "output_c")}}
|
||||
cute_a = from_dlpack(input_a)
|
||||
cute_b = from_dlpack(input_b)
|
||||
cute_c = from_dlpack(output_c)
|
||||
|
||||
{{kernel_name}}_jit(cute_a, cute_b, cute_c)
|
||||
return output_c
|
||||
"""
|
||||
|
||||
|
||||
@unittest.skipUnless(HAS_CUTLASS, "requires cutlass")
|
||||
class TestCuteDSLTemplate(TestCase):
|
||||
"""Test cases for CuteDSL template functionality."""
|
||||
|
||||
def test_gen_imports(self):
|
||||
kernel = CuteDSLTemplateKernel(
|
||||
kernel_name="test_kernel",
|
||||
input_nodes=[],
|
||||
output_node=None,
|
||||
)
|
||||
|
||||
imports = kernel.gen_imports()
|
||||
|
||||
self.assertIn("import torch", imports)
|
||||
self.assertIn("import cutlass", imports)
|
||||
self.assertIn("import cutlass.cute as cute", imports)
|
||||
self.assertIn("from cutlass.cute.runtime import from_dlpack", imports)
|
||||
self.assertIsInstance(imports, str)
|
||||
|
||||
lines = imports.strip().split("\n")
|
||||
self.assertEqual(len(lines), 4)
|
||||
|
||||
def test_render_includes_imports(self):
|
||||
template_source = """@cute.kernel
|
||||
def {{kernel_name}}_kernel():
|
||||
pass
|
||||
|
||||
{{def_kernel("input", "output")}}
|
||||
return output"""
|
||||
|
||||
mock_template = MagicMock()
|
||||
mock_template.render = MagicMock(return_value=template_source)
|
||||
|
||||
kernel = CuteDSLTemplateKernel(
|
||||
kernel_name="test_kernel",
|
||||
input_nodes=[],
|
||||
output_node=None,
|
||||
)
|
||||
|
||||
result = kernel.render(mock_template)
|
||||
self.assertIsInstance(result, PartialRender)
|
||||
|
||||
rendered_code = result._code
|
||||
|
||||
# The imports might have leading whitespace, so strip it
|
||||
rendered_code_stripped = rendered_code.lstrip()
|
||||
|
||||
self.assertTrue(
|
||||
rendered_code_stripped.startswith("import torch"),
|
||||
f"Code should start with 'import torch', got: {rendered_code_stripped[:50]}",
|
||||
)
|
||||
self.assertIn("import cutlass", rendered_code)
|
||||
self.assertIn("import cutlass.cute as cute", rendered_code)
|
||||
self.assertIn("from cutlass.cute.runtime import from_dlpack", rendered_code)
|
||||
self.assertIn("@cute.kernel", rendered_code)
|
||||
|
||||
def test_template_env_contains_hooks(self):
|
||||
kernel = CuteDSLTemplateKernel(
|
||||
kernel_name="test_kernel",
|
||||
input_nodes=[],
|
||||
output_node=None,
|
||||
)
|
||||
|
||||
captured_env = {}
|
||||
|
||||
def mock_render(**kwargs):
|
||||
captured_env.update(kwargs)
|
||||
return "rendered"
|
||||
|
||||
mock_template = MagicMock()
|
||||
mock_template.render = mock_render
|
||||
|
||||
kernel.render(mock_template)
|
||||
|
||||
self.assertIn("def_kernel", captured_env)
|
||||
self.assertIn("kernel_name", captured_env)
|
||||
self.assertTrue(callable(captured_env["def_kernel"]))
|
||||
|
||||
def test_multiple_templates_unique_names(self):
|
||||
# Clean registry first
|
||||
test_name = f"unique_test_{id(self)}"
|
||||
if test_name in CuteDSLTemplate.all_templates:
|
||||
del CuteDSLTemplate.all_templates[test_name]
|
||||
|
||||
_ = CuteDSLTemplate(
|
||||
name=test_name,
|
||||
source="template1",
|
||||
)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
_ = CuteDSLTemplate(
|
||||
name=test_name,
|
||||
source="template2",
|
||||
)
|
||||
|
||||
def test_indented_buffer_usage(self):
|
||||
kernel = CuteDSLTemplateKernel(
|
||||
kernel_name="test_kernel",
|
||||
input_nodes=[],
|
||||
output_node=None,
|
||||
)
|
||||
|
||||
imports = kernel.gen_imports()
|
||||
|
||||
lines = imports.strip().split("\n")
|
||||
for line in lines:
|
||||
if line:
|
||||
self.assertFalse(
|
||||
line.startswith(" "), f"Line should not be indented: '{line}'"
|
||||
)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
def test_cutedsl_add_e2e(self):
|
||||
"""End-to-end test with CuteDSL template including code generation verification."""
|
||||
from torch._inductor.ir import TensorBox
|
||||
from torch._inductor.lowering import lowerings
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
|
||||
template = CuteDSLTemplate(
|
||||
name="test_add_e2e",
|
||||
source=CUTEDSL_ADD_TEMPLATE,
|
||||
)
|
||||
|
||||
def cutedsl_add_lowering(a: TensorBox, b: TensorBox) -> TensorBox:
|
||||
choices = []
|
||||
error = template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=[a, b],
|
||||
layout=a.get_layout(),
|
||||
THREADS_PER_BLOCK=256,
|
||||
)
|
||||
|
||||
if error or not choices:
|
||||
default_lowering = lowerings[torch.ops.aten.add.Tensor]
|
||||
return default_lowering(a, b)
|
||||
|
||||
# Use the single choice directly (no autotuning)
|
||||
return choices[0].output_node()
|
||||
|
||||
with patch.dict(lowerings, {torch.ops.aten.add.Tensor: cutedsl_add_lowering}):
|
||||
# Test function
|
||||
def test_add(x, y):
|
||||
return x + y
|
||||
|
||||
device = "cuda"
|
||||
x = torch.randn(128, 4, device=device, dtype=torch.float32)
|
||||
y = torch.randn(128, 4, device=device, dtype=torch.float32)
|
||||
|
||||
# Compile and get generated code
|
||||
compiled_fn = torch.compile(test_add, backend="inductor")
|
||||
result, (code,) = run_and_get_code(compiled_fn, x, y)
|
||||
|
||||
# Verify CuteDSL code is present
|
||||
self.assertIn(
|
||||
"cute", code.lower(), "CuteDSL code should be in generated code"
|
||||
)
|
||||
# Verify parameter generation worked
|
||||
self.assertIn(
|
||||
"THREADS_PER_BLOCK", code, "Parameter should be in generated code"
|
||||
)
|
||||
|
||||
# Verify correctness
|
||||
expected = x + y
|
||||
self.assertTrue(torch.allclose(result, expected, atol=1e-5))
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
def test_cutedsl_add_e2e_autotune(self):
|
||||
"""E2E test with multiple CuteDSL template variants for autotuning."""
|
||||
from torch._inductor.ir import TensorBox
|
||||
from torch._inductor.lowering import lowerings
|
||||
from torch._inductor.select_algorithm import autotune_select_algorithm
|
||||
|
||||
template = CuteDSLTemplate(
|
||||
name="test_add_autotune",
|
||||
source=CUTEDSL_ADD_TEMPLATE,
|
||||
)
|
||||
|
||||
def cutedsl_add_lowering(a: TensorBox, b: TensorBox) -> TensorBox:
|
||||
choices = []
|
||||
|
||||
# Add multiple variants with different thread counts for autotuning
|
||||
thread_variants = [128, 256, 512]
|
||||
for threads in thread_variants:
|
||||
error = template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=[a, b],
|
||||
layout=a.get_layout(),
|
||||
THREADS_PER_BLOCK=threads,
|
||||
)
|
||||
if error:
|
||||
# Skip this variant if it fails
|
||||
continue
|
||||
|
||||
if not choices:
|
||||
default_lowering = lowerings[torch.ops.aten.add.Tensor]
|
||||
return default_lowering(a, b)
|
||||
|
||||
# Use autotuning to select the best variant
|
||||
return autotune_select_algorithm(
|
||||
"cutedsl_add_autotune",
|
||||
choices,
|
||||
[a, b],
|
||||
a.get_layout(),
|
||||
)
|
||||
|
||||
with patch.dict(lowerings, {torch.ops.aten.add.Tensor: cutedsl_add_lowering}):
|
||||
# Test function
|
||||
def test_add(x, y):
|
||||
return x + y
|
||||
|
||||
device = "cuda"
|
||||
x = torch.randn(128, 128, device=device, dtype=torch.float32)
|
||||
y = torch.randn(128, 128, device=device, dtype=torch.float32)
|
||||
|
||||
# Compile and run
|
||||
compiled_fn = torch.compile(test_add, backend="inductor")
|
||||
result = compiled_fn(x, y)
|
||||
|
||||
# Verify correctness
|
||||
expected = x + y
|
||||
self.assertTrue(torch.allclose(result, expected, atol=1e-5))
|
||||
|
||||
def test_gen_defines(self):
|
||||
"""Test that gen_defines correctly generates CuteDSL parameter definitions."""
|
||||
kernel = CuteDSLTemplateKernel(
|
||||
kernel_name="test_kernel",
|
||||
input_nodes=[],
|
||||
output_node=None,
|
||||
)
|
||||
|
||||
# Test integer parameters
|
||||
params = kernel.gen_defines(
|
||||
THREADS_PER_BLOCK=256,
|
||||
BLOCK_SIZE=128,
|
||||
ENABLE_FEATURE=True,
|
||||
)
|
||||
|
||||
expected_lines = [
|
||||
"THREADS_PER_BLOCK: cutlass.Constexpr = 256",
|
||||
"BLOCK_SIZE: cutlass.Constexpr = 128",
|
||||
"ENABLE_FEATURE: cutlass.Constexpr = True",
|
||||
]
|
||||
|
||||
for expected_line in expected_lines:
|
||||
self.assertIn(expected_line, params)
|
||||
|
||||
# Test float parameters
|
||||
params_float = kernel.gen_defines(SCALE_FACTOR=1.5)
|
||||
self.assertIn("SCALE_FACTOR: cutlass.Constexpr = 1.5", params_float)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
@ -569,6 +569,45 @@ class AsyncCompile:
|
||||
)
|
||||
return LambdaFuture(get_result)
|
||||
|
||||
def cutedsl(self, kernel_name: str, source_code: str):
|
||||
"""
|
||||
Compile CuteDSL (CUTLASS Python DSL) kernels.
|
||||
|
||||
Args:
|
||||
kernel_name: Name of the kernel to be defined
|
||||
source_code: Source code of the CuteDSL kernel, as a string
|
||||
|
||||
Note:
|
||||
CuteDSL currently requires source files to do its compilation, there we
|
||||
use the PyCodeCache to write the source code to a file and load it.
|
||||
"""
|
||||
from torch._inductor.codegen.cutedsl.cutedsl_kernel import (
|
||||
CuteDSLKernelWrapper,
|
||||
MAIN_SUFFIX,
|
||||
)
|
||||
|
||||
kernel_code_log.info("CuteDSL Kernel:\n%s", source_code)
|
||||
|
||||
def task():
|
||||
key, path = torch._inductor.codecache.PyCodeCache.write(source_code)
|
||||
mod = torch._inductor.codecache.PyCodeCache.load_by_key_path(key, path)
|
||||
|
||||
# Find our special entry point named function
|
||||
main_func_name = f"{kernel_name}_{MAIN_SUFFIX}"
|
||||
if not hasattr(mod, main_func_name):
|
||||
available = [name for name in dir(mod) if callable(getattr(mod, name))]
|
||||
raise RuntimeError(
|
||||
f"Could not find CuteDSL main kernel function '{main_func_name}'. Available callables: {available}"
|
||||
)
|
||||
|
||||
return CuteDSLKernelWrapper(getattr(mod, main_func_name), kernel_path=path)
|
||||
|
||||
if get_compile_threads() <= 1:
|
||||
return task()
|
||||
else:
|
||||
future = self.submit(task)
|
||||
return LambdaFuture(lambda: future.result())
|
||||
|
||||
def wait(self, scope: dict[str, Any]) -> None:
|
||||
if get_compile_threads() > 1:
|
||||
with dynamo_timed(
|
||||
|
||||
@ -11,6 +11,7 @@ from ..scheduler import (
|
||||
SchedulerNode,
|
||||
)
|
||||
from .cuda.cuda_cpp_scheduling import CUDACPPScheduling
|
||||
from .cutedsl.cutedsl_scheduling import CuteDSLScheduling
|
||||
from .rocm.rocm_cpp_scheduling import ROCmCPPScheduling
|
||||
from .triton import TritonScheduling
|
||||
|
||||
@ -44,6 +45,7 @@ class CUDACombinedScheduling(BaseScheduling):
|
||||
self._triton_scheduling = TritonScheduling(scheduler)
|
||||
self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler)
|
||||
self._rocm_cpp_scheduling = ROCmCPPScheduling(scheduler)
|
||||
self._cutedsl_scheduling = CuteDSLScheduling(scheduler)
|
||||
|
||||
def get_backend_features(self, device: torch.device) -> OrderedSet[BackendFeature]:
|
||||
return self._triton_scheduling.get_backend_features(device)
|
||||
@ -53,6 +55,8 @@ class CUDACombinedScheduling(BaseScheduling):
|
||||
return self._cuda_cpp_scheduling
|
||||
if self._rocm_cpp_scheduling.is_rocm_cpp_template(node):
|
||||
return self._rocm_cpp_scheduling
|
||||
if self._cutedsl_scheduling.is_cutedsl_template(node):
|
||||
return self._cutedsl_scheduling
|
||||
return self._triton_scheduling
|
||||
|
||||
def can_fuse_vertical(
|
||||
@ -64,6 +68,11 @@ class CUDACombinedScheduling(BaseScheduling):
|
||||
node1
|
||||
) or self._cuda_cpp_scheduling.is_cuda_cpp_template(node2):
|
||||
return False
|
||||
# CuteDSL doesn't support vertical fusion currently
|
||||
elif self._cutedsl_scheduling.is_cutedsl_template(
|
||||
node1
|
||||
) or self._cutedsl_scheduling.is_cutedsl_template(node2):
|
||||
return False
|
||||
return self._triton_scheduling.can_fuse_vertical(node1, node2)
|
||||
|
||||
def can_fuse_horizontal(
|
||||
@ -74,6 +83,10 @@ class CUDACombinedScheduling(BaseScheduling):
|
||||
return self._cuda_cpp_scheduling.can_fuse_horizontal(
|
||||
node1, node2
|
||||
) # always False at the moment
|
||||
if self._cutedsl_scheduling.is_cutedsl_template(node):
|
||||
return self._cutedsl_scheduling.can_fuse_horizontal(
|
||||
node1, node2
|
||||
) # always False at the moment
|
||||
return self._triton_scheduling.can_fuse_horizontal(node1, node2)
|
||||
|
||||
def group_fn(
|
||||
@ -98,6 +111,13 @@ class CUDACombinedScheduling(BaseScheduling):
|
||||
return self._rocm_cpp_scheduling.codegen_template(
|
||||
template_node, epilogue_nodes, prologue_nodes
|
||||
)
|
||||
elif self._cutedsl_scheduling.is_cutedsl_template(template_node):
|
||||
# TODO remove this when we add epilogue support
|
||||
assert not epilogue_nodes
|
||||
assert not prologue_nodes
|
||||
return self._cutedsl_scheduling.codegen_template(
|
||||
template_node, epilogue_nodes, prologue_nodes
|
||||
)
|
||||
else:
|
||||
return self._triton_scheduling.codegen_template(
|
||||
template_node, epilogue_nodes, prologue_nodes
|
||||
|
||||
101
torch/_inductor/codegen/cutedsl/README.md
Normal file
101
torch/_inductor/codegen/cutedsl/README.md
Normal file
@ -0,0 +1,101 @@
|
||||
# CuteDSL Template System
|
||||
|
||||
## Quick Start
|
||||
|
||||
Writing a CuteDSL template:
|
||||
|
||||
```python
|
||||
from torch._inductor.codegen.cutedsl import CuteDSLTemplate
|
||||
|
||||
template_source = """
|
||||
@cute.kernel
|
||||
def {{kernel_name}}_kernel(A, B, C):
|
||||
# Your CUTLASS kernel logic here
|
||||
pass
|
||||
|
||||
{{def_kernel("A", "B", "C")}}
|
||||
# Call the kernel
|
||||
{{kernel_name}}_kernel(A, B, C)
|
||||
return C
|
||||
"""
|
||||
|
||||
my_template = CuteDSLTemplate(
|
||||
name="my_gemm",
|
||||
source=template_source,
|
||||
)
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
- **[CuteDSLTemplate](cutedsl_template.py#L39)**: Template definition and registration. Generates ChoiceCallers for autotuning.
|
||||
- **[CuteDSLTemplateKernel](cutedsl_kernel.py#L61)**: Handles code generation, provides template hooks (`def_kernel`), manages args.
|
||||
- **[CuteDSLScheduling](cutedsl_scheduling.py#L28)**: Integrates with Inductor's scheduler, handles kernel compilation via [`async_compile.cutedsl()`](../../async_compile.py#L756).
|
||||
- **[CuteDSLTemplateBuffer](../../ir.py)**: IR node representing a CuteDSL template operation in the graph.
|
||||
|
||||
### Compilation Process
|
||||
|
||||
CuteDSL requires source files for compilation (cannot compile from strings directly). The process:
|
||||
|
||||
1. **[CuteDSLScheduling](cutedsl_scheduling.py#L59)** generates the kernel code string and calls [`async_compile.cutedsl()`](../../async_compile.py#L756)
|
||||
2. **[async_compile.cutedsl()](../../async_compile.py#L756)** uses [`PyCodeCache.write()`](../../codecache.py) to write source to a temporary `.py` file
|
||||
3. **[PyCodeCache](../../codecache.py)** loads the module from disk, enabling CUTLASS compilation
|
||||
4. The compiled kernel is wrapped in **[CuteDSLKernelWrapper](cutedsl_kernel.py#L22)** to provide a `.run()` interface
|
||||
5. The generated Python file is cached via PyCodeCache, but CUTLASS compilation runs every time (no kernel-level caching yet)
|
||||
|
||||
**Debug tip**: Use `TORCH_LOGS="kernel_code"` to see the generated kernel source and file path during compilation.
|
||||
|
||||
## Writing Templates
|
||||
|
||||
Templates use Jinja2 syntax with these available hooks:
|
||||
|
||||
- `{{kernel_name}}` - Unique kernel identifier
|
||||
- `{{def_kernel(args...)}}` - Generates kernel function signature and argument handling
|
||||
- `{{input_nodes}}` - List of input buffers
|
||||
- `{{output_node}}` - Output buffer
|
||||
- `{{gen_defines()}}` - Generates autotunable parameter definitions with proper CuteDSL typing
|
||||
|
||||
## Autotunable Parameters
|
||||
|
||||
CuteDSL templates support autotunable parameters similar to Triton's `tl.constexpr` system:
|
||||
|
||||
```python
|
||||
template_source = r"""
|
||||
{{gen_defines()}}
|
||||
|
||||
@cute.kernel
|
||||
def {{kernel_name}}_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor):
|
||||
threads_per_block = THREADS_PER_BLOCK # Uses autotuned value
|
||||
block_size = BLOCK_SIZE
|
||||
# ... kernel implementation
|
||||
"""
|
||||
|
||||
# Pass parameters when generating template choices
|
||||
template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=[a, b],
|
||||
layout=layout,
|
||||
THREADS_PER_BLOCK=256, # cutlass.Constexpr = 256
|
||||
BLOCK_SIZE=128, # cutlass.Constexpr = 128
|
||||
SCALE_FACTOR=1.5, # cutlass.Constexpr = 1.5
|
||||
)
|
||||
```
|
||||
|
||||
Templates must:
|
||||
1. Define a `@cute.kernel` decorated function
|
||||
2. Use `{{def_kernel()}}` to create the entry point
|
||||
3. Return the output tensor
|
||||
4. Use `{{gen_defines()}}` for autotunable parameters
|
||||
|
||||
See [test_cutedsl_template.py](../../../../test/inductor/test_cutedsl_template.py) for complete examples.
|
||||
|
||||
## Current Limitations / TODOs
|
||||
|
||||
- **No fusion support**: `can_fuse_vertical` and `can_fuse_horizontal` return False
|
||||
- **Subgraph management**: Bodies and masks not fully implemented
|
||||
- **File-based compilation**: Requires writing to disk (uses PyCodeCache)
|
||||
- **Missing epilogue/prologue**: No support for fused operations yet
|
||||
- **Fixed kernel suffix**: Uses hardcoded "_main" suffix
|
||||
- **No CUTLASS kernel caching**: Only PyCodeCache works; CUTLASS compilation runs every time (major perf issue)
|
||||
|
||||
|
||||
Note: Requires CUTLASS Python package (`pip install nvidia-cutlass`)
|
||||
8
torch/_inductor/codegen/cutedsl/__init__.py
Normal file
8
torch/_inductor/codegen/cutedsl/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from .cutedsl_template import CuteDSLTemplate, CuteDSLTemplateCaller
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CuteDSLTemplate",
|
||||
"CuteDSLTemplateCaller",
|
||||
]
|
||||
228
torch/_inductor/codegen/cutedsl/cutedsl_kernel.py
Normal file
228
torch/_inductor/codegen/cutedsl/cutedsl_kernel.py
Normal file
@ -0,0 +1,228 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch._inductor.codegen.common import IndentedBuffer, Kernel
|
||||
from torch._inductor.ir import Buffer
|
||||
from torch._inductor.select_algorithm import PartialRender
|
||||
from torch._inductor.utils import OrderedSet
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
|
||||
# TODO setting the 'main' kernel w/ this suffix. We have 3 should probably just auto generate this
|
||||
MAIN_SUFFIX = "main"
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code")
|
||||
|
||||
|
||||
class CuteDSLKernelWrapper:
|
||||
"""Wrapper to provide .run() interface for CuteDSL kernels"""
|
||||
|
||||
def __init__(
|
||||
self, kernel_fn: Callable[..., Any], kernel_path: Optional[str] = None
|
||||
):
|
||||
self.kernel_fn = kernel_fn
|
||||
self.kernel_path = kernel_path
|
||||
kernel_code_log.info("CuteDSL kernel path: %s", kernel_path)
|
||||
|
||||
def run(self, *args, stream=None, **kwargs):
|
||||
"""
|
||||
Execute the CuteDSL kernel.
|
||||
|
||||
Args:
|
||||
*args: Arguments to pass to the kernel function
|
||||
stream: TODO: CUDA stream (handled internally by CuteDSL, so ignored)
|
||||
**kwargs: Additional keyword arguments for the kernel
|
||||
|
||||
Returns:
|
||||
Result of the kernel execution
|
||||
"""
|
||||
return self.kernel_fn(*args, **kwargs)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CuteDSLSubgraphInfo:
|
||||
"""Minimal subgraph info for CuteDSL kernels."""
|
||||
|
||||
body: IndentedBuffer
|
||||
template_mask: Optional[str] = None
|
||||
template_out: Optional[str] = None
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
field.name: getattr(self, field.name) for field in dataclasses.fields(self)
|
||||
}
|
||||
|
||||
|
||||
class CuteDSLTemplateKernel(Kernel):
|
||||
"""
|
||||
Template kernel implementation for CuteDSL (CUTLASS Python DSL).
|
||||
Handles code generation and argument management for CuteDSL CUDA kernels.
|
||||
Provides CuteDSL-specific functionality for tensor conversion and kernel configuration.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_name: str,
|
||||
input_nodes: list[Buffer],
|
||||
output_node: Buffer,
|
||||
) -> None:
|
||||
# Call parent Kernel constructor
|
||||
super().__init__()
|
||||
self.kernel_name = kernel_name
|
||||
self.input_nodes = input_nodes
|
||||
self.output_node = output_node
|
||||
|
||||
# TODO Subgraph management for template processing
|
||||
self.subgraph_bodies: dict[str, CuteDSLSubgraphInfo] = {}
|
||||
|
||||
# Template attributes
|
||||
self.body: IndentedBuffer = IndentedBuffer()
|
||||
self.template_mask: Optional[str] = None
|
||||
self.template_out: Optional[str] = None
|
||||
self.template_indices: Optional[list[Any]] = None
|
||||
self.render_hooks: dict[str, Any] = {}
|
||||
|
||||
# TODO Additional attributes needed by template system
|
||||
self.prologue_fused_inputs: OrderedSet[str] = OrderedSet()
|
||||
self.prologue_fused_inputs_preserve_zero: OrderedSet[str] = OrderedSet()
|
||||
self.named_input_nodes: dict[str, Buffer] = {}
|
||||
|
||||
# Create named input nodes mapping
|
||||
for i, input_node in enumerate(input_nodes):
|
||||
node_name = getattr(input_node, "name", f"input_{i}")
|
||||
self.named_input_nodes[node_name] = input_node
|
||||
|
||||
def gen_imports(self) -> str:
|
||||
"""Generate common imports for CuteDSL templates."""
|
||||
imports = IndentedBuffer()
|
||||
imports.splice(
|
||||
"""
|
||||
import torch
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
"""
|
||||
)
|
||||
return imports.getvalue()
|
||||
|
||||
def gen_defines(self, **kwargs) -> str:
|
||||
"""Generate CuteDSL parameter definitions from kwargs, similar to Triton's gen_defines."""
|
||||
params = IndentedBuffer()
|
||||
for name, val in kwargs.items():
|
||||
params.writeline(f"{name}: cutlass.Constexpr = {val}")
|
||||
return params.getvalue()
|
||||
|
||||
def render(self, template, **kwargs):
|
||||
"""Render the kernel using the template, returning PartialRender object with hooks."""
|
||||
# Available {{}} hooks for jinja rendering
|
||||
template_env = {
|
||||
"def_kernel": self.def_kernel,
|
||||
"gen_defines": lambda: self.gen_defines(**kwargs),
|
||||
}
|
||||
|
||||
# Render the template with the environment and provided kwargs
|
||||
rendered_code = template.render(
|
||||
kernel_name=self.kernel_name,
|
||||
input_nodes=self.input_nodes,
|
||||
output_node=self.output_node,
|
||||
**template_env,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Always prepend the common imports
|
||||
imports = self.gen_imports()
|
||||
full_code = imports + rendered_code
|
||||
|
||||
return PartialRender(full_code, self.render_hooks)
|
||||
|
||||
def __enter__(self):
|
||||
"""TODO: Context manager entry - doesn't set anything yet"""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""TODO: Context manager exit - doesn't set anything yet"""
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_subgraph_body(self, body_name: str):
|
||||
"""Set the active subgraph body for template processing."""
|
||||
assert all(
|
||||
hasattr(self, field.name)
|
||||
for field in dataclasses.fields(CuteDSLSubgraphInfo)
|
||||
)
|
||||
old_state = {
|
||||
key.name: getattr(self, key.name)
|
||||
for key in dataclasses.fields(CuteDSLSubgraphInfo)
|
||||
}
|
||||
|
||||
# Auto-create subgraph if it doesn't exist (for kernels without epilogue fusion)
|
||||
if body_name not in self.subgraph_bodies:
|
||||
self.subgraph_bodies[body_name] = CuteDSLSubgraphInfo(
|
||||
body=IndentedBuffer(),
|
||||
template_mask=None,
|
||||
template_out=None,
|
||||
)
|
||||
|
||||
subgraph = self.subgraph_bodies[body_name]
|
||||
for key, value in subgraph.to_dict().items():
|
||||
setattr(self, key, value)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Save current state back to subgraph
|
||||
self.subgraph_bodies[body_name] = CuteDSLSubgraphInfo(
|
||||
**{
|
||||
key.name: getattr(self, key.name)
|
||||
for key in dataclasses.fields(CuteDSLSubgraphInfo)
|
||||
}
|
||||
)
|
||||
# Restore old state
|
||||
for key, value in old_state.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def create_subgraph_body(self, body_name: str):
|
||||
"""Create a new subgraph body for template processing."""
|
||||
assert body_name not in self.subgraph_bodies, (
|
||||
f"Subgraph body '{body_name}' already exists"
|
||||
)
|
||||
self.subgraph_bodies[body_name] = CuteDSLSubgraphInfo(
|
||||
body=IndentedBuffer(),
|
||||
template_mask=None,
|
||||
template_out=None,
|
||||
)
|
||||
with self.set_subgraph_body(body_name):
|
||||
yield
|
||||
|
||||
def def_kernel(self, *argnames):
|
||||
"""Define kernel function signature for CuteDSL templates."""
|
||||
# Populate all the kernel args
|
||||
for i, input_node in enumerate(self.input_nodes):
|
||||
self.args.input(input_node.get_name())
|
||||
|
||||
if self.output_node:
|
||||
self.args.output(self.output_node.get_name())
|
||||
|
||||
def hook():
|
||||
code = IndentedBuffer()
|
||||
code.writeline(f"# Kernel function signature: {self.kernel_name}")
|
||||
code.writeline(
|
||||
f"def {self.kernel_name}_{MAIN_SUFFIX}({', '.join(argnames)}):"
|
||||
)
|
||||
return code.getvalue()
|
||||
|
||||
assert "<DEF_KERNEL>" not in self.render_hooks
|
||||
self.render_hooks["<DEF_KERNEL>"] = hook
|
||||
return "<DEF_KERNEL>"
|
||||
|
||||
def call_kernel(self, name: str, node=None):
|
||||
"""Call the kernel function. Simplified version of TritonTemplateKernel.call_kernel."""
|
||||
wrapper = V.graph.wrapper_code
|
||||
_, call_args, _, arg_types = self.args.python_argdefs()
|
||||
# TODO triton should really be swapped w/ `python`
|
||||
wrapper.generate_kernel_call(name, call_args, triton=True, arg_types=arg_types)
|
||||
141
torch/_inductor/codegen/cutedsl/cutedsl_scheduling.py
Normal file
141
torch/_inductor/codegen/cutedsl/cutedsl_scheduling.py
Normal file
@ -0,0 +1,141 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import hashlib
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from torch._inductor.utils import Placeholder
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from ... import config
|
||||
from ...codecache import code_hash, get_path
|
||||
from ...ir import CuteDSLTemplateBuffer
|
||||
from ...scheduler import (
|
||||
BaseSchedulerNode,
|
||||
BaseScheduling,
|
||||
FusedSchedulerNode,
|
||||
SchedulerNode,
|
||||
)
|
||||
from ...select_algorithm import PartialRender
|
||||
from ...utils import get_fused_kernel_name, get_kernel_metadata
|
||||
from ...virtualized import V
|
||||
from ..common import BackendFeature, IndentedBuffer
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CuteDSLScheduling(BaseScheduling):
|
||||
"""
|
||||
Scheduling implementation for CuteDSL (CUTLASS Python DSL) kernels.
|
||||
This class is intended to be used in combination with other schedulers,
|
||||
and delegated to by CUDACombinedScheduling.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_backend_features(cls, device) -> OrderedSet[BackendFeature]:
|
||||
return OrderedSet()
|
||||
|
||||
@staticmethod
|
||||
def is_cutedsl_template(node: BaseSchedulerNode) -> bool:
|
||||
"""Check if a node is a CuteDSL template."""
|
||||
return isinstance(node, SchedulerNode) and isinstance(
|
||||
node.node, CuteDSLTemplateBuffer
|
||||
)
|
||||
|
||||
def is_cutedsl_fused_template(self, node: BaseSchedulerNode) -> bool:
|
||||
"""Check if a node is a fused CuteDSL template."""
|
||||
return isinstance(node, FusedSchedulerNode) and self.is_cutedsl_template(node)
|
||||
|
||||
def can_fuse_vertical(
|
||||
self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
|
||||
) -> bool:
|
||||
"""
|
||||
TODO CuteDSL doesn't support vertical fusion yet.
|
||||
This could be extended in the future for epilogue fusion.
|
||||
"""
|
||||
return False
|
||||
|
||||
def define_kernel(self, src_code_str: str, node_schedule) -> str:
|
||||
"""Produce the kernel string
|
||||
Args:
|
||||
src_code_str: The finalized kernel code string
|
||||
node_schedule: List of nodes in the schedule
|
||||
|
||||
Note:
|
||||
This is a little weird since async_compile.cutedsl() has to write the string to
|
||||
a file in order to cute compile it. Feels bad to have two...
|
||||
"""
|
||||
wrapper = V.graph.wrapper_code
|
||||
|
||||
# Use the string as the key for caching
|
||||
if src_code_str in wrapper.src_to_kernel:
|
||||
kernel_name = wrapper.src_to_kernel[src_code_str]
|
||||
else:
|
||||
fused_name = (
|
||||
get_fused_kernel_name(node_schedule, config.triton.descriptive_names)
|
||||
if config.triton.descriptive_names
|
||||
else ""
|
||||
)
|
||||
|
||||
kernel_hash = hashlib.sha256(src_code_str.encode("utf-8")).hexdigest()[:8]
|
||||
if fused_name == "fused":
|
||||
kernel_name = f"cutedsl_{kernel_hash}"
|
||||
else:
|
||||
kernel_name = f"cutedsl_{fused_name}_{kernel_hash}"
|
||||
wrapper.src_to_kernel[src_code_str] = kernel_name
|
||||
src_code_str = src_code_str.replace(
|
||||
str(Placeholder.KERNEL_NAME), kernel_name
|
||||
)
|
||||
|
||||
_, _, kernel_path = get_path(code_hash(src_code_str), "py")
|
||||
|
||||
compile_wrapper = IndentedBuffer()
|
||||
compile_wrapper.writeline(f"async_compile.cutedsl({kernel_name!r}, r'''")
|
||||
compile_wrapper.splice(src_code_str, strip=True)
|
||||
compile_wrapper.writeline("''')")
|
||||
|
||||
metadata_comment = f"# kernel path: {kernel_path}"
|
||||
origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
|
||||
metadata_comment += "\n" + origins + "\n" + detailed_origins
|
||||
wrapper.define_kernel(
|
||||
kernel_name, compile_wrapper.getvalue(), metadata_comment
|
||||
)
|
||||
return kernel_name
|
||||
|
||||
def codegen_template(
|
||||
self,
|
||||
template_node: BaseSchedulerNode,
|
||||
epilogue_nodes: Sequence[BaseSchedulerNode],
|
||||
prologue_nodes: Sequence[BaseSchedulerNode],
|
||||
):
|
||||
"""
|
||||
Codegen a CuteDSL template. Currently doesn't support fusion.
|
||||
"""
|
||||
assert self.is_cutedsl_template(template_node), (
|
||||
"Template node passed to CuteDSLScheduling.codegen_template must be a "
|
||||
"SchedulerNode that wraps a CuteDSLTemplateBuffer"
|
||||
)
|
||||
# TODO remove when supported
|
||||
assert not epilogue_nodes, "CuteDSL doesn't support epilogue fusion yet"
|
||||
assert not prologue_nodes, "CuteDSL doesn't support prologue fusion yet"
|
||||
|
||||
template_node = cast(SchedulerNode, template_node)
|
||||
ctb: CuteDSLTemplateBuffer = cast(CuteDSLTemplateBuffer, template_node.node)
|
||||
|
||||
kernel, render = ctb.make_kernel_render(ctb) # type: ignore[misc]
|
||||
with kernel:
|
||||
template_node.mark_run()
|
||||
src_code = render()
|
||||
# Finalize PartialRender if needed
|
||||
if isinstance(src_code, PartialRender):
|
||||
src_code_str = src_code.finalize_all()
|
||||
else:
|
||||
src_code_str = src_code
|
||||
|
||||
with V.set_kernel_handler(kernel):
|
||||
node_schedule = [template_node]
|
||||
kernel_name = self.define_kernel(src_code_str, node_schedule)
|
||||
kernel.call_kernel(kernel_name, ctb)
|
||||
V.graph.removed_buffers |= kernel.removed_buffers
|
||||
self.free_buffers_in_scheduler()
|
||||
228
torch/_inductor/codegen/cutedsl/cutedsl_template.py
Normal file
228
torch/_inductor/codegen/cutedsl/cutedsl_template.py
Normal file
@ -0,0 +1,228 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
import itertools
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._inductor.codecache import PyCodeCache
|
||||
from torch._inductor.ir import ShapeAsConstantBuffer
|
||||
from torch._inductor.select_algorithm import PartialRender
|
||||
from torch._inductor.utils import Placeholder
|
||||
from torch._logging import getArtifactLogger
|
||||
|
||||
from ...autotune_process import BenchmarkRequest, GPUDeviceBenchmarkMixin, TensorMeta
|
||||
from ...ir import Buffer, ChoiceCaller, CuteDSLTemplateBuffer, Layout, TensorBox
|
||||
from ..common import KernelTemplate
|
||||
from .cutedsl_kernel import CuteDSLTemplateKernel
|
||||
|
||||
|
||||
log = getArtifactLogger(__name__, "output_code")
|
||||
|
||||
|
||||
class CuteDSLBenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
|
||||
"""Benchmark request for CuteDSL (CUTLASS Python DSL) kernels."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_name: str,
|
||||
input_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
||||
output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
||||
extra_args: tuple[Any, ...],
|
||||
source_code: PartialRender,
|
||||
) -> None:
|
||||
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
|
||||
|
||||
finalized_code = source_code.finalize_all()
|
||||
self.module_cache_key, self.module_path = PyCodeCache.write(finalized_code)
|
||||
|
||||
def make_run_fn(
|
||||
self, *input_tensors: torch.Tensor, out: torch.Tensor
|
||||
) -> Callable[[], None]:
|
||||
"""
|
||||
Create a function to run the CuteDSL kernel with the given input and output tensors.
|
||||
Similar to TritonBenchmarkRequest.make_run_fn but for CuteDSL kernels.
|
||||
"""
|
||||
mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path)
|
||||
|
||||
# Logic replicated async_compile
|
||||
from .cutedsl_kernel import MAIN_SUFFIX
|
||||
|
||||
main_func_name = f"{self.kernel_name}_{MAIN_SUFFIX}"
|
||||
|
||||
if not hasattr(mod, main_func_name):
|
||||
available = [name for name in dir(mod) if callable(getattr(mod, name))]
|
||||
raise RuntimeError(
|
||||
f"Could not find CuteDSL main kernel function '{main_func_name}'. Available callables: {available}"
|
||||
)
|
||||
|
||||
kernel_func = getattr(mod, main_func_name)
|
||||
|
||||
def run_kernel():
|
||||
return kernel_func(*input_tensors, out)
|
||||
|
||||
return run_kernel
|
||||
|
||||
def cleanup_run_fn(self) -> None:
|
||||
"""Clean up any resources used by the kernel."""
|
||||
|
||||
|
||||
class CuteDSLTemplate(KernelTemplate):
|
||||
"""Template for generating CuteDSL (CUTLASS Python DSL) kernels."""
|
||||
|
||||
kernel_type: type[Any] = CuteDSLTemplateKernel
|
||||
index_counter = itertools.count()
|
||||
all_templates: dict[str, "CuteDSLTemplate"] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
source: str,
|
||||
subgraph_fn: Optional[Any] = None,
|
||||
mask_fn: Optional[Any] = None,
|
||||
) -> None:
|
||||
super().__init__(name)
|
||||
self.source = source
|
||||
self.subgraph_fn = subgraph_fn
|
||||
self.mask_fn = mask_fn
|
||||
self.template = CuteDSLTemplate._template_from_string(source)
|
||||
assert name not in self.all_templates, f"duplicate template name, {name}"
|
||||
CuteDSLTemplate.all_templates[name] = self
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def _template_from_string(source: str) -> Any:
|
||||
return KernelTemplate._template_from_string(source)
|
||||
|
||||
def maybe_append_choice(
|
||||
self, choices: list[Any], **kwargs: Any
|
||||
) -> Optional[NotImplementedError]:
|
||||
"""
|
||||
Maybe generates a new ChoiceCaller and appends it into existing choices.
|
||||
Returns None if success, otherwise returns the error.
|
||||
"""
|
||||
try:
|
||||
choices.append(self.generate(**kwargs))
|
||||
return None
|
||||
except NotImplementedError as e:
|
||||
log.debug("CuteDSL template choice generation failed: %s", e)
|
||||
return e
|
||||
except Exception as e:
|
||||
log.debug("CuteDSL template choice generation error: %s", e)
|
||||
return NotImplementedError(f"CuteDSL template failed: {e}")
|
||||
|
||||
def generate(self, **kwargs: Any) -> ChoiceCaller:
|
||||
"""Generate the CuteDSL kernel caller."""
|
||||
input_nodes = kwargs.pop("input_nodes")
|
||||
layout = kwargs.pop("layout")
|
||||
|
||||
kernel_name = f"cutedsl_{self.name}_{next(self.index_counter)}"
|
||||
|
||||
if self.template is None:
|
||||
raise RuntimeError("Template compilation failed (Jinja2 required)")
|
||||
|
||||
self.output_node: Buffer = Buffer(name="buf_out", layout=layout)
|
||||
|
||||
kernel = self.kernel_type(
|
||||
kernel_name=kernel_name,
|
||||
input_nodes=input_nodes,
|
||||
output_node=self.output_node,
|
||||
)
|
||||
|
||||
code = kernel.render(self.template, **kwargs)
|
||||
|
||||
log.debug("Generated CuteDSL Code:\n%s", code)
|
||||
|
||||
bmreq = CuteDSLBenchmarkRequest(
|
||||
kernel_name=kernel_name,
|
||||
input_tensor_meta=TensorMeta.from_irnodes(input_nodes),
|
||||
output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
|
||||
extra_args=tuple(),
|
||||
source_code=code,
|
||||
)
|
||||
|
||||
def make_kernel_render(out_node, hint_override: Optional[int] = None):
|
||||
render_kernel = self.kernel_type(
|
||||
kernel_name=str(Placeholder.KERNEL_NAME),
|
||||
input_nodes=input_nodes,
|
||||
output_node=out_node,
|
||||
)
|
||||
|
||||
def render():
|
||||
return render_kernel.render(self.template, **kwargs)
|
||||
|
||||
return render_kernel, render
|
||||
|
||||
return CuteDSLTemplateCaller(
|
||||
name=kernel_name,
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
make_kernel_render=make_kernel_render,
|
||||
bmreq=bmreq,
|
||||
template=self,
|
||||
)
|
||||
|
||||
|
||||
class CuteDSLTemplateCaller(ChoiceCaller):
|
||||
"""Caller for CuteDSL templates that integrates with the autotuning system."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
input_nodes: list[Buffer],
|
||||
layout: Layout,
|
||||
make_kernel_render: Any,
|
||||
bmreq: CuteDSLBenchmarkRequest,
|
||||
template: "CuteDSLTemplate",
|
||||
):
|
||||
super().__init__(
|
||||
name=name,
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
description=f"CuteDSL template {name}",
|
||||
)
|
||||
self.make_kernel_render = make_kernel_render
|
||||
self.bmreq = bmreq
|
||||
self.template = template
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"CuteDSLTemplateCaller({self.name})"
|
||||
|
||||
def benchmark(self, *args, out) -> float:
|
||||
"""Benchmark the kernel execution."""
|
||||
return self.bmreq.benchmark(*args, out=out)
|
||||
|
||||
def output_node(self) -> Union[TensorBox, ShapeAsConstantBuffer]:
|
||||
"""Create the output node for this template choice."""
|
||||
return TensorBox.create(
|
||||
CuteDSLTemplateBuffer(
|
||||
layout=self.layout,
|
||||
inputs=self.input_nodes,
|
||||
make_kernel_render=self.make_kernel_render,
|
||||
template=self.template,
|
||||
)
|
||||
)
|
||||
|
||||
def call_name(self) -> str:
|
||||
"""Return the kernel call name."""
|
||||
return self.name
|
||||
|
||||
def to_callable(self) -> Any:
|
||||
"""Return callable that can execute this kernel."""
|
||||
return self.make_kernel_render
|
||||
|
||||
def hash_key(self) -> str:
|
||||
"""Return unique hash key for this choice."""
|
||||
return "-".join(
|
||||
[
|
||||
self.name.rsplit("_", 1)[0],
|
||||
self.bmreq.module_cache_key,
|
||||
]
|
||||
)
|
||||
|
||||
def info_dict(self) -> dict[str, Any]:
|
||||
"""Return information about this kernel."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"backend": "CuteDSL",
|
||||
"template": self.template.name,
|
||||
}
|
||||
@ -5094,6 +5094,23 @@ class CppTemplateBuffer(TemplateBuffer):
|
||||
return super().get_layout()
|
||||
|
||||
|
||||
class CuteDSLTemplateBuffer(TemplateBuffer):
|
||||
"""
|
||||
Buffer for CuteDSL (CUTLASS Python DSL) template kernels.
|
||||
Similar to other template buffers but specialized for CuteDSL operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layout: Layout,
|
||||
inputs: Sequence[IRNode],
|
||||
make_kernel_render: Callable[_P, _T],
|
||||
template: Any,
|
||||
) -> None:
|
||||
super().__init__(layout, inputs, make_kernel_render)
|
||||
self.template = template
|
||||
|
||||
|
||||
def is_node_sequence(
|
||||
nodes: Sequence[Union[IRNode, Sequence[IRNode]]],
|
||||
) -> TypeIs[Sequence[IRNode]]:
|
||||
|
||||
@ -39,6 +39,10 @@ from .common import (
|
||||
)
|
||||
from .flex_cpu import lower_cpu
|
||||
from .flex_decoding import _use_flex_decoding, create_flex_decoding_kernel
|
||||
from .flex_flash_attention import (
|
||||
_use_flex_flash_attention,
|
||||
create_flex_flash_attention_kernel,
|
||||
)
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -437,6 +441,19 @@ def flex_attention(
|
||||
score_mod_other_buffers,
|
||||
mask_mod_other_buffers,
|
||||
)
|
||||
if _use_flex_flash_attention(subgraph, mask_graph, kernel_options):
|
||||
return create_flex_flash_attention_kernel(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
block_mask,
|
||||
scale,
|
||||
kernel_options,
|
||||
subgraph_buffer,
|
||||
mask_graph_buffer,
|
||||
score_mod_other_buffers,
|
||||
mask_mod_other_buffers,
|
||||
)
|
||||
|
||||
(
|
||||
query,
|
||||
|
||||
126
torch/_inductor/kernel/flex/flex_flash_attention.py
Normal file
126
torch/_inductor/kernel/flex/flex_flash_attention.py
Normal file
@ -0,0 +1,126 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""Call into flash-attention 4 for flexattention"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ...ir import FallbackKernel, ShapeAsConstantBuffer, Subgraph, TensorBox
|
||||
from .common import SubgraphResults
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
|
||||
try:
|
||||
from flash_attn.cute import flash_attn_func # type: ignore[import-not-found]
|
||||
|
||||
CUTE_AVAILABLE = True
|
||||
except ImportError:
|
||||
flash_attn_func = None
|
||||
CUTE_AVAILABLE = False
|
||||
|
||||
|
||||
def is_trivial_graph(graph_module: GraphModule, is_score_graph: bool):
|
||||
"""Check if the flex graphs are trivial"""
|
||||
graph = graph_module.graph
|
||||
nodes = list(graph.nodes)
|
||||
# Check if it's just placeholder -> output
|
||||
placeholders = [n for n in nodes if n.op == "placeholder"]
|
||||
output = [n for n in nodes if n.op == "output"]
|
||||
assert len(output) == 1, "Got graph w/ multiple outputs"
|
||||
output_val = output[0].args[0]
|
||||
if is_score_graph:
|
||||
return len(placeholders) == 5 and output_val == placeholders[0]
|
||||
# mask mod graph is empty if we have 4 inputs and full_default output
|
||||
return len(placeholders) == 4 and output_val.target == torch.ops.aten.full.default
|
||||
|
||||
|
||||
def _use_flex_flash_attention(
|
||||
subgraph: Subgraph, mask_graph: Subgraph, kernel_options: dict[str, Any]
|
||||
) -> bool:
|
||||
"""Determine if we can use flex flash attention for the given inputs."""
|
||||
if not CUTE_AVAILABLE:
|
||||
return False
|
||||
if kernel_options.get("disable_flash", False):
|
||||
return False
|
||||
if is_trivial_graph(subgraph.graph_module, True) and is_trivial_graph(
|
||||
mask_graph.graph_module, False
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@torch.library.custom_op("flex_flash_attn::flash_attn_fwd", mutates_args=())
|
||||
def flash_attention_forward_kernel(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
causal: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Minimal flash attention forward kernel using CUTE implementation."""
|
||||
if not CUTE_AVAILABLE:
|
||||
raise RuntimeError("CUTE flash attention not available")
|
||||
assert flash_attn_func is not None
|
||||
|
||||
q_transposed = query.transpose(1, 2)
|
||||
k_transposed = key.transpose(1, 2)
|
||||
v_transposed = value.transpose(1, 2)
|
||||
|
||||
output, lse = flash_attn_func(
|
||||
q_transposed,
|
||||
k_transposed,
|
||||
v_transposed,
|
||||
softmax_scale=scale,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
return output.transpose(1, 2), lse
|
||||
|
||||
|
||||
@torch.library.register_fake("flex_flash_attn::flash_attn_fwd") # type: ignore[misc]
|
||||
def flex_flash_attn_fwd_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
causal: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Fake implementation for the custom op."""
|
||||
batch_size, num_heads, seqlen_q, head_dim = query.shape
|
||||
|
||||
out = query.new_empty(batch_size, seqlen_q, num_heads, head_dim).transpose(1, 2)
|
||||
lse = query.new_empty(batch_size, num_heads, seqlen_q, dtype=torch.float32)
|
||||
|
||||
return out, lse
|
||||
|
||||
|
||||
def create_flex_flash_attention_kernel(
|
||||
query: TensorBox,
|
||||
key: TensorBox,
|
||||
value: TensorBox,
|
||||
block_mask: tuple[Any, ...],
|
||||
scale: float,
|
||||
kernel_options: dict[str, Any],
|
||||
subgraph_buffer: SubgraphResults,
|
||||
mask_graph_buffer: SubgraphResults,
|
||||
score_mod_other_buffers: list[TensorBox],
|
||||
mask_mod_other_buffers: list[TensorBox],
|
||||
) -> tuple[TensorBox | ShapeAsConstantBuffer, TensorBox | ShapeAsConstantBuffer]:
|
||||
"""Create a flex flash attention kernel."""
|
||||
if not CUTE_AVAILABLE:
|
||||
raise RuntimeError("CUTE flash attention not available")
|
||||
|
||||
outputs = FallbackKernel.create(
|
||||
torch.ops.flex_flash_attn.flash_attn_fwd.default,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
scale=scale,
|
||||
causal=False,
|
||||
)
|
||||
assert isinstance(outputs, (tuple, list))
|
||||
return TensorBox.create(outputs[0]), TensorBox.create(outputs[1])
|
||||
@ -198,6 +198,9 @@ class FlexKernelOptions(TypedDict, total=False):
|
||||
waves_per_eu: NotRequired[int]
|
||||
"""ROCm-specific waves per execution unit."""
|
||||
|
||||
disable_flash: NotRequired[bool]
|
||||
""" If True, we will not attempt to run the cute-dsl flash attention kernel"""
|
||||
|
||||
|
||||
class _ModificationType(Enum):
|
||||
"""Enum for the type of modification function.
|
||||
|
||||
Reference in New Issue
Block a user