mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[AOTI] Introduce an extensibility mechanism for the c shim codegen to make it easy to produce c shims for out-of-tree OP kernels as well. Add c_shim for XPU. (#136742)
[AOTI] Introduce an extensibility mechanism for the c shim codegen to make it easy to produce c shims for out-of-tree OP kernels as well. Add c shim for XPU. ### Motivation Since the current c shim codegen will only produce C wrappers for Op's registered in `aten/src/ATen/native/native_functions.yaml`, for the same backend, when a portion of out-of-tree OP's are not registered in that file, but are registered externally. For example, `third_party/torch-xpu-ops/yaml/native_functions.yaml` , in this case, the existing codegen can't fulfill the need to do extensions for the c shims from the out-of-tree OPs for the in-tree that has already been produced. ### Design To extend the c shim with more OP for a backend from out-of-tree. The PR provided a bool option `--aoti-extend` to indicate the codegen is to extend c shim from out-of-tree. The generated c shim is stored in the `extend` subdirectory , for example: ``` torch/include/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h torch/include/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.cpp torch/include/torch/csrc/inductor/aoti_torch/generated/extend/c_shim_xpu.h torch/include/torch/csrc/inductor/aoti_torch/generated/extend/c_shim_xpu.cpp ``` example usage: `python -m torchgen.gen --source-path third_party/torch-xpu-ops/yaml/ --xpu --aoti-extend --update-aoti-c-shim ` `--xpu`: generate c shim for XPU `--aoti-extend `: this is an out-of-tree OPs(defined in `third_party/torch-xpu-ops/yaml/native_functions.yaml`) extend for in-tree ops(defined in `aten/src/ATen/native/native_functions.yaml`) `--update-aoti-c-shim`: always generate c_shim_xpu.h for the extend c_shim. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136742 Approved by: https://github.com/EikanWang, https://github.com/desertfire ghstack dependencies: #139025
This commit is contained in:
committed by
PyTorch MergeBot
parent
929a647363
commit
191971e01d
1
.gitignore
vendored
1
.gitignore
vendored
@ -88,6 +88,7 @@ torch/csrc/cudnn/cuDNN.cpp
|
||||
torch/csrc/generated
|
||||
torch/csrc/generic/TensorMethods.cpp
|
||||
torch/csrc/inductor/aoti_torch/generated/*.cpp
|
||||
torch/csrc/inductor/aoti_torch/generated/extend/*
|
||||
torch/csrc/jit/generated/*
|
||||
torch/csrc/jit/fuser/config.h
|
||||
torch/csrc/nn/THCUNN.cpp
|
||||
|
@ -325,6 +325,10 @@ set(GENERATED_CXX_TORCH_CUDA
|
||||
"${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/generated/c_shim_cuda.cpp"
|
||||
)
|
||||
|
||||
set(GENERATED_CXX_TORCH_XPU
|
||||
"${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/generated/c_shim_xpu.cpp"
|
||||
)
|
||||
|
||||
set(TORCH_GENERATED_CODE
|
||||
${GENERATED_CXX_TORCH}
|
||||
${GENERATED_H_TORCH}
|
||||
@ -1042,6 +1046,7 @@ elseif(USE_CUDA)
|
||||
endif()
|
||||
|
||||
if(USE_XPU)
|
||||
list(APPEND Caffe2_XPU_SRCS ${GENERATED_CXX_TORCH_XPU})
|
||||
add_library(torch_xpu ${Caffe2_XPU_SRCS})
|
||||
torch_compile_options(torch_xpu) # see cmake/public/utils.cmake
|
||||
target_compile_definitions(torch_xpu PRIVATE USE_XPU)
|
||||
|
@ -212,6 +212,11 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
}}
|
||||
"""
|
||||
)
|
||||
extend_aoti_path = (
|
||||
f"torch/csrc/inductor/aoti_torch/generated/extend/c_shim_{self.device}.h"
|
||||
)
|
||||
if os.path.exists(extend_aoti_path):
|
||||
self.header.splice(f"#include <{extend_aoti_path}>")
|
||||
|
||||
enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [
|
||||
"linux",
|
||||
|
56
torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h
Normal file
56
torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h
Normal file
@ -0,0 +1,56 @@
|
||||
|
||||
|
||||
// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND.
|
||||
// See https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 for details
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__addmm_activation(AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha, int32_t use_gelu, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_addbmm(AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_addmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_addmv(AtenTensorHandle self, AtenTensorHandle mat, AtenTensorHandle vec, double beta, double alpha, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_bmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_cholesky_solve(AtenTensorHandle self, AtenTensorHandle input2, int32_t upper, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_convolution(AtenTensorHandle input, AtenTensorHandle weight, AtenTensorHandle* bias, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t transposed, const int64_t* output_padding, int64_t output_padding_len_, int64_t groups, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_convolution_backward(AtenTensorHandle grad_output, AtenTensorHandle input, AtenTensorHandle weight, const int64_t** bias_sizes, int64_t bias_sizes_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t transposed, const int64_t* output_padding, int64_t output_padding_len_, int64_t groups, const int32_t* output_mask, int64_t output_mask_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_cummax(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_cummin(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_exponential(AtenTensorHandle self, double lambd, AtenGeneratorHandle* generator, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_index_put(AtenTensorHandle self, const AtenTensorHandle** indices, int64_t indices_len_, AtenTensorHandle values, int32_t accumulate, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_kthvalue(AtenTensorHandle self, int64_t k, int64_t dim, int32_t keepdim, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_logcumsumexp(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_masked_scatter(AtenTensorHandle self, AtenTensorHandle mask, AtenTensorHandle source, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_masked_scatter_backward(AtenTensorHandle grad_output, AtenTensorHandle mask, const int64_t* sizes, int64_t sizes_len_, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_mul_Scalar(AtenTensorHandle self, double other, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_normal_functional(AtenTensorHandle self, double mean, double std, AtenGeneratorHandle* generator, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_polar(AtenTensorHandle abs, AtenTensorHandle angle, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_rand(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_rand_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randint(int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randint_generator(int64_t high, const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randint_low(int64_t low, int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randint_low_out(AtenTensorHandle out, int64_t low, int64_t high, const int64_t* size, int64_t size_len_);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randn(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randn_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randperm(int64_t n, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_reshape(AtenTensorHandle self, const int64_t* shape, int64_t shape_len_, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_slice_Tensor(AtenTensorHandle self, int64_t dim, int64_t* start, int64_t* end, int64_t step, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_uniform(AtenTensorHandle self, double from, double to, AtenGeneratorHandle* generator, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_view_dtype(AtenTensorHandle self, int32_t dtype, AtenTensorHandle* ret0);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
@ -2201,6 +2201,8 @@ def gen_source_files(
|
||||
per_operator_headers: bool,
|
||||
skip_dispatcher_op_registration: bool,
|
||||
update_aoti_c_shim: bool,
|
||||
aoti_backends: set[DispatchKey],
|
||||
extend_aoti_c_shim: bool,
|
||||
) -> None:
|
||||
extra_cuda_headers = """\
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
@ -2366,7 +2368,7 @@ def gen_source_files(
|
||||
structured_func_group_dict[func.structured_delegate] = func_group
|
||||
break
|
||||
|
||||
if dispatch_key in (DispatchKey.CPU, DispatchKey.CUDA):
|
||||
if dispatch_key in aoti_backends:
|
||||
fallbacks = {}
|
||||
for func in native_functions:
|
||||
op_name = get_fallback_op_name(func)
|
||||
@ -2384,6 +2386,7 @@ def gen_source_files(
|
||||
dispatch_key,
|
||||
backend_indices,
|
||||
header=True,
|
||||
extend_aoti_c_shim=extend_aoti_c_shim,
|
||||
includes="",
|
||||
)
|
||||
if update_aoti_c_shim:
|
||||
@ -2424,7 +2427,11 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
|
||||
headers = []
|
||||
for func in fallback_native_functions:
|
||||
header = get_header_for_aoti(
|
||||
func, structured_func_group_dict, dispatch_key, backend_indices
|
||||
func,
|
||||
structured_func_group_dict,
|
||||
dispatch_key,
|
||||
backend_indices,
|
||||
extend_aoti_c_shim=extend_aoti_c_shim,
|
||||
)
|
||||
if header is not None:
|
||||
headers.append(header)
|
||||
@ -2442,6 +2449,7 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
|
||||
dispatch_key,
|
||||
backend_indices,
|
||||
header=False,
|
||||
extend_aoti_c_shim=extend_aoti_c_shim,
|
||||
includes=headers_for_aoti() + "\n" + extra_headers,
|
||||
),
|
||||
)
|
||||
@ -2837,6 +2845,16 @@ def main() -> None:
|
||||
help="Update AOTInductor C shim after adding an entry to inductor_fallback_ops in torchgen/aoti/fallback_ops.py. "
|
||||
"WARNING: Do not use this unless you are sure what you are doing!!!",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--extend-aoti-c-shim",
|
||||
action="store_true",
|
||||
help="This Flag indicates the generation of c shims for out-of-tree ATen ops,"
|
||||
"which is an extension to the In-tree ATen op c shims. This flag needs to be combined with"
|
||||
"---source-path=<out-of-tree native_functions.yaml>"
|
||||
"--aoti-install-dir=<in-tree aoti_install_dir>/extend"
|
||||
" default is torch/csrc/inductor/aoti_torch/generated/extend"
|
||||
"WARNING: Do not use this unless you are sure what you are doing!!!",
|
||||
)
|
||||
|
||||
options = parser.parse_args()
|
||||
|
||||
@ -2906,6 +2924,7 @@ def main() -> None:
|
||||
Path(core_install_dir).mkdir(parents=True, exist_ok=True)
|
||||
ops_install_dir = f"{options.install_dir}/ops"
|
||||
Path(ops_install_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
aoti_install_dir = f"{options.aoti_install_dir}"
|
||||
Path(aoti_install_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@ -2930,11 +2949,18 @@ def main() -> None:
|
||||
DispatchKey.CompositeExplicitAutogradNonFunctional,
|
||||
DispatchKey.Meta,
|
||||
}
|
||||
|
||||
aoti_backends = {
|
||||
DispatchKey.CPU,
|
||||
DispatchKey.CUDA,
|
||||
}
|
||||
|
||||
if options.mps:
|
||||
functions_keys.add(DispatchKey.MPS)
|
||||
|
||||
if options.xpu:
|
||||
functions_keys.add(DispatchKey.XPU)
|
||||
aoti_backends.add(DispatchKey.XPU)
|
||||
|
||||
if options.backend_whitelist:
|
||||
dispatch_keys = [
|
||||
@ -2975,6 +3001,8 @@ def main() -> None:
|
||||
per_operator_headers=options.per_operator_headers,
|
||||
skip_dispatcher_op_registration=options.skip_dispatcher_op_registration,
|
||||
update_aoti_c_shim=options.update_aoti_c_shim,
|
||||
aoti_backends=aoti_backends,
|
||||
extend_aoti_c_shim=options.extend_aoti_c_shim,
|
||||
)
|
||||
|
||||
if "headers" in options.generate:
|
||||
|
@ -319,6 +319,7 @@ def get_backend_index_for_aoti(
|
||||
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
|
||||
dispatch_key: DispatchKey,
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
extend_aoti_c_shim: bool,
|
||||
) -> BackendIndex | None:
|
||||
backend_index = None
|
||||
if backend_indices[dispatch_key].has_kernel(func) or (
|
||||
@ -329,18 +330,24 @@ def get_backend_index_for_aoti(
|
||||
)
|
||||
):
|
||||
backend_index = backend_indices[dispatch_key]
|
||||
elif backend_indices[DispatchKey.CompositeExplicitAutograd].has_kernel(func):
|
||||
# We need to create C shim wrappers for CompositeExplicitAutograd kernels
|
||||
backend_index = backend_indices[DispatchKey.CompositeExplicitAutograd]
|
||||
elif backend_indices[DispatchKey.CompositeExplicitAutogradNonFunctional].has_kernel(
|
||||
func
|
||||
):
|
||||
# We need to create C shim wrappers for CompositeExplicitAutogradNonFunctional kernels
|
||||
backend_index = backend_indices[
|
||||
else:
|
||||
# for the extend out-of-tree kernels, we don't need to
|
||||
# duplicatly create C shim wrappers for other dispatch keys
|
||||
if extend_aoti_c_shim:
|
||||
return backend_index
|
||||
|
||||
elif backend_indices[DispatchKey.CompositeExplicitAutograd].has_kernel(func):
|
||||
# We need to create C shim wrappers for CompositeExplicitAutograd kernels
|
||||
backend_index = backend_indices[DispatchKey.CompositeExplicitAutograd]
|
||||
elif backend_indices[
|
||||
DispatchKey.CompositeExplicitAutogradNonFunctional
|
||||
]
|
||||
elif backend_indices[DispatchKey.CompositeImplicitAutograd].has_kernel(func):
|
||||
backend_index = backend_indices[DispatchKey.CompositeImplicitAutograd]
|
||||
].has_kernel(func):
|
||||
# We need to create C shim wrappers for CompositeExplicitAutogradNonFunctional kernels
|
||||
backend_index = backend_indices[
|
||||
DispatchKey.CompositeExplicitAutogradNonFunctional
|
||||
]
|
||||
elif backend_indices[DispatchKey.CompositeImplicitAutograd].has_kernel(func):
|
||||
backend_index = backend_indices[DispatchKey.CompositeImplicitAutograd]
|
||||
|
||||
return backend_index
|
||||
|
||||
@ -350,9 +357,10 @@ def get_header_for_aoti(
|
||||
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
|
||||
dispatch_key: DispatchKey,
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
extend_aoti_c_shim: bool,
|
||||
) -> str | None:
|
||||
backend_index = get_backend_index_for_aoti(
|
||||
func, func_group_mapping, dispatch_key, backend_indices
|
||||
func, func_group_mapping, dispatch_key, backend_indices, extend_aoti_c_shim
|
||||
)
|
||||
return (
|
||||
None
|
||||
@ -375,9 +383,10 @@ def gen_c_shim(
|
||||
dispatch_key: DispatchKey,
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
header: bool,
|
||||
extend_aoti_c_shim: bool,
|
||||
) -> str | None:
|
||||
backend_index = get_backend_index_for_aoti(
|
||||
func, func_group_mapping, dispatch_key, backend_indices
|
||||
func, func_group_mapping, dispatch_key, backend_indices, extend_aoti_c_shim
|
||||
)
|
||||
if backend_index is None:
|
||||
return None
|
||||
@ -409,6 +418,7 @@ class ShimGenerator:
|
||||
dispatch_key: DispatchKey
|
||||
backend_indices: dict[DispatchKey, BackendIndex]
|
||||
header: bool # True to generate .h and False to generate .cpp
|
||||
extend_aoti_c_shim: bool
|
||||
|
||||
@method_with_native_function
|
||||
def __call__(
|
||||
@ -421,6 +431,7 @@ class ShimGenerator:
|
||||
self.dispatch_key,
|
||||
self.backend_indices,
|
||||
self.header,
|
||||
self.extend_aoti_c_shim,
|
||||
)
|
||||
return result
|
||||
|
||||
@ -431,20 +442,24 @@ def gen_aoti_c_shim(
|
||||
dispatch_key: DispatchKey,
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
header: bool,
|
||||
extend_aoti_c_shim: bool,
|
||||
includes: str = "",
|
||||
) -> str:
|
||||
body = "\n".join(
|
||||
list(
|
||||
mapMaybe(
|
||||
ShimGenerator(
|
||||
func_group_mapping, dispatch_key, backend_indices, header
|
||||
func_group_mapping,
|
||||
dispatch_key,
|
||||
backend_indices,
|
||||
header,
|
||||
extend_aoti_c_shim,
|
||||
),
|
||||
native_functions,
|
||||
)
|
||||
)
|
||||
)
|
||||
device = dispatch_key.lower()
|
||||
|
||||
warning = """
|
||||
// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND.
|
||||
// See https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 for details"""
|
||||
@ -469,10 +484,13 @@ extern "C" {{
|
||||
"""
|
||||
|
||||
else:
|
||||
c_shim_include = (
|
||||
f"#include <torch/csrc/inductor/aoti_torch/generated/c_shim_{device}.h>"
|
||||
)
|
||||
return f"""
|
||||
{warning}
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_{device}.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/generated/{"extend/" if extend_aoti_c_shim else ""}c_shim_{device}.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
|
Reference in New Issue
Block a user