[Cutlass] Implement memory planning for EVT (#153177)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153177
Approved by: https://github.com/henrylhtsang
ghstack dependencies: #153196, #150907
This commit is contained in:
Michael Lazos
2025-05-08 15:08:26 -07:00
committed by PyTorch MergeBot
parent a3154ca34a
commit 9fa07340fd
5 changed files with 98 additions and 47 deletions

View File

@ -2416,10 +2416,14 @@ class CSEProxy(DefaultHandler):
"""
from ..bounds import ValueRangeAnalysis
from ..select_algorithm import TritonTemplateKernel
from .cuda.cuda_kernel import CUDATemplateKernel
if isinstance(V.kernel, TritonTemplateKernel):
return ValueRanges.unknown()
if isinstance(V.kernel, CUDATemplateKernel):
return ValueRanges.unknown()
fx_node = V.interpreter.current_node
if fx_node.target == name and self.kernel.node_to_bounds is not None:
assert isinstance(self.kernel.node_to_bounds, dict)

View File

@ -3,6 +3,10 @@ import logging
from collections.abc import Sequence
from typing import cast
from torch._inductor.codegen.cuda.cutlass_python_evt import (
CutlassEVTCodegen,
MockCutlassHandler,
)
from torch.utils._ordered_set import OrderedSet
from ...._dynamo.utils import counters
@ -138,6 +142,18 @@ class CUDACPPScheduling(BaseScheduling):
with kernel:
for node in [template_node, *epilogue_nodes]:
node.mark_run()
# typically there is a codegen pass which runs after mark_run
# for this kernel we've already generated the C++ code, but we still
# need to let the kernel know about loads/stores that occur in the fused
# kernel for memory planning to properly optimize allocations
ctb.emulate_store_fn()
for node in epilogue_ir_nodes:
with V.set_ops_handler(MockCutlassHandler(V.get_ops_handler())):
assert isinstance(
node, ComputedBuffer
) # Not sure why we need to do this again
node.get_store_function()(CutlassEVTCodegen.get_index_vars(node))
src_code = render()
with V.set_kernel_handler(kernel):

View File

@ -539,6 +539,12 @@ class CUDATemplateKernel(CUDAKernel):
f"At least 1 stride should be 1. Strides: {node.get_stride()=}"
)
def store(self, name: str, index: Expr, value: Any, mode: Any = None) -> None:
"""
Mock store function for memory planning to optimize allocations properly.
"""
self.store_buffer_names.add(name)
class CUDATemplateCaller(ChoiceCaller):
"""

View File

@ -9,7 +9,7 @@ import sympy
import torch
import torch._inductor.virtualized as virtualized
from torch._inductor.ir import ComputedBuffer, Pointwise
from torch._inductor.ops_handler import DefaultHandler
from torch._inductor.ops_handler import DefaultHandler, WrapperHandler
from torch._inductor.scheduler import BaseSchedulerNode
from torch._inductor.utils import IndentedBuffer, OrderedSet
from torch._inductor.virtualized import OpsValue
@ -20,6 +20,69 @@ from ...virtualized import V
_ACCUMULATOR_ALIAS = "accum"
class CutlassEVTOpsMixIn:
@staticmethod
def _infix_bin_op(op: str, a: str, b: str) -> str:
return f"{a} {op} {b}"
@staticmethod
def _prefix_bin_op(op: str, a: str, b: str) -> str:
return f"{op}({a}, {b})"
@staticmethod
def _prefix_un_op(op: str, a: str) -> str:
return f"{op}({a})"
@staticmethod
def to_dtype(
x: str,
dtype: Any,
src_dtype: Optional[torch.dtype] = None,
use_compute_types: bool = False,
) -> str:
return x
@staticmethod
def constant(value: Any, dtype: Any) -> str:
raise NotImplementedError
@staticmethod
def mul(x0: str, x1: str) -> str:
return CutlassEVTOpsMixIn._infix_bin_op("*", x0, x1)
@staticmethod
def truediv(x0: str, x1: str) -> str:
return CutlassEVTOpsMixIn._infix_bin_op("/", x0, x1)
@staticmethod
def ge(x0: str, x1: str) -> str:
raise NotImplementedError
@staticmethod
def add(x0: str, x1: str) -> str:
return CutlassEVTOpsMixIn._infix_bin_op("+", x0, x1)
@staticmethod
def relu(x0: str) -> str:
return CutlassEVTOpsMixIn._prefix_un_op("relu", x0)
@staticmethod
def sigmoid(x0: str) -> str:
return CutlassEVTOpsMixIn._prefix_un_op("sigmoid", x0)
@staticmethod
def sub(x0: str, x1: str) -> str:
return CutlassEVTOpsMixIn._infix_bin_op("-", x0, x1)
@staticmethod
def tanh(x0: str) -> str:
return CutlassEVTOpsMixIn._prefix_un_op("tanh", x0)
class MockCutlassHandler(CutlassEVTOpsMixIn, WrapperHandler):
"""Passthrough handler for cutlass ops, used for running epilogue nodes for memory planning"""
class _AssignmentFormatter(DefaultHandler):
def __init__(self, parent_handler: "CutlassEVTCodegen"):
self.parent_handler = parent_handler
@ -39,7 +102,7 @@ class _AssignmentFormatter(DefaultHandler):
raise NotImplementedError(name)
class CutlassEVTCodegen:
class CutlassEVTCodegen(CutlassEVTOpsMixIn):
"""
Notes:
* Used by CUTLASSGemmTemplate.
@ -90,7 +153,7 @@ class CutlassEVTCodegen:
node = s_node.node
assert isinstance(node, ComputedBuffer)
with codegen.set_cur_node(node):
index_vars = CutlassEVTCodegen._get_index_vars(node)
index_vars = CutlassEVTCodegen.get_index_vars(node)
node.get_store_function()(index_vars)
return (
@ -157,52 +220,19 @@ class CutlassEVTCodegen:
self.store_name_to_value[name] = value_to_write
return None
def to_dtype(
self,
x: str,
dtype: Any,
src_dtype: Optional[torch.dtype] = None,
use_compute_types: bool = False,
) -> str:
return x
def constant(self, value: Any, dtype: Any) -> str:
raise NotImplementedError
def mul(self, x0: str, x1: str) -> str:
return self._infix_bin_op("*", x0, x1)
def truediv(self, x0: str, x1: str) -> str:
raise NotImplementedError
def ge(self, x0: str, x1: str) -> str:
raise NotImplementedError
def add(self, x0: str, x1: str) -> str:
return self._infix_bin_op("+", x0, x1)
def relu(self, x0: str) -> str:
return self._prefix_un_op("relu", x0)
def sigmoid(self, x0: str) -> str:
return self._prefix_un_op("sigmoid", x0)
def sub(self, x0: str, x1: str) -> str:
raise NotImplementedError
def _get_cur_node(self) -> ComputedBuffer:
assert self.cur_node
return self.cur_node
@staticmethod
def _get_index_vars(node: ComputedBuffer) -> Sequence[sympy.Expr]:
def get_index_vars(node: ComputedBuffer) -> Sequence[sympy.Expr]:
data = node.data
# TODO mlazos: relax this, cutlass supports reductions and other ops
assert isinstance(data, Pointwise)
return data._index(data.ranges)
def _get_current_index_vars(self) -> Sequence[sympy.Expr]:
return self._get_index_vars(self._get_cur_node())
return self.get_index_vars(self._get_cur_node())
def _check_indexing(self, name: str, index: sympy.Expr) -> None:
# We only support indexing that matches the layout today because
@ -242,12 +272,3 @@ class CutlassEVTCodegen:
def _tmp_var(self) -> str:
return f"tmp_{next(self.var_counter)}"
def _infix_bin_op(self, op: str, a: str, b: str) -> str:
return f"{a} {op} {b}"
def _prefix_bin_op(self, op: str, a: str, b: str) -> str:
return f"{op}({a}, {b})"
def _prefix_un_op(self, op: str, a: str) -> str:
return f"{op}({a})"

View File

@ -4775,6 +4775,10 @@ class CUDATemplateBuffer(TemplateBuffer):
def get_workspace_size(self): # type: ignore[no-untyped-def]
return self.workspace_size if self.workspace_size is not None else 0
def emulate_store_fn(self) -> None:
for output in self.get_outputs():
ops.store(output.get_name(), None, None)
class CppTemplateBuffer(TemplateBuffer):
def __init__(self, layout, inputs, make_kernel_render, template, choice) -> None: # type: ignore[no-untyped-def]