Compare commits

...

2 Commits

Author SHA1 Message Date
4a99eee0d7 [Cutlass-EVT] Fix buffer size issues
ghstack-source-id: ef63de21b4f4717afa29c95be5f2c25c0eb48dfa
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161335
2025-08-25 11:16:19 -07:00
e79ce51cbd [Cutlass] Fix regression from f7ad69f
ghstack-source-id: d50b70e33734407431a9259b75c3234c19441aab
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161398
2025-08-25 11:16:19 -07:00
2 changed files with 39 additions and 36 deletions

View File

@ -345,29 +345,31 @@ return tmp_1, D""",
from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import (
create_example_tensors,
)
from torch._inductor.virtualized import V
row_major_buf0 = MockComputedBuffer(
"buf0", None, torch.float32, (3, 4, 1), (4, 1, 0)
)
col_major_buf1 = MockComputedBuffer(
"buf1", None, torch.float32, (3, 2, 1), (1, 3, 0)
)
buffer_renames = {"buf0": "buf0", "buf1": "buf1", "acc": "buf0"}
name_to_buffer = {"buf0": row_major_buf0, "buf1": col_major_buf1}
result = create_example_tensors(
buffer_renames, name_to_buffer, lambda x: int(x)
)
self.assertEqual(result["acc"].shape, (3, 4, 1))
self.assertEqual(result["acc"].stride, (4, 1, 0))
self.assertEqual(
result["acc"].element, torch_dtype_to_cutlass_type(torch.float32)
)
with V.set_graph_handler(MockGraphHandler({})):
row_major_buf0 = MockComputedBuffer(
"buf0", None, torch.float32, (3, 4, 1), (4, 1, 0)
)
col_major_buf1 = MockComputedBuffer(
"buf1", None, torch.float32, (3, 2, 1), (1, 3, 0)
)
buffer_renames = {"buf0": "buf0", "buf1": "buf1", "acc": "buf0"}
name_to_buffer = {"buf0": row_major_buf0, "buf1": col_major_buf1}
result = create_example_tensors(
buffer_renames, name_to_buffer, lambda x: int(x)
)
self.assertEqual(result["acc"].shape, (3, 4, 1))
self.assertEqual(result["acc"].stride, (4, 1, 0))
self.assertEqual(
result["acc"].element, torch_dtype_to_cutlass_type(torch.float32)
)
self.assertEqual(result["buf1"].shape, (3, 2, 1))
self.assertEqual(result["buf1"].stride, (1, 3, 0))
self.assertEqual(
result["buf1"].element, torch_dtype_to_cutlass_type(torch.float32)
)
self.assertEqual(result["buf1"].shape, (3, 2, 1))
self.assertEqual(result["buf1"].stride, (1, 3, 0))
self.assertEqual(
result["buf1"].element, torch_dtype_to_cutlass_type(torch.float32)
)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(not try_import_cutlass(), "requires cutlass")

View File

@ -1168,6 +1168,10 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
op = self.swap_XW(op)
should_swap_xw = True
name_to_buffer = {node.get_name(): node for node in self.input_nodes}
# handle the fake output buffer during lowering
name_to_buffer[Y.get_name()] = Y # type: ignore[assignment]
if epilogue_nodes or is_scaled_mm:
if epilogue_nodes:
(
@ -1179,12 +1183,15 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
Y.get_name(), epilogue_nodes, V.kernel.removed_buffers
)
# TODO: mlazos remove this by returning buffer metadata from
# ir_to_evt_python code
for name, buf in (
V.graph.name_to_buffer | V.graph.graph_inputs
).items():
if name not in name_to_buffer:
name_to_buffer[name] = buf # type: ignore[assignment]
D_output_name = var_name_to_buffer_name["D"]
name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs
for name in V.graph.constants.keys():
name_to_buffer[name] = V.graph.add_tensor_constant(
V.graph.constants[name], name
)
D_output_buffer = name_to_buffer[D_output_name]
Y = D_output_buffer # type: ignore[assignment]
# Interestingly, I don't think the rest of the layout matters here since we
@ -1229,6 +1236,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
op,
evt_py_code,
var_name_to_buffer_name,
name_to_buffer,
Y.get_dtype(),
acc_dtype,
)
@ -1327,6 +1335,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
op: GemmOperation,
evt_py_code: str,
buffer_renames: dict[str, str],
name_to_buffer: dict[str, Buffer],
output_dtype: torch.dtype,
accumulator_dtype: torch.dtype,
) -> tuple[str, str, str, EVTArgRenames]: # type: ignore[name-defined] # noqa: F821
@ -1488,23 +1497,15 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
op: GemmOperation,
evt_py_code: str,
var_name_to_buffer_name: dict[str, str],
name_to_buffer: dict[str, Buffer],
output_dtype: torch.dtype,
accumulator_dtype: torch.dtype,
) -> tuple[str, str, str, EVTArgRenames]:
from .cutlass_lib_extensions.evt_extensions import create_example_tensors, trace
name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs
for name in V.graph.constants.keys():
name_to_buffer[name] = V.graph.add_tensor_constant(
V.graph.constants[name], name
)
# handle the fake output buffer during lowering
name_to_buffer[self.output_node.get_name()] = self.output_node # type: ignore[assignment]
acc_dtype = torch_dtype_to_cutlass_type(accumulator_dtype)
output_dtype = torch_dtype_to_cutlass_type(output_dtype)
examples = create_example_tensors(
var_name_to_buffer_name,
name_to_buffer, # type: ignore[arg-type]