mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[AOTI] Use torchgen to generate C shim functions (#120513)
Summary: The current C shim layer manually implements a C interface for a handful of ops. Obviously that's not scalable if we want to extend it to cover all aten ops. This new torchgen script automatically generates C shim interfaces for CPU and CUDA backends. The interface follows the same parameter passing rules as the current C shim layer, such as * Use plain C data types to pass parameters * Use AtenTensorHandle to pass at::Tensor * Use pointer type to pass optional parameter * Use pointer+length to pass list * Use device_type+device_index to pass device * When a parameter is a pointer of pointer, e.g. AtenTensorHandle**, the script generates either a list of optional values or an optional list of values https://gist.github.com/desertfire/83701532b126c6d34dae6ba68a1b074a is an example of the generated torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.cpp file. The current version doesn't generate C shim wrappers for all aten ops, and probably generates more wrappers than needed on the other hand, but it should serve as a good basis. This PR by itself won't change AOTI codegen and thus won't introduce any FC breakage. The actual wrapper codegen changes will come in another PR with some version control flag to avoid FC breakage. Differential Revision: [D54258087](https://our.internmc.facebook.com/intern/diff/D54258087) Pull Request resolved: https://github.com/pytorch/pytorch/pull/120513 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
ffe45a8188
commit
bd19d6d822
1
.gitignore
vendored
1
.gitignore
vendored
@ -86,6 +86,7 @@ torch/csrc/api/include/torch/version.h
|
||||
torch/csrc/cudnn/cuDNN.cpp
|
||||
torch/csrc/generated
|
||||
torch/csrc/generic/TensorMethods.cpp
|
||||
torch/csrc/inductor/aoti_torch/generated/*
|
||||
torch/csrc/jit/generated/*
|
||||
torch/csrc/jit/fuser/config.h
|
||||
torch/csrc/nn/THCUNN.cpp
|
||||
|
@ -368,6 +368,7 @@ if(NOT INTERN_DISABLE_AUTOGRAD AND NOT BUILD_LITE_INTERPRETER)
|
||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/TraceType_4.cpp"
|
||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/ADInplaceOrViewType_0.cpp"
|
||||
"${TORCH_SRC_DIR}/csrc/autograd/generated/ADInplaceOrViewType_1.cpp"
|
||||
"${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/generated/c_shim_cpu.cpp"
|
||||
)
|
||||
if(BUILD_LAZY_TS_BACKEND)
|
||||
list(APPEND GENERATED_CXX_TORCH
|
||||
@ -422,12 +423,17 @@ set(GENERATED_TESTING_PYTHON
|
||||
"${TORCH_SRC_DIR}/testing/_internal/generated/annotated_fn_args.py"
|
||||
)
|
||||
|
||||
set(GENERATED_CXX_TORCH_CUDA
|
||||
"${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/generated/c_shim_cuda.cpp"
|
||||
)
|
||||
|
||||
set(TORCH_GENERATED_CODE
|
||||
${GENERATED_CXX_TORCH}
|
||||
${GENERATED_H_TORCH}
|
||||
${GENERATED_CXX_PYTHON}
|
||||
${GENERATED_H_PYTHON}
|
||||
${GENERATED_TESTING_PYTHON}
|
||||
${GENERATED_CXX_TORCH_CUDA}
|
||||
)
|
||||
|
||||
set(GEN_PER_OPERATOR_FLAG)
|
||||
@ -970,6 +976,7 @@ endif()
|
||||
# Compile exposed libraries.
|
||||
if(USE_ROCM)
|
||||
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
|
||||
list(APPEND Caffe2_HIP_SRCS ${GENERATED_CXX_TORCH_CUDA})
|
||||
hip_add_library(torch_hip ${Caffe2_HIP_SRCS})
|
||||
if(USE_FLASH_ATTENTION)
|
||||
target_link_libraries(torch_hip PRIVATE __caffe2_oort)
|
||||
@ -988,6 +995,7 @@ if(USE_ROCM)
|
||||
endif()
|
||||
elseif(USE_CUDA)
|
||||
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
|
||||
list(APPEND Caffe2_GPU_SRCS ${GENERATED_CXX_TORCH_CUDA})
|
||||
if(CUDA_SEPARABLE_COMPILATION)
|
||||
# Separate compilation fails when kernels using `thrust::sort_by_key`
|
||||
# are linked with the rest of CUDA code. Workaround by linking them separately.
|
||||
|
1
setup.py
1
setup.py
@ -1250,6 +1250,7 @@ def main():
|
||||
"include/torch/csrc/inductor/aoti_runtime/*.h",
|
||||
"include/torch/csrc/inductor/aoti_torch/*.h",
|
||||
"include/torch/csrc/inductor/aoti_torch/c/*.h",
|
||||
"include/torch/csrc/inductor/aoti_torch/generated/*.h",
|
||||
"include/torch/csrc/jit/*.h",
|
||||
"include/torch/csrc/jit/backends/*.h",
|
||||
"include/torch/csrc/jit/generated/*.h",
|
||||
|
@ -1,7 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/core/List.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/SymIntArrayRef.h>
|
||||
#include <c10/util/ArrayRef.h>
|
||||
#include <c10/util/Logging.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/util/OptionalArrayRef.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
|
||||
|
||||
@ -18,6 +24,8 @@
|
||||
return AOTI_TORCH_SUCCESS;
|
||||
|
||||
namespace torch::aot_inductor {
|
||||
|
||||
// utility functions to convert a pointer to an optional value
|
||||
template <class T>
|
||||
inline c10::optional<T> pointer_to_optional(T* ptr) {
|
||||
return ptr ? c10::make_optional(*ptr) : c10::nullopt;
|
||||
@ -34,4 +42,101 @@ inline c10::optional<at::Tensor> pointer_to_optional(AtenTensorHandle* ptr) {
|
||||
: c10::nullopt;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline c10::optional<at::Tensor> pointer_to_optional(
|
||||
const AtenTensorHandle* ptr) {
|
||||
return ptr ? c10::make_optional(*tensor_handle_to_tensor_pointer(*ptr))
|
||||
: c10::nullopt;
|
||||
}
|
||||
|
||||
inline c10::optional<c10::Device> pointer_to_optional_device(
|
||||
int32_t* device_type,
|
||||
int32_t device_index) {
|
||||
return device_type ? c10::make_optional(c10::Device(
|
||||
static_cast<c10::DeviceType>(*device_type),
|
||||
static_cast<c10::DeviceIndex>(device_index)))
|
||||
: c10::nullopt;
|
||||
}
|
||||
|
||||
// utility functions to convert a pointer to a list
|
||||
template <typename T>
|
||||
struct is_optional : std::false_type {};
|
||||
template <typename T>
|
||||
struct is_optional<c10::optional<T>> : std::true_type {};
|
||||
|
||||
template <class T>
|
||||
inline c10::ArrayRef<T> pointer_to_list(T* ptr, int64_t len) {
|
||||
return c10::ArrayRef<T>(ptr, len);
|
||||
}
|
||||
|
||||
template <
|
||||
class T,
|
||||
class U,
|
||||
typename = std::enable_if_t<!std::is_same_v<T, U>>,
|
||||
typename = std::enable_if_t<!is_optional<T>::value>>
|
||||
inline std::vector<T> pointer_to_list(U* ptr, int64_t len) {
|
||||
// std::vector<T> will be implicitly converted to c10::ArrayRef<T> at the call
|
||||
// site
|
||||
std::vector<T> result;
|
||||
result.reserve(len);
|
||||
for (int64_t i = 0; i < len; i++) {
|
||||
result.emplace_back(T(ptr[i]));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class T, class U, typename = std::enable_if_t<is_optional<T>::value>>
|
||||
inline std::vector<T> pointer_to_list(U** ptr, int64_t len) {
|
||||
// Here U** denotes a list of optional arguments
|
||||
// std::vector<T> will be implicitly converted to c10::ArrayRef<T> at the call
|
||||
// site
|
||||
std::vector<T> result;
|
||||
result.reserve(len);
|
||||
for (int64_t i = 0; i < len; i++) {
|
||||
result.emplace_back(pointer_to_optional(ptr[i]));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::vector<at::Tensor> pointer_to_list(
|
||||
const AtenTensorHandle* ptr,
|
||||
int64_t len) {
|
||||
std::vector<at::Tensor> result;
|
||||
result.reserve(len);
|
||||
for (int64_t i = 0; i < len; i++) {
|
||||
result.emplace_back(*tensor_handle_to_tensor_pointer(*ptr));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::vector<c10::optional<at::Tensor>> pointer_to_list(
|
||||
const AtenTensorHandle** ptr,
|
||||
int64_t len) {
|
||||
std::vector<c10::optional<at::Tensor>> result;
|
||||
result.reserve(len);
|
||||
for (int64_t i = 0; i < len; i++) {
|
||||
result.emplace_back(pointer_to_optional<at::Tensor>(ptr[i]));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <int N>
|
||||
inline std::array<bool, N> pointer_to_list(const int32_t* ptr) {
|
||||
std::array<bool, N> result;
|
||||
std::copy(ptr, ptr + N, result.begin());
|
||||
return result;
|
||||
}
|
||||
|
||||
// utility functions to convert a pointer to a list of optional values
|
||||
template <class T, class U>
|
||||
inline c10::optional<c10::ArrayRef<T>> pointer_to_optional_list(
|
||||
U** ptr,
|
||||
int64_t len) {
|
||||
return ptr
|
||||
? c10::make_optional<c10::ArrayRef<T>>(pointer_to_list<T>(*ptr, len))
|
||||
: c10::nullopt;
|
||||
}
|
||||
|
||||
} // namespace torch::aot_inductor
|
||||
|
@ -44,6 +44,11 @@ from torchgen.context import (
|
||||
with_native_function,
|
||||
with_native_function_and_indices,
|
||||
)
|
||||
from torchgen.gen_aoti_c_shim import (
|
||||
gen_aoti_c_shim,
|
||||
gen_static_dispatch_backend_call_signature,
|
||||
get_backend_index_for_aoti,
|
||||
)
|
||||
from torchgen.gen_functionalization_type import (
|
||||
gen_functionalization_definition,
|
||||
gen_functionalization_registration,
|
||||
@ -416,14 +421,7 @@ def generate_static_dispatch_backend_call(
|
||||
f: NativeFunction,
|
||||
backend_index: BackendIndex,
|
||||
) -> str:
|
||||
cpp_sigs = CppSignatureGroup.from_native_function(
|
||||
f, method=False, fallback_binding=False
|
||||
)
|
||||
if sig.symint and f.func.has_symint():
|
||||
cpp_sig = cpp_sigs.symint_signature
|
||||
else:
|
||||
cpp_sig = cpp_sigs.signature
|
||||
assert cpp_sig is not None
|
||||
cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
|
||||
name = cpp_sig.name()
|
||||
exprs = translate_args(sig, cpp_sig)
|
||||
backend_metadata = backend_index.get_kernel(f)
|
||||
@ -2181,6 +2179,7 @@ def gen_source_files(
|
||||
selector: SelectiveBuilder,
|
||||
static_dispatch_idx: List[BackendIndex],
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
aoti_fm: FileManager,
|
||||
core_fm: FileManager,
|
||||
cpu_fm: FileManager,
|
||||
cpu_vec_fm: FileManager,
|
||||
@ -2350,6 +2349,60 @@ def gen_source_files(
|
||||
else:
|
||||
raise AssertionError(f"unrecognized {dispatch_key} for ufunc")
|
||||
|
||||
if dispatch_key in (DispatchKey.CPU, DispatchKey.CUDA):
|
||||
|
||||
def get_header(
|
||||
f: NativeFunction,
|
||||
) -> Optional[str]:
|
||||
backend_index = get_backend_index_for_aoti(
|
||||
f, dispatch_key, backend_indices
|
||||
)
|
||||
return (
|
||||
None
|
||||
if backend_index is None
|
||||
else f"#include <ATen/ops/{f.root_name}_{backend_index.dispatch_key.lower()}_dispatch.h>"
|
||||
)
|
||||
|
||||
def headers_for_aoti() -> str:
|
||||
headers = []
|
||||
for g in grouped_native_functions:
|
||||
if isinstance(g, NativeFunctionsGroup):
|
||||
for f in g.functions():
|
||||
# some variants are registered in the backend, but some are registered as CompositeExplicitAutograd
|
||||
header = get_header(f)
|
||||
if header is not None:
|
||||
headers.append(header)
|
||||
else:
|
||||
header = get_header(g)
|
||||
if header is not None:
|
||||
headers.append(header)
|
||||
return "\n".join(sorted(set(headers)))
|
||||
|
||||
extra_headers = (
|
||||
extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else ""
|
||||
)
|
||||
|
||||
aoti_fm.write(
|
||||
f"c_shim_{dispatch_key.lower()}.h",
|
||||
lambda: gen_aoti_c_shim(
|
||||
native_functions,
|
||||
dispatch_key,
|
||||
backend_indices,
|
||||
header=True,
|
||||
includes="",
|
||||
),
|
||||
)
|
||||
aoti_fm.write(
|
||||
f"c_shim_{dispatch_key.lower()}.cpp",
|
||||
lambda: gen_aoti_c_shim(
|
||||
native_functions,
|
||||
dispatch_key,
|
||||
backend_indices,
|
||||
header=False,
|
||||
includes=headers_for_aoti() + "\n" + extra_headers,
|
||||
),
|
||||
)
|
||||
|
||||
del fm
|
||||
|
||||
# BackendSelect is generated specially
|
||||
@ -2783,6 +2836,9 @@ def main() -> None:
|
||||
cpu_vec_fm = make_file_manager(options=options)
|
||||
cuda_fm = make_file_manager(options=options)
|
||||
ops_fm = make_file_manager(options=options, install_dir=ops_install_dir)
|
||||
aoti_fm = make_file_manager(
|
||||
options=options, install_dir="torch/csrc/inductor/aoti_torch/generated"
|
||||
)
|
||||
|
||||
# Only a limited set of dispatch keys get CPUFunctions.h headers generated
|
||||
# for them; this is the set
|
||||
@ -2825,6 +2881,7 @@ def main() -> None:
|
||||
selector=selector,
|
||||
static_dispatch_idx=static_dispatch_idx,
|
||||
backend_indices=backend_indices,
|
||||
aoti_fm=aoti_fm,
|
||||
core_fm=core_fm,
|
||||
cpu_fm=cpu_fm,
|
||||
cpu_vec_fm=cpu_vec_fm,
|
||||
|
431
torchgen/gen_aoti_c_shim.py
Normal file
431
torchgen/gen_aoti_c_shim.py
Normal file
@ -0,0 +1,431 @@
|
||||
import textwrap
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from torchgen.api.types import DispatcherSignature
|
||||
from torchgen.api.types.signatures import CppSignature, CppSignatureGroup
|
||||
|
||||
from torchgen.context import method_with_native_function
|
||||
from torchgen.model import (
|
||||
Argument,
|
||||
BackendIndex,
|
||||
BaseTy,
|
||||
BaseType,
|
||||
DispatchKey,
|
||||
FunctionSchema,
|
||||
ListType,
|
||||
NativeFunction,
|
||||
OptionalType,
|
||||
Type,
|
||||
)
|
||||
from torchgen.utils import mapMaybe
|
||||
|
||||
|
||||
def returns_are_all_tensor(schema: FunctionSchema) -> bool:
|
||||
return len(schema.returns) != 0 and all(
|
||||
ret.type.is_tensor_like() for ret in schema.returns
|
||||
)
|
||||
|
||||
|
||||
base_type_to_c_type = {
|
||||
BaseTy.Tensor: "AtenTensorHandle",
|
||||
BaseTy.bool: "int32_t", # Use int to pass bool
|
||||
BaseTy.int: "int64_t",
|
||||
BaseTy.SymInt: "int64_t", # Inductor-generated code won't see a SymInt
|
||||
BaseTy.Scalar: "double", # Use double to pass both integer and floating point
|
||||
BaseTy.float: "double", # TODO: how about other floating point types?
|
||||
BaseTy.str: "const char*",
|
||||
BaseTy.DeviceIndex: "int32_t",
|
||||
BaseTy.Layout: "int32_t", # Represent enum as int
|
||||
BaseTy.MemoryFormat: "int32_t", # Represent enum as int
|
||||
BaseTy.ScalarType: "int32_t", # Represent enum as int
|
||||
}
|
||||
|
||||
base_type_to_aten_type = {
|
||||
BaseTy.Tensor: "at::Tensor",
|
||||
BaseTy.bool: "bool",
|
||||
BaseTy.int: "int64_t",
|
||||
BaseTy.SymInt: "c10::SymInt",
|
||||
BaseTy.Scalar: "c10::Scalar",
|
||||
BaseTy.float: "double",
|
||||
BaseTy.str: "c10::string_view",
|
||||
BaseTy.DeviceIndex: "c10::DeviceIndex",
|
||||
BaseTy.Layout: "c10::Layout",
|
||||
BaseTy.MemoryFormat: "c10::MemoryFormat",
|
||||
BaseTy.ScalarType: "c10::ScalarType",
|
||||
}
|
||||
|
||||
base_type_to_callsite_expr = {
|
||||
BaseTy.Tensor: "*tensor_handle_to_tensor_pointer",
|
||||
BaseTy.bool: "",
|
||||
BaseTy.int: "",
|
||||
BaseTy.SymInt: "",
|
||||
BaseTy.Scalar: "",
|
||||
BaseTy.float: "",
|
||||
BaseTy.str: "",
|
||||
BaseTy.DeviceIndex: "static_cast<c10::DeviceIndex>",
|
||||
BaseTy.Layout: "static_cast<c10::Layout>",
|
||||
BaseTy.MemoryFormat: "static_cast<c10::MemoryFormat>",
|
||||
BaseTy.ScalarType: "static_cast<c10::ScalarType>",
|
||||
}
|
||||
|
||||
|
||||
# convert args to C types, names in declarations, and expressions in function bodies
|
||||
def convert_arg_type_and_name(typ: Type, name: str) -> Tuple[List[str], List[str], List[str], List[str]]: # type: ignore[return]
|
||||
if isinstance(typ, BaseType):
|
||||
if typ.name in base_type_to_c_type:
|
||||
return (
|
||||
[base_type_to_c_type[typ.name]],
|
||||
[name],
|
||||
[base_type_to_aten_type[typ.name]],
|
||||
[
|
||||
f"{base_type_to_callsite_expr[typ.name]}({name})"
|
||||
if base_type_to_callsite_expr[typ.name]
|
||||
else name
|
||||
],
|
||||
)
|
||||
elif typ.name == BaseTy.Device:
|
||||
return (
|
||||
["int32_t", "int32_t"],
|
||||
[name, name + "_index_"],
|
||||
["c10::Device"],
|
||||
[
|
||||
f"c10::Device(static_cast<c10::DeviceType>({name}), static_cast<c10::DeviceIndex>({name}_index_))"
|
||||
],
|
||||
)
|
||||
else:
|
||||
# TODO: BaseTy.Dimname, BaseTy.Generator, etc.
|
||||
raise NotImplementedError(f"TODO: add support for arg type {repr(typ)}")
|
||||
elif isinstance(typ, OptionalType):
|
||||
c_types, names, aten_types, callsite_exprs = convert_arg_type_and_name(
|
||||
typ.elem, name
|
||||
)
|
||||
j = 0 # index for names
|
||||
new_aten_types = []
|
||||
new_callsite_exprs = []
|
||||
for i, aten_type in enumerate(aten_types):
|
||||
# Use pointer to denote optional type
|
||||
c_types[j] = c_types[j] + "*"
|
||||
if aten_type.startswith("c10::ArrayRef<"):
|
||||
# ArrayRef is passed as pointer + size, but no need to add "*" to the size argument
|
||||
new_aten_types.append(f"c10::optional<{aten_type}>")
|
||||
base_type = aten_type[len("c10::ArrayRef<") : -1]
|
||||
new_callsite_exprs.append(
|
||||
f"pointer_to_optional_list<{base_type}>({names[j]}, {names[j+1]})"
|
||||
)
|
||||
j += 2
|
||||
elif aten_type == "c10::Device":
|
||||
# Device is passed as device_type + device_index
|
||||
new_aten_types.append("c10::optional<c10::Device>")
|
||||
new_callsite_exprs.append(
|
||||
f"pointer_to_optional_device({names[j]}, {names[j+1]})"
|
||||
)
|
||||
j += 2
|
||||
else:
|
||||
new_aten_types.append(f"c10::optional<{aten_type}>")
|
||||
new_callsite_exprs.append(
|
||||
f"pointer_to_optional<{aten_type}>({names[j]})"
|
||||
)
|
||||
j += 1
|
||||
|
||||
return (
|
||||
c_types,
|
||||
names,
|
||||
new_aten_types,
|
||||
new_callsite_exprs,
|
||||
)
|
||||
elif isinstance(typ, ListType):
|
||||
# Need to explictly pass the list as pointer + length
|
||||
c_types, names, aten_types, _ = convert_arg_type_and_name(typ.elem, name)
|
||||
assert len(c_types) == 1, "ListType with unsupported element type " + repr(typ)
|
||||
|
||||
# The list content should never be modified
|
||||
c_types[0] = f"const {c_types[0]}*"
|
||||
c_types.append("int64_t")
|
||||
name = names[0]
|
||||
names.append(name + "_len_")
|
||||
|
||||
atype = aten_types[0]
|
||||
callsite_exprs = []
|
||||
if atype == "bool":
|
||||
# no converter from std::vector<bool> to c10::ArrayRef<bool>
|
||||
# construct std::array<bool, N> instead
|
||||
assert typ.size is not None
|
||||
callsite_exprs.append(f"pointer_to_list<{typ.size}>({name})")
|
||||
elif atype == "c10::optional<at::Tensor>":
|
||||
# convert from std::vector<c10::optional<at::Tensor>> to c10::List<c10::optional<at::Tensor>>
|
||||
callsite_exprs.append(
|
||||
f"c10::List<{atype}>(c10::ArrayRef<{atype}>(pointer_to_list<{atype}>({name}, {name}_len_)))"
|
||||
)
|
||||
else:
|
||||
callsite_exprs.append(f"pointer_to_list<{atype}>({name}, {name}_len_)")
|
||||
|
||||
aten_types = [f"c10::ArrayRef<{t}>" for t in aten_types]
|
||||
return (
|
||||
c_types,
|
||||
names,
|
||||
aten_types,
|
||||
callsite_exprs,
|
||||
)
|
||||
|
||||
|
||||
def zip_type_and_name(types: List[str], names: List[str]) -> List[str]:
|
||||
return [typ + " " + name for typ, name in zip(types, names)]
|
||||
|
||||
|
||||
# Generate argument declarations and callsite expressions
|
||||
def gen_arguments(flat_arguments: Sequence[Argument]) -> Tuple[List[str], List[str]]:
|
||||
types = []
|
||||
new_names = []
|
||||
callsite_exprs = []
|
||||
for arg in flat_arguments:
|
||||
new_types, names, _, new_callsite_exprs = convert_arg_type_and_name(
|
||||
arg.type, arg.name
|
||||
)
|
||||
types.extend(new_types)
|
||||
new_names.extend(names)
|
||||
callsite_exprs.extend(new_callsite_exprs)
|
||||
return zip_type_and_name(types, new_names), callsite_exprs
|
||||
|
||||
|
||||
# Return values are passed out as pointer arguments because all the C shim functions
|
||||
# are expected to return AOTITorchError.
|
||||
# Generate returns as declarations and callsite expressions
|
||||
def gen_returns(schema: FunctionSchema) -> Tuple[List[str], List[str]]:
|
||||
types = []
|
||||
names = []
|
||||
for idx, ret in enumerate(schema.returns):
|
||||
names.append(f"ret{idx}")
|
||||
if isinstance(ret.type, BaseType) and ret.type.name in base_type_to_c_type:
|
||||
types.append(base_type_to_c_type[ret.type.name] + "*")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"TODO: add support for return type {repr(ret.type)}"
|
||||
)
|
||||
|
||||
def convert_return(typ: BaseType, val: str) -> str:
|
||||
if typ.name == BaseTy.Tensor:
|
||||
return f"new_tensor_handle(std::move({val}));"
|
||||
elif typ.name == BaseTy.SymInt:
|
||||
return f"{val}.expect_int()"
|
||||
elif typ.name == BaseTy.Scalar:
|
||||
return f"{val}.toDouble()"
|
||||
else:
|
||||
return val
|
||||
|
||||
ret_pointer_can_be_null = False
|
||||
unambiguous_name = schema.name.unambiguous_name()
|
||||
for name in ["_scaled_dot_product_flash_attention"]:
|
||||
if name in unambiguous_name:
|
||||
ret_pointer_can_be_null = True
|
||||
break
|
||||
|
||||
callsite_exprs: List[str] = []
|
||||
for idx, ret in enumerate(schema.returns):
|
||||
tmp = "tmp_result" if len(names) == 1 else f"std::get<{idx}>(tmp_result)"
|
||||
assert isinstance(ret.type, BaseType)
|
||||
rval = convert_return(ret.type, tmp)
|
||||
if ret_pointer_can_be_null:
|
||||
callsite_exprs.append(f"if ({names[idx]}) {{ *{names[idx]} = {rval}; }}")
|
||||
else:
|
||||
callsite_exprs.append(f"*{names[idx]} = {rval};")
|
||||
|
||||
return zip_type_and_name(types, names), callsite_exprs
|
||||
|
||||
|
||||
# gen.py generates header first and then src, so caching the result here to avoid duplicate work
|
||||
declaration_definition_cache: Dict[Tuple[str, str, str], Tuple[str, str]] = {}
|
||||
|
||||
|
||||
def gen_declaration_and_definition(
|
||||
schema: FunctionSchema, device: str, backend_call: str
|
||||
) -> Tuple[str, str]:
|
||||
func_name = schema.name.unambiguous_name()
|
||||
|
||||
global declaration_definition_cache
|
||||
if (func_name, device, backend_call) in declaration_definition_cache:
|
||||
return declaration_definition_cache[(func_name, device, backend_call)]
|
||||
|
||||
if schema.is_out_fn():
|
||||
# out_variant has out arguments in the front, and it's ok to ignore return value
|
||||
# because C shim functions only return AOTITorchError
|
||||
# Somehow at::native out-variant functions have out arguments in the back
|
||||
args, callsite_exprs = gen_arguments(
|
||||
[*schema.arguments.flat_non_out, *schema.arguments.out]
|
||||
if "at::native" in backend_call
|
||||
else [*schema.arguments.out, *schema.arguments.flat_non_out],
|
||||
)
|
||||
ret_assignments: List[str] = []
|
||||
else:
|
||||
args, callsite_exprs = gen_arguments(schema.arguments.flat_all)
|
||||
ret_declarations, ret_assignments = gen_returns(schema)
|
||||
args.extend(ret_declarations)
|
||||
|
||||
declaration = f"AOTITorchError aoti_torch_{device}_{func_name}({', '.join(args)})"
|
||||
|
||||
tmp_result = "auto tmp_result = " if ret_assignments else ""
|
||||
ret_assignments_str = "\n" + "\n".join(ret_assignments) if ret_assignments else ""
|
||||
definition = f"""
|
||||
{declaration} {{
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({{
|
||||
{tmp_result}{backend_call}(
|
||||
{textwrap.indent(', '.join(callsite_exprs), " ")}
|
||||
);{textwrap.indent(ret_assignments_str, " ")}
|
||||
}});
|
||||
}}
|
||||
"""
|
||||
declaration_definition_cache[(func_name, device, backend_call)] = (
|
||||
declaration,
|
||||
definition,
|
||||
)
|
||||
return declaration, definition
|
||||
|
||||
|
||||
def gen_static_dispatch_backend_call_signature(
|
||||
sig: Union[CppSignature, DispatcherSignature],
|
||||
f: NativeFunction,
|
||||
) -> CppSignature:
|
||||
sig = DispatcherSignature.from_schema(f.func)
|
||||
cpp_sigs = CppSignatureGroup.from_native_function(
|
||||
f, method=False, fallback_binding=False
|
||||
)
|
||||
if sig.symint and f.func.has_symint():
|
||||
cpp_sig = cpp_sigs.symint_signature
|
||||
else:
|
||||
cpp_sig = cpp_sigs.signature
|
||||
assert cpp_sig is not None
|
||||
return cpp_sig
|
||||
|
||||
|
||||
def gen_static_dispatch_backend_call(
|
||||
f: NativeFunction,
|
||||
backend_index: BackendIndex,
|
||||
) -> str:
|
||||
assert backend_index.has_kernel(f)
|
||||
sig = DispatcherSignature.from_schema(f.func)
|
||||
cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
|
||||
return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}"
|
||||
|
||||
|
||||
def get_backend_index_for_aoti(
|
||||
f: NativeFunction,
|
||||
dispatch_key: DispatchKey,
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
) -> Optional[BackendIndex]:
|
||||
if "pointwise" in f.tags:
|
||||
# TODO: No need to generate C shim for Inductor lowered ops.
|
||||
# Only skip pointwise kernels for now, and we can add more tags later.
|
||||
return None
|
||||
|
||||
backend_index = None
|
||||
if backend_indices[dispatch_key].has_kernel(f):
|
||||
backend_index = backend_indices[dispatch_key]
|
||||
elif backend_indices[DispatchKey.CompositeExplicitAutograd].has_kernel(f):
|
||||
# We need to create C shim wrappers for CompositeExplicitAutograd kernels
|
||||
backend_index = backend_indices[DispatchKey.CompositeExplicitAutograd]
|
||||
elif backend_indices[DispatchKey.CompositeExplicitAutogradNonFunctional].has_kernel(
|
||||
f
|
||||
):
|
||||
# We need to create C shim wrappers for CompositeExplicitAutogradNonFunctional kernels
|
||||
backend_index = backend_indices[
|
||||
DispatchKey.CompositeExplicitAutogradNonFunctional
|
||||
]
|
||||
return backend_index
|
||||
|
||||
|
||||
def gen_c_shim(
|
||||
f: NativeFunction,
|
||||
dispatch_key: DispatchKey,
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
header: bool,
|
||||
) -> Optional[str]:
|
||||
backend_index = get_backend_index_for_aoti(f, dispatch_key, backend_indices)
|
||||
if backend_index is None:
|
||||
return None
|
||||
|
||||
schema = f.func
|
||||
device = dispatch_key.lower()
|
||||
backend_call = gen_static_dispatch_backend_call(
|
||||
f,
|
||||
backend_index,
|
||||
)
|
||||
|
||||
try:
|
||||
if header:
|
||||
declaration, _ = gen_declaration_and_definition(
|
||||
schema, device, backend_call
|
||||
)
|
||||
return f"AOTI_TORCH_EXPORT {declaration};"
|
||||
else:
|
||||
_, definition = gen_declaration_and_definition(schema, device, backend_call)
|
||||
return definition
|
||||
|
||||
except NotImplementedError:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ShimGenerator:
|
||||
dispatch_key: DispatchKey
|
||||
backend_indices: Dict[DispatchKey, BackendIndex]
|
||||
header: bool # True to generate .h and False to generate .cpp
|
||||
|
||||
@method_with_native_function
|
||||
def __call__(self, f: NativeFunction) -> Optional[str]:
|
||||
result = gen_c_shim(f, self.dispatch_key, self.backend_indices, self.header)
|
||||
return result
|
||||
|
||||
|
||||
def gen_aoti_c_shim(
|
||||
native_functions: Sequence[NativeFunction],
|
||||
dispatch_key: DispatchKey,
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
header: bool,
|
||||
includes: str = "",
|
||||
) -> str:
|
||||
body = "\n".join(
|
||||
list(
|
||||
mapMaybe(
|
||||
ShimGenerator(dispatch_key, backend_indices, header),
|
||||
native_functions,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if header:
|
||||
return f"""
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {{
|
||||
#endif
|
||||
|
||||
{body}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}} // extern "C"
|
||||
#endif
|
||||
|
||||
"""
|
||||
else:
|
||||
device = dispatch_key.lower()
|
||||
return f"""
|
||||
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_{device}.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/{str(dispatch_key)}Functions.h>
|
||||
#include <ATen/CompositeExplicitAutogradFunctions.h>
|
||||
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
|
||||
#else
|
||||
{includes}
|
||||
#endif
|
||||
|
||||
using namespace torch::aot_inductor;
|
||||
|
||||
{body}
|
||||
|
||||
"""
|
Reference in New Issue
Block a user