mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][cutlass backend] BE changes post cutlass_cppgen name change (#164589)
Differential Revision: D83809105 Handle reviews from https://github.com/pytorch/pytorch/pull/164159 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164589 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
2164b66121
commit
96181d6f76
@ -255,10 +255,7 @@ class TestCutlassBackend(TestCase):
|
||||
|
||||
self.assertTrue(try_import_cutlass())
|
||||
|
||||
if config.is_fbcode():
|
||||
import cutlass_cppgen # noqa: F401
|
||||
else:
|
||||
import cutlass_cppgen as python_cutlass # noqa: F401
|
||||
import cutlass_cppgen # type: ignore[import-not-found] # noqa: F401
|
||||
import cutlass_library # noqa: F401
|
||||
|
||||
def test_cutlass_key(self):
|
||||
|
@ -4,7 +4,6 @@ import unittest
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
import torch._inductor.config as config
|
||||
from torch._dynamo.test_case import TestCase
|
||||
from torch._inductor.codegen.cuda.cutlass_utils import (
|
||||
torch_dtype_to_cutlass_type,
|
||||
@ -27,18 +26,14 @@ if try_import_cutlass():
|
||||
|
||||
LayoutType = cutlass_lib.LayoutType
|
||||
DataType = cutlass_lib.DataType
|
||||
from cutlass_cppgen.backend.evt.ir.tensor import Tensor as CutlassTensor
|
||||
|
||||
from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import (
|
||||
_render_argument_type,
|
||||
_trace,
|
||||
trace,
|
||||
)
|
||||
|
||||
if config.is_fbcode():
|
||||
import cutlass_cppgen as python_cutlass # type: ignore[import-untyped, import-not-found] # noqa: F401
|
||||
else:
|
||||
import cutlass_cppgen as python_cutlass # type: ignore[import-untyped, import-not-found] # noqa: F401
|
||||
CutlassTensor = python_cutlass.backend.evt.ir.tensor.Tensor
|
||||
|
||||
BIAS_CODE = """def example_epilogue(accum, C, aux, bias):
|
||||
F = accum + C + aux
|
||||
E = relu(F) + bias
|
||||
|
@ -3,7 +3,6 @@ from typing import Any, Union
|
||||
|
||||
from sympy import Expr
|
||||
|
||||
import torch._inductor.config as config
|
||||
from torch._inductor.ir import (
|
||||
ComputedBuffer,
|
||||
InputBuffer,
|
||||
@ -29,6 +28,27 @@ if try_import_cutlass():
|
||||
import textwrap
|
||||
from typing import Union
|
||||
|
||||
from cutlass_cppgen.backend.c_types import ( # type: ignore[import-not-found]
|
||||
EmptyByte,
|
||||
)
|
||||
from cutlass_cppgen.backend.epilogue import ( # type: ignore[import-not-found]
|
||||
dtype2ctype,
|
||||
)
|
||||
from cutlass_cppgen.backend.evt import ( # type: ignore[import-not-found]
|
||||
EpilogueFunctorVisitor,
|
||||
)
|
||||
from cutlass_cppgen.backend.evt.backend.emitter_base import ( # type: ignore[import-not-found]
|
||||
FusionCallbacks,
|
||||
)
|
||||
from cutlass_cppgen.backend.evt.backend.sm90_emitter import ( # type: ignore[import-not-found]
|
||||
CollectiveEpilogue,
|
||||
)
|
||||
from cutlass_cppgen.backend.evt.frontend import ( # type: ignore[import-not-found]
|
||||
PythonASTFrontend,
|
||||
)
|
||||
from cutlass_cppgen.backend.evt.ir.tensor import ( # type: ignore[import-not-found]
|
||||
Tensor as CutlassTensor,
|
||||
)
|
||||
from cutlass_library import (
|
||||
DataType,
|
||||
EpilogueScheduleType,
|
||||
@ -36,15 +56,10 @@ if try_import_cutlass():
|
||||
TileDescription,
|
||||
)
|
||||
|
||||
if config.is_fbcode():
|
||||
import cutlass_cppgen as python_cutlass # type: ignore[import-untyped, import-not-found] # noqa: F401
|
||||
else:
|
||||
import cutlass_cppgen as python_cutlass # type: ignore[import-untyped, import-not-found] # noqa: F401
|
||||
|
||||
from torch._inductor.codegen.cuda import cuda_env
|
||||
from torch._inductor.utils import IndentedBuffer
|
||||
|
||||
_CUTLASS_C_DTYPES = OrderedSet(python_cutlass.backend.epilogue.dtype2ctype.values()) # type: ignore[var-annotated]
|
||||
_CUTLASS_C_DTYPES = OrderedSet(dtype2ctype.values()) # type: ignore[var-annotated]
|
||||
|
||||
class EVTArgRenames:
|
||||
"""Handles mapping buffer names to variable names in the cpp kernel signature and body"""
|
||||
@ -67,10 +82,10 @@ if try_import_cutlass():
|
||||
var_name_to_buffer_name: dict[str, str],
|
||||
name_to_buffer: dict[str, Buffer],
|
||||
size_hint_fn: Callable[[Union[Expr, int]], int],
|
||||
) -> dict[str, python_cutlass.backend.evt.ir.tensor.Tensor]:
|
||||
) -> dict[str, CutlassTensor]:
|
||||
def cutlass_tensor_from_buffer(
|
||||
buffer: Buffer,
|
||||
) -> python_cutlass.backend.evt.ir.tensor.Tensor:
|
||||
) -> CutlassTensor:
|
||||
shape = buffer.get_layout().size
|
||||
stride = buffer.get_layout().stride
|
||||
shape = tuple(size_hint_fn(x) for x in shape)
|
||||
@ -85,7 +100,7 @@ if try_import_cutlass():
|
||||
non-contiguous layout, received stride: {stride} and shape: {shape}"
|
||||
)
|
||||
|
||||
return python_cutlass.backend.evt.ir.tensor.Tensor(
|
||||
return CutlassTensor(
|
||||
shape=shape,
|
||||
layout_tag=(
|
||||
LayoutType.RowMajor if is_row_major else LayoutType.ColumnMajor
|
||||
@ -100,7 +115,7 @@ non-contiguous layout, received stride: {stride} and shape: {shape}"
|
||||
|
||||
def trace(
|
||||
fn_src: str,
|
||||
example_tensors: dict[str, python_cutlass.backend.evt.ir.tensor.Tensor],
|
||||
example_tensors: dict[str, CutlassTensor],
|
||||
accum_type: DataType,
|
||||
output_type: DataType,
|
||||
tile_description: TileDescription,
|
||||
@ -112,22 +127,14 @@ non-contiguous layout, received stride: {stride} and shape: {shape}"
|
||||
cuda_arch = int(cuda_env.get_cuda_arch()) # type: ignore[arg-type]
|
||||
assert cuda_arch >= 90, "Only SM90+ is supported for EVT"
|
||||
epilogue_functor = _trace(fn_src, example_tensors, cuda_arch, **kwargs)
|
||||
visitor = python_cutlass.backend.evt.EpilogueFunctorVisitor(
|
||||
cuda_arch, epilogue_functor
|
||||
)
|
||||
fusion_callbacks = (
|
||||
python_cutlass.backend.evt.backend.emitter_base.FusionCallbacks(
|
||||
visitor.graph, cuda_arch, emit_CD=False
|
||||
)
|
||||
)
|
||||
collective_epilogue = (
|
||||
python_cutlass.backend.evt.backend.sm90_emitter.CollectiveEpilogue(
|
||||
tile_description,
|
||||
epilogue_schedule,
|
||||
accum_type,
|
||||
output_type,
|
||||
fusion_callbacks,
|
||||
)
|
||||
visitor = EpilogueFunctorVisitor(cuda_arch, epilogue_functor)
|
||||
fusion_callbacks = FusionCallbacks(visitor.graph, cuda_arch, emit_CD=False)
|
||||
collective_epilogue = CollectiveEpilogue(
|
||||
tile_description,
|
||||
epilogue_schedule,
|
||||
accum_type,
|
||||
output_type,
|
||||
fusion_callbacks,
|
||||
)
|
||||
evt_name, evt_code = collective_epilogue.emit()
|
||||
evt_args, arg_renames = _render_argument_type(
|
||||
@ -141,18 +148,18 @@ non-contiguous layout, received stride: {stride} and shape: {shape}"
|
||||
# The reason for this is that inspect.getsource does not work with functions defined at runtime via exec/eval
|
||||
def _trace(
|
||||
fn_src: str,
|
||||
example_tensors: dict[str, python_cutlass.backend.evt.ir.tensor.Tensor],
|
||||
example_tensors: dict[str, CutlassTensor],
|
||||
cc: int,
|
||||
**kwargs: Any,
|
||||
) -> EpilogueFunctor:
|
||||
class EpilogueFunctor(python_cutlass.backend.evt.frontend.PythonASTFrontend):
|
||||
class EpilogueFunctor(PythonASTFrontend):
|
||||
def __init__(self, cc: int, **kwargs: Any):
|
||||
self.source = textwrap.dedent(fn_src)
|
||||
super().__init__(cc, **kwargs)
|
||||
|
||||
def parse(
|
||||
self,
|
||||
example_inputs: dict[str, python_cutlass.backend.evt.ir.tensor.Tensor],
|
||||
example_inputs: dict[str, CutlassTensor],
|
||||
) -> None:
|
||||
self.example_inputs = example_inputs
|
||||
self.ast = ast.parse(self.source)
|
||||
@ -173,9 +180,10 @@ non-contiguous layout, received stride: {stride} and shape: {shape}"
|
||||
|
||||
# Fragile, but this is the only way to guarantee t is expected type because t is a local class
|
||||
def is_nested_visitor_type(t: type) -> bool:
|
||||
return ".".join([t.__module__, t.__qualname__]) in {
|
||||
"cutlass_cppgen.backend.c_types.visitor_factory.<locals>.VisitorType",
|
||||
}
|
||||
return (
|
||||
".".join([t.__module__, t.__qualname__])
|
||||
== "cutlass_cppgen.backend.c_types.visitor_factory.<locals>.VisitorType"
|
||||
)
|
||||
|
||||
buffer = IndentedBuffer()
|
||||
with buffer.set_tabwidth(2):
|
||||
@ -233,9 +241,10 @@ non-contiguous layout, received stride: {stride} and shape: {shape}"
|
||||
# Today, arguments are either a pointer to the
|
||||
# node's memory, a stride tuple, the datatype
|
||||
# Once again, need to check for local class type for stride tuple
|
||||
if str(arg_ty) in {
|
||||
"<class 'cutlass_cppgen.backend.c_types.tuple_factory_.<locals>.TupleType'>",
|
||||
}:
|
||||
if (
|
||||
str(arg_ty)
|
||||
== "<class 'cutlass_cppgen.backend.c_types.tuple_factory_.<locals>.TupleType'>"
|
||||
):
|
||||
DEFAULT_STRIDE_LEN = 3
|
||||
assert len(node.get_layout().stride) <= DEFAULT_STRIDE_LEN
|
||||
stride = [size_hint_fn(x) for x in node.get_layout().stride]
|
||||
@ -260,7 +269,7 @@ non-contiguous layout, received stride: {stride} and shape: {shape}"
|
||||
arg_ty in _CUTLASS_C_DTYPES
|
||||
): # Assumption: this is the element dtype, this holds for all cutlass ir nodes currently
|
||||
return f"{CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}(0)"
|
||||
elif issubclass(arg_ty, python_cutlass.backend.c_types.EmptyByte):
|
||||
elif issubclass(arg_ty, EmptyByte):
|
||||
return "{}"
|
||||
|
||||
raise NotImplementedError(f"Unsupported arg type: {arg_ty}")
|
||||
|
@ -9,6 +9,7 @@ import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
import sympy
|
||||
|
||||
@ -40,23 +41,20 @@ def move_cutlass_compiled_cache() -> None:
|
||||
if not try_import_cutlass.cache_info().currsize > 0:
|
||||
return
|
||||
|
||||
if config.is_fbcode():
|
||||
import cutlass_cppgen as python_cutlass # type: ignore[import-not-found]
|
||||
else:
|
||||
import cutlass_cppgen as python_cutlass # type: ignore[import-not-found] # noqa: F401
|
||||
import cutlass_cppgen # type: ignore[import-not-found]
|
||||
|
||||
# Check if the CACHE_FILE attribute exists in python_cutlass and if the file exists
|
||||
if not hasattr(python_cutlass, "CACHE_FILE") or not os.path.exists(
|
||||
python_cutlass.CACHE_FILE
|
||||
# Check if the CACHE_FILE attribute exists in cutlass_cppgen and if the file exists
|
||||
if not hasattr(cutlass_cppgen, "CACHE_FILE") or not os.path.exists(
|
||||
cutlass_cppgen.CACHE_FILE
|
||||
):
|
||||
return
|
||||
|
||||
try:
|
||||
filename = os.path.basename(python_cutlass.CACHE_FILE)
|
||||
shutil.move(python_cutlass.CACHE_FILE, os.path.join(cache_dir(), filename))
|
||||
filename = os.path.basename(cutlass_cppgen.CACHE_FILE)
|
||||
shutil.move(cutlass_cppgen.CACHE_FILE, os.path.join(cache_dir(), filename))
|
||||
log.debug("Moved CUTLASS compiled cache file to %s", cache_dir())
|
||||
except OSError as e:
|
||||
log.warning("Failed to move CUTLASS compiled cache file: %s", str(e))
|
||||
log.warning("Failed to move CUTLASS compiled cache file: %s", e)
|
||||
|
||||
|
||||
def _rename_cutlass_import(content: str, cutlass_modules: list[str]) -> str:
|
||||
@ -158,7 +156,7 @@ def try_import_cutlass() -> bool:
|
||||
)
|
||||
|
||||
try:
|
||||
import cutlass_cppgen # noqa: F401, F811
|
||||
import cutlass_cppgen # type: ignore[import-not-found] # noqa: F401, F811
|
||||
import cutlass_library.generator # noqa: F401
|
||||
import cutlass_library.library # noqa: F401
|
||||
import cutlass_library.manifest # noqa: F401
|
||||
@ -422,7 +420,7 @@ def get_max_alignment(inductor_layout: Layout) -> int:
|
||||
size = inductor_layout.size
|
||||
offset = inductor_layout.offset
|
||||
|
||||
def is_static_int(number):
|
||||
def is_static_int(number: object) -> TypeIs[int | sympy.Integer]:
|
||||
return isinstance(number, (int | sympy.Integer))
|
||||
|
||||
def a_factor_of(x, alignment):
|
||||
|
Reference in New Issue
Block a user