mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Add shape to load_input in matmul templates (#162513)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162513 Approved by: https://github.com/eellison ghstack dependencies: #162426
This commit is contained in:
committed by
PyTorch MergeBot
parent
f17c5e0789
commit
f654cff566
@ -254,7 +254,7 @@ class PartialRender:
|
||||
class SubgraphInfo:
|
||||
body: IndentedBuffer
|
||||
template_mask: Optional[str] = None
|
||||
template_out: Optional[str] = None
|
||||
template_out_shape: Optional[Union[str, tuple[str]]] = None
|
||||
compute: IndentedBuffer = dataclasses.field(default_factory=IndentedBuffer)
|
||||
indexing_code: IndentedBuffer = dataclasses.field(default_factory=IndentedBuffer)
|
||||
loads: IndentedBuffer = dataclasses.field(default_factory=IndentedBuffer)
|
||||
@ -445,7 +445,7 @@ class TritonTemplateKernel(TritonKernel):
|
||||
self.loads: IndentedBuffer = FakeIndentedBuffer()
|
||||
self.stores: IndentedBuffer = FakeIndentedBuffer()
|
||||
self.template_mask: Optional[str] = None
|
||||
self.template_out: Optional[str] = None
|
||||
self.template_out_shape: Optional[Union[str, tuple[str]]] = None
|
||||
self.ops_handler: Optional[V.WrapperHandler] = None # type: ignore[name-defined]
|
||||
|
||||
# When caching is enabled, the generated code is not dependent on the input nodes names, or
|
||||
@ -841,6 +841,7 @@ class TritonTemplateKernel(TritonKernel):
|
||||
mask: Optional[str] = None,
|
||||
other: Optional[Union[float, int]] = 0.0,
|
||||
indent_width: int = 4,
|
||||
index_shape: Optional[tuple[str]] = None,
|
||||
):
|
||||
"""Loads an input and applies any necessary preprocessing or masking.
|
||||
|
||||
@ -918,7 +919,7 @@ class TritonTemplateKernel(TritonKernel):
|
||||
# We are using "None" for clarity in output code, but
|
||||
# we could alternatively emit `xmask = tl.full([xindex.shape], True, tl.int1)`
|
||||
self.template_mask = mask if mask is not None else "None"
|
||||
self.template_out = "xindex"
|
||||
self.template_out_shape = index_shape if index_shape else "xindex"
|
||||
self.template_indices = indices
|
||||
self.named_input_nodes[input_name].data.freeze_layout()
|
||||
self.cse.invalidate(OrderedSet())
|
||||
@ -981,7 +982,7 @@ class TritonTemplateKernel(TritonKernel):
|
||||
else:
|
||||
out_indexing = self.indexing(
|
||||
output_index,
|
||||
copy_shape=self.template_out,
|
||||
copy_shape=self.template_out_shape,
|
||||
override_mask=self.template_mask,
|
||||
)
|
||||
from .codegen.triton import IndexingOptions
|
||||
@ -1020,7 +1021,7 @@ class TritonTemplateKernel(TritonKernel):
|
||||
val: str,
|
||||
mask: Optional[str] = None,
|
||||
indent_width: int = 4,
|
||||
val_shape: Optional[list[str]] = None,
|
||||
val_shape: Optional[tuple[str]] = None,
|
||||
):
|
||||
"""Stores the final output and appends any epilogue fusions if the buffer hasn't been optimized away.
|
||||
|
||||
@ -1059,7 +1060,7 @@ class TritonTemplateKernel(TritonKernel):
|
||||
"xindex"
|
||||
)
|
||||
self.template_mask = mask
|
||||
self.template_out = val
|
||||
self.template_out_shape = val_shape if val_shape else val
|
||||
self.template_indices = indices
|
||||
output_index = self.output_node.get_layout().make_indexer()(index_symbols)
|
||||
output_index = self.rename_indexing(output_index)
|
||||
@ -1209,7 +1210,7 @@ class TritonTemplateKernel(TritonKernel):
|
||||
dense_indexing=False,
|
||||
# We pass template_out as the shape to broadcast the indexing to as
|
||||
# the mask might be broadcast to the output shape
|
||||
copy_shape=self.template_out,
|
||||
copy_shape=self.template_out_shape,
|
||||
override_mask=self.template_mask,
|
||||
block_ptr=block_ptr,
|
||||
tma_compatibility_checker=tma_compatibility_checker,
|
||||
|
Reference in New Issue
Block a user