[AOTI] Refine the C shim autogen mechanism (#125589)

Summary: Based on the discussions in https://github.com/pytorch/pytorch/pull/120513. Instead of auto-generate C shim fallback ops for thousands of ops, we maintain a list of fallback ops based on torch/_inductor/lowering.py, and only generate C shim functions for those ops. At the torchgen time, we will re-generate C shim files and compare the header file contents against the existing C shim headers. If there is any change, the compilation will fail with prompt on how to proceed. This makes sure the ABI-compatible C shim layer is small enough to maintain in the long run.

Differential Revision: [D57004046](https://our.internmc.facebook.com/intern/diff/D57004046)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125589
Approved by: https://github.com/frank-wei, https://github.com/chenyang78, https://github.com/albanD, https://github.com/ezyang
This commit is contained in:
Bin Bao
2024-05-08 13:38:28 -07:00
committed by PyTorch MergeBot
parent 0bde9c08ef
commit ed48ea9997
10 changed files with 501 additions and 60 deletions

View File

@ -15,18 +15,12 @@ from torchgen.model import (
FunctionSchema,
ListType,
NativeFunction,
NativeFunctionsGroup,
OptionalType,
Type,
)
from torchgen.utils import mapMaybe
def returns_are_all_tensor(schema: FunctionSchema) -> bool:
return len(schema.returns) != 0 and all(
ret.type.is_tensor_like() for ret in schema.returns
)
base_type_to_c_type = {
BaseTy.Tensor: "AtenTensorHandle",
BaseTy.bool: "int32_t", # Use int to pass bool
@ -301,52 +295,63 @@ def gen_static_dispatch_backend_call(
f: NativeFunction,
backend_index: BackendIndex,
) -> str:
assert backend_index.has_kernel(f)
sig = DispatcherSignature.from_schema(f.func)
cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}"
def get_backend_index_for_aoti(
f: NativeFunction,
func: NativeFunction,
func_group: Optional[NativeFunctionsGroup],
dispatch_key: DispatchKey,
backend_indices: Dict[DispatchKey, BackendIndex],
) -> Optional[BackendIndex]:
if "pointwise" in f.tags:
# TODO: No need to generate C shim for Inductor lowered ops.
# Only skip pointwise kernels for now, and we can add more tags later.
return None
backend_index = None
if backend_indices[dispatch_key].has_kernel(f):
if backend_indices[dispatch_key].has_kernel(func) or (
func.structured_delegate is not None
and func_group is not None
and backend_indices[dispatch_key].has_kernel(func_group)
):
backend_index = backend_indices[dispatch_key]
elif backend_indices[DispatchKey.CompositeExplicitAutograd].has_kernel(f):
elif backend_indices[DispatchKey.CompositeExplicitAutograd].has_kernel(func):
# We need to create C shim wrappers for CompositeExplicitAutograd kernels
backend_index = backend_indices[DispatchKey.CompositeExplicitAutograd]
elif backend_indices[DispatchKey.CompositeExplicitAutogradNonFunctional].has_kernel(
f
func
):
# We need to create C shim wrappers for CompositeExplicitAutogradNonFunctional kernels
backend_index = backend_indices[
DispatchKey.CompositeExplicitAutogradNonFunctional
]
return backend_index
def get_fallback_op_name(func: NativeFunction) -> str:
return (
f"{func.namespace}.{func.func.name.name}.{func.func.name.overload_name}"
if func.func.name.overload_name
else f"{func.namespace}.{func.func.name.name}"
)
def gen_c_shim(
f: NativeFunction,
func: NativeFunction,
func_group: Optional[NativeFunctionsGroup],
dispatch_key: DispatchKey,
backend_indices: Dict[DispatchKey, BackendIndex],
header: bool,
) -> Optional[str]:
backend_index = get_backend_index_for_aoti(f, dispatch_key, backend_indices)
backend_index = get_backend_index_for_aoti(
func, func_group, dispatch_key, backend_indices
)
if backend_index is None:
return None
schema = f.func
schema = func.func
device = dispatch_key.lower()
backend_call = gen_static_dispatch_backend_call(
f,
func,
backend_index,
)
@ -366,34 +371,56 @@ def gen_c_shim(
@dataclass(frozen=True)
class ShimGenerator:
func_group_mapping: Dict[str, Optional[NativeFunctionsGroup]]
dispatch_key: DispatchKey
backend_indices: Dict[DispatchKey, BackendIndex]
header: bool # True to generate .h and False to generate .cpp
@method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]:
result = gen_c_shim(f, self.dispatch_key, self.backend_indices, self.header)
def __call__(
self,
func: NativeFunction,
) -> Optional[str]:
result = gen_c_shim(
func,
self.func_group_mapping.get(get_fallback_op_name(func), None),
self.dispatch_key,
self.backend_indices,
self.header,
)
return result
def gen_aoti_c_shim(
native_functions: Sequence[NativeFunction],
native_functions: Sequence[Tuple[NativeFunction, Optional[NativeFunctionsGroup]]],
dispatch_key: DispatchKey,
backend_indices: Dict[DispatchKey, BackendIndex],
header: bool,
includes: str = "",
) -> str:
func_group_mapping = {
get_fallback_op_name(func): func_group for func, func_group in native_functions
}
body = "\n".join(
list(
mapMaybe(
ShimGenerator(dispatch_key, backend_indices, header),
native_functions,
ShimGenerator(
func_group_mapping, dispatch_key, backend_indices, header
),
[func for func, _ in native_functions],
)
)
)
device = dispatch_key.lower()
warning = (
"// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND."
)
if header:
return f"""
{warning}
#pragma once
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
@ -407,13 +434,14 @@ extern "C" {{
#ifdef __cplusplus
}} // extern "C"
#endif
"""
else:
device = dispatch_key.lower()
return f"""
#include <torch/csrc/inductor/aoti_torch/utils.h>
{warning}
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_{device}.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/{str(dispatch_key)}Functions.h>
@ -425,6 +453,4 @@ extern "C" {{
using namespace torch::aot_inductor;
{body}
"""
{body}"""