[AOTI] Use torchgen to generate C shim functions (#120513)

Summary: The current C shim layer manually implements a C interface for a handful of ops. Obviously that's not scalable if we want to extend it to cover all aten ops. This new torchgen script automatically generates C shim interfaces for CPU and CUDA backends. The interface follows the same parameter passing rules as the current C shim layer, such as

* Use plain C data types to pass parameters
* Use AtenTensorHandle to pass at::Tensor
* Use pointer type to pass optional parameter
* Use pointer+length to pass list
* Use device_type+device_index to pass device
* When a parameter is a pointer of pointer, e.g. AtenTensorHandle**, the script generates either a list of optional values or an optional list of values

https://gist.github.com/desertfire/83701532b126c6d34dae6ba68a1b074a is an example of the generated torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.cpp file. The current version doesn't generate C shim wrappers for all aten ops, and probably generates more wrappers than needed on the other hand, but it should serve as a good basis.

This PR by itself won't change AOTI codegen and thus won't introduce any FC breakage. The actual wrapper codegen changes will come in another PR with some version control flag to avoid FC breakage.

Differential Revision: [D54258087](https://our.internmc.facebook.com/intern/diff/D54258087)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120513
Approved by: https://github.com/jansel
This commit is contained in:
Bin Bao
2024-03-04 08:31:50 -08:00
committed by PyTorch MergeBot
parent ffe45a8188
commit bd19d6d822
6 changed files with 611 additions and 8 deletions

1
.gitignore vendored
View File

@ -86,6 +86,7 @@ torch/csrc/api/include/torch/version.h
torch/csrc/cudnn/cuDNN.cpp
torch/csrc/generated
torch/csrc/generic/TensorMethods.cpp
torch/csrc/inductor/aoti_torch/generated/*
torch/csrc/jit/generated/*
torch/csrc/jit/fuser/config.h
torch/csrc/nn/THCUNN.cpp

View File

@ -368,6 +368,7 @@ if(NOT INTERN_DISABLE_AUTOGRAD AND NOT BUILD_LITE_INTERPRETER)
"${TORCH_SRC_DIR}/csrc/autograd/generated/TraceType_4.cpp"
"${TORCH_SRC_DIR}/csrc/autograd/generated/ADInplaceOrViewType_0.cpp"
"${TORCH_SRC_DIR}/csrc/autograd/generated/ADInplaceOrViewType_1.cpp"
"${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/generated/c_shim_cpu.cpp"
)
if(BUILD_LAZY_TS_BACKEND)
list(APPEND GENERATED_CXX_TORCH
@ -422,12 +423,17 @@ set(GENERATED_TESTING_PYTHON
"${TORCH_SRC_DIR}/testing/_internal/generated/annotated_fn_args.py"
)
set(GENERATED_CXX_TORCH_CUDA
"${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/generated/c_shim_cuda.cpp"
)
set(TORCH_GENERATED_CODE
${GENERATED_CXX_TORCH}
${GENERATED_H_TORCH}
${GENERATED_CXX_PYTHON}
${GENERATED_H_PYTHON}
${GENERATED_TESTING_PYTHON}
${GENERATED_CXX_TORCH_CUDA}
)
set(GEN_PER_OPERATOR_FLAG)
@ -970,6 +976,7 @@ endif()
# Compile exposed libraries.
if(USE_ROCM)
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
list(APPEND Caffe2_HIP_SRCS ${GENERATED_CXX_TORCH_CUDA})
hip_add_library(torch_hip ${Caffe2_HIP_SRCS})
if(USE_FLASH_ATTENTION)
target_link_libraries(torch_hip PRIVATE __caffe2_oort)
@ -988,6 +995,7 @@ if(USE_ROCM)
endif()
elseif(USE_CUDA)
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
list(APPEND Caffe2_GPU_SRCS ${GENERATED_CXX_TORCH_CUDA})
if(CUDA_SEPARABLE_COMPILATION)
# Separate compilation fails when kernels using `thrust::sort_by_key`
# are linked with the rest of CUDA code. Workaround by linking them separately.

View File

@ -1250,6 +1250,7 @@ def main():
"include/torch/csrc/inductor/aoti_runtime/*.h",
"include/torch/csrc/inductor/aoti_torch/*.h",
"include/torch/csrc/inductor/aoti_torch/c/*.h",
"include/torch/csrc/inductor/aoti_torch/generated/*.h",
"include/torch/csrc/jit/*.h",
"include/torch/csrc/jit/backends/*.h",
"include/torch/csrc/jit/generated/*.h",

View File

@ -1,7 +1,13 @@
#pragma once
#include <ATen/Tensor.h>
#include <ATen/core/List.h>
#include <c10/core/DeviceType.h>
#include <c10/core/SymIntArrayRef.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Logging.h>
#include <c10/util/Optional.h>
#include <c10/util/OptionalArrayRef.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
@ -18,6 +24,8 @@
return AOTI_TORCH_SUCCESS;
namespace torch::aot_inductor {
// utility functions to convert a pointer to an optional value
template <class T>
inline c10::optional<T> pointer_to_optional(T* ptr) {
return ptr ? c10::make_optional(*ptr) : c10::nullopt;
@ -34,4 +42,101 @@ inline c10::optional<at::Tensor> pointer_to_optional(AtenTensorHandle* ptr) {
: c10::nullopt;
}
template <>
inline c10::optional<at::Tensor> pointer_to_optional(
const AtenTensorHandle* ptr) {
return ptr ? c10::make_optional(*tensor_handle_to_tensor_pointer(*ptr))
: c10::nullopt;
}
inline c10::optional<c10::Device> pointer_to_optional_device(
int32_t* device_type,
int32_t device_index) {
return device_type ? c10::make_optional(c10::Device(
static_cast<c10::DeviceType>(*device_type),
static_cast<c10::DeviceIndex>(device_index)))
: c10::nullopt;
}
// utility functions to convert a pointer to a list
template <typename T>
struct is_optional : std::false_type {};
template <typename T>
struct is_optional<c10::optional<T>> : std::true_type {};
template <class T>
inline c10::ArrayRef<T> pointer_to_list(T* ptr, int64_t len) {
return c10::ArrayRef<T>(ptr, len);
}
template <
class T,
class U,
typename = std::enable_if_t<!std::is_same_v<T, U>>,
typename = std::enable_if_t<!is_optional<T>::value>>
inline std::vector<T> pointer_to_list(U* ptr, int64_t len) {
// std::vector<T> will be implicitly converted to c10::ArrayRef<T> at the call
// site
std::vector<T> result;
result.reserve(len);
for (int64_t i = 0; i < len; i++) {
result.emplace_back(T(ptr[i]));
}
return result;
}
template <class T, class U, typename = std::enable_if_t<is_optional<T>::value>>
inline std::vector<T> pointer_to_list(U** ptr, int64_t len) {
// Here U** denotes a list of optional arguments
// std::vector<T> will be implicitly converted to c10::ArrayRef<T> at the call
// site
std::vector<T> result;
result.reserve(len);
for (int64_t i = 0; i < len; i++) {
result.emplace_back(pointer_to_optional(ptr[i]));
}
return result;
}
template <>
inline std::vector<at::Tensor> pointer_to_list(
const AtenTensorHandle* ptr,
int64_t len) {
std::vector<at::Tensor> result;
result.reserve(len);
for (int64_t i = 0; i < len; i++) {
result.emplace_back(*tensor_handle_to_tensor_pointer(*ptr));
}
return result;
}
template <>
inline std::vector<c10::optional<at::Tensor>> pointer_to_list(
const AtenTensorHandle** ptr,
int64_t len) {
std::vector<c10::optional<at::Tensor>> result;
result.reserve(len);
for (int64_t i = 0; i < len; i++) {
result.emplace_back(pointer_to_optional<at::Tensor>(ptr[i]));
}
return result;
}
template <int N>
inline std::array<bool, N> pointer_to_list(const int32_t* ptr) {
std::array<bool, N> result;
std::copy(ptr, ptr + N, result.begin());
return result;
}
// utility functions to convert a pointer to a list of optional values
template <class T, class U>
inline c10::optional<c10::ArrayRef<T>> pointer_to_optional_list(
U** ptr,
int64_t len) {
return ptr
? c10::make_optional<c10::ArrayRef<T>>(pointer_to_list<T>(*ptr, len))
: c10::nullopt;
}
} // namespace torch::aot_inductor

View File

@ -44,6 +44,11 @@ from torchgen.context import (
with_native_function,
with_native_function_and_indices,
)
from torchgen.gen_aoti_c_shim import (
gen_aoti_c_shim,
gen_static_dispatch_backend_call_signature,
get_backend_index_for_aoti,
)
from torchgen.gen_functionalization_type import (
gen_functionalization_definition,
gen_functionalization_registration,
@ -416,14 +421,7 @@ def generate_static_dispatch_backend_call(
f: NativeFunction,
backend_index: BackendIndex,
) -> str:
cpp_sigs = CppSignatureGroup.from_native_function(
f, method=False, fallback_binding=False
)
if sig.symint and f.func.has_symint():
cpp_sig = cpp_sigs.symint_signature
else:
cpp_sig = cpp_sigs.signature
assert cpp_sig is not None
cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
name = cpp_sig.name()
exprs = translate_args(sig, cpp_sig)
backend_metadata = backend_index.get_kernel(f)
@ -2181,6 +2179,7 @@ def gen_source_files(
selector: SelectiveBuilder,
static_dispatch_idx: List[BackendIndex],
backend_indices: Dict[DispatchKey, BackendIndex],
aoti_fm: FileManager,
core_fm: FileManager,
cpu_fm: FileManager,
cpu_vec_fm: FileManager,
@ -2350,6 +2349,60 @@ def gen_source_files(
else:
raise AssertionError(f"unrecognized {dispatch_key} for ufunc")
if dispatch_key in (DispatchKey.CPU, DispatchKey.CUDA):
def get_header(
f: NativeFunction,
) -> Optional[str]:
backend_index = get_backend_index_for_aoti(
f, dispatch_key, backend_indices
)
return (
None
if backend_index is None
else f"#include <ATen/ops/{f.root_name}_{backend_index.dispatch_key.lower()}_dispatch.h>"
)
def headers_for_aoti() -> str:
headers = []
for g in grouped_native_functions:
if isinstance(g, NativeFunctionsGroup):
for f in g.functions():
# some variants are registered in the backend, but some are registered as CompositeExplicitAutograd
header = get_header(f)
if header is not None:
headers.append(header)
else:
header = get_header(g)
if header is not None:
headers.append(header)
return "\n".join(sorted(set(headers)))
extra_headers = (
extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else ""
)
aoti_fm.write(
f"c_shim_{dispatch_key.lower()}.h",
lambda: gen_aoti_c_shim(
native_functions,
dispatch_key,
backend_indices,
header=True,
includes="",
),
)
aoti_fm.write(
f"c_shim_{dispatch_key.lower()}.cpp",
lambda: gen_aoti_c_shim(
native_functions,
dispatch_key,
backend_indices,
header=False,
includes=headers_for_aoti() + "\n" + extra_headers,
),
)
del fm
# BackendSelect is generated specially
@ -2783,6 +2836,9 @@ def main() -> None:
cpu_vec_fm = make_file_manager(options=options)
cuda_fm = make_file_manager(options=options)
ops_fm = make_file_manager(options=options, install_dir=ops_install_dir)
aoti_fm = make_file_manager(
options=options, install_dir="torch/csrc/inductor/aoti_torch/generated"
)
# Only a limited set of dispatch keys get CPUFunctions.h headers generated
# for them; this is the set
@ -2825,6 +2881,7 @@ def main() -> None:
selector=selector,
static_dispatch_idx=static_dispatch_idx,
backend_indices=backend_indices,
aoti_fm=aoti_fm,
core_fm=core_fm,
cpu_fm=cpu_fm,
cpu_vec_fm=cpu_vec_fm,

431
torchgen/gen_aoti_c_shim.py Normal file
View File

@ -0,0 +1,431 @@
import textwrap
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple, Union
from torchgen.api.types import DispatcherSignature
from torchgen.api.types.signatures import CppSignature, CppSignatureGroup
from torchgen.context import method_with_native_function
from torchgen.model import (
Argument,
BackendIndex,
BaseTy,
BaseType,
DispatchKey,
FunctionSchema,
ListType,
NativeFunction,
OptionalType,
Type,
)
from torchgen.utils import mapMaybe
def returns_are_all_tensor(schema: FunctionSchema) -> bool:
return len(schema.returns) != 0 and all(
ret.type.is_tensor_like() for ret in schema.returns
)
base_type_to_c_type = {
BaseTy.Tensor: "AtenTensorHandle",
BaseTy.bool: "int32_t", # Use int to pass bool
BaseTy.int: "int64_t",
BaseTy.SymInt: "int64_t", # Inductor-generated code won't see a SymInt
BaseTy.Scalar: "double", # Use double to pass both integer and floating point
BaseTy.float: "double", # TODO: how about other floating point types?
BaseTy.str: "const char*",
BaseTy.DeviceIndex: "int32_t",
BaseTy.Layout: "int32_t", # Represent enum as int
BaseTy.MemoryFormat: "int32_t", # Represent enum as int
BaseTy.ScalarType: "int32_t", # Represent enum as int
}
base_type_to_aten_type = {
BaseTy.Tensor: "at::Tensor",
BaseTy.bool: "bool",
BaseTy.int: "int64_t",
BaseTy.SymInt: "c10::SymInt",
BaseTy.Scalar: "c10::Scalar",
BaseTy.float: "double",
BaseTy.str: "c10::string_view",
BaseTy.DeviceIndex: "c10::DeviceIndex",
BaseTy.Layout: "c10::Layout",
BaseTy.MemoryFormat: "c10::MemoryFormat",
BaseTy.ScalarType: "c10::ScalarType",
}
base_type_to_callsite_expr = {
BaseTy.Tensor: "*tensor_handle_to_tensor_pointer",
BaseTy.bool: "",
BaseTy.int: "",
BaseTy.SymInt: "",
BaseTy.Scalar: "",
BaseTy.float: "",
BaseTy.str: "",
BaseTy.DeviceIndex: "static_cast<c10::DeviceIndex>",
BaseTy.Layout: "static_cast<c10::Layout>",
BaseTy.MemoryFormat: "static_cast<c10::MemoryFormat>",
BaseTy.ScalarType: "static_cast<c10::ScalarType>",
}
# convert args to C types, names in declarations, and expressions in function bodies
def convert_arg_type_and_name(typ: Type, name: str) -> Tuple[List[str], List[str], List[str], List[str]]: # type: ignore[return]
if isinstance(typ, BaseType):
if typ.name in base_type_to_c_type:
return (
[base_type_to_c_type[typ.name]],
[name],
[base_type_to_aten_type[typ.name]],
[
f"{base_type_to_callsite_expr[typ.name]}({name})"
if base_type_to_callsite_expr[typ.name]
else name
],
)
elif typ.name == BaseTy.Device:
return (
["int32_t", "int32_t"],
[name, name + "_index_"],
["c10::Device"],
[
f"c10::Device(static_cast<c10::DeviceType>({name}), static_cast<c10::DeviceIndex>({name}_index_))"
],
)
else:
# TODO: BaseTy.Dimname, BaseTy.Generator, etc.
raise NotImplementedError(f"TODO: add support for arg type {repr(typ)}")
elif isinstance(typ, OptionalType):
c_types, names, aten_types, callsite_exprs = convert_arg_type_and_name(
typ.elem, name
)
j = 0 # index for names
new_aten_types = []
new_callsite_exprs = []
for i, aten_type in enumerate(aten_types):
# Use pointer to denote optional type
c_types[j] = c_types[j] + "*"
if aten_type.startswith("c10::ArrayRef<"):
# ArrayRef is passed as pointer + size, but no need to add "*" to the size argument
new_aten_types.append(f"c10::optional<{aten_type}>")
base_type = aten_type[len("c10::ArrayRef<") : -1]
new_callsite_exprs.append(
f"pointer_to_optional_list<{base_type}>({names[j]}, {names[j+1]})"
)
j += 2
elif aten_type == "c10::Device":
# Device is passed as device_type + device_index
new_aten_types.append("c10::optional<c10::Device>")
new_callsite_exprs.append(
f"pointer_to_optional_device({names[j]}, {names[j+1]})"
)
j += 2
else:
new_aten_types.append(f"c10::optional<{aten_type}>")
new_callsite_exprs.append(
f"pointer_to_optional<{aten_type}>({names[j]})"
)
j += 1
return (
c_types,
names,
new_aten_types,
new_callsite_exprs,
)
elif isinstance(typ, ListType):
# Need to explictly pass the list as pointer + length
c_types, names, aten_types, _ = convert_arg_type_and_name(typ.elem, name)
assert len(c_types) == 1, "ListType with unsupported element type " + repr(typ)
# The list content should never be modified
c_types[0] = f"const {c_types[0]}*"
c_types.append("int64_t")
name = names[0]
names.append(name + "_len_")
atype = aten_types[0]
callsite_exprs = []
if atype == "bool":
# no converter from std::vector<bool> to c10::ArrayRef<bool>
# construct std::array<bool, N> instead
assert typ.size is not None
callsite_exprs.append(f"pointer_to_list<{typ.size}>({name})")
elif atype == "c10::optional<at::Tensor>":
# convert from std::vector<c10::optional<at::Tensor>> to c10::List<c10::optional<at::Tensor>>
callsite_exprs.append(
f"c10::List<{atype}>(c10::ArrayRef<{atype}>(pointer_to_list<{atype}>({name}, {name}_len_)))"
)
else:
callsite_exprs.append(f"pointer_to_list<{atype}>({name}, {name}_len_)")
aten_types = [f"c10::ArrayRef<{t}>" for t in aten_types]
return (
c_types,
names,
aten_types,
callsite_exprs,
)
def zip_type_and_name(types: List[str], names: List[str]) -> List[str]:
return [typ + " " + name for typ, name in zip(types, names)]
# Generate argument declarations and callsite expressions
def gen_arguments(flat_arguments: Sequence[Argument]) -> Tuple[List[str], List[str]]:
types = []
new_names = []
callsite_exprs = []
for arg in flat_arguments:
new_types, names, _, new_callsite_exprs = convert_arg_type_and_name(
arg.type, arg.name
)
types.extend(new_types)
new_names.extend(names)
callsite_exprs.extend(new_callsite_exprs)
return zip_type_and_name(types, new_names), callsite_exprs
# Return values are passed out as pointer arguments because all the C shim functions
# are expected to return AOTITorchError.
# Generate returns as declarations and callsite expressions
def gen_returns(schema: FunctionSchema) -> Tuple[List[str], List[str]]:
types = []
names = []
for idx, ret in enumerate(schema.returns):
names.append(f"ret{idx}")
if isinstance(ret.type, BaseType) and ret.type.name in base_type_to_c_type:
types.append(base_type_to_c_type[ret.type.name] + "*")
else:
raise NotImplementedError(
f"TODO: add support for return type {repr(ret.type)}"
)
def convert_return(typ: BaseType, val: str) -> str:
if typ.name == BaseTy.Tensor:
return f"new_tensor_handle(std::move({val}));"
elif typ.name == BaseTy.SymInt:
return f"{val}.expect_int()"
elif typ.name == BaseTy.Scalar:
return f"{val}.toDouble()"
else:
return val
ret_pointer_can_be_null = False
unambiguous_name = schema.name.unambiguous_name()
for name in ["_scaled_dot_product_flash_attention"]:
if name in unambiguous_name:
ret_pointer_can_be_null = True
break
callsite_exprs: List[str] = []
for idx, ret in enumerate(schema.returns):
tmp = "tmp_result" if len(names) == 1 else f"std::get<{idx}>(tmp_result)"
assert isinstance(ret.type, BaseType)
rval = convert_return(ret.type, tmp)
if ret_pointer_can_be_null:
callsite_exprs.append(f"if ({names[idx]}) {{ *{names[idx]} = {rval}; }}")
else:
callsite_exprs.append(f"*{names[idx]} = {rval};")
return zip_type_and_name(types, names), callsite_exprs
# gen.py generates header first and then src, so caching the result here to avoid duplicate work
declaration_definition_cache: Dict[Tuple[str, str, str], Tuple[str, str]] = {}
def gen_declaration_and_definition(
schema: FunctionSchema, device: str, backend_call: str
) -> Tuple[str, str]:
func_name = schema.name.unambiguous_name()
global declaration_definition_cache
if (func_name, device, backend_call) in declaration_definition_cache:
return declaration_definition_cache[(func_name, device, backend_call)]
if schema.is_out_fn():
# out_variant has out arguments in the front, and it's ok to ignore return value
# because C shim functions only return AOTITorchError
# Somehow at::native out-variant functions have out arguments in the back
args, callsite_exprs = gen_arguments(
[*schema.arguments.flat_non_out, *schema.arguments.out]
if "at::native" in backend_call
else [*schema.arguments.out, *schema.arguments.flat_non_out],
)
ret_assignments: List[str] = []
else:
args, callsite_exprs = gen_arguments(schema.arguments.flat_all)
ret_declarations, ret_assignments = gen_returns(schema)
args.extend(ret_declarations)
declaration = f"AOTITorchError aoti_torch_{device}_{func_name}({', '.join(args)})"
tmp_result = "auto tmp_result = " if ret_assignments else ""
ret_assignments_str = "\n" + "\n".join(ret_assignments) if ret_assignments else ""
definition = f"""
{declaration} {{
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({{
{tmp_result}{backend_call}(
{textwrap.indent(', '.join(callsite_exprs), " ")}
);{textwrap.indent(ret_assignments_str, " ")}
}});
}}
"""
declaration_definition_cache[(func_name, device, backend_call)] = (
declaration,
definition,
)
return declaration, definition
def gen_static_dispatch_backend_call_signature(
sig: Union[CppSignature, DispatcherSignature],
f: NativeFunction,
) -> CppSignature:
sig = DispatcherSignature.from_schema(f.func)
cpp_sigs = CppSignatureGroup.from_native_function(
f, method=False, fallback_binding=False
)
if sig.symint and f.func.has_symint():
cpp_sig = cpp_sigs.symint_signature
else:
cpp_sig = cpp_sigs.signature
assert cpp_sig is not None
return cpp_sig
def gen_static_dispatch_backend_call(
f: NativeFunction,
backend_index: BackendIndex,
) -> str:
assert backend_index.has_kernel(f)
sig = DispatcherSignature.from_schema(f.func)
cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}"
def get_backend_index_for_aoti(
f: NativeFunction,
dispatch_key: DispatchKey,
backend_indices: Dict[DispatchKey, BackendIndex],
) -> Optional[BackendIndex]:
if "pointwise" in f.tags:
# TODO: No need to generate C shim for Inductor lowered ops.
# Only skip pointwise kernels for now, and we can add more tags later.
return None
backend_index = None
if backend_indices[dispatch_key].has_kernel(f):
backend_index = backend_indices[dispatch_key]
elif backend_indices[DispatchKey.CompositeExplicitAutograd].has_kernel(f):
# We need to create C shim wrappers for CompositeExplicitAutograd kernels
backend_index = backend_indices[DispatchKey.CompositeExplicitAutograd]
elif backend_indices[DispatchKey.CompositeExplicitAutogradNonFunctional].has_kernel(
f
):
# We need to create C shim wrappers for CompositeExplicitAutogradNonFunctional kernels
backend_index = backend_indices[
DispatchKey.CompositeExplicitAutogradNonFunctional
]
return backend_index
def gen_c_shim(
f: NativeFunction,
dispatch_key: DispatchKey,
backend_indices: Dict[DispatchKey, BackendIndex],
header: bool,
) -> Optional[str]:
backend_index = get_backend_index_for_aoti(f, dispatch_key, backend_indices)
if backend_index is None:
return None
schema = f.func
device = dispatch_key.lower()
backend_call = gen_static_dispatch_backend_call(
f,
backend_index,
)
try:
if header:
declaration, _ = gen_declaration_and_definition(
schema, device, backend_call
)
return f"AOTI_TORCH_EXPORT {declaration};"
else:
_, definition = gen_declaration_and_definition(schema, device, backend_call)
return definition
except NotImplementedError:
return None
@dataclass(frozen=True)
class ShimGenerator:
dispatch_key: DispatchKey
backend_indices: Dict[DispatchKey, BackendIndex]
header: bool # True to generate .h and False to generate .cpp
@method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]:
result = gen_c_shim(f, self.dispatch_key, self.backend_indices, self.header)
return result
def gen_aoti_c_shim(
native_functions: Sequence[NativeFunction],
dispatch_key: DispatchKey,
backend_indices: Dict[DispatchKey, BackendIndex],
header: bool,
includes: str = "",
) -> str:
body = "\n".join(
list(
mapMaybe(
ShimGenerator(dispatch_key, backend_indices, header),
native_functions,
)
)
)
if header:
return f"""
#pragma once
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#ifdef __cplusplus
extern "C" {{
#endif
{body}
#ifdef __cplusplus
}} // extern "C"
#endif
"""
else:
device = dispatch_key.lower()
return f"""
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_{device}.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/{str(dispatch_key)}Functions.h>
#include <ATen/CompositeExplicitAutogradFunctions.h>
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
#else
{includes}
#endif
using namespace torch::aot_inductor;
{body}
"""