Compare commits

...

1 Commits

Author SHA1 Message Date
d76a1d38f2 Initial commit 2025-08-06 13:30:16 -07:00
2 changed files with 13 additions and 4 deletions

View File

@ -67,9 +67,12 @@ if try_import_cutlass():
name_to_buffer: dict[str, Buffer], name_to_buffer: dict[str, Buffer],
size_hint_fn: Callable[[Union[Expr, int]], int], size_hint_fn: Callable[[Union[Expr, int]], int],
) -> dict[str, python_cutlass.backend.evt.ir.tensor.Tensor]: ) -> dict[str, python_cutlass.backend.evt.ir.tensor.Tensor]:
key = None
def cutlass_tensor_from_buffer( def cutlass_tensor_from_buffer(
buffer: Buffer, buffer: Buffer,
) -> python_cutlass.backend.evt.ir.tensor.Tensor: ) -> python_cutlass.backend.evt.ir.tensor.Tensor:
nonlocal key
shape = buffer.get_layout().size shape = buffer.get_layout().size
stride = buffer.get_layout().stride stride = buffer.get_layout().stride
shape = tuple(size_hint_fn(x) for x in shape) shape = tuple(size_hint_fn(x) for x in shape)
@ -84,6 +87,10 @@ if try_import_cutlass():
non-contiguous layout, received stride: {stride} and shape: {shape}" non-contiguous layout, received stride: {stride} and shape: {shape}"
) )
print(
f"{key}: {shape}, {LayoutType.RowMajor if is_row_major else LayoutType.ColumnMajor}, {torch_dtype_to_cutlass_type(buffer.get_layout().dtype)}"
)
return python_cutlass.backend.evt.ir.tensor.Tensor( return python_cutlass.backend.evt.ir.tensor.Tensor(
shape=shape, shape=shape,
layout_tag=( layout_tag=(
@ -92,10 +99,11 @@ non-contiguous layout, received stride: {stride} and shape: {shape}"
element=torch_dtype_to_cutlass_type(buffer.get_layout().dtype), element=torch_dtype_to_cutlass_type(buffer.get_layout().dtype),
) )
return { out = dict()
key: cutlass_tensor_from_buffer(name_to_buffer[name]) for key, name in var_name_to_buffer_name.items():
for key, name in var_name_to_buffer_name.items() out[key] = cutlass_tensor_from_buffer(name_to_buffer[name])
}
return out
def trace( def trace(
fn_src: str, fn_src: str,

View File

@ -1502,6 +1502,7 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
acc_dtype = torch_dtype_to_cutlass_type(accumulator_dtype) acc_dtype = torch_dtype_to_cutlass_type(accumulator_dtype)
output_dtype = torch_dtype_to_cutlass_type(output_dtype) output_dtype = torch_dtype_to_cutlass_type(output_dtype)
print(evt_py_code)
examples = create_example_tensors( examples = create_example_tensors(
var_name_to_buffer_name, var_name_to_buffer_name,
name_to_buffer, # type: ignore[arg-type] name_to_buffer, # type: ignore[arg-type]