mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Cutlass] Implement memory planning for EVT (#153177)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153177 Approved by: https://github.com/henrylhtsang ghstack dependencies: #153196, #150907
This commit is contained in:
committed by
PyTorch MergeBot
parent
a3154ca34a
commit
9fa07340fd
@ -2416,10 +2416,14 @@ class CSEProxy(DefaultHandler):
|
||||
"""
|
||||
from ..bounds import ValueRangeAnalysis
|
||||
from ..select_algorithm import TritonTemplateKernel
|
||||
from .cuda.cuda_kernel import CUDATemplateKernel
|
||||
|
||||
if isinstance(V.kernel, TritonTemplateKernel):
|
||||
return ValueRanges.unknown()
|
||||
|
||||
if isinstance(V.kernel, CUDATemplateKernel):
|
||||
return ValueRanges.unknown()
|
||||
|
||||
fx_node = V.interpreter.current_node
|
||||
if fx_node.target == name and self.kernel.node_to_bounds is not None:
|
||||
assert isinstance(self.kernel.node_to_bounds, dict)
|
||||
|
@ -3,6 +3,10 @@ import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from torch._inductor.codegen.cuda.cutlass_python_evt import (
|
||||
CutlassEVTCodegen,
|
||||
MockCutlassHandler,
|
||||
)
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from ...._dynamo.utils import counters
|
||||
@ -138,6 +142,18 @@ class CUDACPPScheduling(BaseScheduling):
|
||||
with kernel:
|
||||
for node in [template_node, *epilogue_nodes]:
|
||||
node.mark_run()
|
||||
|
||||
# typically there is a codegen pass which runs after mark_run
|
||||
# for this kernel we've already generated the C++ code, but we still
|
||||
# need to let the kernel know about loads/stores that occur in the fused
|
||||
# kernel for memory planning to properly optimize allocations
|
||||
ctb.emulate_store_fn()
|
||||
for node in epilogue_ir_nodes:
|
||||
with V.set_ops_handler(MockCutlassHandler(V.get_ops_handler())):
|
||||
assert isinstance(
|
||||
node, ComputedBuffer
|
||||
) # Not sure why we need to do this again
|
||||
node.get_store_function()(CutlassEVTCodegen.get_index_vars(node))
|
||||
src_code = render()
|
||||
|
||||
with V.set_kernel_handler(kernel):
|
||||
|
@ -539,6 +539,12 @@ class CUDATemplateKernel(CUDAKernel):
|
||||
f"At least 1 stride should be 1. Strides: {node.get_stride()=}"
|
||||
)
|
||||
|
||||
def store(self, name: str, index: Expr, value: Any, mode: Any = None) -> None:
|
||||
"""
|
||||
Mock store function for memory planning to optimize allocations properly.
|
||||
"""
|
||||
self.store_buffer_names.add(name)
|
||||
|
||||
|
||||
class CUDATemplateCaller(ChoiceCaller):
|
||||
"""
|
||||
|
@ -9,7 +9,7 @@ import sympy
|
||||
import torch
|
||||
import torch._inductor.virtualized as virtualized
|
||||
from torch._inductor.ir import ComputedBuffer, Pointwise
|
||||
from torch._inductor.ops_handler import DefaultHandler
|
||||
from torch._inductor.ops_handler import DefaultHandler, WrapperHandler
|
||||
from torch._inductor.scheduler import BaseSchedulerNode
|
||||
from torch._inductor.utils import IndentedBuffer, OrderedSet
|
||||
from torch._inductor.virtualized import OpsValue
|
||||
@ -20,6 +20,69 @@ from ...virtualized import V
|
||||
_ACCUMULATOR_ALIAS = "accum"
|
||||
|
||||
|
||||
class CutlassEVTOpsMixIn:
|
||||
@staticmethod
|
||||
def _infix_bin_op(op: str, a: str, b: str) -> str:
|
||||
return f"{a} {op} {b}"
|
||||
|
||||
@staticmethod
|
||||
def _prefix_bin_op(op: str, a: str, b: str) -> str:
|
||||
return f"{op}({a}, {b})"
|
||||
|
||||
@staticmethod
|
||||
def _prefix_un_op(op: str, a: str) -> str:
|
||||
return f"{op}({a})"
|
||||
|
||||
@staticmethod
|
||||
def to_dtype(
|
||||
x: str,
|
||||
dtype: Any,
|
||||
src_dtype: Optional[torch.dtype] = None,
|
||||
use_compute_types: bool = False,
|
||||
) -> str:
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def constant(value: Any, dtype: Any) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def mul(x0: str, x1: str) -> str:
|
||||
return CutlassEVTOpsMixIn._infix_bin_op("*", x0, x1)
|
||||
|
||||
@staticmethod
|
||||
def truediv(x0: str, x1: str) -> str:
|
||||
return CutlassEVTOpsMixIn._infix_bin_op("/", x0, x1)
|
||||
|
||||
@staticmethod
|
||||
def ge(x0: str, x1: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def add(x0: str, x1: str) -> str:
|
||||
return CutlassEVTOpsMixIn._infix_bin_op("+", x0, x1)
|
||||
|
||||
@staticmethod
|
||||
def relu(x0: str) -> str:
|
||||
return CutlassEVTOpsMixIn._prefix_un_op("relu", x0)
|
||||
|
||||
@staticmethod
|
||||
def sigmoid(x0: str) -> str:
|
||||
return CutlassEVTOpsMixIn._prefix_un_op("sigmoid", x0)
|
||||
|
||||
@staticmethod
|
||||
def sub(x0: str, x1: str) -> str:
|
||||
return CutlassEVTOpsMixIn._infix_bin_op("-", x0, x1)
|
||||
|
||||
@staticmethod
|
||||
def tanh(x0: str) -> str:
|
||||
return CutlassEVTOpsMixIn._prefix_un_op("tanh", x0)
|
||||
|
||||
|
||||
class MockCutlassHandler(CutlassEVTOpsMixIn, WrapperHandler):
|
||||
"""Passthrough handler for cutlass ops, used for running epilogue nodes for memory planning"""
|
||||
|
||||
|
||||
class _AssignmentFormatter(DefaultHandler):
|
||||
def __init__(self, parent_handler: "CutlassEVTCodegen"):
|
||||
self.parent_handler = parent_handler
|
||||
@ -39,7 +102,7 @@ class _AssignmentFormatter(DefaultHandler):
|
||||
raise NotImplementedError(name)
|
||||
|
||||
|
||||
class CutlassEVTCodegen:
|
||||
class CutlassEVTCodegen(CutlassEVTOpsMixIn):
|
||||
"""
|
||||
Notes:
|
||||
* Used by CUTLASSGemmTemplate.
|
||||
@ -90,7 +153,7 @@ class CutlassEVTCodegen:
|
||||
node = s_node.node
|
||||
assert isinstance(node, ComputedBuffer)
|
||||
with codegen.set_cur_node(node):
|
||||
index_vars = CutlassEVTCodegen._get_index_vars(node)
|
||||
index_vars = CutlassEVTCodegen.get_index_vars(node)
|
||||
node.get_store_function()(index_vars)
|
||||
|
||||
return (
|
||||
@ -157,52 +220,19 @@ class CutlassEVTCodegen:
|
||||
self.store_name_to_value[name] = value_to_write
|
||||
return None
|
||||
|
||||
def to_dtype(
|
||||
self,
|
||||
x: str,
|
||||
dtype: Any,
|
||||
src_dtype: Optional[torch.dtype] = None,
|
||||
use_compute_types: bool = False,
|
||||
) -> str:
|
||||
return x
|
||||
|
||||
def constant(self, value: Any, dtype: Any) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def mul(self, x0: str, x1: str) -> str:
|
||||
return self._infix_bin_op("*", x0, x1)
|
||||
|
||||
def truediv(self, x0: str, x1: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def ge(self, x0: str, x1: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def add(self, x0: str, x1: str) -> str:
|
||||
return self._infix_bin_op("+", x0, x1)
|
||||
|
||||
def relu(self, x0: str) -> str:
|
||||
return self._prefix_un_op("relu", x0)
|
||||
|
||||
def sigmoid(self, x0: str) -> str:
|
||||
return self._prefix_un_op("sigmoid", x0)
|
||||
|
||||
def sub(self, x0: str, x1: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_cur_node(self) -> ComputedBuffer:
|
||||
assert self.cur_node
|
||||
return self.cur_node
|
||||
|
||||
@staticmethod
|
||||
def _get_index_vars(node: ComputedBuffer) -> Sequence[sympy.Expr]:
|
||||
def get_index_vars(node: ComputedBuffer) -> Sequence[sympy.Expr]:
|
||||
data = node.data
|
||||
# TODO mlazos: relax this, cutlass supports reductions and other ops
|
||||
assert isinstance(data, Pointwise)
|
||||
return data._index(data.ranges)
|
||||
|
||||
def _get_current_index_vars(self) -> Sequence[sympy.Expr]:
|
||||
return self._get_index_vars(self._get_cur_node())
|
||||
return self.get_index_vars(self._get_cur_node())
|
||||
|
||||
def _check_indexing(self, name: str, index: sympy.Expr) -> None:
|
||||
# We only support indexing that matches the layout today because
|
||||
@ -242,12 +272,3 @@ class CutlassEVTCodegen:
|
||||
|
||||
def _tmp_var(self) -> str:
|
||||
return f"tmp_{next(self.var_counter)}"
|
||||
|
||||
def _infix_bin_op(self, op: str, a: str, b: str) -> str:
|
||||
return f"{a} {op} {b}"
|
||||
|
||||
def _prefix_bin_op(self, op: str, a: str, b: str) -> str:
|
||||
return f"{op}({a}, {b})"
|
||||
|
||||
def _prefix_un_op(self, op: str, a: str) -> str:
|
||||
return f"{op}({a})"
|
||||
|
@ -4775,6 +4775,10 @@ class CUDATemplateBuffer(TemplateBuffer):
|
||||
def get_workspace_size(self): # type: ignore[no-untyped-def]
|
||||
return self.workspace_size if self.workspace_size is not None else 0
|
||||
|
||||
def emulate_store_fn(self) -> None:
|
||||
for output in self.get_outputs():
|
||||
ops.store(output.get_name(), None, None)
|
||||
|
||||
|
||||
class CppTemplateBuffer(TemplateBuffer):
|
||||
def __init__(self, layout, inputs, make_kernel_render, template, choice) -> None: # type: ignore[no-untyped-def]
|
||||
|
Reference in New Issue
Block a user