Compare commits

...

4 Commits

Author SHA1 Message Date
0d87736a9d [Cutlass] Include fp8 headers in aoti cpp wrapper
ghstack-source-id: 9e05cf0e9a4189851d357a87cc7b4397502f6fc8
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155173
2025-06-10 11:14:14 -07:00
bd673511f8 [Cutlass] allow filtering by fast_accum
ghstack-source-id: 15baa28ced00b73116ab7ce68bf8be5f580ad7db
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155195
2025-06-09 12:41:31 -07:00
9663fa8ff9 [Cutlass] EVT dynamic shapes support
ghstack-source-id: 760373aafb0488cd44dbfbc480eb393d0504a0bb
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154835

xx

xx
2025-06-05 02:41:37 -07:00
5fea0376fd [Cutlass] fp8 dynamic shapes test
ghstack-source-id: 748c8c17abcccf124c091efcbe98ce019e6149c9
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154829

xx

xx

xx
2025-06-05 02:41:36 -07:00
10 changed files with 299 additions and 58 deletions

View File

@ -446,21 +446,26 @@ class TestCutlassBackend(TestCase):
Main test for mm.
"""
class MyModel(torch.nn.Module):
def forward(self, a, b):
return a @ b
model = MyModel().cuda()
# M, N, K
shapes = [
(128, 128, 16),
(1024, 1024, 256),
]
shapes = shapes[0:1] if not dynamic else shapes
# M, N, K
shapes = shapes if dynamic else shapes[0:1]
class MyModel(torch.nn.Module):
def forward(self, a, b):
return a @ b
model = MyModel().cuda()
inputs = [
(torch.randn(M, K).cuda().to(dtype), torch.randn(K, N).cuda().to(dtype))
for (M, N, K) in shapes
]
dynamic_shapes = (
{
"a": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC},
@ -483,11 +488,100 @@ class TestCutlassBackend(TestCase):
model, inputs, dynamic_shapes=dynamic_shapes
)
else:
compiled_model = torch.compile(model, dynamic=dynamic)
compiled_model = torch.compile(model, dynamic=True)
actual = [compiled_model(*input) for input in inputs]
torch.testing.assert_close(actual, expected)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@parametrize("dynamic", (False, True))
@parametrize("use_aoti", (False, True))
@parametrize("dtype", (torch.float8_e4m3fn,))
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_max_autotune_cutlass_backend_fp8_scaled_mm(
self,
dynamic: bool,
max_autotune_gemm_backends: str = "CUTLASS",
use_aoti: bool = False,
dtype: torch.dtype = torch.float16,
):
"""
Main test for mm.
"""
# M, N, K
shapes = [
(128, 128, 16),
(1024, 1024, 256),
]
# M, N, K
shapes = shapes if dynamic else shapes[0:1]
inputs = []
for shape in shapes:
M, N, K = shape
output_dtype = torch.bfloat16
device = "cuda"
x = torch.randn(M, K, dtype=output_dtype, device=device)
w = torch.randn(N, K, dtype=output_dtype, device=device)
# quantize weight (prior to inference)
w_fp8, w_inverse_scale = _quantize_rowwise(w, dtype)
w_t_fp8 = w_fp8.t()
w_inverse_scale = w_inverse_scale.t() # scale_b should be (1, N)
# quantize input x
x_fp8, x_inverse_scale = _quantize_rowwise(x, dtype)
inputs.append((x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale))
class MyModel(torch.nn.Module):
def forward(self, x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale):
y = torch._scaled_mm(
x_fp8,
w_t_fp8,
x_inverse_scale,
w_inverse_scale,
None,
out_dtype=torch.bfloat16,
use_fast_accum=False,
)
return y
dynamic_shapes = (
{
"x_fp8": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC},
"x_inverse_scale": {0: Dim.DYNAMIC, 1: 1},
"w_t_fp8": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC},
"w_inverse_scale": {0: 1, 1: Dim.DYNAMIC},
}
if dynamic
else None
)
model = MyModel().cuda()
with config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_max_profiling_configs": 2,
"benchmark_epilogue_fusion": False, # EVT doesn't support benchmark fusion yet
"cuda.cutlass_tma_only": True,
}
), dynamo_config.patch({"error_on_recompile": dynamic}):
expected = [model(*input) for input in inputs]
if use_aoti:
actual = AOTIRunnerUtil.run_multiple(
model, inputs, dynamic_shapes=dynamic_shapes
)
else:
compiled_model = torch.compile(model, dynamic=True)
actual = [compiled_model(*input) for input in inputs]
torch.testing.assert_close(actual, expected, rtol=1e-2, atol=0.05)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@parametrize("dynamic", (False, True))
@parametrize("use_aoti", (False, True))
@ -1052,6 +1146,93 @@ class TestCutlassBackend(TestCase):
cuda_template_count += 1
assert cuda_template_count > 0, "No CUDATemplateCaller choices"
@unittest.skipIf(not SM90OrLater, "need sm_90")
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_cutlass_backend_fp8_scaled_mm_fast_accum_filtering(
self,
):
float8_dtype = torch.float8_e4m3fn
# Only bf16 output type is supported for row-wise scaling, not fp32
output_dtype: torch.dtype = torch.bfloat16
device = "cuda"
M, K, N = 128, 128, 128 # Matmul Y = X [M, K] x W [N, K]
x = torch.randn(M, K, dtype=output_dtype, device=device)
w = torch.randn(N, K, dtype=output_dtype, device=device)
bias = None
# quantize weight (prior to inference)
w_fp8, w_inverse_scale = _quantize_rowwise(w, float8_dtype)
w_t_fp8 = w_fp8.t()
w_inverse_scale = w_inverse_scale.t() # scale_b should be (1, N)
# quantize input x
x_fp8, x_inverse_scale = _quantize_rowwise(x, float8_dtype)
def linear(
x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias, use_fast_accum
):
y = torch._scaled_mm(
x_fp8,
w_t_fp8,
x_inverse_scale,
w_inverse_scale,
bias,
out_dtype=output_dtype,
use_fast_accum=use_fast_accum,
)
return y
linear_compiled = torch.compile(linear, backend="inductor")
def select_no_algorithm(*args, **kwargs):
raise NoValidChoicesError
def run_test(use_fast_accum):
with fresh_inductor_cache():
with config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTLASS",
"cuda.cutlass_max_profiling_configs": 2,
}
):
with mock.patch(
"torch._inductor.kernel.mm.autotune_select_algorithm",
wraps=select_no_algorithm,
) as sa:
with self.assertRaisesRegex(
InductorError, r".*NoValidChoicesError.*"
):
linear_compiled(
x_fp8,
x_inverse_scale,
w_t_fp8,
w_inverse_scale,
bias,
use_fast_accum,
)
args, _ = sa.call_args
_, choices, _, _ = args
cuda_template_count = 0
for choice in choices:
if isinstance(choice, CUDATemplateCaller):
choice_info = choice.info_dict()
op_conf_name = choice_info.get("op_conf_name", "")
assert isinstance(op_conf_name, str)
if use_fast_accum:
assert (
"fastaccum" in op_conf_name
), "Only fastaccum Kernels should have been allowed"
else:
assert (
"fastaccum" not in op_conf_name
), "fastaccum Kernels should have been filtered"
cuda_template_count += 1
assert cuda_template_count > 0, "No CUDATemplateCaller choices"
run_test(True)
run_test(False)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_cutlass_backend_shape_coverage_mm(
@ -1566,7 +1747,10 @@ class TestCutlassBackend(TestCase):
@unittest.skipIf(not SM90OrLater, "need sm_90")
@use_evt_config
@evt_all_ops
def test_evt_multi_output(self, op):
@parametrize(
"dynamic", (False, True)
) # To not drastically increase test time we only test dynamic on this test
def test_evt_multi_output(self, op, dynamic):
class TestModel(torch.nn.Module):
def forward(self, a, b, extra_args):
acc = a @ b
@ -1577,18 +1761,24 @@ class TestCutlassBackend(TestCase):
M = 1024
N = 512
a = torch.ones(M, N).cuda().half()
b = torch.ones(N, N).cuda().half()
extra_args = gen_args(op, (M, N))
model = TestModel().cuda()
shapes = [(512, 512)] if not dynamic else [(1024, 64), (128, 256)]
for i, shape in enumerate(shapes):
M, N = shape
a = torch.ones(M, N).cuda().half()
b = torch.ones(N, N).cuda().half()
extra_args = gen_args(op, (M, N))
model = TestModel().cuda()
result = torch.compile(model)(a, b, extra_args)
ref_result = model(a, b, extra_args)
result = torch.compile(model)(a, b, extra_args)
ref_result = model(a, b, extra_args)
self.assertEqual(
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 2
)
torch.testing.assert_close(result, ref_result)
self.assertEqual(
torch._dynamo.utils.counters["inductor"][
"cuda_epilogue_fusion_counter"
],
2 * (i + 1),
)
torch.testing.assert_close(result, ref_result)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@use_evt_config
@ -1648,9 +1838,9 @@ class TestCutlassBackend(TestCase):
"shape",
(
(
16,
16,
32,
512,
128,
64,
),
),
)
@ -1720,9 +1910,9 @@ class TestCutlassBackend(TestCase):
"shape",
(
(
16,
16,
32,
512,
128,
64,
),
),
)

View File

@ -347,7 +347,9 @@ return tmp_1, D""",
)
buffer_renames = {"buf0": "buf0", "buf1": "buf1", "acc": "buf0"}
name_to_buffer = {"buf0": row_major_buf0, "buf1": col_major_buf1}
result = create_example_tensors(buffer_renames, name_to_buffer)
result = create_example_tensors(
buffer_renames, name_to_buffer, lambda x: int(x)
)
self.assertEqual(result["acc"].shape, (3, 4, 1))
self.assertEqual(result["acc"].stride, (4, 1, 0))
self.assertEqual(
@ -370,7 +372,9 @@ return tmp_1, D""",
self.assertExpectedInline(
_render_argument_type(
epilogue_functor, _create_mock_buffer_name_map(EXAMPLE_TENSORS)
epilogue_functor,
_create_mock_buffer_name_map(EXAMPLE_TENSORS),
lambda x: int(x),
),
"""\
{ /* thread */
@ -425,7 +429,9 @@ def fn(accum, bias):
self.assertExpectedInline(
_render_argument_type(
epilogue_functor, _create_mock_buffer_name_map(example_tensors)
epilogue_functor,
_create_mock_buffer_name_map(example_tensors),
lambda x: int(x),
),
"""\
{ /* thread */
@ -450,6 +456,7 @@ def fn(accum, bias):
MockTileDescription(),
EpilogueScheduleType.ScheduleAuto,
_create_mock_buffer_name_map(EXAMPLE_TENSORS),
lambda x: x, # static shapes
)
self.assertExpectedInline(
code,

View File

@ -12,7 +12,7 @@ import torch._inductor.config as config
from torch import dtype as torch_dtype
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
from torch._inductor.scheduler import BaseSchedulerNode
from torch._inductor.utils import do_bench_using_profiling, Placeholder
from torch._inductor.utils import do_bench_using_profiling, OrderedSet, Placeholder
from torch.utils._sympy.value_ranges import ValueRanges
from .cutlass_utils import DTYPE_TO_CUTLASS_TYPE
@ -81,6 +81,7 @@ class CUDAKernel(Kernel):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.layout_args: dict[str, list[LayoutArg]] = defaultdict(list)
self.size_args: list[Union[Expr, int]] = []
# Mapping from arg name to IRNode.
self.named_nodes: dict[str, IRNode] = {}
@ -172,6 +173,9 @@ class CUDAKernel(Kernel):
LDD = get_ld(Y)
return (M, N, K, B, LDA, LDB, LDC, LDD)
def get_dynamic_shape_args(self) -> list[Union[Expr, int]]:
return [*self.get_layout_args(), *self.size_args]
@staticmethod
def find_ld_idx(node: IRNode) -> int:
strides = node.get_stride()
@ -257,6 +261,7 @@ class CUDATemplateKernel(CUDAKernel):
e.g. The template might have input argument defined as [X, W, Bias],
and the actual input passed into this template could be [Bias, X, W].
In this case, the `input_reorder` would be [2, 0, 1].
additional_size_args: Additional size arguments for epilogue inputs
"""
names = [x.strip() for x in names_str.strip().split(",")]
if len(inputs) + len(outputs) != len(names):
@ -276,17 +281,30 @@ class CUDATemplateKernel(CUDAKernel):
self.named_nodes[name] = node
self.args.input_buffers[node.get_name()] = name
free_symbols: OrderedSet[Expr] = OrderedSet()
for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs):
if node is not None:
self.named_nodes[name] = node
self.args.output_buffers[node.get_name()] = name
if name not in (
"X",
"W",
"Bias",
"Y",
): # we handle these symbolic shapes explicitly
for expr in itertools.chain(node.get_size(), node.get_stride()):
if isinstance(expr, Expr):
for s in expr.free_symbols:
free_symbols.add(s) # type: ignore[arg-type]
arg_defs, *_ = self.args.cpp_argdefs(DTYPE_TO_CUTLASS_TYPE)
self.init_layout_args()
size_args = [
f"const int {s}" for s in ("M", "N", "K", "B", "lda", "ldb", "ldc", "ldd")
]
size_vars = ["M", "N", "K", "B", "lda", "ldb", "ldc", "ldd"]
size_vars.extend(str(s) for s in free_symbols)
self.size_args.extend(free_symbols)
size_args = [f"const int {s}" for s in size_vars]
runtime_arg_decls = ",".join(
[f"{arg.ty} {arg.name}" for arg in self.runtime_arg_info]
@ -326,11 +344,11 @@ class CUDATemplateKernel(CUDAKernel):
else:
_, call_args, _, arg_types = self.args.python_argdefs()
layout_args = self.get_layout_args()
call_args.extend(layout_args) # type: ignore[arg-type]
dynamic_shape_args = self.get_dynamic_shape_args()
call_args.extend(dynamic_shape_args) # type: ignore[arg-type]
for arg in self.runtime_arg_values:
call_args.append(arg)
arg_types.extend("int" for a in layout_args)
arg_types.extend("int" for _ in dynamic_shape_args)
for arg in self.runtime_arg_info:
arg_types.append(arg.ty)
# dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar

View File

@ -116,7 +116,7 @@ class CUDATemplate(KernelTemplate):
expected_args,
)
V.graph.sizevars.size_hints(map(sympy.expand, call_args[len(expected_args) :]))
size_args = V.graph.sizevars.size_hints(kernel.get_layout_args())
size_args = V.graph.sizevars.size_hints(kernel.get_dynamic_shape_args())
extra_args = tuple(list(size_args) + self.get_runtime_arg_values(**kwargs))
kernel_hash = hashlib.sha256(code.encode("utf-8")).hexdigest()[:8]

View File

@ -1,4 +1,6 @@
from typing import Any, Union
from typing import Any, Callable, Union
from sympy import Expr
from torch._inductor.ir import (
ComputedBuffer,
@ -61,18 +63,13 @@ if try_import_cutlass():
def create_example_tensors(
var_name_to_buffer_name: dict[str, str],
name_to_buffer: dict[str, Buffer],
size_hint_fn: Callable[[Union[Expr, int]], int],
) -> dict[str, CutlassTensor]:
def cutlass_tensor_from_buffer(buffer: Buffer) -> CutlassTensor:
shape = buffer.get_layout().size
stride = buffer.get_layout().stride
assert all(isinstance(x, int) or x.is_integer for x in shape), (
f"{buffer.get_name()}'s shape {shape} contains symints which aren't supported for cutlass EVT"
)
assert all(isinstance(x, int) or x.is_integer for x in stride), (
f"{buffer.get_name()}'s stride {stride} contains symints which aren't supported for cutlass EVT"
)
shape = tuple(int(x) for x in shape)
stride = tuple(int(x) for x in stride)
shape = tuple(size_hint_fn(x) for x in shape)
stride = tuple(size_hint_fn(x) for x in stride)
is_row_major = is_contiguous_strides_for_shape(stride, shape)
is_column_major = is_contiguous_strides_for_shape(stride[::-1], shape[::-1])
@ -104,6 +101,7 @@ non-contiguous layout, recieved stride: {stride} and shape: {shape}"
tile_description: TileDescription,
epilogue_schedule: EpilogueScheduleType,
name_to_buffer: dict[str, Buffer],
size_hint_fn: Callable[[Union[Expr, int]], int],
**kwargs: dict[str, Any],
) -> tuple[str, str, str]:
cuda_arch = int(cuda_env.get_cuda_arch()) # type: ignore[arg-type]
@ -119,7 +117,7 @@ non-contiguous layout, recieved stride: {stride} and shape: {shape}"
fusion_callbacks,
)
evt_name, evt_code = collective_epilogue.emit()
evt_args = _render_argument_type(epilogue_functor, name_to_buffer)
evt_args = _render_argument_type(epilogue_functor, name_to_buffer, size_hint_fn)
return evt_name, evt_args, evt_code
# Based off of
@ -147,6 +145,7 @@ non-contiguous layout, recieved stride: {stride} and shape: {shape}"
def _render_argument_type(
epilogue_functor: EpilogueFunctor,
name_to_buffer: dict[str, Buffer],
size_hint_fn: Callable[[Union[Expr, int]], int],
) -> str:
epilogue_thread_type = epilogue_functor.epilogue_thread_type
@ -165,7 +164,10 @@ non-contiguous layout, recieved stride: {stride} and shape: {shape}"
buffer.writeline(f"{{}}, /* {name} */")
else:
fields = [
(fname, _get_arg_from_node(ty, name_to_buffer[name]))
(
fname,
_get_arg_from_node(ty, name_to_buffer[name], size_hint_fn),
)
for fname, ty in t._fields_
]
field_strs = [
@ -197,7 +199,9 @@ non-contiguous layout, recieved stride: {stride} and shape: {shape}"
return buffer.getvalue()
def _get_arg_from_node(arg_ty: type, node: Buffer) -> str:
def _get_arg_from_node(
arg_ty: type, node: Buffer, size_hint_fn: Callable[[Union[Expr, int]], int]
) -> str:
from ..cuda_template import CUTLASSTemplate
# Today, arguments are either a pointer to the
@ -209,7 +213,7 @@ non-contiguous layout, recieved stride: {stride} and shape: {shape}"
):
DEFAULT_STRIDE_LEN = 3
assert len(node.get_layout().stride) <= DEFAULT_STRIDE_LEN
stride = [int(x) for x in node.get_layout().stride]
stride = [size_hint_fn(x) for x in node.get_layout().stride]
for _ in range(DEFAULT_STRIDE_LEN - len(stride)):
stride.append(0)

View File

@ -314,7 +314,7 @@ DTYPE_TO_CUTLASS_TYPE = {
**DTYPE_TO_CPP,
torch.float16: "__half",
torch.bfloat16: "__nv_bfloat16",
torch.float8_e4m3fn: "cutlass::float_e4m3_t",
torch.float8_e4m3fn: "__nv_fp8_e4m3",
}

View File

@ -422,6 +422,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
alpha: float,
beta: float,
input_reorder: Optional[list[int]] = None,
use_fast_accum: Optional[bool] = None,
) -> None:
"""
Args:
@ -437,6 +438,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
)
self.alpha = alpha
self.beta = beta
self.use_fast_accum = use_fast_accum
assert 2 <= len(input_nodes) <= 5
assert self._are_inputs_layout_compatible(
[node.get_layout() for node in input_nodes]
@ -453,6 +455,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
alpha: Union[float, int] = 1,
beta: Union[float, int] = 0,
input_reorder: Optional[list[int]] = None,
use_fast_accum: Optional[bool] = None,
**extra_kwargs,
) -> None:
raise NotImplementedError
@ -559,6 +562,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
self.maybe_append_choice(
choices, description=description, op=op, swizzle=swizzle
)
if len(ops) == 0:
input_layouts = [node.get_layout() for node in input_nodes]
input_strides = [node.get_stride() for node in input_nodes]
@ -873,6 +877,11 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
# TODO: update epilogue functor according to epilogues.
op.element_epilogue = op.accumulator_type()
if self.use_fast_accum is not None:
is_op_fast_accum = "fastaccum" in op.configuration_name()
if self.use_fast_accum ^ is_op_fast_accum:
return None
# Set bias layout and alignment.
status = self._set_bias_layout_and_alignment(op)
if not status:
@ -1243,8 +1252,11 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
alpha: float,
beta: float,
input_reorder: Optional[list[int]] = None,
use_fast_accum: Optional[bool] = None,
):
super().__init__(input_nodes, layout, alpha, beta, input_reorder)
super().__init__(
input_nodes, layout, alpha, beta, input_reorder, use_fast_accum
)
@staticmethod
def add_cutlass_gemm_choices(
@ -1254,10 +1266,16 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
alpha: Union[float, int] = 1,
beta: Union[float, int] = 0,
input_reorder: Optional[list[int]] = None,
use_fast_accum: Optional[bool] = None,
**extra_kwargs,
) -> None:
template = CUTLASS3xGemmTemplate(
input_nodes, layout, alpha, beta, input_reorder
input_nodes,
layout,
alpha,
beta,
input_reorder,
use_fast_accum,
)
template._add_cutlass_gemm_choices(
choices, layout, input_nodes, alpha, beta, input_reorder, **extra_kwargs
@ -1390,6 +1408,7 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
examples = create_example_tensors(
var_name_to_buffer_name,
name_to_buffer, # type: ignore[arg-type]
V.graph.sizevars.size_hint,
)
evt_name, evt_args, evt_code = trace(
evt_py_code,
@ -1399,6 +1418,7 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
op.tile_description, # type: ignore[attr-defined]
op.epilogue_schedule, # type: ignore[attr-defined]
{k: name_to_buffer[v] for k, v in var_name_to_buffer_name.items()}, # type: ignore[arg-type,misc]
V.graph.sizevars.size_hint,
)
return (
@ -1624,6 +1644,7 @@ class CUTLASS2xGemmTemplate(CUTLASSGemmTemplate):
alpha: Union[float, int] = 1,
beta: Union[float, int] = 0,
input_reorder: Optional[list[int]] = None,
use_fast_accum: Optional[bool] = False,
**extra_kwargs,
) -> None:
template = CUTLASS2xGemmTemplate(

View File

@ -1158,12 +1158,12 @@ def tuned_scaled_mm(
)
if is_nonzero and use_cutlass_template(layout, m, n, k):
if use_fast_accum:
log.warning(
"use_fast_accum=True is not supported by cutlass template, skipping cutlass choices"
)
else:
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, input_nodes) # type: ignore[arg-type]
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
choices,
layout,
input_nodes, # type: ignore[arg-type]
use_fast_accum=use_fast_accum, # type: ignore[arg-type]
)
if is_nonzero and use_ck_gemm_template(layout, m, n, k):
CKGemmTemplate.add_ck_gemm_choices(choices, layout, input_nodes)

View File

@ -574,7 +574,7 @@ class SizeVarAllocator:
def size_hints(
self,
exprs: Iterable[Expr],
exprs: Iterable[Union[Expr, int]],
*,
fallback: Optional[int] = None,
) -> tuple[int, ...]:

View File

@ -12,6 +12,7 @@
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#endif
namespace torch::aot_inductor {