[torchgen] Generate wrapper functions under custom namespaces (#81744)

Summary:
A follow up of #81581. Before these 2 PRs, if an operator with custom kernel namespace is added to `native_functions.yaml` (or any other yaml consumed by `torchgen`), although we are able to recognize the custom kernel in files such as `NativeFunctions.h` and `RegisterCPU.cpp`, we still generate backend specific wrappers under the hardcoded `at` namespace. This changes the behavior, by generating wrapper functions under custom namespaces.

For example, if the entries in yaml file looks like:

```
 - func: op_1(Tensor(a) self) -> Tensor(a)
  dispatch:
    CPU: at::op_1_kernel # ATen kernel

- func: op_2(Tensor(a) self) -> Tensor(a)
  dispatch:
    CPU: custom::op_2_kernel # custom kernel
```

We generate the following code for `CPUFunctions_inl.h` and `RegisterCPU.cpp`:

`CPUFunctions_inl.h`:
```
namespace at {
namespace cpu {
TORCH_API at::Tensor & op_1(const at::Tensor & self);
} // namespace cpu
} // namespace at

namespace custom {
namespace cpu {
TORCH_API at::Tensor & op_2(const at::Tensor & self);
} // namespace cpu
} // namespace custom

```

Notice the difference between `at::cpu` and `custom::cpu`.

Then the definition for these can be found in `RegisterCPU.cpp`.

`RegisterCPU.cpp`:
```
#include "CPUFunctions.h"

namespace at {

namespace {
at::Tensor & wrapper_op_1(const at::Tensor & self) {
    // No device check
  // DeviceGuard omitted
  return at::native::op_1_kernel(self);
}
} // anonymous namespace

TORCH_LIBRARY_IMPL(aten, CPU, m) {
m.impl("op_1", TORCH_FN(wrapper_op_1));
}

namespace cpu {
at::Tensor & op_1(at::Tensor & self) {
  return wrapper_op_1(self);
}
} // namespace cpu
} // namespace at

namespace custom {

namespace {
at::Tensor & wrapper_op_2(const at::Tensor & self) {
    // No device check
  // DeviceGuard omitted
  return at::native::op_2_kernel(self);
}
} // anonymous namespace

TORCH_LIBRARY_IMPL(aten, CPU, m) {
m.impl("op_2", TORCH_FN(wrapper_op_2));
}

namespace cpu {
at::Tensor & op_2(at::Tensor & self) {
  return wrapper_op_2(self);
}
} // namespace cpu
} // namespace custom

```

The benefit for this change is that it unifies all the namespaces derived from custom ops. In the example above, there are:

1. `custom::native` for kernels
2. `custom::<dispatch_key>` e.g., `custom::cpu` for wrappers

This customized operator will have nothing to do with `at::native`, `at::cpu` etc.

Test Plan: This is very hard to test. I will refactor this logic, abstract out some layers so it's testable. Will do it in coming PRs

Differential Revision: D37972772

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81744
Approved by: https://github.com/bdhirsh
This commit is contained in:
Mengwei Liu
2022-08-04 07:48:44 +00:00
committed by PyTorch MergeBot
parent cda8635a5e
commit 406ce692ca
8 changed files with 259 additions and 128 deletions

View File

@ -422,6 +422,8 @@ def gen_dispatcher_registrations(
grouped_native_functions,
)
)
newline = "\n"
ns_helper = NamespaceHelper(namespace_str="at")
deferred_dispatch_registrations = ""
static_init_dispatch_registrations = ""
if eager_registration:
@ -453,8 +455,6 @@ TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() {
f"Register{dispatch_key}.cpp",
"RegisterDispatchKey.cpp",
lambda: {
"static_init_dispatch_registrations": static_init_dispatch_registrations,
"deferred_dispatch_registrations": deferred_dispatch_registrations,
"extra_cuda_headers": "",
"external_backend_headers": external_backend_headers_str,
"ops_headers": "#include <ATen/Functions.h>"
@ -465,21 +465,31 @@ TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() {
"dispatch_headers": dest.gen_registration_headers(
backend_index, per_operator_headers=per_operator_headers, rocm=False
),
"dispatch_helpers": dest.gen_registration_helpers(backend_index),
"dispatch_namespaced_definitions": "",
"dispatch_anonymous_definitions": list(
concatMap(
dest.RegisterDispatchKey(
backend_index,
Target.ANONYMOUS_DEFINITION,
selector,
rocm=False,
class_method_name=f"{class_name}",
skip_dispatcher_op_registration=False,
"dispatch_definitions": fm.substitute_with_template(
"RegisterDispatchDefinitions.ini",
lambda: {
"ns_prologue": ns_helper.prologue,
"ns_epilogue": ns_helper.epilogue,
"static_init_dispatch_registrations": static_init_dispatch_registrations,
"deferred_dispatch_registrations": deferred_dispatch_registrations,
"dispatch_helpers": dest.gen_registration_helpers(backend_index),
"dispatch_namespace": dispatch_key.lower(),
"dispatch_namespaced_definitions": "",
"dispatch_anonymous_definitions": list(
concatMap(
dest.RegisterDispatchKey(
backend_index,
Target.ANONYMOUS_DEFINITION,
selector,
rocm=False,
class_method_name=f"{class_name}",
skip_dispatcher_op_registration=False,
),
grouped_native_functions,
)
),
grouped_native_functions,
)
),
},
).split(newline),
},
)