mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 14:15:01 +08:00
As part of #125683, this PR adds epilogue fusion support for bf16/fp16 gemms. The key changes are as follows: 1. bf16 linear w/ epilogue fusion of some ops was originally supported via ATen oneDNN linear pointwise ops. In order to match the ATen op semantics, in-template epilogue support is added to the cpp gemm template so that we would have: "gemm + in-template epilogues -> template buffer". If the template is chosen for codegen, the in-template epilogues will be concatenated with the out-of-template epilogues that are appended during the scheduling. 2. Support bf16/fp16 legalization for `codegen_loop_bodies` which is used to generate the epilogue loops. 3. We used to leverage the in-place buffer mechanism to handle the in-place buffers in the epilogue codegen, in particular, for the reuses for output buffers of GEMM, template and epilogues. This is not correct since the output buffer is an "output" not an "in-place" buffer of the template kernel itself. Now, we use a dedicated "aliases" dict to manage such buffer reuses and the intermediate aliasing buffers are removed after codegen. 4. Add `localize_buffer` method to `LocalBufferScope` to allow the replacement of a global buffer with a local one in the given inductor IR nodes. This helps the fused loops to work on smaller-sized local buffers for better data locality. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126545 Approved by: https://github.com/jansel
120 lines
3.7 KiB
Python
120 lines
3.7 KiB
Python
# mypy: allow-untyped-defs
|
|
import functools
|
|
import itertools
|
|
import logging
|
|
|
|
import sys
|
|
from typing import Callable, List, Optional
|
|
from unittest.mock import patch
|
|
|
|
import sympy
|
|
|
|
from .. import codecache, config, ir
|
|
from ..autotune_process import CppBenchmarkRequest, TensorMeta
|
|
from ..utils import IndentedBuffer, Placeholder, unique
|
|
from ..virtualized import V
|
|
from .common import KernelTemplate
|
|
from .cpp_template_kernel import CppTemplateCaller, CppTemplateKernel
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class CppTemplate(KernelTemplate):
|
|
index_counter = itertools.count()
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
input_nodes,
|
|
layout: ir.Layout,
|
|
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
|
|
):
|
|
super().__init__(name)
|
|
self.input_nodes = input_nodes
|
|
self.output_node: ir.Buffer = ir.Buffer("buf_out", layout)
|
|
self.layout = layout
|
|
self.epilogue_creator = epilogue_creator
|
|
|
|
def generate(self, **kwargs):
|
|
kernel_name = f"cpp_{self.name}"
|
|
with patch.object(
|
|
V.graph, "get_dtype", self._fake_get_dtype(self.output_node)
|
|
), patch.object(ir.FlexibleLayout, "allow_indexing", True), CppTemplateKernel(
|
|
kernel_name=kernel_name,
|
|
) as kernel:
|
|
code = kernel.render(self, **kwargs)
|
|
_, call_args, _, _ = kernel.args.python_argdefs()
|
|
log.debug("Generated Code:\n%s", code)
|
|
log.debug(
|
|
"Args: cpp_argdefs: %s, python_argdefs: %s",
|
|
kernel.args.cpp_argdefs(),
|
|
kernel.args.python_argdefs(),
|
|
)
|
|
|
|
expected_args = list(
|
|
unique(input_node.get_name() for input_node in self.input_nodes)
|
|
)
|
|
expected_args.extend([self.output_node.get_name()])
|
|
assert list(call_args)[: len(expected_args)] == expected_args, (
|
|
call_args,
|
|
expected_args,
|
|
)
|
|
extra_args = V.graph.sizevars.size_hints(
|
|
map(sympy.expand, call_args[len(expected_args) :])
|
|
)
|
|
|
|
kernel_hash_name = f"cpp_{self.name}_{next(self.index_counter)}"
|
|
|
|
# Create the BenchmarkRequest for CPP
|
|
bmreq = CppBenchmarkRequest(
|
|
kernel_name=kernel_name,
|
|
input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes),
|
|
output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
|
|
extra_args=extra_args,
|
|
source_code=code,
|
|
)
|
|
|
|
def make_kernel_render(
|
|
template_node: ir.CppTemplateBuffer,
|
|
epilogue_nodes: Optional[List[ir.IRNode]] = None,
|
|
):
|
|
kernel = CppTemplateKernel(
|
|
kernel_name=str(Placeholder.KERNEL_NAME),
|
|
)
|
|
render = functools.partial(
|
|
kernel.render,
|
|
self,
|
|
template_buffer_node=template_node,
|
|
epilogue_nodes=epilogue_nodes,
|
|
**kwargs,
|
|
)
|
|
return kernel, render
|
|
|
|
return CppTemplateCaller(
|
|
kernel_hash_name,
|
|
self.name,
|
|
self.input_nodes,
|
|
self.output_node.get_layout(),
|
|
make_kernel_render,
|
|
bmreq,
|
|
self,
|
|
)
|
|
|
|
def header(self) -> IndentedBuffer:
|
|
res = IndentedBuffer()
|
|
res.writeline(codecache.cpp_prefix())
|
|
res.splice(
|
|
"""
|
|
#include "c10/util/Unroll.h"
|
|
"""
|
|
)
|
|
enable_kernel_profile = (
|
|
config.cpp.enable_kernel_profile and sys.platform == "linux"
|
|
)
|
|
if enable_kernel_profile:
|
|
res.writelines(["#include <ATen/record_function.h>"])
|
|
return res
|
|
|
|
def render(self, **kwargs) -> str:
|
|
raise NotImplementedError
|