[torchgen] Generate wrapper functions under custom namespaces (#81744)

Summary:
A follow up of #81581. Before these 2 PRs, if an operator with custom kernel namespace is added to `native_functions.yaml` (or any other yaml consumed by `torchgen`), although we are able to recognize the custom kernel in files such as `NativeFunctions.h` and `RegisterCPU.cpp`, we still generate backend specific wrappers under the hardcoded `at` namespace. This changes the behavior, by generating wrapper functions under custom namespaces.

For example, if the entries in yaml file looks like:

```
 - func: op_1(Tensor(a) self) -> Tensor(a)
  dispatch:
    CPU: at::op_1_kernel # ATen kernel

- func: op_2(Tensor(a) self) -> Tensor(a)
  dispatch:
    CPU: custom::op_2_kernel # custom kernel
```

We generate the following code for `CPUFunctions_inl.h` and `RegisterCPU.cpp`:

`CPUFunctions_inl.h`:
```
namespace at {
namespace cpu {
TORCH_API at::Tensor & op_1(const at::Tensor & self);
} // namespace cpu
} // namespace at

namespace custom {
namespace cpu {
TORCH_API at::Tensor & op_2(const at::Tensor & self);
} // namespace cpu
} // namespace custom

```

Notice the difference between `at::cpu` and `custom::cpu`.

Then the definition for these can be found in `RegisterCPU.cpp`.

`RegisterCPU.cpp`:
```
#include "CPUFunctions.h"

namespace at {

namespace {
at::Tensor & wrapper_op_1(const at::Tensor & self) {
    // No device check
  // DeviceGuard omitted
  return at::native::op_1_kernel(self);
}
} // anonymous namespace

TORCH_LIBRARY_IMPL(aten, CPU, m) {
m.impl("op_1", TORCH_FN(wrapper_op_1));
}

namespace cpu {
at::Tensor & op_1(at::Tensor & self) {
  return wrapper_op_1(self);
}
} // namespace cpu
} // namespace at

namespace custom {

namespace {
at::Tensor & wrapper_op_2(const at::Tensor & self) {
    // No device check
  // DeviceGuard omitted
  return at::native::op_2_kernel(self);
}
} // anonymous namespace

TORCH_LIBRARY_IMPL(aten, CPU, m) {
m.impl("op_2", TORCH_FN(wrapper_op_2));
}

namespace cpu {
at::Tensor & op_2(at::Tensor & self) {
  return wrapper_op_2(self);
}
} // namespace cpu
} // namespace custom

```

The benefit for this change is that it unifies all the namespaces derived from custom ops. In the example above, there are:

1. `custom::native` for kernels
2. `custom::<dispatch_key>` e.g., `custom::cpu` for wrappers

This customized operator will have nothing to do with `at::native`, `at::cpu` etc.

Test Plan: This is very hard to test. I will refactor this logic, abstract out some layers so it's testable. Will do it in coming PRs

Differential Revision: D37972772

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81744
Approved by: https://github.com/bdhirsh
This commit is contained in:
Mengwei Liu
2022-08-04 07:48:44 +00:00
committed by PyTorch MergeBot
parent cda8635a5e
commit 406ce692ca
8 changed files with 259 additions and 128 deletions

View File

@ -176,6 +176,25 @@ class FileManager:
with open(filename, "w") as f:
f.write(contents)
# Read from template file and replace pattern with callable (type could be dict or str).
def substitute_with_template(
self, template_fn: str, env_callable: Callable[[], Union[str, Dict[str, Any]]]
) -> str:
template_path = os.path.join(self.template_dir, template_fn)
env = env_callable()
if isinstance(env, dict):
# TODO: Update the comment reference to the correct location
if "generated_comment" not in env:
comment = "@" + "generated by torchgen/gen.py"
comment += " from {}".format(os.path.basename(template_path))
env["generated_comment"] = comment
template = _read_template(template_path)
return template.substitute(env)
elif isinstance(env, str):
return env
else:
assert_never(env)
def write_with_template(
self,
filename: str,
@ -186,19 +205,11 @@ class FileManager:
assert filename not in self.filenames, "duplicate file write {filename}"
self.filenames.add(filename)
if not self.dry_run:
env = env_callable()
if isinstance(env, dict):
# TODO: Update the comment reference to the correct location
if "generated_comment" not in env:
comment = "@" + "generated by torchgen/gen.py"
comment += " from {}".format(os.path.basename(template_fn))
env["generated_comment"] = comment
template = _read_template(os.path.join(self.template_dir, template_fn))
self._write_if_changed(filename, template.substitute(env))
elif isinstance(env, str):
self._write_if_changed(filename, env)
else:
assert_never(env)
substitute_out = self.substitute_with_template(
template_fn=template_fn,
env_callable=env_callable,
)
self._write_if_changed(filename=filename, contents=substitute_out)
def write(
self,