mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC https://github.com/pytorch/pytorch/issues/125683 for more background info. 1. Cpp template infrastructure Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates. 2. Initial FP32 gemm template This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction. 3. Correctness and performance The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details. Static shapes | Benchmark | torchbench | huggingface | timm_models | |------------|-------------|--------------|--------------| | Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x | | Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x | | Single-threaded (baseline) | 1.56x | 1.19x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x | Key models being sped up: drq: 1.14x soft_act: 1.12 cait_m36_384: 1.18x Dynamic shapes | Benchmark | torchbench | huggingface | timm_models | | --- | --- | --- | --- | | Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x | | Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x | | Single-threaded (baseline) | 1.55x | 1.20x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x | Key models being sped up: BERT_pytorch: 1.22x pyhpc_turbulent: 1.13x soft_actor_critic: 1.77x BlenderbotForCausalLM: 1.09x cait_m36_384: 1.17x Pull Request resolved: https://github.com/pytorch/pytorch/pull/124021 Approved by: https://github.com/jansel
117 lines
3.5 KiB
Python
117 lines
3.5 KiB
Python
import functools
|
|
import itertools
|
|
import logging
|
|
|
|
import sys
|
|
from typing import List, Optional
|
|
from unittest.mock import patch
|
|
|
|
import sympy
|
|
|
|
from .. import codecache, config, ir
|
|
from ..autotune_process import CppBenchmarkRequest, TensorMeta
|
|
from ..utils import IndentedBuffer, Placeholder, unique
|
|
from ..virtualized import V
|
|
from .common import KernelTemplate
|
|
from .cpp_template_kernel import CppTemplateCaller, CppTemplateKernel
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class CppTemplate(KernelTemplate):
|
|
index_counter = itertools.count()
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
input_nodes,
|
|
layout: ir.Layout,
|
|
):
|
|
super().__init__(name)
|
|
self.input_nodes = input_nodes
|
|
self.output_node: ir.Buffer = ir.Buffer("buf_out", layout)
|
|
self.layout = layout
|
|
|
|
def generate(self, **kwargs):
|
|
kernel_name = f"cpp_{self.name}"
|
|
with patch.object(
|
|
V.graph, "get_dtype", self._fake_get_dtype(self.output_node)
|
|
), CppTemplateKernel(
|
|
kernel_name=kernel_name,
|
|
) as kernel:
|
|
code = self.render(kernel=kernel, **kwargs)
|
|
_, call_args, _ = kernel.args.python_argdefs()
|
|
log.debug("Generated Code:\n%s", code)
|
|
log.debug(
|
|
"Args: cpp_argdefs: %s, python_argdefs: %s",
|
|
kernel.args.cpp_argdefs(),
|
|
kernel.args.python_argdefs(),
|
|
)
|
|
|
|
expected_args = list(
|
|
unique(input_node.get_name() for input_node in self.input_nodes)
|
|
)
|
|
expected_args.extend([self.output_node.get_name()])
|
|
assert list(call_args)[: len(expected_args)] == expected_args, (
|
|
call_args,
|
|
expected_args,
|
|
)
|
|
extra_args = V.graph.sizevars.size_hints(
|
|
map(sympy.expand, call_args[len(expected_args) :])
|
|
)
|
|
|
|
kernel_hash_name = f"cpp_{self.name}_{next(self.index_counter)}"
|
|
|
|
# Create the BenchmarkRequest for CPP
|
|
bmreq = CppBenchmarkRequest(
|
|
kernel_name=kernel_name,
|
|
input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes),
|
|
output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
|
|
extra_args=extra_args,
|
|
source_code=code,
|
|
)
|
|
|
|
def make_kernel_render(
|
|
template_node: ir.CppTemplateBuffer,
|
|
epilogue_nodes: Optional[List[ir.IRNode]] = None,
|
|
):
|
|
kernel = CppTemplateKernel(
|
|
kernel_name=str(Placeholder.KERNEL_NAME),
|
|
)
|
|
render = functools.partial(
|
|
self.render,
|
|
kernel=kernel,
|
|
template_buffer_node=template_node,
|
|
epilogue_nodes=epilogue_nodes,
|
|
**kwargs,
|
|
)
|
|
return kernel, render
|
|
|
|
return CppTemplateCaller(
|
|
kernel_hash_name,
|
|
self.name,
|
|
self.input_nodes,
|
|
self.output_node.get_layout(),
|
|
make_kernel_render,
|
|
bmreq,
|
|
self,
|
|
)
|
|
|
|
def header(self) -> IndentedBuffer:
|
|
res = IndentedBuffer()
|
|
res.writeline(codecache.cpp_prefix())
|
|
res.splice(
|
|
"""
|
|
#include "c10/util/Unroll.h"
|
|
"""
|
|
)
|
|
enable_kernel_profile = (
|
|
config.cpp.enable_kernel_profile and sys.platform == "linux"
|
|
)
|
|
if enable_kernel_profile:
|
|
res.writelines(["#include <ATen/record_function.h>"])
|
|
return res
|
|
|
|
def render(self, **kwargs) -> str:
|
|
raise NotImplementedError
|