[torchgen] Generate wrapper functions under custom namespaces (#81744)

Summary:
A follow up of #81581. Before these 2 PRs, if an operator with custom kernel namespace is added to `native_functions.yaml` (or any other yaml consumed by `torchgen`), although we are able to recognize the custom kernel in files such as `NativeFunctions.h` and `RegisterCPU.cpp`, we still generate backend specific wrappers under the hardcoded `at` namespace. This changes the behavior, by generating wrapper functions under custom namespaces.

For example, if the entries in yaml file looks like:

```
 - func: op_1(Tensor(a) self) -> Tensor(a)
  dispatch:
    CPU: at::op_1_kernel # ATen kernel

- func: op_2(Tensor(a) self) -> Tensor(a)
  dispatch:
    CPU: custom::op_2_kernel # custom kernel
```

We generate the following code for `CPUFunctions_inl.h` and `RegisterCPU.cpp`:

`CPUFunctions_inl.h`:
```
namespace at {
namespace cpu {
TORCH_API at::Tensor & op_1(const at::Tensor & self);
} // namespace cpu
} // namespace at

namespace custom {
namespace cpu {
TORCH_API at::Tensor & op_2(const at::Tensor & self);
} // namespace cpu
} // namespace custom

```

Notice the difference between `at::cpu` and `custom::cpu`.

Then the definition for these can be found in `RegisterCPU.cpp`.

`RegisterCPU.cpp`:
```
#include "CPUFunctions.h"

namespace at {

namespace {
at::Tensor & wrapper_op_1(const at::Tensor & self) {
    // No device check
  // DeviceGuard omitted
  return at::native::op_1_kernel(self);
}
} // anonymous namespace

TORCH_LIBRARY_IMPL(aten, CPU, m) {
m.impl("op_1", TORCH_FN(wrapper_op_1));
}

namespace cpu {
at::Tensor & op_1(at::Tensor & self) {
  return wrapper_op_1(self);
}
} // namespace cpu
} // namespace at

namespace custom {

namespace {
at::Tensor & wrapper_op_2(const at::Tensor & self) {
    // No device check
  // DeviceGuard omitted
  return at::native::op_2_kernel(self);
}
} // anonymous namespace

TORCH_LIBRARY_IMPL(aten, CPU, m) {
m.impl("op_2", TORCH_FN(wrapper_op_2));
}

namespace cpu {
at::Tensor & op_2(at::Tensor & self) {
  return wrapper_op_2(self);
}
} // namespace cpu
} // namespace custom

```

The benefit for this change is that it unifies all the namespaces derived from custom ops. In the example above, there are:

1. `custom::native` for kernels
2. `custom::<dispatch_key>` e.g., `custom::cpu` for wrappers

This customized operator will have nothing to do with `at::native`, `at::cpu` etc.

Test Plan: This is very hard to test. I will refactor this logic, abstract out some layers so it's testable. Will do it in coming PRs

Differential Revision: D37972772

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81744
Approved by: https://github.com/bdhirsh
This commit is contained in:
Mengwei Liu
2022-08-04 07:48:44 +00:00
committed by PyTorch MergeBot
parent cda8635a5e
commit 406ce692ca
8 changed files with 259 additions and 128 deletions

View File

@ -1877,6 +1877,7 @@ test_suite(
"aten/src/ATen/templates/LazyIr.h",
"aten/src/ATen/templates/LazyNonNativeIr.h",
"aten/src/ATen/templates/RegisterDispatchKey.cpp",
"aten/src/ATen/templates/RegisterDispatchDefinitions.ini",
"aten/src/ATen/native/native_functions.yaml",
"aten/src/ATen/native/tags.yaml",
"aten/src/ATen/native/ts_native_functions.yaml",

View File

@ -18,10 +18,5 @@
${DispatchKeyFunctions_inl_includes}
namespace at {
namespace ${dispatch_namespace} {
${dispatch_namespaced_declarations}
} // namespace ${dispatch_namespace}
} // namespace at

View File

@ -0,0 +1,24 @@
${ns_prologue}
// NB: TORCH_LIBRARY_IMPL must be in an anonymous namespace to avoid
// ambiguity with conflicting identifiers that may have been defined in
// at namespace already.
namespace {
${dispatch_helpers}
${dispatch_anonymous_definitions}
${static_init_dispatch_registrations}
} // anonymous namespace
${deferred_dispatch_registrations}
namespace ${dispatch_namespace} {
${dispatch_namespaced_definitions}
} // namespace ${dispatch_namespace}
${ns_epilogue}

View File

@ -50,28 +50,5 @@ $external_backend_headers
$dispatch_headers
$ops_headers
namespace at {
// NB: TORCH_LIBRARY_IMPL must be in an anonymous namespace to avoid
// ambiguity with conflicting identifiers that may have been defined in
// at namespace already.
namespace {
${dispatch_helpers}
${dispatch_anonymous_definitions}
${static_init_dispatch_registrations}
} // anonymous namespace
${deferred_dispatch_registrations}
namespace ${dispatch_namespace} {
${dispatch_namespaced_definitions}
} // namespace ${dispatch_namespace}
} // namespace at
// See template file RegisterDispatchDefinitions.ini
$dispatch_definitions

View File

@ -92,6 +92,7 @@ def define_targets(rules):
":LazyIr.h",
":LazyNonNativeIr.h",
":RegisterDispatchKey.cpp",
":RegisterDispatchDefinitions.ini",
":native_functions.yaml",
":shape_inference.h",
":tags.yaml",

View File

@ -15,6 +15,7 @@ import torchgen.api.meta as meta
import torchgen.api.native as native
import torchgen.api.structured as structured
import torchgen.dest as dest
from torchgen.api import cpp
from torchgen.api.translate import translate
from torchgen.api.types import (
@ -1408,6 +1409,168 @@ def get_native_function_declarations(
return declarations
def get_kernel_namespace(
*, f: Union[NativeFunction, NativeFunctionsGroup], backend_idx: BackendIndex
) -> str:
backend_metadata = backend_idx.get_kernel(f)
assert not backend_metadata or "::native" in backend_metadata.cpp_namespace, (
f"The kernel for function {f.func.name if isinstance(f, NativeFunction) else f.functional.func.name} "
f"with dispatch key {backend_idx.dispatch_key}"
f" has a namespace {backend_metadata.cpp_namespace} and it's not ending with '::native'."
)
return (
backend_metadata.cpp_namespace if backend_metadata else DEFAULT_KERNEL_NAMESPACE
)
# Return native function definitions grouped by dispatch key and custom namespace.
# Used in RegisterDispatchKey.cpp and etc.
def get_native_function_definitions(
*,
fm: FileManager,
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
dispatch_key: DispatchKey,
backend_idx: BackendIndex,
selector: SelectiveBuilder,
rocm: bool,
skip_dispatcher_op_registration: bool,
gen_dispatch_helpers: bool,
) -> List[str]:
definitions: List[str] = []
ns_definitions: Dict[str, List[str]] = defaultdict(list)
anonymous_definitions: Dict[str, List[str]] = defaultdict(list)
registrations: Dict[str, Dict[str, List[str]]] = defaultdict(dict)
newline = "\n"
ns_gen = dest.RegisterDispatchKey(
backend_idx,
Target.NAMESPACED_DEFINITION,
selector,
rocm=rocm,
class_method_name=None,
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
)
anonymous_gen = dest.RegisterDispatchKey(
backend_idx,
Target.ANONYMOUS_DEFINITION,
selector,
rocm=rocm,
class_method_name=None,
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
)
reg_gen = dest.RegisterDispatchKey(
backend_idx,
Target.REGISTRATION,
selector,
rocm=rocm,
class_method_name=None,
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
)
for f in grouped_native_functions:
kernel_namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace(
"::native", ""
)
ns_definitions[kernel_namespace].extend(
ns_gen(f),
)
anonymous_definitions[kernel_namespace].extend(
anonymous_gen(f),
)
namespace = (
f.namespace if isinstance(f, NativeFunction) else f.functional.namespace
)
if namespace not in registrations[kernel_namespace]:
registrations[kernel_namespace] = defaultdict(list)
registrations[kernel_namespace][namespace].extend(
reg_gen(f),
)
for kernel_namespace in ns_definitions:
if len(ns_definitions[kernel_namespace]) == 0:
continue
ns_helper = NamespaceHelper(namespace_str=kernel_namespace)
registration_body = ""
for namespace in registrations[kernel_namespace]:
if not registrations[kernel_namespace][namespace]:
continue
registration_body += f"""
TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
{newline.join(registrations[kernel_namespace][namespace])}
}};"""
definitions.extend(
fm.substitute_with_template(
"RegisterDispatchDefinitions.ini",
lambda: {
"ns_prologue": ns_helper.prologue,
"ns_epilogue": ns_helper.epilogue,
"dispatch_helpers": dest.gen_registration_helpers(backend_idx)
if gen_dispatch_helpers
else [],
"dispatch_anonymous_definitions": anonymous_definitions[
kernel_namespace
],
"static_init_dispatch_registrations": ""
if skip_dispatcher_op_registration
else registration_body,
"deferred_dispatch_registrations": "",
"dispatch_namespace": dispatch_key.lower(),
"dispatch_namespaced_definitions": ns_definitions[kernel_namespace],
},
).split(newline)
)
return definitions
# Return native function declarations grouped by dispatch key and custom namespace.
# Used in CPUFunctions_inl.h and etc.
def get_namespaced_declaration(
*,
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
dispatch_key: DispatchKey,
backend_idx: BackendIndex,
selector: SelectiveBuilder,
rocm: bool,
) -> List[str]:
declarations: List[str] = []
ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list)
newline = "\n"
func = dest.RegisterDispatchKey(
backend_idx,
Target.NAMESPACED_DECLARATION,
selector,
rocm=rocm,
class_method_name=None,
skip_dispatcher_op_registration=False,
)
for f in grouped_native_functions:
namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace(
"native", dispatch_key.lower()
)
ns_grouped_kernels[namespace].extend(
func(f),
)
for namespace, kernels in ns_grouped_kernels.items():
if len(kernels) == 0:
continue
ns_helper = NamespaceHelper(
namespace_str=namespace, entity_name="", max_level=3
)
ordered_kernels = list(OrderedDict.fromkeys(kernels))
declarations.extend(
f"""
{ns_helper.prologue}
{newline.join(ordered_kernels)}
{ns_helper.epilogue}
""".split(
newline
)
)
return declarations
# Return native function schema registration code for aten and other namespaces.
def get_native_function_schema_registrations(
*,
@ -1550,18 +1713,12 @@ def gen_aggregated_headers(
lambda: {
"DispatchKeyFunctions_inl_includes": [],
"dispatch_namespace": dispatch_key.lower(),
"dispatch_namespaced_declarations": list(
concatMap(
dest.RegisterDispatchKey(
backend_indices[dispatch_key],
Target.NAMESPACED_DECLARATION,
selector,
rocm=rocm,
class_method_name=None,
skip_dispatcher_op_registration=False,
),
grouped_native_functions,
)
"dispatch_namespaced_declarations": get_namespaced_declaration(
grouped_native_functions=grouped_native_functions,
dispatch_key=dispatch_key,
backend_idx=backend_indices[dispatch_key],
selector=selector,
rocm=rocm,
),
},
)
@ -1998,33 +2155,17 @@ def gen_source_files(
)
ns_grouped_native_functions[namespace].append(grouped_native_function)
static_init_dispatch_registrations = ""
for namespace, functions in ns_grouped_native_functions.items():
dispatch_registrations_body = (
""
if skip_dispatcher_op_registration
else "\n".join(
list(
concatMap(
dest.RegisterDispatchKey(
backend_index,
Target.REGISTRATION,
selector,
rocm=rocm,
class_method_name=None,
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
),
functions,
)
)
)
)
static_init_dispatch_registrations += f"""
TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
{dispatch_registrations_body}
}};"""
dispatch_namespace = str(dispatch_key).lower()
dispatch_definitions = get_native_function_definitions(
fm=fm,
grouped_native_functions=grouped_native_functions,
dispatch_key=dispatch_key,
backend_idx=backend_index,
selector=selector,
rocm=rocm,
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
gen_dispatch_helpers=True,
)
fm.write_with_template(
f"Register{dispatch_key}.cpp",
"RegisterDispatchKey.cpp",
@ -2037,37 +2178,8 @@ TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
backend_index, per_operator_headers, rocm
),
"ops_headers": operator_headers(),
"DispatchKey": dispatch_key,
"dispatch_namespace": dispatch_key.lower(),
"dispatch_helpers": dest.gen_registration_helpers(backend_index),
"dispatch_namespaced_definitions": list(
concatMap(
dest.RegisterDispatchKey(
backend_index,
Target.NAMESPACED_DEFINITION,
selector,
rocm=rocm,
class_method_name=None,
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
),
grouped_native_functions,
)
),
"dispatch_anonymous_definitions": list(
concatMap(
dest.RegisterDispatchKey(
backend_index,
Target.ANONYMOUS_DEFINITION,
selector,
rocm=rocm,
class_method_name=None,
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
),
grouped_native_functions,
)
),
"static_init_dispatch_registrations": static_init_dispatch_registrations,
"deferred_dispatch_registrations": "",
"dispatch_helpers": "",
"dispatch_definitions": dispatch_definitions,
},
)

View File

@ -422,6 +422,8 @@ def gen_dispatcher_registrations(
grouped_native_functions,
)
)
newline = "\n"
ns_helper = NamespaceHelper(namespace_str="at")
deferred_dispatch_registrations = ""
static_init_dispatch_registrations = ""
if eager_registration:
@ -453,8 +455,6 @@ TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() {
f"Register{dispatch_key}.cpp",
"RegisterDispatchKey.cpp",
lambda: {
"static_init_dispatch_registrations": static_init_dispatch_registrations,
"deferred_dispatch_registrations": deferred_dispatch_registrations,
"extra_cuda_headers": "",
"external_backend_headers": external_backend_headers_str,
"ops_headers": "#include <ATen/Functions.h>"
@ -465,21 +465,31 @@ TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() {
"dispatch_headers": dest.gen_registration_headers(
backend_index, per_operator_headers=per_operator_headers, rocm=False
),
"dispatch_helpers": dest.gen_registration_helpers(backend_index),
"dispatch_namespaced_definitions": "",
"dispatch_anonymous_definitions": list(
concatMap(
dest.RegisterDispatchKey(
backend_index,
Target.ANONYMOUS_DEFINITION,
selector,
rocm=False,
class_method_name=f"{class_name}",
skip_dispatcher_op_registration=False,
"dispatch_definitions": fm.substitute_with_template(
"RegisterDispatchDefinitions.ini",
lambda: {
"ns_prologue": ns_helper.prologue,
"ns_epilogue": ns_helper.epilogue,
"static_init_dispatch_registrations": static_init_dispatch_registrations,
"deferred_dispatch_registrations": deferred_dispatch_registrations,
"dispatch_helpers": dest.gen_registration_helpers(backend_index),
"dispatch_namespace": dispatch_key.lower(),
"dispatch_namespaced_definitions": "",
"dispatch_anonymous_definitions": list(
concatMap(
dest.RegisterDispatchKey(
backend_index,
Target.ANONYMOUS_DEFINITION,
selector,
rocm=False,
class_method_name=f"{class_name}",
skip_dispatcher_op_registration=False,
),
grouped_native_functions,
)
),
grouped_native_functions,
)
),
},
).split(newline),
},
)

View File

@ -176,6 +176,25 @@ class FileManager:
with open(filename, "w") as f:
f.write(contents)
# Read from template file and replace pattern with callable (type could be dict or str).
def substitute_with_template(
self, template_fn: str, env_callable: Callable[[], Union[str, Dict[str, Any]]]
) -> str:
template_path = os.path.join(self.template_dir, template_fn)
env = env_callable()
if isinstance(env, dict):
# TODO: Update the comment reference to the correct location
if "generated_comment" not in env:
comment = "@" + "generated by torchgen/gen.py"
comment += " from {}".format(os.path.basename(template_path))
env["generated_comment"] = comment
template = _read_template(template_path)
return template.substitute(env)
elif isinstance(env, str):
return env
else:
assert_never(env)
def write_with_template(
self,
filename: str,
@ -186,19 +205,11 @@ class FileManager:
assert filename not in self.filenames, "duplicate file write {filename}"
self.filenames.add(filename)
if not self.dry_run:
env = env_callable()
if isinstance(env, dict):
# TODO: Update the comment reference to the correct location
if "generated_comment" not in env:
comment = "@" + "generated by torchgen/gen.py"
comment += " from {}".format(os.path.basename(template_fn))
env["generated_comment"] = comment
template = _read_template(os.path.join(self.template_dir, template_fn))
self._write_if_changed(filename, template.substitute(env))
elif isinstance(env, str):
self._write_if_changed(filename, env)
else:
assert_never(env)
substitute_out = self.substitute_with_template(
template_fn=template_fn,
env_callable=env_callable,
)
self._write_if_changed(filename=filename, contents=substitute_out)
def write(
self,