mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[torchgen] Generate wrapper functions under custom namespaces (#81744)
Summary: A follow up of #81581. Before these 2 PRs, if an operator with custom kernel namespace is added to `native_functions.yaml` (or any other yaml consumed by `torchgen`), although we are able to recognize the custom kernel in files such as `NativeFunctions.h` and `RegisterCPU.cpp`, we still generate backend specific wrappers under the hardcoded `at` namespace. This changes the behavior, by generating wrapper functions under custom namespaces. For example, if the entries in yaml file looks like: ``` - func: op_1(Tensor(a) self) -> Tensor(a) dispatch: CPU: at::op_1_kernel # ATen kernel - func: op_2(Tensor(a) self) -> Tensor(a) dispatch: CPU: custom::op_2_kernel # custom kernel ``` We generate the following code for `CPUFunctions_inl.h` and `RegisterCPU.cpp`: `CPUFunctions_inl.h`: ``` namespace at { namespace cpu { TORCH_API at::Tensor & op_1(const at::Tensor & self); } // namespace cpu } // namespace at namespace custom { namespace cpu { TORCH_API at::Tensor & op_2(const at::Tensor & self); } // namespace cpu } // namespace custom ``` Notice the difference between `at::cpu` and `custom::cpu`. Then the definition for these can be found in `RegisterCPU.cpp`. `RegisterCPU.cpp`: ``` #include "CPUFunctions.h" namespace at { namespace { at::Tensor & wrapper_op_1(const at::Tensor & self) { // No device check // DeviceGuard omitted return at::native::op_1_kernel(self); } } // anonymous namespace TORCH_LIBRARY_IMPL(aten, CPU, m) { m.impl("op_1", TORCH_FN(wrapper_op_1)); } namespace cpu { at::Tensor & op_1(at::Tensor & self) { return wrapper_op_1(self); } } // namespace cpu } // namespace at namespace custom { namespace { at::Tensor & wrapper_op_2(const at::Tensor & self) { // No device check // DeviceGuard omitted return at::native::op_2_kernel(self); } } // anonymous namespace TORCH_LIBRARY_IMPL(aten, CPU, m) { m.impl("op_2", TORCH_FN(wrapper_op_2)); } namespace cpu { at::Tensor & op_2(at::Tensor & self) { return wrapper_op_2(self); } } // namespace cpu } // namespace custom ``` The benefit for this change is that it unifies all the namespaces derived from custom ops. In the example above, there are: 1. `custom::native` for kernels 2. `custom::<dispatch_key>` e.g., `custom::cpu` for wrappers This customized operator will have nothing to do with `at::native`, `at::cpu` etc. Test Plan: This is very hard to test. I will refactor this logic, abstract out some layers so it's testable. Will do it in coming PRs Differential Revision: D37972772 Pull Request resolved: https://github.com/pytorch/pytorch/pull/81744 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
cda8635a5e
commit
406ce692ca
@ -1877,6 +1877,7 @@ test_suite(
|
||||
"aten/src/ATen/templates/LazyIr.h",
|
||||
"aten/src/ATen/templates/LazyNonNativeIr.h",
|
||||
"aten/src/ATen/templates/RegisterDispatchKey.cpp",
|
||||
"aten/src/ATen/templates/RegisterDispatchDefinitions.ini",
|
||||
"aten/src/ATen/native/native_functions.yaml",
|
||||
"aten/src/ATen/native/tags.yaml",
|
||||
"aten/src/ATen/native/ts_native_functions.yaml",
|
||||
|
@ -18,10 +18,5 @@
|
||||
|
||||
${DispatchKeyFunctions_inl_includes}
|
||||
|
||||
namespace at {
|
||||
namespace ${dispatch_namespace} {
|
||||
|
||||
${dispatch_namespaced_declarations}
|
||||
|
||||
} // namespace ${dispatch_namespace}
|
||||
} // namespace at
|
||||
|
24
aten/src/ATen/templates/RegisterDispatchDefinitions.ini
Normal file
24
aten/src/ATen/templates/RegisterDispatchDefinitions.ini
Normal file
@ -0,0 +1,24 @@
|
||||
${ns_prologue}
|
||||
|
||||
// NB: TORCH_LIBRARY_IMPL must be in an anonymous namespace to avoid
|
||||
// ambiguity with conflicting identifiers that may have been defined in
|
||||
// at namespace already.
|
||||
namespace {
|
||||
|
||||
${dispatch_helpers}
|
||||
|
||||
${dispatch_anonymous_definitions}
|
||||
|
||||
${static_init_dispatch_registrations}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
${deferred_dispatch_registrations}
|
||||
|
||||
namespace ${dispatch_namespace} {
|
||||
|
||||
${dispatch_namespaced_definitions}
|
||||
|
||||
} // namespace ${dispatch_namespace}
|
||||
|
||||
${ns_epilogue}
|
@ -50,28 +50,5 @@ $external_backend_headers
|
||||
$dispatch_headers
|
||||
$ops_headers
|
||||
|
||||
|
||||
namespace at {
|
||||
|
||||
// NB: TORCH_LIBRARY_IMPL must be in an anonymous namespace to avoid
|
||||
// ambiguity with conflicting identifiers that may have been defined in
|
||||
// at namespace already.
|
||||
namespace {
|
||||
|
||||
${dispatch_helpers}
|
||||
|
||||
${dispatch_anonymous_definitions}
|
||||
|
||||
${static_init_dispatch_registrations}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
${deferred_dispatch_registrations}
|
||||
|
||||
namespace ${dispatch_namespace} {
|
||||
|
||||
${dispatch_namespaced_definitions}
|
||||
|
||||
} // namespace ${dispatch_namespace}
|
||||
|
||||
} // namespace at
|
||||
// See template file RegisterDispatchDefinitions.ini
|
||||
$dispatch_definitions
|
||||
|
@ -92,6 +92,7 @@ def define_targets(rules):
|
||||
":LazyIr.h",
|
||||
":LazyNonNativeIr.h",
|
||||
":RegisterDispatchKey.cpp",
|
||||
":RegisterDispatchDefinitions.ini",
|
||||
":native_functions.yaml",
|
||||
":shape_inference.h",
|
||||
":tags.yaml",
|
||||
|
250
torchgen/gen.py
250
torchgen/gen.py
@ -15,6 +15,7 @@ import torchgen.api.meta as meta
|
||||
import torchgen.api.native as native
|
||||
import torchgen.api.structured as structured
|
||||
import torchgen.dest as dest
|
||||
|
||||
from torchgen.api import cpp
|
||||
from torchgen.api.translate import translate
|
||||
from torchgen.api.types import (
|
||||
@ -1408,6 +1409,168 @@ def get_native_function_declarations(
|
||||
return declarations
|
||||
|
||||
|
||||
def get_kernel_namespace(
|
||||
*, f: Union[NativeFunction, NativeFunctionsGroup], backend_idx: BackendIndex
|
||||
) -> str:
|
||||
backend_metadata = backend_idx.get_kernel(f)
|
||||
assert not backend_metadata or "::native" in backend_metadata.cpp_namespace, (
|
||||
f"The kernel for function {f.func.name if isinstance(f, NativeFunction) else f.functional.func.name} "
|
||||
f"with dispatch key {backend_idx.dispatch_key}"
|
||||
f" has a namespace {backend_metadata.cpp_namespace} and it's not ending with '::native'."
|
||||
)
|
||||
return (
|
||||
backend_metadata.cpp_namespace if backend_metadata else DEFAULT_KERNEL_NAMESPACE
|
||||
)
|
||||
|
||||
|
||||
# Return native function definitions grouped by dispatch key and custom namespace.
|
||||
# Used in RegisterDispatchKey.cpp and etc.
|
||||
def get_native_function_definitions(
|
||||
*,
|
||||
fm: FileManager,
|
||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
||||
dispatch_key: DispatchKey,
|
||||
backend_idx: BackendIndex,
|
||||
selector: SelectiveBuilder,
|
||||
rocm: bool,
|
||||
skip_dispatcher_op_registration: bool,
|
||||
gen_dispatch_helpers: bool,
|
||||
) -> List[str]:
|
||||
definitions: List[str] = []
|
||||
ns_definitions: Dict[str, List[str]] = defaultdict(list)
|
||||
anonymous_definitions: Dict[str, List[str]] = defaultdict(list)
|
||||
registrations: Dict[str, Dict[str, List[str]]] = defaultdict(dict)
|
||||
newline = "\n"
|
||||
ns_gen = dest.RegisterDispatchKey(
|
||||
backend_idx,
|
||||
Target.NAMESPACED_DEFINITION,
|
||||
selector,
|
||||
rocm=rocm,
|
||||
class_method_name=None,
|
||||
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
|
||||
)
|
||||
anonymous_gen = dest.RegisterDispatchKey(
|
||||
backend_idx,
|
||||
Target.ANONYMOUS_DEFINITION,
|
||||
selector,
|
||||
rocm=rocm,
|
||||
class_method_name=None,
|
||||
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
|
||||
)
|
||||
reg_gen = dest.RegisterDispatchKey(
|
||||
backend_idx,
|
||||
Target.REGISTRATION,
|
||||
selector,
|
||||
rocm=rocm,
|
||||
class_method_name=None,
|
||||
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
|
||||
)
|
||||
for f in grouped_native_functions:
|
||||
kernel_namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace(
|
||||
"::native", ""
|
||||
)
|
||||
|
||||
ns_definitions[kernel_namespace].extend(
|
||||
ns_gen(f),
|
||||
)
|
||||
anonymous_definitions[kernel_namespace].extend(
|
||||
anonymous_gen(f),
|
||||
)
|
||||
namespace = (
|
||||
f.namespace if isinstance(f, NativeFunction) else f.functional.namespace
|
||||
)
|
||||
if namespace not in registrations[kernel_namespace]:
|
||||
registrations[kernel_namespace] = defaultdict(list)
|
||||
registrations[kernel_namespace][namespace].extend(
|
||||
reg_gen(f),
|
||||
)
|
||||
|
||||
for kernel_namespace in ns_definitions:
|
||||
if len(ns_definitions[kernel_namespace]) == 0:
|
||||
continue
|
||||
ns_helper = NamespaceHelper(namespace_str=kernel_namespace)
|
||||
registration_body = ""
|
||||
for namespace in registrations[kernel_namespace]:
|
||||
if not registrations[kernel_namespace][namespace]:
|
||||
continue
|
||||
registration_body += f"""
|
||||
TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
|
||||
{newline.join(registrations[kernel_namespace][namespace])}
|
||||
}};"""
|
||||
definitions.extend(
|
||||
fm.substitute_with_template(
|
||||
"RegisterDispatchDefinitions.ini",
|
||||
lambda: {
|
||||
"ns_prologue": ns_helper.prologue,
|
||||
"ns_epilogue": ns_helper.epilogue,
|
||||
"dispatch_helpers": dest.gen_registration_helpers(backend_idx)
|
||||
if gen_dispatch_helpers
|
||||
else [],
|
||||
"dispatch_anonymous_definitions": anonymous_definitions[
|
||||
kernel_namespace
|
||||
],
|
||||
"static_init_dispatch_registrations": ""
|
||||
if skip_dispatcher_op_registration
|
||||
else registration_body,
|
||||
"deferred_dispatch_registrations": "",
|
||||
"dispatch_namespace": dispatch_key.lower(),
|
||||
"dispatch_namespaced_definitions": ns_definitions[kernel_namespace],
|
||||
},
|
||||
).split(newline)
|
||||
)
|
||||
|
||||
return definitions
|
||||
|
||||
|
||||
# Return native function declarations grouped by dispatch key and custom namespace.
|
||||
# Used in CPUFunctions_inl.h and etc.
|
||||
def get_namespaced_declaration(
|
||||
*,
|
||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
||||
dispatch_key: DispatchKey,
|
||||
backend_idx: BackendIndex,
|
||||
selector: SelectiveBuilder,
|
||||
rocm: bool,
|
||||
) -> List[str]:
|
||||
declarations: List[str] = []
|
||||
ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list)
|
||||
newline = "\n"
|
||||
func = dest.RegisterDispatchKey(
|
||||
backend_idx,
|
||||
Target.NAMESPACED_DECLARATION,
|
||||
selector,
|
||||
rocm=rocm,
|
||||
class_method_name=None,
|
||||
skip_dispatcher_op_registration=False,
|
||||
)
|
||||
for f in grouped_native_functions:
|
||||
namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace(
|
||||
"native", dispatch_key.lower()
|
||||
)
|
||||
|
||||
ns_grouped_kernels[namespace].extend(
|
||||
func(f),
|
||||
)
|
||||
|
||||
for namespace, kernels in ns_grouped_kernels.items():
|
||||
if len(kernels) == 0:
|
||||
continue
|
||||
ns_helper = NamespaceHelper(
|
||||
namespace_str=namespace, entity_name="", max_level=3
|
||||
)
|
||||
ordered_kernels = list(OrderedDict.fromkeys(kernels))
|
||||
declarations.extend(
|
||||
f"""
|
||||
{ns_helper.prologue}
|
||||
{newline.join(ordered_kernels)}
|
||||
{ns_helper.epilogue}
|
||||
""".split(
|
||||
newline
|
||||
)
|
||||
)
|
||||
return declarations
|
||||
|
||||
|
||||
# Return native function schema registration code for aten and other namespaces.
|
||||
def get_native_function_schema_registrations(
|
||||
*,
|
||||
@ -1550,18 +1713,12 @@ def gen_aggregated_headers(
|
||||
lambda: {
|
||||
"DispatchKeyFunctions_inl_includes": [],
|
||||
"dispatch_namespace": dispatch_key.lower(),
|
||||
"dispatch_namespaced_declarations": list(
|
||||
concatMap(
|
||||
dest.RegisterDispatchKey(
|
||||
backend_indices[dispatch_key],
|
||||
Target.NAMESPACED_DECLARATION,
|
||||
selector,
|
||||
rocm=rocm,
|
||||
class_method_name=None,
|
||||
skip_dispatcher_op_registration=False,
|
||||
),
|
||||
grouped_native_functions,
|
||||
)
|
||||
"dispatch_namespaced_declarations": get_namespaced_declaration(
|
||||
grouped_native_functions=grouped_native_functions,
|
||||
dispatch_key=dispatch_key,
|
||||
backend_idx=backend_indices[dispatch_key],
|
||||
selector=selector,
|
||||
rocm=rocm,
|
||||
),
|
||||
},
|
||||
)
|
||||
@ -1998,33 +2155,17 @@ def gen_source_files(
|
||||
)
|
||||
ns_grouped_native_functions[namespace].append(grouped_native_function)
|
||||
|
||||
static_init_dispatch_registrations = ""
|
||||
for namespace, functions in ns_grouped_native_functions.items():
|
||||
dispatch_registrations_body = (
|
||||
""
|
||||
if skip_dispatcher_op_registration
|
||||
else "\n".join(
|
||||
list(
|
||||
concatMap(
|
||||
dest.RegisterDispatchKey(
|
||||
backend_index,
|
||||
Target.REGISTRATION,
|
||||
selector,
|
||||
rocm=rocm,
|
||||
class_method_name=None,
|
||||
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
|
||||
),
|
||||
functions,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
static_init_dispatch_registrations += f"""
|
||||
TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
|
||||
{dispatch_registrations_body}
|
||||
}};"""
|
||||
dispatch_namespace = str(dispatch_key).lower()
|
||||
dispatch_definitions = get_native_function_definitions(
|
||||
fm=fm,
|
||||
grouped_native_functions=grouped_native_functions,
|
||||
dispatch_key=dispatch_key,
|
||||
backend_idx=backend_index,
|
||||
selector=selector,
|
||||
rocm=rocm,
|
||||
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
|
||||
gen_dispatch_helpers=True,
|
||||
)
|
||||
fm.write_with_template(
|
||||
f"Register{dispatch_key}.cpp",
|
||||
"RegisterDispatchKey.cpp",
|
||||
@ -2037,37 +2178,8 @@ TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
|
||||
backend_index, per_operator_headers, rocm
|
||||
),
|
||||
"ops_headers": operator_headers(),
|
||||
"DispatchKey": dispatch_key,
|
||||
"dispatch_namespace": dispatch_key.lower(),
|
||||
"dispatch_helpers": dest.gen_registration_helpers(backend_index),
|
||||
"dispatch_namespaced_definitions": list(
|
||||
concatMap(
|
||||
dest.RegisterDispatchKey(
|
||||
backend_index,
|
||||
Target.NAMESPACED_DEFINITION,
|
||||
selector,
|
||||
rocm=rocm,
|
||||
class_method_name=None,
|
||||
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
|
||||
),
|
||||
grouped_native_functions,
|
||||
)
|
||||
),
|
||||
"dispatch_anonymous_definitions": list(
|
||||
concatMap(
|
||||
dest.RegisterDispatchKey(
|
||||
backend_index,
|
||||
Target.ANONYMOUS_DEFINITION,
|
||||
selector,
|
||||
rocm=rocm,
|
||||
class_method_name=None,
|
||||
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
|
||||
),
|
||||
grouped_native_functions,
|
||||
)
|
||||
),
|
||||
"static_init_dispatch_registrations": static_init_dispatch_registrations,
|
||||
"deferred_dispatch_registrations": "",
|
||||
"dispatch_helpers": "",
|
||||
"dispatch_definitions": dispatch_definitions,
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -422,6 +422,8 @@ def gen_dispatcher_registrations(
|
||||
grouped_native_functions,
|
||||
)
|
||||
)
|
||||
newline = "\n"
|
||||
ns_helper = NamespaceHelper(namespace_str="at")
|
||||
deferred_dispatch_registrations = ""
|
||||
static_init_dispatch_registrations = ""
|
||||
if eager_registration:
|
||||
@ -453,8 +455,6 @@ TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() {
|
||||
f"Register{dispatch_key}.cpp",
|
||||
"RegisterDispatchKey.cpp",
|
||||
lambda: {
|
||||
"static_init_dispatch_registrations": static_init_dispatch_registrations,
|
||||
"deferred_dispatch_registrations": deferred_dispatch_registrations,
|
||||
"extra_cuda_headers": "",
|
||||
"external_backend_headers": external_backend_headers_str,
|
||||
"ops_headers": "#include <ATen/Functions.h>"
|
||||
@ -465,21 +465,31 @@ TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() {
|
||||
"dispatch_headers": dest.gen_registration_headers(
|
||||
backend_index, per_operator_headers=per_operator_headers, rocm=False
|
||||
),
|
||||
"dispatch_helpers": dest.gen_registration_helpers(backend_index),
|
||||
"dispatch_namespaced_definitions": "",
|
||||
"dispatch_anonymous_definitions": list(
|
||||
concatMap(
|
||||
dest.RegisterDispatchKey(
|
||||
backend_index,
|
||||
Target.ANONYMOUS_DEFINITION,
|
||||
selector,
|
||||
rocm=False,
|
||||
class_method_name=f"{class_name}",
|
||||
skip_dispatcher_op_registration=False,
|
||||
"dispatch_definitions": fm.substitute_with_template(
|
||||
"RegisterDispatchDefinitions.ini",
|
||||
lambda: {
|
||||
"ns_prologue": ns_helper.prologue,
|
||||
"ns_epilogue": ns_helper.epilogue,
|
||||
"static_init_dispatch_registrations": static_init_dispatch_registrations,
|
||||
"deferred_dispatch_registrations": deferred_dispatch_registrations,
|
||||
"dispatch_helpers": dest.gen_registration_helpers(backend_index),
|
||||
"dispatch_namespace": dispatch_key.lower(),
|
||||
"dispatch_namespaced_definitions": "",
|
||||
"dispatch_anonymous_definitions": list(
|
||||
concatMap(
|
||||
dest.RegisterDispatchKey(
|
||||
backend_index,
|
||||
Target.ANONYMOUS_DEFINITION,
|
||||
selector,
|
||||
rocm=False,
|
||||
class_method_name=f"{class_name}",
|
||||
skip_dispatcher_op_registration=False,
|
||||
),
|
||||
grouped_native_functions,
|
||||
)
|
||||
),
|
||||
grouped_native_functions,
|
||||
)
|
||||
),
|
||||
},
|
||||
).split(newline),
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -176,6 +176,25 @@ class FileManager:
|
||||
with open(filename, "w") as f:
|
||||
f.write(contents)
|
||||
|
||||
# Read from template file and replace pattern with callable (type could be dict or str).
|
||||
def substitute_with_template(
|
||||
self, template_fn: str, env_callable: Callable[[], Union[str, Dict[str, Any]]]
|
||||
) -> str:
|
||||
template_path = os.path.join(self.template_dir, template_fn)
|
||||
env = env_callable()
|
||||
if isinstance(env, dict):
|
||||
# TODO: Update the comment reference to the correct location
|
||||
if "generated_comment" not in env:
|
||||
comment = "@" + "generated by torchgen/gen.py"
|
||||
comment += " from {}".format(os.path.basename(template_path))
|
||||
env["generated_comment"] = comment
|
||||
template = _read_template(template_path)
|
||||
return template.substitute(env)
|
||||
elif isinstance(env, str):
|
||||
return env
|
||||
else:
|
||||
assert_never(env)
|
||||
|
||||
def write_with_template(
|
||||
self,
|
||||
filename: str,
|
||||
@ -186,19 +205,11 @@ class FileManager:
|
||||
assert filename not in self.filenames, "duplicate file write {filename}"
|
||||
self.filenames.add(filename)
|
||||
if not self.dry_run:
|
||||
env = env_callable()
|
||||
if isinstance(env, dict):
|
||||
# TODO: Update the comment reference to the correct location
|
||||
if "generated_comment" not in env:
|
||||
comment = "@" + "generated by torchgen/gen.py"
|
||||
comment += " from {}".format(os.path.basename(template_fn))
|
||||
env["generated_comment"] = comment
|
||||
template = _read_template(os.path.join(self.template_dir, template_fn))
|
||||
self._write_if_changed(filename, template.substitute(env))
|
||||
elif isinstance(env, str):
|
||||
self._write_if_changed(filename, env)
|
||||
else:
|
||||
assert_never(env)
|
||||
substitute_out = self.substitute_with_template(
|
||||
template_fn=template_fn,
|
||||
env_callable=env_callable,
|
||||
)
|
||||
self._write_if_changed(filename=filename, contents=substitute_out)
|
||||
|
||||
def write(
|
||||
self,
|
||||
|
Reference in New Issue
Block a user