mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[executorch] Let custom ops registration code only import ATen headers (#107064)
Summary: Basically we generate `CustomOpsNativeFunctions.h` for registering custom ops into PyTorch JIT runtime. This header needs to hookup with the C++ kernel implementation of all the custom ops. For this reason it should include ATen headers instead of Executorch headers. This PR changes it. Test Plan: Rely on existing CI jobs Differential Revision: D48282828 Pull Request resolved: https://github.com/pytorch/pytorch/pull/107064 Approved by: https://github.com/kirklandsign
This commit is contained in:
committed by
PyTorch MergeBot
parent
f26aa2dcd9
commit
ddd2f682b9
@ -426,6 +426,10 @@ def gen_headers(
|
||||
backend_indices=backend_indices,
|
||||
native_function_decl_gen=dest.compute_native_function_declaration,
|
||||
),
|
||||
"headers": [
|
||||
"#include <ATen/ATen.h>",
|
||||
"#include <torch/torch.h>",
|
||||
],
|
||||
},
|
||||
)
|
||||
aten_headers.append('#include "CustomOpsNativeFunctions.h"')
|
||||
@ -444,16 +448,26 @@ def gen_headers(
|
||||
),
|
||||
},
|
||||
)
|
||||
headers = {
|
||||
"headers": [
|
||||
"#include <executorch/runtime/core/exec_aten/exec_aten.h> // at::Tensor etc.",
|
||||
"#include <executorch/codegen/macros.h> // TORCH_API",
|
||||
"#include <executorch/runtime/kernel/kernel_runtime_context.h>",
|
||||
],
|
||||
}
|
||||
if use_aten_lib:
|
||||
cpu_fm.write(
|
||||
"NativeFunctions.h",
|
||||
lambda: {
|
||||
"nativeFunctions_declarations": get_native_function_declarations(
|
||||
grouped_native_functions=native_functions,
|
||||
backend_indices=backend_indices,
|
||||
native_function_decl_gen=dest.compute_native_function_declaration,
|
||||
),
|
||||
},
|
||||
lambda: dict(
|
||||
{
|
||||
"nativeFunctions_declarations": get_native_function_declarations(
|
||||
grouped_native_functions=native_functions,
|
||||
backend_indices=backend_indices,
|
||||
native_function_decl_gen=dest.compute_native_function_declaration,
|
||||
),
|
||||
},
|
||||
**headers,
|
||||
),
|
||||
)
|
||||
else:
|
||||
ns_grouped_kernels = get_ns_grouped_kernels(
|
||||
@ -463,11 +477,14 @@ def gen_headers(
|
||||
)
|
||||
cpu_fm.write(
|
||||
"NativeFunctions.h",
|
||||
lambda: {
|
||||
"nativeFunctions_declarations": get_native_function_declarations_from_ns_grouped_kernels(
|
||||
ns_grouped_kernels=ns_grouped_kernels,
|
||||
),
|
||||
},
|
||||
lambda: dict(
|
||||
{
|
||||
"nativeFunctions_declarations": get_native_function_declarations_from_ns_grouped_kernels(
|
||||
ns_grouped_kernels=ns_grouped_kernels,
|
||||
),
|
||||
},
|
||||
**headers,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user