[torchgen] Support multiple namespace in NativeFunctions.h (#79733)

Summary:
This is a follow up to #78015. This PR
* introduces namespace logic for generating `NativeFunctions.h`.
* adds helper function to extract namespace from string
* relaxes the constraint on the levels we support for custom kernel namespace to 2

Test Plan:
Yaml entry:
```
- func: unsqueeze.out(Tensor(a) self, int dim, *, Tensor(a!) out) -> Tensor(a!)
  variants: function
  device_check: NoCheck
  dispatch:
    CPU: custom_1::custom_2::unsqueeze
```

Generated `NativeFunctions.h`:

```
namespace custom_1 {
namespace custom_2 {
namespace native {
    TORCH_API at::Tensor & unsqueeze(const at::Tensor & self, int64_t dim, at::Tensor & out);
} // namespace native
} // namespace custom_2
} // namespace custom_1

```

Differential Revision: D37198111

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79733
Approved by: https://github.com/bdhirsh
This commit is contained in:
Mengwei Liu
2022-07-08 21:56:49 +00:00
committed by PyTorch MergeBot
parent ff6655defb
commit 5c8a9803c8
8 changed files with 165 additions and 95 deletions

View File

@ -14,10 +14,4 @@
#include <vector>
${extra_includes}
namespace at {
namespace native {
${native_function_declarations}
} // namespace native
} // namespace at

View File

@ -30,10 +30,4 @@
${NativeFunctions_includes}
namespace at {
namespace native {
${NativeFunctions_declarations}
} // namespace native
} // namespace at

22
tools/test/test_utils.py Normal file
View File

@ -0,0 +1,22 @@
import unittest
from torchgen.utils import NamespaceHelper
class TestNamespaceHelper(unittest.TestCase):
def test_create_from_namespaced_tuple(self) -> None:
helper = NamespaceHelper.from_namespaced_entity("aten::add")
self.assertEqual(helper.entity_name, "add")
self.assertEqual(helper.get_cpp_namespace(), "aten")
def test_default_namespace(self) -> None:
helper = NamespaceHelper.from_namespaced_entity("add")
self.assertEqual(helper.entity_name, "add")
self.assertEqual(helper.get_cpp_namespace(), "")
self.assertEqual(helper.get_cpp_namespace("default"), "default")
def test_namespace_levels_more_than_max(self) -> None:
with self.assertRaises(AssertionError):
NamespaceHelper(
namespace_str="custom_1::custom_2", entity_name="", max_level=1
)

View File

@ -32,6 +32,7 @@ from torchgen.model import (
NativeFunctionsViewGroup,
ViewSchemaKind,
BaseOperatorName,
DEFAULT_KERNEL_NAMESPACE,
)
from torchgen.native_function_generation import (
pre_group_native_functions,
@ -64,6 +65,7 @@ from torchgen.utils import (
FileManager,
assert_never,
make_file_manager,
NamespaceHelper,
)
from torchgen.context import (
method_with_native_function,
@ -111,37 +113,6 @@ T = TypeVar("T")
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
class NamespaceHelper:
"""A helper for constructing the namespace open and close strings for a nested set of namespaces.
e.g. for namespace_str torch::lazy,
prologue:
namespace torch {
namespace lazy {
epilogue:
} // namespace lazy
} // namespace torch
"""
def __init__(self, namespace_str: str):
# cpp_namespace can be a colon joined string such as torch::lazy
cpp_namespaces = namespace_str.split("::")
self.prologue_ = "\n".join([f"namespace {n} {{" for n in cpp_namespaces])
self.epilogue_ = "\n".join(
[f"}} // namespace {n}" for n in reversed(cpp_namespaces)]
)
@property
def prologue(self) -> str:
return self.prologue_
@property
def epilogue(self) -> str:
return self.epilogue_
# A custom loader for YAML to let us also keep track of line numbers
# of each entry in the YAML file
class LineLoader(YamlLoader):
@ -1374,6 +1345,53 @@ def get_grouped_native_functions(
)
# Return native function declarations grouped by their namespaces.
def get_native_function_declarations(
*,
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
backend_indices: Dict[DispatchKey, BackendIndex],
) -> List[str]:
declarations: List[str] = []
ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list)
newline = "\n"
for f in grouped_native_functions:
native_function_namespaces = set()
for backend_idx in backend_indices.values():
backend_metadata = backend_idx.get_kernel(f)
namespace = (
backend_metadata.cpp_namespace
if backend_metadata
else DEFAULT_KERNEL_NAMESPACE
)
native_function_namespaces.add(namespace)
assert (
len(native_function_namespaces) == 1
), "Codegen only supports one namespace per operator."
ns_grouped_kernels[namespace].extend(
dest.compute_native_function_declaration(f, backend_idx)
)
for namespace, kernels in ns_grouped_kernels.items():
ns_helper = NamespaceHelper(
namespace_str=namespace,
entity_name="",
max_level=3,
)
# Convert to a set first to remove duplicate kernel names. Backends are
# allowed to repeat kernel names; only generate the declaration once!
ordered_kernels = list(OrderedDict.fromkeys(kernels))
declarations.extend(
f"""
{ns_helper.prologue}
{newline.join(ordered_kernels)}
{ns_helper.epilogue}
""".split(
newline
)
)
return declarations
def gen_aggregated_headers(
*,
native_functions: Sequence[NativeFunction],
@ -1450,27 +1468,15 @@ def gen_aggregated_headers(
),
},
)
declarations = get_native_function_declarations(
grouped_native_functions=grouped_native_functions,
backend_indices=backend_indices,
)
cpu_fm.write(
"NativeFunctions.h",
lambda: {
"NativeFunctions_includes": ["#include <ATen/NativeMetaFunctions.h>"],
"NativeFunctions_declarations": list(
concatMap(
# Convert to a set first to remove duplicate kernel names.
# Backends are allowed to repeat kernel names; only generate the declaration once!
lambda f: list(
OrderedDict.fromkeys(
concatMap(
lambda backend_idx: dest.compute_native_function_declaration(
f, backend_idx
),
backend_indices.values(),
)
)
),
grouped_native_functions,
)
),
"NativeFunctions_declarations": declarations,
},
)
@ -1598,7 +1604,9 @@ def gen_per_operator_headers(
),
},
)
declarations = get_native_function_declarations(
grouped_native_functions=grouped_functions, backend_indices=backend_indices
)
ops_fm.write_with_template(
f"{name}_native.h",
"NativeFunction.h",
@ -1606,23 +1614,7 @@ def gen_per_operator_headers(
"extra_includes": (
f"#include <ATen/ops/{name}_meta.h>" if is_structured else []
),
"native_function_declarations": list(
concatMap(
# Convert to a set first to remove duplicate kernel names.
# Backends are allowed to repeat kernel names; only generate the declaration once!
lambda f: list(
OrderedDict.fromkeys(
concatMap(
lambda backend_idx: dest.compute_native_function_declaration(
f, backend_idx
),
backend_indices.values(),
)
)
),
grouped_functions,
)
),
"native_function_declarations": declarations,
},
)

View File

@ -8,7 +8,6 @@ from typing import List, Dict, Union, Sequence, Optional
from torchgen.gen import (
get_grouped_native_functions,
parse_native_yaml,
NamespaceHelper,
)
from torchgen.model import (
BackendIndex,
@ -19,7 +18,14 @@ from torchgen.model import (
OperatorName,
)
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import Target, concatMap, context, YamlLoader, FileManager
from torchgen.utils import (
Target,
concatMap,
context,
YamlLoader,
FileManager,
NamespaceHelper,
)
from torchgen.context import native_function_manager
from torchgen.code_template import CodeTemplate
import torchgen.dest as dest

View File

@ -22,7 +22,6 @@ from torchgen.dest.lazy_ir import GenLazyIR, GenTSLazyIR
from torchgen.gen import (
get_grouped_native_functions,
parse_native_yaml,
NamespaceHelper,
)
from torchgen.api.lazy import setValueT
@ -33,7 +32,7 @@ from torchgen.model import (
OperatorName,
)
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import concatMap, YamlLoader, FileManager
from torchgen.utils import concatMap, YamlLoader, FileManager, NamespaceHelper
import torchgen.dest as dest
from .gen_backend_stubs import (
parse_backend_yaml,

View File

@ -6,7 +6,7 @@ from dataclasses import dataclass
from enum import auto, Enum
from typing import Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
from torchgen.utils import assert_never
from torchgen.utils import assert_never, NamespaceHelper
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
@ -454,12 +454,11 @@ class NativeFunction:
funcs = e.pop("func")
assert isinstance(funcs, str), f"not a str: {funcs}"
# only support one level of namespace. E.g., aten::add
namespaced_funcs = funcs.split("::", 1)
if len(namespaced_funcs) == 1:
namespace = "aten"
else:
namespace = namespaced_funcs[0]
func = FunctionSchema.parse(namespaced_funcs[-1])
namespace_helper = NamespaceHelper.from_namespaced_entity(
namespaced_entity=funcs, max_level=1
)
namespace = namespace_helper.get_cpp_namespace(default="aten")
func = FunctionSchema.parse(namespace_helper.entity_name)
cpp_no_default_args_list = e.pop("cpp_no_default_args", [])
assert isinstance(cpp_no_default_args_list, list)
@ -579,19 +578,20 @@ class NativeFunction:
f"Dispatch key {dispatch_key} of kernel {v} "
"is not a supported dispatch key."
)
# We only allow one level of namespace for kernels and operator.
# We only allow at most 2 levels of namespace for kernels.
# We will append "native" to a custom kernel namespace.
tokens = v.split("::", 1)
namespace_helper = NamespaceHelper.from_namespaced_entity(
v, max_level=2
)
kernel_namespace = namespace_helper.get_cpp_namespace(default="at")
# Why is 'structured' included? External backends (e.g.
# XLA) opt into which ops are structured independently
# of which in-tree ops are structured
dispatch[dispatch_key] = BackendMetadata(
kernel=tokens[-1],
kernel=namespace_helper.entity_name,
structured=structured
and is_structured_dispatch_key(dispatch_key),
cpp_namespace=(tokens[0] + "::native")
if len(tokens) > 1
else DEFAULT_KERNEL_NAMESPACE,
cpp_namespace=(kernel_namespace + "::native"),
)
if (
dispatch_key is DispatchKey.CompositeImplicitAutograd

View File

@ -396,3 +396,66 @@ def _format(
indent_str = " " * indent
body = f", {delimiter}{curr_indent_str}".join(fields_str)
return f"{start}{indent_str}{body}{end}"
class NamespaceHelper:
"""A helper for constructing the namespace open and close strings for a nested set of namespaces.
e.g. for namespace_str torch::lazy,
prologue:
namespace torch {
namespace lazy {
epilogue:
} // namespace lazy
} // namespace torch
"""
def __init__(self, namespace_str: str, entity_name: str = "", max_level: int = 2):
# cpp_namespace can be a colon joined string such as torch::lazy
cpp_namespaces = namespace_str.split("::")
assert (
len(cpp_namespaces) <= max_level
), f"Codegen doesn't support more than {max_level} level(s) of custom namespace. Got {namespace_str}."
self.cpp_namespace_ = namespace_str
self.prologue_ = "\n".join([f"namespace {n} {{" for n in cpp_namespaces])
self.epilogue_ = "\n".join(
[f"}} // namespace {n}" for n in reversed(cpp_namespaces)]
)
self.namespaces_ = cpp_namespaces
self.entity_name_ = entity_name
@staticmethod
def from_namespaced_entity(
namespaced_entity: str, max_level: int = 2
) -> "NamespaceHelper":
"""
Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add"
"""
names = namespaced_entity.split("::")
entity_name = names[-1]
namespace_str = "::".join(names[:-1])
return NamespaceHelper(
namespace_str=namespace_str, entity_name=entity_name, max_level=max_level
)
@property
def prologue(self) -> str:
return self.prologue_
@property
def epilogue(self) -> str:
return self.epilogue_
@property
def entity_name(self) -> str:
return self.entity_name_
# Only allow certain level of namespaces
def get_cpp_namespace(self, default: str = "") -> str:
"""
Return the namespace string from joining all the namespaces by "::" (hence no leading "::").
Return default if namespace string is empty.
"""
return self.cpp_namespace_ if self.cpp_namespace_ else default