[AOTI] Introduce an extensibility mechanism for the c shim codegen to make it easy to produce c shims for out-of-tree OP kernels as well. Add c_shim for XPU. (#136742)

[AOTI] Introduce an extensibility mechanism for the c shim codegen to make it easy to produce c shims for out-of-tree OP kernels as well. Add c shim for XPU.

### Motivation
Since the current c shim codegen will only produce C wrappers for Op's registered in `aten/src/ATen/native/native_functions.yaml`, for the same backend, when a portion of out-of-tree OP's are not registered in that file, but are registered externally. For example, `third_party/torch-xpu-ops/yaml/native_functions.yaml` , in this case, the existing codegen can't fulfill the need to do extensions for the c shims from the out-of-tree OPs for the in-tree that has already been produced.

### Design
To extend the c shim with more OP for a backend from out-of-tree.
The PR provided a bool option `--aoti-extend` to indicate the codegen is to extend c shim from out-of-tree.
The generated c shim is stored in the `extend` subdirectory , for example:
```
torch/include/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h
torch/include/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.cpp
torch/include/torch/csrc/inductor/aoti_torch/generated/extend/c_shim_xpu.h
torch/include/torch/csrc/inductor/aoti_torch/generated/extend/c_shim_xpu.cpp
```
example usage:
`python -m torchgen.gen --source-path third_party/torch-xpu-ops/yaml/ --xpu --aoti-extend --update-aoti-c-shim  `
`--xpu`:  generate c shim for XPU
`--aoti-extend `: this is an out-of-tree OPs(defined in `third_party/torch-xpu-ops/yaml/native_functions.yaml`)  extend for in-tree ops(defined in `aten/src/ATen/native/native_functions.yaml`)
`--update-aoti-c-shim`: always generate c_shim_xpu.h for the extend c_shim.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136742
Approved by: https://github.com/EikanWang, https://github.com/desertfire
ghstack dependencies: #139025
This commit is contained in:
xinan.lin
2024-11-08 18:04:35 -08:00
committed by PyTorch MergeBot
parent 929a647363
commit 191971e01d
6 changed files with 131 additions and 18 deletions

1
.gitignore vendored
View File

@ -88,6 +88,7 @@ torch/csrc/cudnn/cuDNN.cpp
torch/csrc/generated
torch/csrc/generic/TensorMethods.cpp
torch/csrc/inductor/aoti_torch/generated/*.cpp
torch/csrc/inductor/aoti_torch/generated/extend/*
torch/csrc/jit/generated/*
torch/csrc/jit/fuser/config.h
torch/csrc/nn/THCUNN.cpp

View File

@ -325,6 +325,10 @@ set(GENERATED_CXX_TORCH_CUDA
"${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/generated/c_shim_cuda.cpp"
)
set(GENERATED_CXX_TORCH_XPU
"${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/generated/c_shim_xpu.cpp"
)
set(TORCH_GENERATED_CODE
${GENERATED_CXX_TORCH}
${GENERATED_H_TORCH}
@ -1042,6 +1046,7 @@ elseif(USE_CUDA)
endif()
if(USE_XPU)
list(APPEND Caffe2_XPU_SRCS ${GENERATED_CXX_TORCH_XPU})
add_library(torch_xpu ${Caffe2_XPU_SRCS})
torch_compile_options(torch_xpu) # see cmake/public/utils.cmake
target_compile_definitions(torch_xpu PRIVATE USE_XPU)

View File

@ -212,6 +212,11 @@ class CppWrapperCpu(PythonWrapperCodegen):
}}
"""
)
extend_aoti_path = (
f"torch/csrc/inductor/aoti_torch/generated/extend/c_shim_{self.device}.h"
)
if os.path.exists(extend_aoti_path):
self.header.splice(f"#include <{extend_aoti_path}>")
enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [
"linux",

View File

@ -0,0 +1,56 @@
// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND.
// See https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 for details
#pragma once
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#ifdef __cplusplus
extern "C" {
#endif
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__addmm_activation(AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha, int32_t use_gelu, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_addbmm(AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_addmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_addmv(AtenTensorHandle self, AtenTensorHandle mat, AtenTensorHandle vec, double beta, double alpha, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_bmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_cholesky_solve(AtenTensorHandle self, AtenTensorHandle input2, int32_t upper, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_convolution(AtenTensorHandle input, AtenTensorHandle weight, AtenTensorHandle* bias, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t transposed, const int64_t* output_padding, int64_t output_padding_len_, int64_t groups, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_convolution_backward(AtenTensorHandle grad_output, AtenTensorHandle input, AtenTensorHandle weight, const int64_t** bias_sizes, int64_t bias_sizes_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t transposed, const int64_t* output_padding, int64_t output_padding_len_, int64_t groups, const int32_t* output_mask, int64_t output_mask_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_cummax(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_cummin(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_exponential(AtenTensorHandle self, double lambd, AtenGeneratorHandle* generator, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_index_put(AtenTensorHandle self, const AtenTensorHandle** indices, int64_t indices_len_, AtenTensorHandle values, int32_t accumulate, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_kthvalue(AtenTensorHandle self, int64_t k, int64_t dim, int32_t keepdim, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_logcumsumexp(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_masked_scatter(AtenTensorHandle self, AtenTensorHandle mask, AtenTensorHandle source, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_masked_scatter_backward(AtenTensorHandle grad_output, AtenTensorHandle mask, const int64_t* sizes, int64_t sizes_len_, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_mul_Scalar(AtenTensorHandle self, double other, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_normal_functional(AtenTensorHandle self, double mean, double std, AtenGeneratorHandle* generator, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_polar(AtenTensorHandle abs, AtenTensorHandle angle, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_rand(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_rand_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randint(int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randint_generator(int64_t high, const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randint_low(int64_t low, int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randint_low_out(AtenTensorHandle out, int64_t low, int64_t high, const int64_t* size, int64_t size_len_);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randn(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randn_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randperm(int64_t n, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_reshape(AtenTensorHandle self, const int64_t* shape, int64_t shape_len_, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_slice_Tensor(AtenTensorHandle self, int64_t dim, int64_t* start, int64_t* end, int64_t step, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_uniform(AtenTensorHandle self, double from, double to, AtenGeneratorHandle* generator, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_view_dtype(AtenTensorHandle self, int32_t dtype, AtenTensorHandle* ret0);
#ifdef __cplusplus
} // extern "C"
#endif

View File

@ -2201,6 +2201,8 @@ def gen_source_files(
per_operator_headers: bool,
skip_dispatcher_op_registration: bool,
update_aoti_c_shim: bool,
aoti_backends: set[DispatchKey],
extend_aoti_c_shim: bool,
) -> None:
extra_cuda_headers = """\
#include <c10/cuda/CUDAGuard.h>
@ -2366,7 +2368,7 @@ def gen_source_files(
structured_func_group_dict[func.structured_delegate] = func_group
break
if dispatch_key in (DispatchKey.CPU, DispatchKey.CUDA):
if dispatch_key in aoti_backends:
fallbacks = {}
for func in native_functions:
op_name = get_fallback_op_name(func)
@ -2384,6 +2386,7 @@ def gen_source_files(
dispatch_key,
backend_indices,
header=True,
extend_aoti_c_shim=extend_aoti_c_shim,
includes="",
)
if update_aoti_c_shim:
@ -2424,7 +2427,11 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
headers = []
for func in fallback_native_functions:
header = get_header_for_aoti(
func, structured_func_group_dict, dispatch_key, backend_indices
func,
structured_func_group_dict,
dispatch_key,
backend_indices,
extend_aoti_c_shim=extend_aoti_c_shim,
)
if header is not None:
headers.append(header)
@ -2442,6 +2449,7 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
dispatch_key,
backend_indices,
header=False,
extend_aoti_c_shim=extend_aoti_c_shim,
includes=headers_for_aoti() + "\n" + extra_headers,
),
)
@ -2837,6 +2845,16 @@ def main() -> None:
help="Update AOTInductor C shim after adding an entry to inductor_fallback_ops in torchgen/aoti/fallback_ops.py. "
"WARNING: Do not use this unless you are sure what you are doing!!!",
)
parser.add_argument(
"--extend-aoti-c-shim",
action="store_true",
help="This Flag indicates the generation of c shims for out-of-tree ATen ops,"
"which is an extension to the In-tree ATen op c shims. This flag needs to be combined with"
"---source-path=<out-of-tree native_functions.yaml>"
"--aoti-install-dir=<in-tree aoti_install_dir>/extend"
" default is torch/csrc/inductor/aoti_torch/generated/extend"
"WARNING: Do not use this unless you are sure what you are doing!!!",
)
options = parser.parse_args()
@ -2906,6 +2924,7 @@ def main() -> None:
Path(core_install_dir).mkdir(parents=True, exist_ok=True)
ops_install_dir = f"{options.install_dir}/ops"
Path(ops_install_dir).mkdir(parents=True, exist_ok=True)
aoti_install_dir = f"{options.aoti_install_dir}"
Path(aoti_install_dir).mkdir(parents=True, exist_ok=True)
@ -2930,11 +2949,18 @@ def main() -> None:
DispatchKey.CompositeExplicitAutogradNonFunctional,
DispatchKey.Meta,
}
aoti_backends = {
DispatchKey.CPU,
DispatchKey.CUDA,
}
if options.mps:
functions_keys.add(DispatchKey.MPS)
if options.xpu:
functions_keys.add(DispatchKey.XPU)
aoti_backends.add(DispatchKey.XPU)
if options.backend_whitelist:
dispatch_keys = [
@ -2975,6 +3001,8 @@ def main() -> None:
per_operator_headers=options.per_operator_headers,
skip_dispatcher_op_registration=options.skip_dispatcher_op_registration,
update_aoti_c_shim=options.update_aoti_c_shim,
aoti_backends=aoti_backends,
extend_aoti_c_shim=options.extend_aoti_c_shim,
)
if "headers" in options.generate:

View File

@ -319,6 +319,7 @@ def get_backend_index_for_aoti(
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
dispatch_key: DispatchKey,
backend_indices: dict[DispatchKey, BackendIndex],
extend_aoti_c_shim: bool,
) -> BackendIndex | None:
backend_index = None
if backend_indices[dispatch_key].has_kernel(func) or (
@ -329,18 +330,24 @@ def get_backend_index_for_aoti(
)
):
backend_index = backend_indices[dispatch_key]
elif backend_indices[DispatchKey.CompositeExplicitAutograd].has_kernel(func):
# We need to create C shim wrappers for CompositeExplicitAutograd kernels
backend_index = backend_indices[DispatchKey.CompositeExplicitAutograd]
elif backend_indices[DispatchKey.CompositeExplicitAutogradNonFunctional].has_kernel(
func
):
# We need to create C shim wrappers for CompositeExplicitAutogradNonFunctional kernels
backend_index = backend_indices[
else:
# for the extend out-of-tree kernels, we don't need to
# duplicatly create C shim wrappers for other dispatch keys
if extend_aoti_c_shim:
return backend_index
elif backend_indices[DispatchKey.CompositeExplicitAutograd].has_kernel(func):
# We need to create C shim wrappers for CompositeExplicitAutograd kernels
backend_index = backend_indices[DispatchKey.CompositeExplicitAutograd]
elif backend_indices[
DispatchKey.CompositeExplicitAutogradNonFunctional
]
elif backend_indices[DispatchKey.CompositeImplicitAutograd].has_kernel(func):
backend_index = backend_indices[DispatchKey.CompositeImplicitAutograd]
].has_kernel(func):
# We need to create C shim wrappers for CompositeExplicitAutogradNonFunctional kernels
backend_index = backend_indices[
DispatchKey.CompositeExplicitAutogradNonFunctional
]
elif backend_indices[DispatchKey.CompositeImplicitAutograd].has_kernel(func):
backend_index = backend_indices[DispatchKey.CompositeImplicitAutograd]
return backend_index
@ -350,9 +357,10 @@ def get_header_for_aoti(
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
dispatch_key: DispatchKey,
backend_indices: dict[DispatchKey, BackendIndex],
extend_aoti_c_shim: bool,
) -> str | None:
backend_index = get_backend_index_for_aoti(
func, func_group_mapping, dispatch_key, backend_indices
func, func_group_mapping, dispatch_key, backend_indices, extend_aoti_c_shim
)
return (
None
@ -375,9 +383,10 @@ def gen_c_shim(
dispatch_key: DispatchKey,
backend_indices: dict[DispatchKey, BackendIndex],
header: bool,
extend_aoti_c_shim: bool,
) -> str | None:
backend_index = get_backend_index_for_aoti(
func, func_group_mapping, dispatch_key, backend_indices
func, func_group_mapping, dispatch_key, backend_indices, extend_aoti_c_shim
)
if backend_index is None:
return None
@ -409,6 +418,7 @@ class ShimGenerator:
dispatch_key: DispatchKey
backend_indices: dict[DispatchKey, BackendIndex]
header: bool # True to generate .h and False to generate .cpp
extend_aoti_c_shim: bool
@method_with_native_function
def __call__(
@ -421,6 +431,7 @@ class ShimGenerator:
self.dispatch_key,
self.backend_indices,
self.header,
self.extend_aoti_c_shim,
)
return result
@ -431,20 +442,24 @@ def gen_aoti_c_shim(
dispatch_key: DispatchKey,
backend_indices: dict[DispatchKey, BackendIndex],
header: bool,
extend_aoti_c_shim: bool,
includes: str = "",
) -> str:
body = "\n".join(
list(
mapMaybe(
ShimGenerator(
func_group_mapping, dispatch_key, backend_indices, header
func_group_mapping,
dispatch_key,
backend_indices,
header,
extend_aoti_c_shim,
),
native_functions,
)
)
)
device = dispatch_key.lower()
warning = """
// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND.
// See https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 for details"""
@ -469,10 +484,13 @@ extern "C" {{
"""
else:
c_shim_include = (
f"#include <torch/csrc/inductor/aoti_torch/generated/c_shim_{device}.h>"
)
return f"""
{warning}
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_{device}.h>
#include <torch/csrc/inductor/aoti_torch/generated/{"extend/" if extend_aoti_c_shim else ""}c_shim_{device}.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#ifndef AT_PER_OPERATOR_HEADERS