[BE][cutlass backend] BE changes post cutlass_cppgen name change (#164589)

Differential Revision: D83809105

Handle reviews from https://github.com/pytorch/pytorch/pull/164159

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164589
Approved by: https://github.com/Skylion007
This commit is contained in:
Henry Tsang
2025-10-06 17:22:08 +00:00
committed by PyTorch MergeBot
parent 2164b66121
commit 96181d6f76
4 changed files with 59 additions and 60 deletions

View File

@ -255,10 +255,7 @@ class TestCutlassBackend(TestCase):
self.assertTrue(try_import_cutlass())
if config.is_fbcode():
import cutlass_cppgen # noqa: F401
else:
import cutlass_cppgen as python_cutlass # noqa: F401
import cutlass_cppgen # type: ignore[import-not-found] # noqa: F401
import cutlass_library # noqa: F401
def test_cutlass_key(self):

View File

@ -4,7 +4,6 @@ import unittest
import sympy
import torch
import torch._inductor.config as config
from torch._dynamo.test_case import TestCase
from torch._inductor.codegen.cuda.cutlass_utils import (
torch_dtype_to_cutlass_type,
@ -27,18 +26,14 @@ if try_import_cutlass():
LayoutType = cutlass_lib.LayoutType
DataType = cutlass_lib.DataType
from cutlass_cppgen.backend.evt.ir.tensor import Tensor as CutlassTensor
from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import (
_render_argument_type,
_trace,
trace,
)
if config.is_fbcode():
import cutlass_cppgen as python_cutlass # type: ignore[import-untyped, import-not-found] # noqa: F401
else:
import cutlass_cppgen as python_cutlass # type: ignore[import-untyped, import-not-found] # noqa: F401
CutlassTensor = python_cutlass.backend.evt.ir.tensor.Tensor
BIAS_CODE = """def example_epilogue(accum, C, aux, bias):
F = accum + C + aux
E = relu(F) + bias

View File

@ -3,7 +3,6 @@ from typing import Any, Union
from sympy import Expr
import torch._inductor.config as config
from torch._inductor.ir import (
ComputedBuffer,
InputBuffer,
@ -29,6 +28,27 @@ if try_import_cutlass():
import textwrap
from typing import Union
from cutlass_cppgen.backend.c_types import ( # type: ignore[import-not-found]
EmptyByte,
)
from cutlass_cppgen.backend.epilogue import ( # type: ignore[import-not-found]
dtype2ctype,
)
from cutlass_cppgen.backend.evt import ( # type: ignore[import-not-found]
EpilogueFunctorVisitor,
)
from cutlass_cppgen.backend.evt.backend.emitter_base import ( # type: ignore[import-not-found]
FusionCallbacks,
)
from cutlass_cppgen.backend.evt.backend.sm90_emitter import ( # type: ignore[import-not-found]
CollectiveEpilogue,
)
from cutlass_cppgen.backend.evt.frontend import ( # type: ignore[import-not-found]
PythonASTFrontend,
)
from cutlass_cppgen.backend.evt.ir.tensor import ( # type: ignore[import-not-found]
Tensor as CutlassTensor,
)
from cutlass_library import (
DataType,
EpilogueScheduleType,
@ -36,15 +56,10 @@ if try_import_cutlass():
TileDescription,
)
if config.is_fbcode():
import cutlass_cppgen as python_cutlass # type: ignore[import-untyped, import-not-found] # noqa: F401
else:
import cutlass_cppgen as python_cutlass # type: ignore[import-untyped, import-not-found] # noqa: F401
from torch._inductor.codegen.cuda import cuda_env
from torch._inductor.utils import IndentedBuffer
_CUTLASS_C_DTYPES = OrderedSet(python_cutlass.backend.epilogue.dtype2ctype.values()) # type: ignore[var-annotated]
_CUTLASS_C_DTYPES = OrderedSet(dtype2ctype.values()) # type: ignore[var-annotated]
class EVTArgRenames:
"""Handles mapping buffer names to variable names in the cpp kernel signature and body"""
@ -67,10 +82,10 @@ if try_import_cutlass():
var_name_to_buffer_name: dict[str, str],
name_to_buffer: dict[str, Buffer],
size_hint_fn: Callable[[Union[Expr, int]], int],
) -> dict[str, python_cutlass.backend.evt.ir.tensor.Tensor]:
) -> dict[str, CutlassTensor]:
def cutlass_tensor_from_buffer(
buffer: Buffer,
) -> python_cutlass.backend.evt.ir.tensor.Tensor:
) -> CutlassTensor:
shape = buffer.get_layout().size
stride = buffer.get_layout().stride
shape = tuple(size_hint_fn(x) for x in shape)
@ -85,7 +100,7 @@ if try_import_cutlass():
non-contiguous layout, received stride: {stride} and shape: {shape}"
)
return python_cutlass.backend.evt.ir.tensor.Tensor(
return CutlassTensor(
shape=shape,
layout_tag=(
LayoutType.RowMajor if is_row_major else LayoutType.ColumnMajor
@ -100,7 +115,7 @@ non-contiguous layout, received stride: {stride} and shape: {shape}"
def trace(
fn_src: str,
example_tensors: dict[str, python_cutlass.backend.evt.ir.tensor.Tensor],
example_tensors: dict[str, CutlassTensor],
accum_type: DataType,
output_type: DataType,
tile_description: TileDescription,
@ -112,22 +127,14 @@ non-contiguous layout, received stride: {stride} and shape: {shape}"
cuda_arch = int(cuda_env.get_cuda_arch()) # type: ignore[arg-type]
assert cuda_arch >= 90, "Only SM90+ is supported for EVT"
epilogue_functor = _trace(fn_src, example_tensors, cuda_arch, **kwargs)
visitor = python_cutlass.backend.evt.EpilogueFunctorVisitor(
cuda_arch, epilogue_functor
)
fusion_callbacks = (
python_cutlass.backend.evt.backend.emitter_base.FusionCallbacks(
visitor.graph, cuda_arch, emit_CD=False
)
)
collective_epilogue = (
python_cutlass.backend.evt.backend.sm90_emitter.CollectiveEpilogue(
tile_description,
epilogue_schedule,
accum_type,
output_type,
fusion_callbacks,
)
visitor = EpilogueFunctorVisitor(cuda_arch, epilogue_functor)
fusion_callbacks = FusionCallbacks(visitor.graph, cuda_arch, emit_CD=False)
collective_epilogue = CollectiveEpilogue(
tile_description,
epilogue_schedule,
accum_type,
output_type,
fusion_callbacks,
)
evt_name, evt_code = collective_epilogue.emit()
evt_args, arg_renames = _render_argument_type(
@ -141,18 +148,18 @@ non-contiguous layout, received stride: {stride} and shape: {shape}"
# The reason for this is that inspect.getsource does not work with functions defined at runtime via exec/eval
def _trace(
fn_src: str,
example_tensors: dict[str, python_cutlass.backend.evt.ir.tensor.Tensor],
example_tensors: dict[str, CutlassTensor],
cc: int,
**kwargs: Any,
) -> EpilogueFunctor:
class EpilogueFunctor(python_cutlass.backend.evt.frontend.PythonASTFrontend):
class EpilogueFunctor(PythonASTFrontend):
def __init__(self, cc: int, **kwargs: Any):
self.source = textwrap.dedent(fn_src)
super().__init__(cc, **kwargs)
def parse(
self,
example_inputs: dict[str, python_cutlass.backend.evt.ir.tensor.Tensor],
example_inputs: dict[str, CutlassTensor],
) -> None:
self.example_inputs = example_inputs
self.ast = ast.parse(self.source)
@ -173,9 +180,10 @@ non-contiguous layout, received stride: {stride} and shape: {shape}"
# Fragile, but this is the only way to guarantee t is expected type because t is a local class
def is_nested_visitor_type(t: type) -> bool:
return ".".join([t.__module__, t.__qualname__]) in {
"cutlass_cppgen.backend.c_types.visitor_factory.<locals>.VisitorType",
}
return (
".".join([t.__module__, t.__qualname__])
== "cutlass_cppgen.backend.c_types.visitor_factory.<locals>.VisitorType"
)
buffer = IndentedBuffer()
with buffer.set_tabwidth(2):
@ -233,9 +241,10 @@ non-contiguous layout, received stride: {stride} and shape: {shape}"
# Today, arguments are either a pointer to the
# node's memory, a stride tuple, the datatype
# Once again, need to check for local class type for stride tuple
if str(arg_ty) in {
"<class 'cutlass_cppgen.backend.c_types.tuple_factory_.<locals>.TupleType'>",
}:
if (
str(arg_ty)
== "<class 'cutlass_cppgen.backend.c_types.tuple_factory_.<locals>.TupleType'>"
):
DEFAULT_STRIDE_LEN = 3
assert len(node.get_layout().stride) <= DEFAULT_STRIDE_LEN
stride = [size_hint_fn(x) for x in node.get_layout().stride]
@ -260,7 +269,7 @@ non-contiguous layout, received stride: {stride} and shape: {shape}"
arg_ty in _CUTLASS_C_DTYPES
): # Assumption: this is the element dtype, this holds for all cutlass ir nodes currently
return f"{CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}(0)"
elif issubclass(arg_ty, python_cutlass.backend.c_types.EmptyByte):
elif issubclass(arg_ty, EmptyByte):
return "{}"
raise NotImplementedError(f"Unsupported arg type: {arg_ty}")

View File

@ -9,6 +9,7 @@ import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional
from typing_extensions import TypeIs
import sympy
@ -40,23 +41,20 @@ def move_cutlass_compiled_cache() -> None:
if not try_import_cutlass.cache_info().currsize > 0:
return
if config.is_fbcode():
import cutlass_cppgen as python_cutlass # type: ignore[import-not-found]
else:
import cutlass_cppgen as python_cutlass # type: ignore[import-not-found] # noqa: F401
import cutlass_cppgen # type: ignore[import-not-found]
# Check if the CACHE_FILE attribute exists in python_cutlass and if the file exists
if not hasattr(python_cutlass, "CACHE_FILE") or not os.path.exists(
python_cutlass.CACHE_FILE
# Check if the CACHE_FILE attribute exists in cutlass_cppgen and if the file exists
if not hasattr(cutlass_cppgen, "CACHE_FILE") or not os.path.exists(
cutlass_cppgen.CACHE_FILE
):
return
try:
filename = os.path.basename(python_cutlass.CACHE_FILE)
shutil.move(python_cutlass.CACHE_FILE, os.path.join(cache_dir(), filename))
filename = os.path.basename(cutlass_cppgen.CACHE_FILE)
shutil.move(cutlass_cppgen.CACHE_FILE, os.path.join(cache_dir(), filename))
log.debug("Moved CUTLASS compiled cache file to %s", cache_dir())
except OSError as e:
log.warning("Failed to move CUTLASS compiled cache file: %s", str(e))
log.warning("Failed to move CUTLASS compiled cache file: %s", e)
def _rename_cutlass_import(content: str, cutlass_modules: list[str]) -> str:
@ -158,7 +156,7 @@ def try_import_cutlass() -> bool:
)
try:
import cutlass_cppgen # noqa: F401, F811
import cutlass_cppgen # type: ignore[import-not-found] # noqa: F401, F811
import cutlass_library.generator # noqa: F401
import cutlass_library.library # noqa: F401
import cutlass_library.manifest # noqa: F401
@ -422,7 +420,7 @@ def get_max_alignment(inductor_layout: Layout) -> int:
size = inductor_layout.size
offset = inductor_layout.offset
def is_static_int(number):
def is_static_int(number: object) -> TypeIs[int | sympy.Integer]:
return isinstance(number, (int | sympy.Integer))
def a_factor_of(x, alignment):