mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[Triton] [Inductor] Enable Epilogue Subtiling in the blackwell ws template (#163145)
Summary: Enables support for epilogue subtiling in the blackwell ws template. This requires the ability to call `store_output` twice in the same kernel and reuse the same tensor descriptor across allocations. Test Plan: Tested with test_max_autotune.py on a Blackwell server. Rollback Plan: Differential Revision: D82610077 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163145 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
124dd364e9
commit
0390798dad
@ -107,6 +107,8 @@ if TYPE_CHECKING:
|
||||
|
||||
from torch._inductor.codegen.simd import IterationRangesRoot
|
||||
|
||||
from .codegen.common import CSE
|
||||
|
||||
|
||||
class KernelNamespace:
|
||||
pass
|
||||
@ -261,13 +263,14 @@ class SubgraphInfo:
|
||||
loads: IndentedBuffer = dataclasses.field(default_factory=IndentedBuffer)
|
||||
stores: IndentedBuffer = dataclasses.field(default_factory=IndentedBuffer)
|
||||
ops_handler: Optional[V.WrapperHandler] = None # type: ignore[name-defined]
|
||||
cse: Optional["CSE[Any]"] = None
|
||||
|
||||
# only copied over if not None
|
||||
range_trees: Optional[list["IterationRangesRoot"]] = None
|
||||
numels: Optional[dict[str, sympy.Expr]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.only_copy_if_non_none_fields = ("range_trees", "numels")
|
||||
self.only_copy_if_non_none_fields = ("range_trees", "numels", "cse")
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
@ -557,12 +560,10 @@ class TritonTemplateKernel(TritonKernel):
|
||||
setattr(self, key, value)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def create_subgraph_body(self, body_name: str):
|
||||
def create_subgraph_body(self, body_name: str, clear_cse: bool = False):
|
||||
assert body_name not in self.subgraph_bodies
|
||||
self.subgraph_bodies[body_name] = SubgraphInfo(
|
||||
IndentedBuffer(),
|
||||
None,
|
||||
None,
|
||||
IndentedBuffer(), None, None, cse=self.cse.clone() if clear_cse else None
|
||||
)
|
||||
with self.set_subgraph_body(body_name):
|
||||
yield
|
||||
@ -1071,7 +1072,13 @@ class TritonTemplateKernel(TritonKernel):
|
||||
# XBLOCK/YBLOCK and xoffset/yoffset. We append XBLOCK/YBLOCK
|
||||
# to the top of the kernel so we can safely extract the tensor
|
||||
# descriptor construction to the top of the kernel.
|
||||
self.defines += f"{block_name}: tl.constexpr = {block_size}\n"
|
||||
if block_name in self.prologue_cache:
|
||||
assert self.prologue_cache[block_name] == block_size, (
|
||||
f"Constant {block_name} must be used for all stores"
|
||||
)
|
||||
else:
|
||||
self.prologue_cache[block_name] = block_size
|
||||
self.prologue.writeline(f"{block_name}: tl.constexpr = {block_size}")
|
||||
else:
|
||||
block_name = block_size
|
||||
line0 = f"{offset_name} = {texpr(tma_index)}"
|
||||
@ -1124,7 +1131,10 @@ class TritonTemplateKernel(TritonKernel):
|
||||
block_indexing (bool): Are the input indices presented as offsets for creating the block (e.g.
|
||||
inputs to TMA) or are they tensors that should be passed in directly.
|
||||
"""
|
||||
with self.create_subgraph_body("<STORE_OUTPUT>"):
|
||||
subgraph_name = self._get_store_output_subgraph_name(
|
||||
next(self.store_output_ctr)
|
||||
)
|
||||
with self.create_subgraph_body(subgraph_name, clear_cse=True):
|
||||
assert isinstance(indices, (list, tuple))
|
||||
assert isinstance(val, str)
|
||||
assert isinstance(mask, (str, type(None)))
|
||||
@ -1300,13 +1310,14 @@ class TritonTemplateKernel(TritonKernel):
|
||||
self.codegen_body()
|
||||
|
||||
def hook():
|
||||
# more stuff might have been added since the codegen_body above
|
||||
self.codegen_body()
|
||||
self.cse.invalidate(OrderedSet())
|
||||
with self.set_subgraph_body(subgraph_name):
|
||||
# more stuff might have been added since the codegen_body above
|
||||
self.codegen_body()
|
||||
self.cse.invalidate(OrderedSet())
|
||||
|
||||
return textwrap.indent(self.body.getvalue(), " " * indent_width).strip()
|
||||
return textwrap.indent(self.body.getvalue(), " " * indent_width).strip()
|
||||
|
||||
return self._register_hook("<STORE_OUTPUT>", hook)
|
||||
return self._register_hook(subgraph_name, hook)
|
||||
|
||||
def _register_hook(
|
||||
self,
|
||||
@ -1812,8 +1823,7 @@ class TritonTemplate(KernelTemplate):
|
||||
|
||||
try:
|
||||
template = kernel.render(self.template, kwargs, caching_enabled)
|
||||
with kernel.set_subgraph_body("<STORE_OUTPUT>"):
|
||||
code = template.finalize_all()
|
||||
code = template.finalize_all()
|
||||
except ZeroDivisionError:
|
||||
# TODO(nmacchioni): fix sympy division by zero
|
||||
return None
|
||||
|
Reference in New Issue
Block a user