[inductor] Add shape to load_input in matmul templates (#162513)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162513
Approved by: https://github.com/eellison
ghstack dependencies: #162426
This commit is contained in:
Isuru Fernando
2025-09-10 20:06:51 +00:00
committed by PyTorch MergeBot
parent f17c5e0789
commit f654cff566
3 changed files with 31 additions and 17 deletions

View File

@ -254,7 +254,7 @@ class PartialRender:
class SubgraphInfo:
body: IndentedBuffer
template_mask: Optional[str] = None
template_out: Optional[str] = None
template_out_shape: Optional[Union[str, tuple[str]]] = None
compute: IndentedBuffer = dataclasses.field(default_factory=IndentedBuffer)
indexing_code: IndentedBuffer = dataclasses.field(default_factory=IndentedBuffer)
loads: IndentedBuffer = dataclasses.field(default_factory=IndentedBuffer)
@ -445,7 +445,7 @@ class TritonTemplateKernel(TritonKernel):
self.loads: IndentedBuffer = FakeIndentedBuffer()
self.stores: IndentedBuffer = FakeIndentedBuffer()
self.template_mask: Optional[str] = None
self.template_out: Optional[str] = None
self.template_out_shape: Optional[Union[str, tuple[str]]] = None
self.ops_handler: Optional[V.WrapperHandler] = None # type: ignore[name-defined]
# When caching is enabled, the generated code is not dependent on the input nodes names, or
@ -841,6 +841,7 @@ class TritonTemplateKernel(TritonKernel):
mask: Optional[str] = None,
other: Optional[Union[float, int]] = 0.0,
indent_width: int = 4,
index_shape: Optional[tuple[str]] = None,
):
"""Loads an input and applies any necessary preprocessing or masking.
@ -918,7 +919,7 @@ class TritonTemplateKernel(TritonKernel):
# We are using "None" for clarity in output code, but
# we could alternatively emit `xmask = tl.full([xindex.shape], True, tl.int1)`
self.template_mask = mask if mask is not None else "None"
self.template_out = "xindex"
self.template_out_shape = index_shape if index_shape else "xindex"
self.template_indices = indices
self.named_input_nodes[input_name].data.freeze_layout()
self.cse.invalidate(OrderedSet())
@ -981,7 +982,7 @@ class TritonTemplateKernel(TritonKernel):
else:
out_indexing = self.indexing(
output_index,
copy_shape=self.template_out,
copy_shape=self.template_out_shape,
override_mask=self.template_mask,
)
from .codegen.triton import IndexingOptions
@ -1020,7 +1021,7 @@ class TritonTemplateKernel(TritonKernel):
val: str,
mask: Optional[str] = None,
indent_width: int = 4,
val_shape: Optional[list[str]] = None,
val_shape: Optional[tuple[str]] = None,
):
"""Stores the final output and appends any epilogue fusions if the buffer hasn't been optimized away.
@ -1059,7 +1060,7 @@ class TritonTemplateKernel(TritonKernel):
"xindex"
)
self.template_mask = mask
self.template_out = val
self.template_out_shape = val_shape if val_shape else val
self.template_indices = indices
output_index = self.output_node.get_layout().make_indexer()(index_symbols)
output_index = self.rename_indexing(output_index)
@ -1209,7 +1210,7 @@ class TritonTemplateKernel(TritonKernel):
dense_indexing=False,
# We pass template_out as the shape to broadcast the indexing to as
# the mask might be broadcast to the output shape
copy_shape=self.template_out,
copy_shape=self.template_out_shape,
override_mask=self.template_mask,
block_ptr=block_ptr,
tma_compatibility_checker=tma_compatibility_checker,