mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[AOTI] Refine the C shim autogen mechanism (#125589)
Summary: Based on the discussions in https://github.com/pytorch/pytorch/pull/120513. Instead of auto-generate C shim fallback ops for thousands of ops, we maintain a list of fallback ops based on torch/_inductor/lowering.py, and only generate C shim functions for those ops. At the torchgen time, we will re-generate C shim files and compare the header file contents against the existing C shim headers. If there is any change, the compilation will fail with prompt on how to proceed. This makes sure the ABI-compatible C shim layer is small enough to maintain in the long run. Differential Revision: [D57004046](https://our.internmc.facebook.com/intern/diff/D57004046) Pull Request resolved: https://github.com/pytorch/pytorch/pull/125589 Approved by: https://github.com/frank-wei, https://github.com/chenyang78, https://github.com/albanD, https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
0bde9c08ef
commit
ed48ea9997
@ -15,18 +15,12 @@ from torchgen.model import (
|
||||
FunctionSchema,
|
||||
ListType,
|
||||
NativeFunction,
|
||||
NativeFunctionsGroup,
|
||||
OptionalType,
|
||||
Type,
|
||||
)
|
||||
from torchgen.utils import mapMaybe
|
||||
|
||||
|
||||
def returns_are_all_tensor(schema: FunctionSchema) -> bool:
|
||||
return len(schema.returns) != 0 and all(
|
||||
ret.type.is_tensor_like() for ret in schema.returns
|
||||
)
|
||||
|
||||
|
||||
base_type_to_c_type = {
|
||||
BaseTy.Tensor: "AtenTensorHandle",
|
||||
BaseTy.bool: "int32_t", # Use int to pass bool
|
||||
@ -301,52 +295,63 @@ def gen_static_dispatch_backend_call(
|
||||
f: NativeFunction,
|
||||
backend_index: BackendIndex,
|
||||
) -> str:
|
||||
assert backend_index.has_kernel(f)
|
||||
sig = DispatcherSignature.from_schema(f.func)
|
||||
cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
|
||||
return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}"
|
||||
|
||||
|
||||
def get_backend_index_for_aoti(
|
||||
f: NativeFunction,
|
||||
func: NativeFunction,
|
||||
func_group: Optional[NativeFunctionsGroup],
|
||||
dispatch_key: DispatchKey,
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
) -> Optional[BackendIndex]:
|
||||
if "pointwise" in f.tags:
|
||||
# TODO: No need to generate C shim for Inductor lowered ops.
|
||||
# Only skip pointwise kernels for now, and we can add more tags later.
|
||||
return None
|
||||
|
||||
backend_index = None
|
||||
if backend_indices[dispatch_key].has_kernel(f):
|
||||
if backend_indices[dispatch_key].has_kernel(func) or (
|
||||
func.structured_delegate is not None
|
||||
and func_group is not None
|
||||
and backend_indices[dispatch_key].has_kernel(func_group)
|
||||
):
|
||||
backend_index = backend_indices[dispatch_key]
|
||||
elif backend_indices[DispatchKey.CompositeExplicitAutograd].has_kernel(f):
|
||||
elif backend_indices[DispatchKey.CompositeExplicitAutograd].has_kernel(func):
|
||||
# We need to create C shim wrappers for CompositeExplicitAutograd kernels
|
||||
backend_index = backend_indices[DispatchKey.CompositeExplicitAutograd]
|
||||
elif backend_indices[DispatchKey.CompositeExplicitAutogradNonFunctional].has_kernel(
|
||||
f
|
||||
func
|
||||
):
|
||||
# We need to create C shim wrappers for CompositeExplicitAutogradNonFunctional kernels
|
||||
backend_index = backend_indices[
|
||||
DispatchKey.CompositeExplicitAutogradNonFunctional
|
||||
]
|
||||
|
||||
return backend_index
|
||||
|
||||
|
||||
def get_fallback_op_name(func: NativeFunction) -> str:
|
||||
return (
|
||||
f"{func.namespace}.{func.func.name.name}.{func.func.name.overload_name}"
|
||||
if func.func.name.overload_name
|
||||
else f"{func.namespace}.{func.func.name.name}"
|
||||
)
|
||||
|
||||
|
||||
def gen_c_shim(
|
||||
f: NativeFunction,
|
||||
func: NativeFunction,
|
||||
func_group: Optional[NativeFunctionsGroup],
|
||||
dispatch_key: DispatchKey,
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
header: bool,
|
||||
) -> Optional[str]:
|
||||
backend_index = get_backend_index_for_aoti(f, dispatch_key, backend_indices)
|
||||
backend_index = get_backend_index_for_aoti(
|
||||
func, func_group, dispatch_key, backend_indices
|
||||
)
|
||||
if backend_index is None:
|
||||
return None
|
||||
|
||||
schema = f.func
|
||||
schema = func.func
|
||||
device = dispatch_key.lower()
|
||||
backend_call = gen_static_dispatch_backend_call(
|
||||
f,
|
||||
func,
|
||||
backend_index,
|
||||
)
|
||||
|
||||
@ -366,34 +371,56 @@ def gen_c_shim(
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ShimGenerator:
|
||||
func_group_mapping: Dict[str, Optional[NativeFunctionsGroup]]
|
||||
dispatch_key: DispatchKey
|
||||
backend_indices: Dict[DispatchKey, BackendIndex]
|
||||
header: bool # True to generate .h and False to generate .cpp
|
||||
|
||||
@method_with_native_function
|
||||
def __call__(self, f: NativeFunction) -> Optional[str]:
|
||||
result = gen_c_shim(f, self.dispatch_key, self.backend_indices, self.header)
|
||||
def __call__(
|
||||
self,
|
||||
func: NativeFunction,
|
||||
) -> Optional[str]:
|
||||
result = gen_c_shim(
|
||||
func,
|
||||
self.func_group_mapping.get(get_fallback_op_name(func), None),
|
||||
self.dispatch_key,
|
||||
self.backend_indices,
|
||||
self.header,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def gen_aoti_c_shim(
|
||||
native_functions: Sequence[NativeFunction],
|
||||
native_functions: Sequence[Tuple[NativeFunction, Optional[NativeFunctionsGroup]]],
|
||||
dispatch_key: DispatchKey,
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
header: bool,
|
||||
includes: str = "",
|
||||
) -> str:
|
||||
func_group_mapping = {
|
||||
get_fallback_op_name(func): func_group for func, func_group in native_functions
|
||||
}
|
||||
body = "\n".join(
|
||||
list(
|
||||
mapMaybe(
|
||||
ShimGenerator(dispatch_key, backend_indices, header),
|
||||
native_functions,
|
||||
ShimGenerator(
|
||||
func_group_mapping, dispatch_key, backend_indices, header
|
||||
),
|
||||
[func for func, _ in native_functions],
|
||||
)
|
||||
)
|
||||
)
|
||||
device = dispatch_key.lower()
|
||||
|
||||
warning = (
|
||||
"// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND."
|
||||
)
|
||||
|
||||
if header:
|
||||
return f"""
|
||||
{warning}
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
@ -407,13 +434,14 @@ extern "C" {{
|
||||
#ifdef __cplusplus
|
||||
}} // extern "C"
|
||||
#endif
|
||||
|
||||
"""
|
||||
|
||||
else:
|
||||
device = dispatch_key.lower()
|
||||
return f"""
|
||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||
{warning}
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_{device}.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/{str(dispatch_key)}Functions.h>
|
||||
@ -425,6 +453,4 @@ extern "C" {{
|
||||
|
||||
using namespace torch::aot_inductor;
|
||||
|
||||
{body}
|
||||
|
||||
"""
|
||||
{body}"""
|
||||
|
Reference in New Issue
Block a user