mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[torchgen] Generate wrapper functions under custom namespaces (#81744)
Summary: A follow up of #81581. Before these 2 PRs, if an operator with custom kernel namespace is added to `native_functions.yaml` (or any other yaml consumed by `torchgen`), although we are able to recognize the custom kernel in files such as `NativeFunctions.h` and `RegisterCPU.cpp`, we still generate backend specific wrappers under the hardcoded `at` namespace. This changes the behavior, by generating wrapper functions under custom namespaces. For example, if the entries in yaml file looks like: ``` - func: op_1(Tensor(a) self) -> Tensor(a) dispatch: CPU: at::op_1_kernel # ATen kernel - func: op_2(Tensor(a) self) -> Tensor(a) dispatch: CPU: custom::op_2_kernel # custom kernel ``` We generate the following code for `CPUFunctions_inl.h` and `RegisterCPU.cpp`: `CPUFunctions_inl.h`: ``` namespace at { namespace cpu { TORCH_API at::Tensor & op_1(const at::Tensor & self); } // namespace cpu } // namespace at namespace custom { namespace cpu { TORCH_API at::Tensor & op_2(const at::Tensor & self); } // namespace cpu } // namespace custom ``` Notice the difference between `at::cpu` and `custom::cpu`. Then the definition for these can be found in `RegisterCPU.cpp`. `RegisterCPU.cpp`: ``` #include "CPUFunctions.h" namespace at { namespace { at::Tensor & wrapper_op_1(const at::Tensor & self) { // No device check // DeviceGuard omitted return at::native::op_1_kernel(self); } } // anonymous namespace TORCH_LIBRARY_IMPL(aten, CPU, m) { m.impl("op_1", TORCH_FN(wrapper_op_1)); } namespace cpu { at::Tensor & op_1(at::Tensor & self) { return wrapper_op_1(self); } } // namespace cpu } // namespace at namespace custom { namespace { at::Tensor & wrapper_op_2(const at::Tensor & self) { // No device check // DeviceGuard omitted return at::native::op_2_kernel(self); } } // anonymous namespace TORCH_LIBRARY_IMPL(aten, CPU, m) { m.impl("op_2", TORCH_FN(wrapper_op_2)); } namespace cpu { at::Tensor & op_2(at::Tensor & self) { return wrapper_op_2(self); } } // namespace cpu } // namespace custom ``` The benefit for this change is that it unifies all the namespaces derived from custom ops. In the example above, there are: 1. `custom::native` for kernels 2. `custom::<dispatch_key>` e.g., `custom::cpu` for wrappers This customized operator will have nothing to do with `at::native`, `at::cpu` etc. Test Plan: This is very hard to test. I will refactor this logic, abstract out some layers so it's testable. Will do it in coming PRs Differential Revision: D37972772 Pull Request resolved: https://github.com/pytorch/pytorch/pull/81744 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
cda8635a5e
commit
406ce692ca
@ -422,6 +422,8 @@ def gen_dispatcher_registrations(
|
||||
grouped_native_functions,
|
||||
)
|
||||
)
|
||||
newline = "\n"
|
||||
ns_helper = NamespaceHelper(namespace_str="at")
|
||||
deferred_dispatch_registrations = ""
|
||||
static_init_dispatch_registrations = ""
|
||||
if eager_registration:
|
||||
@ -453,8 +455,6 @@ TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() {
|
||||
f"Register{dispatch_key}.cpp",
|
||||
"RegisterDispatchKey.cpp",
|
||||
lambda: {
|
||||
"static_init_dispatch_registrations": static_init_dispatch_registrations,
|
||||
"deferred_dispatch_registrations": deferred_dispatch_registrations,
|
||||
"extra_cuda_headers": "",
|
||||
"external_backend_headers": external_backend_headers_str,
|
||||
"ops_headers": "#include <ATen/Functions.h>"
|
||||
@ -465,21 +465,31 @@ TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() {
|
||||
"dispatch_headers": dest.gen_registration_headers(
|
||||
backend_index, per_operator_headers=per_operator_headers, rocm=False
|
||||
),
|
||||
"dispatch_helpers": dest.gen_registration_helpers(backend_index),
|
||||
"dispatch_namespaced_definitions": "",
|
||||
"dispatch_anonymous_definitions": list(
|
||||
concatMap(
|
||||
dest.RegisterDispatchKey(
|
||||
backend_index,
|
||||
Target.ANONYMOUS_DEFINITION,
|
||||
selector,
|
||||
rocm=False,
|
||||
class_method_name=f"{class_name}",
|
||||
skip_dispatcher_op_registration=False,
|
||||
"dispatch_definitions": fm.substitute_with_template(
|
||||
"RegisterDispatchDefinitions.ini",
|
||||
lambda: {
|
||||
"ns_prologue": ns_helper.prologue,
|
||||
"ns_epilogue": ns_helper.epilogue,
|
||||
"static_init_dispatch_registrations": static_init_dispatch_registrations,
|
||||
"deferred_dispatch_registrations": deferred_dispatch_registrations,
|
||||
"dispatch_helpers": dest.gen_registration_helpers(backend_index),
|
||||
"dispatch_namespace": dispatch_key.lower(),
|
||||
"dispatch_namespaced_definitions": "",
|
||||
"dispatch_anonymous_definitions": list(
|
||||
concatMap(
|
||||
dest.RegisterDispatchKey(
|
||||
backend_index,
|
||||
Target.ANONYMOUS_DEFINITION,
|
||||
selector,
|
||||
rocm=False,
|
||||
class_method_name=f"{class_name}",
|
||||
skip_dispatcher_op_registration=False,
|
||||
),
|
||||
grouped_native_functions,
|
||||
)
|
||||
),
|
||||
grouped_native_functions,
|
||||
)
|
||||
),
|
||||
},
|
||||
).split(newline),
|
||||
},
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user