mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 00:14:54 +08:00
690 lines
24 KiB
Python
690 lines
24 KiB
Python
# mypy: allow-untyped-defs
|
|
import functools
|
|
import itertools
|
|
import logging
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union
|
|
|
|
from sympy import Expr, symbols
|
|
|
|
import torch._inductor.config as config
|
|
from torch import dtype as torch_dtype
|
|
from torch._inductor.codegen.common import get_device_op_overrides
|
|
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
|
|
from torch._inductor.scheduler import BaseSchedulerNode
|
|
from torch._inductor.utils import do_bench_using_profiling, OrderedSet, Placeholder
|
|
from torch.utils._sympy.value_ranges import ValueRanges
|
|
|
|
from .utils import DTYPE_TO_CUTLASS_TYPE
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from .template import ArgInfo
|
|
|
|
from ...autotune_process import CUTLASSBenchmarkRequest
|
|
from ...ir import (
|
|
Buffer,
|
|
ChoiceCaller,
|
|
CUTLASSTemplateBuffer,
|
|
IRNode,
|
|
Layout,
|
|
PrimitiveInfoType,
|
|
ShapeAsConstantBuffer,
|
|
TensorBox,
|
|
)
|
|
from ...utils import sympy_product
|
|
from ...virtualized import V
|
|
from ..common import (
|
|
CSEVariable,
|
|
IndentedBuffer,
|
|
Kernel,
|
|
OpOverrides,
|
|
WorkspaceArg,
|
|
WorkspaceZeroMode,
|
|
)
|
|
from ..cpp_utils import CppPrinter, DTYPE_TO_CPP
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from torch._inductor.codegen.cutlass.template import CUTLASSTemplate
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
cexpr = CppPrinter().doprint
|
|
|
|
|
|
def _normalize_idx(index: int, total_length: int) -> int:
|
|
return index if index >= 0 else index + total_length
|
|
|
|
|
|
ValidLayoutSymbols = Literal["M", "N", "K", "B", "lda", "ldb", "ldc", "ldd"]
|
|
ValidLayoutAttrs = Literal["size", "stride"]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class LayoutArg:
|
|
node: IRNode
|
|
symbol: ValidLayoutSymbols
|
|
attr: ValidLayoutAttrs
|
|
dim: int
|
|
|
|
def matches(self, node, attr, dim) -> bool:
|
|
return self.node == node and self.attr == attr and self.dim == dim
|
|
|
|
|
|
class CUTLASSKernel(Kernel):
|
|
"""
|
|
Baseclass for Cutlass based Kernels
|
|
"""
|
|
|
|
overrides = OpOverrides # type: ignore[assignment]
|
|
|
|
def __init__(self, *args, **kwargs) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
self.layout_args: dict[str, list[LayoutArg]] = defaultdict(list)
|
|
self.size_args: list[Union[Expr, int]] = []
|
|
# Mapping from arg name to IRNode.
|
|
self.named_nodes: dict[str, IRNode] = {}
|
|
|
|
def find_symbol(
|
|
self, node: IRNode, attr: ValidLayoutAttrs, dim: int
|
|
) -> Optional[str]:
|
|
arg = self.find_layout_arg(node, attr, dim)
|
|
return arg.symbol if arg else None
|
|
|
|
def find_layout_arg(
|
|
self, node: IRNode, attr: ValidLayoutAttrs, dim: int
|
|
) -> Optional[LayoutArg]:
|
|
matches = [
|
|
arg
|
|
for arg in itertools.chain.from_iterable(self.layout_args.values())
|
|
if arg.matches(node, attr, dim)
|
|
]
|
|
if len(matches) >= 1:
|
|
# Verify all matches have the same node, attribute, and dimension
|
|
# And if they come from the same node, whichever symbol we use is fine.
|
|
# if in runtime the logic changes, this would trigger guard
|
|
first_match = matches[0]
|
|
if not all(
|
|
match.node == first_match.node
|
|
and match.attr == first_match.attr
|
|
and match.dim == first_match.dim
|
|
for match in matches
|
|
):
|
|
raise AssertionError("All matching layout args should be identical")
|
|
return first_match
|
|
return None
|
|
|
|
def add_layout_arg(
|
|
self, symbol: ValidLayoutSymbols, node: IRNode, attr: ValidLayoutAttrs, dim: int
|
|
):
|
|
arg = LayoutArg(node, symbol, attr, dim)
|
|
self.layout_args[symbol].append(arg)
|
|
|
|
def init_layout_args(self) -> None:
|
|
X = self.named_nodes["X"]
|
|
W = self.named_nodes["W"]
|
|
Y = self.named_nodes["Y"]
|
|
Bias = self.named_nodes.get("Bias", None)
|
|
x_mdim = _normalize_idx(-2, len(X.get_size()))
|
|
x_kdim = _normalize_idx(-1, len(X.get_size()))
|
|
w_kdim = _normalize_idx(-2, len(W.get_size()))
|
|
w_ndim = _normalize_idx(-1, len(W.get_size()))
|
|
y_mdim = _normalize_idx(-2, len(Y.get_size()))
|
|
y_ndim = _normalize_idx(-1, len(Y.get_size()))
|
|
self.add_layout_arg("M", X, "size", x_mdim)
|
|
self.add_layout_arg("K", X, "size", x_kdim)
|
|
self.add_layout_arg("K", W, "size", w_kdim)
|
|
self.add_layout_arg("N", W, "size", w_ndim)
|
|
self.add_layout_arg("M", Y, "size", y_mdim)
|
|
self.add_layout_arg("N", Y, "size", y_ndim)
|
|
if len(X.get_size()) > 2:
|
|
self.add_layout_arg("B", X, "size", 0)
|
|
|
|
lda_dim = self.find_ld_idx(X)
|
|
ldb_dim = self.find_ld_idx(W)
|
|
ldc_dim = self.find_ld_idx(Bias) if Bias else None
|
|
ldd_dim = self.find_ld_idx(Y)
|
|
self.add_layout_arg("lda", X, "stride", lda_dim)
|
|
self.add_layout_arg("ldb", W, "stride", ldb_dim)
|
|
if Bias is not None and ldc_dim is not None:
|
|
self.add_layout_arg("ldc", Bias, "stride", ldc_dim)
|
|
self.add_layout_arg("ldd", Y, "stride", ldd_dim)
|
|
|
|
def get_layout_args(self) -> tuple[Union[Expr, int], ...]:
|
|
X = self.named_nodes["X"]
|
|
W = self.named_nodes["W"]
|
|
Y = self.named_nodes["Y"]
|
|
Bias = self.named_nodes.get("Bias", None)
|
|
mdim = _normalize_idx(-2, len(X.get_size()))
|
|
ndim = _normalize_idx(-1, len(W.get_size()))
|
|
kdim = _normalize_idx(-1, len(X.get_size()))
|
|
|
|
def get_ld(node) -> Union[Expr, int]:
|
|
dim = self.find_ld_idx(node)
|
|
return node.get_stride()[dim]
|
|
|
|
M = X.get_size()[mdim]
|
|
N = W.get_size()[ndim]
|
|
K = X.get_size()[kdim]
|
|
B = X.get_size()[0] if len(X.get_size()) > 2 else 1
|
|
LDA = get_ld(X)
|
|
LDB = get_ld(W)
|
|
LDC = get_ld(Bias) if Bias else 0
|
|
LDD = get_ld(Y)
|
|
return (M, N, K, B, LDA, LDB, LDC, LDD)
|
|
|
|
def get_dynamic_shape_args(self) -> list[Union[Expr, int]]:
|
|
return [*self.get_layout_args(), *self.size_args]
|
|
|
|
def get_offset_args(self) -> list[Expr]:
|
|
return [node.get_layout().offset for node in self.named_nodes.values()]
|
|
|
|
@staticmethod
|
|
def find_ld_idx(node: IRNode) -> int:
|
|
strides = node.get_stride()
|
|
# Handle 1D tensor case
|
|
if V.graph.sizevars.statically_known_equals(strides[-1], 1):
|
|
return _normalize_idx(-2, len(strides))
|
|
|
|
assert V.graph.sizevars.statically_known_equals(strides[-2], 1), strides[-2]
|
|
return _normalize_idx(-1, len(strides))
|
|
|
|
|
|
class CUTLASSTemplateKernel(CUTLASSKernel):
|
|
"""
|
|
Template kernels defined by Cutlass in C++.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
kernel_name: str,
|
|
runtime_arg_info: list["ArgInfo"],
|
|
runtime_arg_values: list[Any],
|
|
device_type: str = "cuda", # type: ignore[assignment]
|
|
) -> None:
|
|
"""
|
|
Initializes a new instance of the CUTLASSTemplateKernel class.
|
|
|
|
Args:
|
|
kernel_name (str): The name of the kernel.
|
|
"""
|
|
super().__init__()
|
|
self.kernel_name = kernel_name
|
|
self.runtime_arg_info = runtime_arg_info
|
|
self.runtime_arg_values = runtime_arg_values
|
|
self.device_type = device_type
|
|
self.device_codegen = get_device_op_overrides(self.device_type)
|
|
self._EXTRA_CPP_ARGS = f"size_t* workspace_size, uint8_t* workspace, {self.device_codegen.cpp_stream_type()} stream"
|
|
|
|
def check_not_null(self, node: IRNode) -> str:
|
|
"""
|
|
Generates code to check that a node is not null.
|
|
"""
|
|
if node is None:
|
|
return ""
|
|
|
|
size_str = self.size(node, 0, -1)
|
|
name_str = self.arg_name(node)
|
|
if name_str is None:
|
|
return ""
|
|
|
|
res = IndentedBuffer(initial_indent=2)
|
|
res.tabwidth = 1
|
|
res.splice(
|
|
f"""
|
|
{{
|
|
if (!{name_str}) {{
|
|
int64_t {name_str}_size = {size_str};
|
|
if ({name_str}_size > 0) {{
|
|
throw std::runtime_error("input {name_str} is null but size is not 0!");
|
|
}}
|
|
}}
|
|
}}
|
|
"""
|
|
)
|
|
return res.getvalue()
|
|
|
|
def get_signature(self) -> str:
|
|
return self.signature
|
|
|
|
def def_kernel(
|
|
self,
|
|
inputs: list[IRNode],
|
|
outputs: list[IRNode],
|
|
names_str: str = "",
|
|
input_reorder: Optional[list[int]] = None,
|
|
) -> str:
|
|
"""
|
|
Hook called from template code to generate function definition and
|
|
needed args.
|
|
|
|
Args:
|
|
inputs: List of input IRNodes
|
|
outputs: List of output IRNodes
|
|
names_str: Comma separated list of input + output argument names.
|
|
input_reorder: The actual order of input nodes.
|
|
e.g. The template might have input argument defined as [X, W, Bias],
|
|
and the actual input passed into this template could be [Bias, X, W].
|
|
In this case, the `input_reorder` would be [2, 0, 1].
|
|
additional_size_args: Additional size arguments for epilogue inputs
|
|
"""
|
|
# NB: name order matters here, it's used to match up offsets
|
|
names = [x.strip() for x in names_str.strip().split(",")]
|
|
if len(inputs) + len(outputs) != len(names):
|
|
raise RuntimeError(
|
|
f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}"
|
|
)
|
|
|
|
if input_reorder is not None:
|
|
assert len(inputs) == len(input_reorder)
|
|
else:
|
|
input_reorder = list(range(len(inputs)))
|
|
|
|
for idx in input_reorder:
|
|
name = names[idx]
|
|
node = inputs[idx]
|
|
if node is not None:
|
|
self.named_nodes[name] = node
|
|
self.args.input_buffers[node.get_name()] = name
|
|
|
|
free_symbols: OrderedSet[Expr] = OrderedSet()
|
|
for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs):
|
|
if node is not None:
|
|
# NB: named nodes must be populated in the order of names
|
|
self.named_nodes[name] = node
|
|
self.args.output_buffers[node.get_name()] = name
|
|
|
|
if name not in (
|
|
"X",
|
|
"W",
|
|
"Bias",
|
|
"Y",
|
|
): # we handle these symbolic shapes explicitly
|
|
for expr in itertools.chain(node.get_size(), node.get_stride()):
|
|
if isinstance(expr, Expr):
|
|
for s in expr.free_symbols:
|
|
free_symbols.add(s) # type: ignore[arg-type]
|
|
|
|
arg_defs, *_ = self.args.cpp_argdefs(DTYPE_TO_CUTLASS_TYPE)
|
|
|
|
self.init_layout_args()
|
|
size_vars = ["M", "N", "K", "B", "lda", "ldb", "ldc", "ldd"]
|
|
size_vars.extend(str(s) for s in free_symbols)
|
|
self.size_args.extend(free_symbols)
|
|
size_args = [f"const int {s}" for s in size_vars]
|
|
offset_args = [f"const int {name}_offset" for name in self.named_nodes.keys()]
|
|
runtime_arg_decls = ",".join(
|
|
[f"{arg.ty} {arg.name}" for arg in self.runtime_arg_info]
|
|
)
|
|
if runtime_arg_decls:
|
|
runtime_arg_decls += ", "
|
|
|
|
signature = (
|
|
f"int {self.kernel_name}({', '.join(arg_defs + size_args + offset_args)},\
|
|
{runtime_arg_decls}{self._EXTRA_CPP_ARGS})"
|
|
)
|
|
self.signature = signature
|
|
return signature
|
|
|
|
def call_kernel(
|
|
self,
|
|
name: str,
|
|
node: "CUTLASSTemplateBuffer", # type: ignore[name-defined]
|
|
) -> None:
|
|
"""
|
|
Generates code to call the kernel through V.graph.wrapper_code.
|
|
used from within torch._inductor.wrapper.PythonWrapperCodegen
|
|
|
|
name: Name of kernel function.
|
|
node: The CUTLASSTemplateBuffer node which contains information about the kernel, it's fused epilogue nodes
|
|
as well as all required inputs and outputs.
|
|
"""
|
|
wrapper = V.graph.wrapper_code
|
|
|
|
arg_types: list[Any]
|
|
if V.graph.cpp_wrapper:
|
|
# Make sure we initialize these kernels since they're exported as
|
|
# C-style symbol names.
|
|
assert isinstance(wrapper, CppWrapperCpu)
|
|
wrapper.initialized_kernels[name] = self
|
|
# We always originally initialize name with "KERNEL_NAME". So, we
|
|
# we replace with the real kernel name passed as an arg to this function.
|
|
self.signature = self.signature.replace(str(Placeholder.KERNEL_NAME), name)
|
|
_, call_args, arg_types = self.args.cpp_argdefs(DTYPE_TO_CUTLASS_TYPE)
|
|
else:
|
|
_, call_args, _, arg_types = self.args.python_argdefs()
|
|
|
|
dynamic_shape_args = self.get_dynamic_shape_args()
|
|
offset_args = self.get_offset_args()
|
|
call_args.extend(dynamic_shape_args) # type: ignore[arg-type]
|
|
call_args.extend(offset_args) # type: ignore[arg-type]
|
|
for arg in self.runtime_arg_values:
|
|
call_args.append(str(arg))
|
|
arg_types.extend("const int" for _ in dynamic_shape_args)
|
|
arg_types.extend("const int" for _ in offset_args)
|
|
for arg in self.runtime_arg_info:
|
|
arg_types.append(arg.ty)
|
|
# dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
|
|
for i in range(len(call_args)):
|
|
if V.graph.is_unspec_arg(call_args[i]):
|
|
call_args[i] = call_args[i] + ".item()"
|
|
elif isinstance(arg_types[i], torch_dtype):
|
|
call_args[i] = (
|
|
call_args[i]
|
|
if V.graph.cpp_wrapper
|
|
else f"c_void_p({call_args[i]}.data_ptr())"
|
|
)
|
|
|
|
# workspace_size ptr is NULL to mark this call is not intended for retrieving workspace_size.
|
|
# workspace_size should have already been retrieved prior to this call.
|
|
# workspace_size is here.
|
|
call_args.append("nullptr" if V.graph.cpp_wrapper else "None")
|
|
if V.graph.cpp_wrapper:
|
|
arg_types.append("size_t*")
|
|
|
|
if node.get_workspace_size() > 0:
|
|
ws = WorkspaceArg(
|
|
count=node.get_workspace_size(),
|
|
device=V.graph.get_current_device_or_throw(),
|
|
zero_mode=WorkspaceZeroMode.UNINITIALIZED,
|
|
outer_name=WorkspaceArg.unique_name(),
|
|
)
|
|
wrapper.generate_workspace_allocation(ws)
|
|
workspace = str(ws.outer_name)
|
|
call_args.append(
|
|
workspace
|
|
if V.graph.cpp_wrapper
|
|
else f"c_void_p({workspace}.data_ptr())"
|
|
)
|
|
else:
|
|
ws = None
|
|
call_args.append("nullptr" if V.graph.cpp_wrapper else "None")
|
|
if V.graph.cpp_wrapper:
|
|
arg_types.append("uint8_t*")
|
|
|
|
wrapper.generate_kernel_call(
|
|
name,
|
|
call_args,
|
|
triton=False,
|
|
arg_types=arg_types,
|
|
)
|
|
if ws:
|
|
wrapper.generate_workspace_deallocation(ws)
|
|
|
|
def dtype(self, node: IRNode) -> Optional[str]:
|
|
"""
|
|
Generates code which represents dtype of a given node.
|
|
"""
|
|
|
|
if node is None:
|
|
return "void"
|
|
return DTYPE_TO_CPP.get(node.get_layout().dtype)
|
|
|
|
def cutlass_dtype(self, node: IRNode, default_dtype="void") -> Optional[str]:
|
|
# Helper method, called into from CUTLASSGemmTemplate
|
|
if node is None:
|
|
return default_dtype
|
|
from torch._inductor.codegen.cutlass.template import CUTLASSTemplate
|
|
|
|
return CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]
|
|
|
|
def max_valid_index(self, node: IRNode, default=-1):
|
|
# Helper method, called into from CUTLASSGemmTemplate
|
|
if node is None:
|
|
return default
|
|
max_valid_offset = 0
|
|
for i in range(len(node.get_size())):
|
|
max_valid_offset += (node.get_size()[i] - 1) * node.get_stride()[i]
|
|
return max_valid_offset
|
|
|
|
def ptr(self, node: IRNode) -> str:
|
|
"""
|
|
Generates code which represents pointer of a given node.
|
|
"""
|
|
|
|
if node is None:
|
|
return "nullptr"
|
|
arg_name = self.arg_name(node)
|
|
if arg_name is None:
|
|
return "nullptr"
|
|
return f"{arg_name} + {arg_name}_offset"
|
|
|
|
def size(
|
|
self,
|
|
node: IRNode,
|
|
start_index: int,
|
|
end_index: Optional[int] = None,
|
|
default_value: int = 0,
|
|
) -> str:
|
|
"""
|
|
Hook called from template code to get the size of an arg.
|
|
Generates code which represents size of a given node in [start_index, end_index).
|
|
If node is None, returns default_value.
|
|
|
|
TODO: Will add needed args to pass it in if it is dynamic.
|
|
"""
|
|
|
|
if node is None:
|
|
return str(default_value)
|
|
|
|
start_index = _normalize_idx(start_index, len(node.get_size()))
|
|
if end_index is None:
|
|
end_index = start_index
|
|
end_index = _normalize_idx(end_index, len(node.get_size()))
|
|
sizes = [
|
|
self.find_symbol(node, "size", dim=i) or node.get_size()[i]
|
|
for i in range(start_index, end_index + 1)
|
|
]
|
|
if len(sizes) == 0:
|
|
return str(default_value)
|
|
|
|
sizes = [symbols(v) if isinstance(v, str) else v for v in sizes]
|
|
val = sympy_product(sizes)
|
|
return val
|
|
|
|
def stride(self, node: IRNode, index: int, default_value: int = 0) -> str:
|
|
"""
|
|
Hook called from template code to get the stride of an arg.
|
|
Generates code which represents stride of a given node at index.
|
|
If node is None, returns default_value.
|
|
|
|
TODO: Will add needed args to pass it in if it is dynamic.
|
|
"""
|
|
|
|
if node is None:
|
|
return str(default_value)
|
|
|
|
index = _normalize_idx(index, len(node.get_size()))
|
|
if index < 0:
|
|
return str(default_value)
|
|
|
|
stride = node.get_stride()[index]
|
|
if V.graph.sizevars.statically_known_leq(stride, 1):
|
|
return str(stride)
|
|
return self.find_symbol(node, "stride", dim=index) or str(stride)
|
|
|
|
def batch_stride(self, node: IRNode, default_value: int = 0) -> str:
|
|
"""
|
|
Hook called from template code to get the batch stride of an arg.
|
|
Returns 0 if batch dim is not present.
|
|
|
|
This method assumes that batch stride is the largest stride.
|
|
"""
|
|
|
|
if node is None:
|
|
return str(default_value)
|
|
|
|
if len(node.get_size()) < 3:
|
|
return str(default_value)
|
|
|
|
batch_stride = node.get_stride()[0]
|
|
if V.graph.sizevars.statically_known_leq(batch_stride, 1):
|
|
return str(batch_stride)
|
|
|
|
return "{}*{}".format(
|
|
self.find_symbol(node, "size", dim=1) or node.get_size()[1],
|
|
self.find_symbol(node, "size", dim=2) or node.get_size()[2],
|
|
)
|
|
|
|
def row_or_column_stride(self, node: IRNode, default_value: int = 0) -> str:
|
|
"""
|
|
Hook called from template code to get the row or column stride of an arg.
|
|
This is required by some CUTLASS 2.X APIs.
|
|
If the node is in row_major, it returns stride[-2].
|
|
If the node is in column_major, it returns stride[-1].
|
|
|
|
TODO: Will add needed args to pass it in if it is dynamic.
|
|
"""
|
|
|
|
if node is None or len(node.get_stride()) < 2:
|
|
return str(default_value)
|
|
|
|
stride0 = node.get_stride()[-1]
|
|
stride1 = node.get_stride()[-2]
|
|
if stride0 == 1:
|
|
return cexpr(self.rename_indexing(stride1))
|
|
elif stride1 == 1:
|
|
return cexpr(self.rename_indexing(stride0))
|
|
else:
|
|
raise RuntimeError(
|
|
f"At least 1 stride should be 1. Strides: {node.get_stride()=}"
|
|
)
|
|
|
|
def load(self, name: str, index: Expr, mode: Any = None) -> CSEVariable:
|
|
"""
|
|
Mock load function for memory planning to optimize allocations properly.
|
|
"""
|
|
return self.create_cse_var(name, bounds=ValueRanges.unknown())
|
|
|
|
def store(self, name: str, index: Expr, value: Any, mode: Any = None) -> None:
|
|
"""
|
|
Mock store function for memory planning to optimize allocations properly.
|
|
"""
|
|
self.store_buffer_names.add(name)
|
|
|
|
|
|
class CUTLASSTemplateCaller(ChoiceCaller):
|
|
"""
|
|
CUTLASSTemplateCaller
|
|
|
|
This class represents a caller for CUTLASS template kernels. It is a subclass of ChoiceCaller.
|
|
Attributes:
|
|
name (str): The name of the caller.
|
|
category (str): The category of the caller.
|
|
bmreq (CUTLASSBenchmarkRequest): The benchmark request for the caller.
|
|
template_buffer (CUTLASSTemplateBuffer): The template buffer for the caller.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
category: str,
|
|
input_nodes: list[Buffer],
|
|
layout: Layout,
|
|
make_kernel_render: Callable[
|
|
[CUTLASSTemplateBuffer, Optional[list[BaseSchedulerNode]]],
|
|
tuple[CUTLASSTemplateKernel, functools.partial[str]],
|
|
],
|
|
bmreq: CUTLASSBenchmarkRequest,
|
|
supports_epilogue_fusion: bool,
|
|
template: "CUTLASSTemplate", # type: ignore[name-defined]
|
|
info_kwargs: Optional[
|
|
dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]
|
|
], # type: ignore[type-arg]
|
|
description: str,
|
|
) -> None:
|
|
super().__init__(name, input_nodes, layout, description)
|
|
self.category = category
|
|
self.make_kernel_render = make_kernel_render
|
|
self.bmreq = bmreq
|
|
self.supports_epilogue_fusion = supports_epilogue_fusion
|
|
self.template = template
|
|
self.info_kwargs = info_kwargs
|
|
|
|
def precompile(self) -> None:
|
|
assert self.bmreq is not None
|
|
self.bmreq.precompile()
|
|
|
|
def benchmark(self, *args, out) -> float:
|
|
assert self.bmreq is not None
|
|
if config.profile_bandwidth_with_do_bench_using_profiling:
|
|
algo = self.bmreq.make_run_fn(*args, out=out)
|
|
return do_bench_using_profiling(algo)
|
|
return self.bmreq.benchmark(*args, out=out)
|
|
|
|
def __str__(self) -> str:
|
|
return f"CUTLASSTemplateCaller(source_file={self.bmreq.source_file})"
|
|
|
|
def call_name(self) -> str:
|
|
return f"cutlass_template_kernels.{self.name}"
|
|
|
|
def kernel_hash_key(self) -> str:
|
|
"""
|
|
Return kernel hash key that does not depend on swizzle.
|
|
"""
|
|
return "-".join(
|
|
[
|
|
self.category,
|
|
self.bmreq.hash_key,
|
|
]
|
|
)
|
|
|
|
def hash_key(self) -> str:
|
|
"""
|
|
Return kernel hash key that does not depend on swizzle.
|
|
"""
|
|
swizzle_str: str = (
|
|
str(self.info_kwargs.get("swizzle"))
|
|
if isinstance(self.info_kwargs, dict)
|
|
else "None"
|
|
)
|
|
return "-".join(
|
|
[
|
|
self.category,
|
|
self.bmreq.hash_key,
|
|
swizzle_str,
|
|
]
|
|
)
|
|
|
|
def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]:
|
|
"""
|
|
Information returned here is logged to the autotune log file when that is enabled.
|
|
|
|
In general, we should avoid calling this function as it is expensive to compute,
|
|
and can add up very fast.
|
|
"""
|
|
if self.info_kwargs is not None and "op" in self.info_kwargs:
|
|
op: Any = self.info_kwargs["op"]
|
|
return {
|
|
"backend": "CUDA",
|
|
"op_type": type(op).__name__,
|
|
"op_conf_name": str(op.configuration_name()),
|
|
"op_arch": str(op.arch),
|
|
"tile_shape": str(op.tile_description.tile_shape),
|
|
"epilogue_schedule": str(op.epilogue_schedule),
|
|
"kernel_schedule": str(op.kernel_schedule),
|
|
"element_accumulator": str(op.accumulator_type()),
|
|
"op_name": str(op.procedural_name()),
|
|
"instruction_shape": str(
|
|
op.tile_description.math_instruction.instruction_shape
|
|
),
|
|
"swizzle": str(self.info_kwargs["swizzle"]),
|
|
}
|
|
else:
|
|
return {"backend": "CUDA", "op_type": "unknown"}
|
|
|
|
def output_node(self) -> Union[TensorBox, ShapeAsConstantBuffer]:
|
|
self.bmreq.update_workspace_size()
|
|
return TensorBox.create(
|
|
CUTLASSTemplateBuffer(
|
|
layout=self.layout,
|
|
inputs=self.input_nodes,
|
|
make_kernel_render=self.make_kernel_render,
|
|
workspace_size=self.bmreq.workspace_size,
|
|
supports_epilogue_fusion=self.supports_epilogue_fusion,
|
|
template=self.template,
|
|
)
|
|
)
|