Compare commits

...

1 Commits

Author SHA1 Message Date
d76a1d38f2 Initial commit 2025-08-06 13:30:16 -07:00
2 changed files with 13 additions and 4 deletions

View File

@ -67,9 +67,12 @@ if try_import_cutlass():
name_to_buffer: dict[str, Buffer],
size_hint_fn: Callable[[Union[Expr, int]], int],
) -> dict[str, python_cutlass.backend.evt.ir.tensor.Tensor]:
key = None
def cutlass_tensor_from_buffer(
buffer: Buffer,
) -> python_cutlass.backend.evt.ir.tensor.Tensor:
nonlocal key
shape = buffer.get_layout().size
stride = buffer.get_layout().stride
shape = tuple(size_hint_fn(x) for x in shape)
@ -84,6 +87,10 @@ if try_import_cutlass():
non-contiguous layout, received stride: {stride} and shape: {shape}"
)
print(
f"{key}: {shape}, {LayoutType.RowMajor if is_row_major else LayoutType.ColumnMajor}, {torch_dtype_to_cutlass_type(buffer.get_layout().dtype)}"
)
return python_cutlass.backend.evt.ir.tensor.Tensor(
shape=shape,
layout_tag=(
@ -92,10 +99,11 @@ non-contiguous layout, received stride: {stride} and shape: {shape}"
element=torch_dtype_to_cutlass_type(buffer.get_layout().dtype),
)
return {
key: cutlass_tensor_from_buffer(name_to_buffer[name])
for key, name in var_name_to_buffer_name.items()
}
out = dict()
for key, name in var_name_to_buffer_name.items():
out[key] = cutlass_tensor_from_buffer(name_to_buffer[name])
return out
def trace(
fn_src: str,

View File

@ -1502,6 +1502,7 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
acc_dtype = torch_dtype_to_cutlass_type(accumulator_dtype)
output_dtype = torch_dtype_to_cutlass_type(output_dtype)
print(evt_py_code)
examples = create_example_tensors(
var_name_to_buffer_name,
name_to_buffer, # type: ignore[arg-type]