mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
This reverts commit 7b6d036c05bd782f5e59bdb353f9e47865e9db50. Reverted https://github.com/pytorch/pytorch/pull/126545 on behalf of https://github.com/DanilBaibak due to Broken trunk ([comment](https://github.com/pytorch/pytorch/pull/124021#issuecomment-2126568331))
117 lines
3.5 KiB
Python
117 lines
3.5 KiB
Python
import functools
|
|
import itertools
|
|
import logging
|
|
|
|
import sys
|
|
from typing import List, Optional
|
|
from unittest.mock import patch
|
|
|
|
import sympy
|
|
|
|
from .. import codecache, config, ir
|
|
from ..autotune_process import CppBenchmarkRequest, TensorMeta
|
|
from ..utils import IndentedBuffer, Placeholder, unique
|
|
from ..virtualized import V
|
|
from .common import KernelTemplate
|
|
from .cpp_template_kernel import CppTemplateCaller, CppTemplateKernel
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class CppTemplate(KernelTemplate):
|
|
index_counter = itertools.count()
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
input_nodes,
|
|
layout: ir.Layout,
|
|
):
|
|
super().__init__(name)
|
|
self.input_nodes = input_nodes
|
|
self.output_node: ir.Buffer = ir.Buffer("buf_out", layout)
|
|
self.layout = layout
|
|
|
|
def generate(self, **kwargs):
|
|
kernel_name = f"cpp_{self.name}"
|
|
with patch.object(
|
|
V.graph, "get_dtype", self._fake_get_dtype(self.output_node)
|
|
), CppTemplateKernel(
|
|
kernel_name=kernel_name,
|
|
) as kernel:
|
|
code = kernel.render(self, **kwargs)
|
|
_, call_args, _ = kernel.args.python_argdefs()
|
|
log.debug("Generated Code:\n%s", code)
|
|
log.debug(
|
|
"Args: cpp_argdefs: %s, python_argdefs: %s",
|
|
kernel.args.cpp_argdefs(),
|
|
kernel.args.python_argdefs(),
|
|
)
|
|
|
|
expected_args = list(
|
|
unique(input_node.get_name() for input_node in self.input_nodes)
|
|
)
|
|
expected_args.extend([self.output_node.get_name()])
|
|
assert list(call_args)[: len(expected_args)] == expected_args, (
|
|
call_args,
|
|
expected_args,
|
|
)
|
|
extra_args = V.graph.sizevars.size_hints(
|
|
map(sympy.expand, call_args[len(expected_args) :])
|
|
)
|
|
|
|
kernel_hash_name = f"cpp_{self.name}_{next(self.index_counter)}"
|
|
|
|
# Create the BenchmarkRequest for CPP
|
|
bmreq = CppBenchmarkRequest(
|
|
kernel_name=kernel_name,
|
|
input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes),
|
|
output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
|
|
extra_args=extra_args,
|
|
source_code=code,
|
|
)
|
|
|
|
def make_kernel_render(
|
|
template_node: ir.CppTemplateBuffer,
|
|
epilogue_nodes: Optional[List[ir.IRNode]] = None,
|
|
):
|
|
kernel = CppTemplateKernel(
|
|
kernel_name=str(Placeholder.KERNEL_NAME),
|
|
)
|
|
render = functools.partial(
|
|
kernel.render,
|
|
self,
|
|
template_buffer_node=template_node,
|
|
epilogue_nodes=epilogue_nodes,
|
|
**kwargs,
|
|
)
|
|
return kernel, render
|
|
|
|
return CppTemplateCaller(
|
|
kernel_hash_name,
|
|
self.name,
|
|
self.input_nodes,
|
|
self.output_node.get_layout(),
|
|
make_kernel_render,
|
|
bmreq,
|
|
self,
|
|
)
|
|
|
|
def header(self) -> IndentedBuffer:
|
|
res = IndentedBuffer()
|
|
res.writeline(codecache.cpp_prefix())
|
|
res.splice(
|
|
"""
|
|
#include "c10/util/Unroll.h"
|
|
"""
|
|
)
|
|
enable_kernel_profile = (
|
|
config.cpp.enable_kernel_profile and sys.platform == "linux"
|
|
)
|
|
if enable_kernel_profile:
|
|
res.writelines(["#include <ATen/record_function.h>"])
|
|
return res
|
|
|
|
def render(self, **kwargs) -> str:
|
|
raise NotImplementedError
|