mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
[torchgen] Generate wrapper functions under custom namespaces (#81744)
Summary: A follow up of #81581. Before these 2 PRs, if an operator with custom kernel namespace is added to `native_functions.yaml` (or any other yaml consumed by `torchgen`), although we are able to recognize the custom kernel in files such as `NativeFunctions.h` and `RegisterCPU.cpp`, we still generate backend specific wrappers under the hardcoded `at` namespace. This changes the behavior, by generating wrapper functions under custom namespaces. For example, if the entries in yaml file looks like: ``` - func: op_1(Tensor(a) self) -> Tensor(a) dispatch: CPU: at::op_1_kernel # ATen kernel - func: op_2(Tensor(a) self) -> Tensor(a) dispatch: CPU: custom::op_2_kernel # custom kernel ``` We generate the following code for `CPUFunctions_inl.h` and `RegisterCPU.cpp`: `CPUFunctions_inl.h`: ``` namespace at { namespace cpu { TORCH_API at::Tensor & op_1(const at::Tensor & self); } // namespace cpu } // namespace at namespace custom { namespace cpu { TORCH_API at::Tensor & op_2(const at::Tensor & self); } // namespace cpu } // namespace custom ``` Notice the difference between `at::cpu` and `custom::cpu`. Then the definition for these can be found in `RegisterCPU.cpp`. `RegisterCPU.cpp`: ``` #include "CPUFunctions.h" namespace at { namespace { at::Tensor & wrapper_op_1(const at::Tensor & self) { // No device check // DeviceGuard omitted return at::native::op_1_kernel(self); } } // anonymous namespace TORCH_LIBRARY_IMPL(aten, CPU, m) { m.impl("op_1", TORCH_FN(wrapper_op_1)); } namespace cpu { at::Tensor & op_1(at::Tensor & self) { return wrapper_op_1(self); } } // namespace cpu } // namespace at namespace custom { namespace { at::Tensor & wrapper_op_2(const at::Tensor & self) { // No device check // DeviceGuard omitted return at::native::op_2_kernel(self); } } // anonymous namespace TORCH_LIBRARY_IMPL(aten, CPU, m) { m.impl("op_2", TORCH_FN(wrapper_op_2)); } namespace cpu { at::Tensor & op_2(at::Tensor & self) { return wrapper_op_2(self); } } // namespace cpu } // namespace custom ``` The benefit for this change is that it unifies all the namespaces derived from custom ops. In the example above, there are: 1. `custom::native` for kernels 2. `custom::<dispatch_key>` e.g., `custom::cpu` for wrappers This customized operator will have nothing to do with `at::native`, `at::cpu` etc. Test Plan: This is very hard to test. I will refactor this logic, abstract out some layers so it's testable. Will do it in coming PRs Differential Revision: D37972772 Pull Request resolved: https://github.com/pytorch/pytorch/pull/81744 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
cda8635a5e
commit
406ce692ca
@ -176,6 +176,25 @@ class FileManager:
|
||||
with open(filename, "w") as f:
|
||||
f.write(contents)
|
||||
|
||||
# Read from template file and replace pattern with callable (type could be dict or str).
|
||||
def substitute_with_template(
|
||||
self, template_fn: str, env_callable: Callable[[], Union[str, Dict[str, Any]]]
|
||||
) -> str:
|
||||
template_path = os.path.join(self.template_dir, template_fn)
|
||||
env = env_callable()
|
||||
if isinstance(env, dict):
|
||||
# TODO: Update the comment reference to the correct location
|
||||
if "generated_comment" not in env:
|
||||
comment = "@" + "generated by torchgen/gen.py"
|
||||
comment += " from {}".format(os.path.basename(template_path))
|
||||
env["generated_comment"] = comment
|
||||
template = _read_template(template_path)
|
||||
return template.substitute(env)
|
||||
elif isinstance(env, str):
|
||||
return env
|
||||
else:
|
||||
assert_never(env)
|
||||
|
||||
def write_with_template(
|
||||
self,
|
||||
filename: str,
|
||||
@ -186,19 +205,11 @@ class FileManager:
|
||||
assert filename not in self.filenames, "duplicate file write {filename}"
|
||||
self.filenames.add(filename)
|
||||
if not self.dry_run:
|
||||
env = env_callable()
|
||||
if isinstance(env, dict):
|
||||
# TODO: Update the comment reference to the correct location
|
||||
if "generated_comment" not in env:
|
||||
comment = "@" + "generated by torchgen/gen.py"
|
||||
comment += " from {}".format(os.path.basename(template_fn))
|
||||
env["generated_comment"] = comment
|
||||
template = _read_template(os.path.join(self.template_dir, template_fn))
|
||||
self._write_if_changed(filename, template.substitute(env))
|
||||
elif isinstance(env, str):
|
||||
self._write_if_changed(filename, env)
|
||||
else:
|
||||
assert_never(env)
|
||||
substitute_out = self.substitute_with_template(
|
||||
template_fn=template_fn,
|
||||
env_callable=env_callable,
|
||||
)
|
||||
self._write_if_changed(filename=filename, contents=substitute_out)
|
||||
|
||||
def write(
|
||||
self,
|
||||
|
Reference in New Issue
Block a user