Compare commits

...

3 Commits

Author SHA1 Message Date
27a20fa048 [Cutlass] Support bias arg for fp8 GEMM
ghstack-source-id: 96e7fe551979aaf59491c6b6199c8cdcfe166230
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154761
2025-05-30 21:36:51 -07:00
fb0defa780 [Cutlass] Cleanup gemm_template evt handling
ghstack-source-id: d398bb7e2762405e92d0a31da4f70f2d4f87de38
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154775
2025-05-30 21:36:51 -07:00
af65ce5622 Revert "[cutlass backend] Cache config generation locally and remotely (#154686)"
This reverts commit a4b0023f3b3c28735ac7d4adede6343746c8c946.
2025-05-30 21:36:51 -07:00
6 changed files with 24 additions and 141 deletions

View File

@ -1654,7 +1654,7 @@ class TestCutlassBackend(TestCase):
),
),
)
@parametrize("has_bias", (False,))
@parametrize("has_bias", (False, True))
@parametrize("use_fast_accum", (False,))
def test_fp8_rowwise_scaling(
self,
@ -1726,7 +1726,7 @@ class TestCutlassBackend(TestCase):
),
),
)
@parametrize("has_bias", (False,))
@parametrize("has_bias", (False, True))
@parametrize("use_fast_accum", (False,))
def test_fp8_tensorwise_scaling(
self,

View File

@ -242,7 +242,6 @@ class CUDATemplateKernel(CUDAKernel):
self,
inputs: list[IRNode],
outputs: list[IRNode],
epilogue_inputs: list[IRNode],
names_str: str = "",
input_reorder: Optional[list[int]] = None,
) -> str:
@ -260,7 +259,7 @@ class CUDATemplateKernel(CUDAKernel):
In this case, the `input_reorder` would be [2, 0, 1].
"""
names = [x.strip() for x in names_str.strip().split(",")]
if len(inputs) + len(epilogue_inputs) + len(outputs) != len(names):
if len(inputs) + len(outputs) != len(names):
raise RuntimeError(
f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}"
)
@ -277,13 +276,6 @@ class CUDATemplateKernel(CUDAKernel):
self.named_nodes[name] = node
self.args.input_buffers[node.get_name()] = name
for epilogue_input in epilogue_inputs:
if epilogue_input is not None:
self.named_nodes[epilogue_input.get_name()] = epilogue_input
self.args.input_buffers[epilogue_input.get_name()] = (
epilogue_input.get_name()
)
for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs):
if node is not None:
self.named_nodes[name] = node

View File

@ -1,93 +0,0 @@
# mypy: allow-untyped-defs
import functools
import hashlib
import json
import logging
import os
import time
from typing import Any, Optional
import torch._inductor.config as config
from torch._inductor.codecache import cutlass_key
from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch, get_cuda_version
from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer
from torch._inductor.runtime.cache_dir_utils import cache_dir
from torch._inductor.utils import clear_on_fresh_inductor_cache
log = logging.getLogger(__name__)
CONFIG_PREFIX: str = "configs"
def get_config_request_key(
arch: str,
cuda_version: str,
instantiation_level: str,
) -> str:
"""
Return a key for the full ops, based on cutlass key, arch, cuda version, and instantiation level.
"""
hash_target = "-".join(
[
cutlass_key().decode(),
arch,
cuda_version,
instantiation_level,
]
)
return hashlib.sha256(hash_target.encode("utf-8")).hexdigest()[0:8]
def _generate_config_filename(request_key: str) -> str:
"""
Generate a filename for the full ops.
"""
return f"{CONFIG_PREFIX}_{request_key}.json"
@clear_on_fresh_inductor_cache
@functools.lru_cache(None)
def maybe_fetch_ops() -> Optional[list[Any]]:
"""
Fetch ops from databases.
"""
if config.force_disable_caches:
return None
# setup
arch: str = get_cuda_arch()
# get_cuda_version might return "12.4.0" or "12.4"
# but we want to use "12.4"
version: str = ".".join(get_cuda_version().split(".")[:2])
instantiation_level: str = config.cuda.cutlass_instantiation_level
# filename and filepath
request_key: str = get_config_request_key(arch, version, instantiation_level)
filename: str = _generate_config_filename(request_key)
filepath: str = os.path.join(cache_dir(), filename)
# try fetch
serialized_ops: Optional[list[str]] = None
start_time = time.time()
if os.path.isfile(filepath):
# locally
with open(filepath) as f:
serialized_ops = json.load(f)
elif config.is_fbcode():
from torch._inductor.fb.cutlass_remote_cache import (
maybe_fetch_cutlass_configs_from_remote,
)
# from remote
serialized_ops = maybe_fetch_cutlass_configs_from_remote(filepath)
if serialized_ops is None:
return None
# deserialize
serializer = get_cutlass_operation_serializer()
full_ops = [serializer.deserialize(x) for x in serialized_ops] # type: ignore[union-attr]
log.info("Loaded ops from %s cache in %.3fs", filename, time.time() - start_time)
return full_ops

View File

@ -226,7 +226,7 @@ class CutlassEVTCodegen(CutlassEVTOpsMixIn):
return dict(self.var_name_to_buffer_name)
def get_reads(self) -> list[str]:
return list(self.reads)
return list(self.reads.difference(self.store_name_to_value.keys()))
def get_writes(self) -> list[str]:
return list(self.store_name_to_value.keys())

View File

@ -10,7 +10,6 @@ from typing import Any, Optional, Union
import torch
import torch.utils._pytree as pytree
from torch._inductor.codegen.cuda.cutlass_cache import maybe_fetch_ops
from torch._inductor.scheduler import BaseSchedulerNode
from torch._inductor.select_algorithm import create_inputs_key
from torch._inductor.utils import clear_on_fresh_inductor_cache
@ -26,7 +25,7 @@ from ...ir import (
Layout,
ReinterpretView,
)
from ...utils import is_dynamic, OrderedSet, Placeholder
from ...utils import is_dynamic, Placeholder
from ...virtualized import V
from ..common import IndentedBuffer
from . import cutlass_utils
@ -437,7 +436,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
)
self.alpha = alpha
self.beta = beta
assert len(input_nodes) == 2 or len(input_nodes) == 3 or len(input_nodes) == 4
assert len(input_nodes) >= 2 and len(input_nodes) <= 5
assert self._are_inputs_layout_compatible(
[node.get_layout() for node in input_nodes]
)
@ -931,14 +930,8 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
log.debug("Using cached ops for %s", self.cache_key)
return self.filtered_ops_cache[self.cache_key]
maybe_ops = maybe_fetch_ops()
if maybe_ops is None:
log.debug("Cannot fetch ops from cache, generating ops from scratch")
full_ops = cutlass_utils.gen_ops()
ops = pytree.tree_flatten(full_ops)[0]
else:
log.debug("Using cached ops from cache")
ops = maybe_ops
full_ops = cutlass_utils.gen_ops()
ops = pytree.tree_flatten(full_ops)[0]
res: dict[str, cutlass_gemm_op.GemmOperation] = {}
start_time = time.time()
@ -1049,8 +1042,9 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
# to make op mutable without affecting others
op = copy.deepcopy(op)
if Bias is not None:
assert Bias.get_layout().dtype == X.get_layout().dtype
is_scaled_mm = len(self.input_nodes) in (4, 5)
if Bias is not None and not is_scaled_mm:
assert Bias.get_dtype() == X.get_dtype()
# This might have been set to void during filtering, when the assumption was still that there's no C
# operand
op.C.element = op.A.element
@ -1071,37 +1065,32 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
op = self.swap_XW(op)
should_swap_xw = True
is_scaled_mm = len(self.input_nodes) == 4
if epilogue_nodes or is_scaled_mm:
if epilogue_nodes:
(
evt_read_names,
evt_write_names,
input_names,
output_names,
var_name_to_buffer_name,
evt_py_code,
) = CutlassEVTCodegen.ir_to_evt_python_code(
Y.get_name(), epilogue_nodes, V.kernel.removed_buffers
)
D_output_name = var_name_to_buffer_name["D"]
name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs
D_output_buffer = name_to_buffer[D_output_name]
D_dtype = D_output_buffer.get_dtype()
Y = D_output_buffer # type: ignore[assignment]
# Interestingly, I don't think the rest of the layout matters here since we
# use the properties of the Y buffer to fill in D's properties in the epilogue
# args. This is needed though because it defines types expected in the epilogue args.
op.D.element = cutlass_utils.torch_dtype_to_cutlass_type(
D_output_buffer.get_layout().dtype
D_output_buffer.get_dtype()
)
read_names = OrderedSet(evt_read_names) - OrderedSet(evt_write_names)
write_names = OrderedSet(evt_write_names)
assert write_names, "There should be at least one write"
assert output_names, "There should be at least one write"
input_names = list(read_names)
output_names = list(write_names)
epilogue_inputs = [name_to_buffer[name] for name in input_names]
epilogue_outputs = [name_to_buffer[name] for name in output_names]
outputs = [name_to_buffer[name] for name in output_names]
else: # Scaled MM, we read the two scale matrices and write a single output
(
evt_read_names,
@ -1115,9 +1104,8 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
input_names = list(evt_read_names)
output_names = [] # We only need Y
D_dtype = Y.get_layout().dtype
epilogue_inputs = [self.input_nodes[2], self.input_nodes[3]]
epilogue_outputs = []
outputs = []
acc_dtype = cutlass_utils.get_accumulator_dtype(
[X.get_dtype(), W.get_dtype()]
@ -1128,15 +1116,15 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
op,
evt_py_code,
var_name_to_buffer_name,
D_dtype,
Y.get_dtype(),
acc_dtype,
)
inputs = [
X,
W,
Bias,
*epilogue_inputs, # type: ignore[list-item]
Bias,
Y,
*extra_inputs,
]
@ -1145,15 +1133,13 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
)
else:
evt_name = None
epilogue_inputs = []
epilogue_outputs = [Y]
outputs = [Y]
evt_args = f"{{ElementComputeEpilogue({self.alpha}), ElementComputeEpilogue({self.beta})}}"
evt_code = ""
kernel_call_signature = kernel.def_kernel(
inputs=inputs, # type: ignore[arg-type]
outputs=epilogue_outputs, # type: ignore[arg-type]
epilogue_inputs=[],
outputs=outputs, # type: ignore[arg-type]
names_str=names_str,
input_reorder=input_reorder,
)
@ -1315,7 +1301,7 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
Returns:
bool: True if layouts are GEMM compatible, otherwise False.
"""
assert len(layouts) == 2 or len(layouts) == 3 or len(layouts) == 4
assert len(layouts) >= 2 and len(layouts) <= 5
# Check if A and B are compatible
A_layout, B_layout = layouts[:2]
if len(A_layout.size) < 1:
@ -1495,7 +1481,7 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
self,
op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821
) -> tuple[Optional[Buffer], list[Optional[Buffer]], list[str]]:
Bias = None if len(self.input_nodes) in (2, 4) else self.input_nodes[2]
Bias = self.input_nodes[2] if len(self.input_nodes) == 3 else None
inputs: list[Optional[Buffer]] = []
names: list[str] = []
return (Bias, inputs, names)

View File

@ -1,6 +1,5 @@
# mypy: allow-untyped-defs
import enum
import functools
import json
from enum import Enum
from typing import Optional
@ -459,7 +458,6 @@ class CUTLASSOperationSerializer:
return enum_class[json_dict["name"]]
@functools.lru_cache(1)
def get_cutlass_operation_serializer() -> Optional[CUTLASSOperationSerializer]:
if not try_import_cutlass():
return None