mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[torchgen] Support multiple namespace in NativeFunctions.h (#79733)
Summary: This is a follow up to #78015. This PR * introduces namespace logic for generating `NativeFunctions.h`. * adds helper function to extract namespace from string * relaxes the constraint on the levels we support for custom kernel namespace to 2 Test Plan: Yaml entry: ``` - func: unsqueeze.out(Tensor(a) self, int dim, *, Tensor(a!) out) -> Tensor(a!) variants: function device_check: NoCheck dispatch: CPU: custom_1::custom_2::unsqueeze ``` Generated `NativeFunctions.h`: ``` namespace custom_1 { namespace custom_2 { namespace native { TORCH_API at::Tensor & unsqueeze(const at::Tensor & self, int64_t dim, at::Tensor & out); } // namespace native } // namespace custom_2 } // namespace custom_1 ``` Differential Revision: D37198111 Pull Request resolved: https://github.com/pytorch/pytorch/pull/79733 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
ff6655defb
commit
5c8a9803c8
@ -14,10 +14,4 @@
|
||||
#include <vector>
|
||||
${extra_includes}
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
${native_function_declarations}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
@ -30,10 +30,4 @@
|
||||
|
||||
${NativeFunctions_includes}
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
${NativeFunctions_declarations}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
22
tools/test/test_utils.py
Normal file
22
tools/test/test_utils.py
Normal file
@ -0,0 +1,22 @@
|
||||
import unittest
|
||||
|
||||
from torchgen.utils import NamespaceHelper
|
||||
|
||||
|
||||
class TestNamespaceHelper(unittest.TestCase):
|
||||
def test_create_from_namespaced_tuple(self) -> None:
|
||||
helper = NamespaceHelper.from_namespaced_entity("aten::add")
|
||||
self.assertEqual(helper.entity_name, "add")
|
||||
self.assertEqual(helper.get_cpp_namespace(), "aten")
|
||||
|
||||
def test_default_namespace(self) -> None:
|
||||
helper = NamespaceHelper.from_namespaced_entity("add")
|
||||
self.assertEqual(helper.entity_name, "add")
|
||||
self.assertEqual(helper.get_cpp_namespace(), "")
|
||||
self.assertEqual(helper.get_cpp_namespace("default"), "default")
|
||||
|
||||
def test_namespace_levels_more_than_max(self) -> None:
|
||||
with self.assertRaises(AssertionError):
|
||||
NamespaceHelper(
|
||||
namespace_str="custom_1::custom_2", entity_name="", max_level=1
|
||||
)
|
124
torchgen/gen.py
124
torchgen/gen.py
@ -32,6 +32,7 @@ from torchgen.model import (
|
||||
NativeFunctionsViewGroup,
|
||||
ViewSchemaKind,
|
||||
BaseOperatorName,
|
||||
DEFAULT_KERNEL_NAMESPACE,
|
||||
)
|
||||
from torchgen.native_function_generation import (
|
||||
pre_group_native_functions,
|
||||
@ -64,6 +65,7 @@ from torchgen.utils import (
|
||||
FileManager,
|
||||
assert_never,
|
||||
make_file_manager,
|
||||
NamespaceHelper,
|
||||
)
|
||||
from torchgen.context import (
|
||||
method_with_native_function,
|
||||
@ -111,37 +113,6 @@ T = TypeVar("T")
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||
|
||||
|
||||
class NamespaceHelper:
|
||||
"""A helper for constructing the namespace open and close strings for a nested set of namespaces.
|
||||
|
||||
e.g. for namespace_str torch::lazy,
|
||||
|
||||
prologue:
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
epilogue:
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
"""
|
||||
|
||||
def __init__(self, namespace_str: str):
|
||||
# cpp_namespace can be a colon joined string such as torch::lazy
|
||||
cpp_namespaces = namespace_str.split("::")
|
||||
self.prologue_ = "\n".join([f"namespace {n} {{" for n in cpp_namespaces])
|
||||
self.epilogue_ = "\n".join(
|
||||
[f"}} // namespace {n}" for n in reversed(cpp_namespaces)]
|
||||
)
|
||||
|
||||
@property
|
||||
def prologue(self) -> str:
|
||||
return self.prologue_
|
||||
|
||||
@property
|
||||
def epilogue(self) -> str:
|
||||
return self.epilogue_
|
||||
|
||||
|
||||
# A custom loader for YAML to let us also keep track of line numbers
|
||||
# of each entry in the YAML file
|
||||
class LineLoader(YamlLoader):
|
||||
@ -1374,6 +1345,53 @@ def get_grouped_native_functions(
|
||||
)
|
||||
|
||||
|
||||
# Return native function declarations grouped by their namespaces.
|
||||
def get_native_function_declarations(
|
||||
*,
|
||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
) -> List[str]:
|
||||
declarations: List[str] = []
|
||||
ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list)
|
||||
newline = "\n"
|
||||
for f in grouped_native_functions:
|
||||
native_function_namespaces = set()
|
||||
for backend_idx in backend_indices.values():
|
||||
backend_metadata = backend_idx.get_kernel(f)
|
||||
namespace = (
|
||||
backend_metadata.cpp_namespace
|
||||
if backend_metadata
|
||||
else DEFAULT_KERNEL_NAMESPACE
|
||||
)
|
||||
native_function_namespaces.add(namespace)
|
||||
assert (
|
||||
len(native_function_namespaces) == 1
|
||||
), "Codegen only supports one namespace per operator."
|
||||
ns_grouped_kernels[namespace].extend(
|
||||
dest.compute_native_function_declaration(f, backend_idx)
|
||||
)
|
||||
|
||||
for namespace, kernels in ns_grouped_kernels.items():
|
||||
ns_helper = NamespaceHelper(
|
||||
namespace_str=namespace,
|
||||
entity_name="",
|
||||
max_level=3,
|
||||
)
|
||||
# Convert to a set first to remove duplicate kernel names. Backends are
|
||||
# allowed to repeat kernel names; only generate the declaration once!
|
||||
ordered_kernels = list(OrderedDict.fromkeys(kernels))
|
||||
declarations.extend(
|
||||
f"""
|
||||
{ns_helper.prologue}
|
||||
{newline.join(ordered_kernels)}
|
||||
{ns_helper.epilogue}
|
||||
""".split(
|
||||
newline
|
||||
)
|
||||
)
|
||||
return declarations
|
||||
|
||||
|
||||
def gen_aggregated_headers(
|
||||
*,
|
||||
native_functions: Sequence[NativeFunction],
|
||||
@ -1450,27 +1468,15 @@ def gen_aggregated_headers(
|
||||
),
|
||||
},
|
||||
)
|
||||
declarations = get_native_function_declarations(
|
||||
grouped_native_functions=grouped_native_functions,
|
||||
backend_indices=backend_indices,
|
||||
)
|
||||
cpu_fm.write(
|
||||
"NativeFunctions.h",
|
||||
lambda: {
|
||||
"NativeFunctions_includes": ["#include <ATen/NativeMetaFunctions.h>"],
|
||||
"NativeFunctions_declarations": list(
|
||||
concatMap(
|
||||
# Convert to a set first to remove duplicate kernel names.
|
||||
# Backends are allowed to repeat kernel names; only generate the declaration once!
|
||||
lambda f: list(
|
||||
OrderedDict.fromkeys(
|
||||
concatMap(
|
||||
lambda backend_idx: dest.compute_native_function_declaration(
|
||||
f, backend_idx
|
||||
),
|
||||
backend_indices.values(),
|
||||
)
|
||||
)
|
||||
),
|
||||
grouped_native_functions,
|
||||
)
|
||||
),
|
||||
"NativeFunctions_declarations": declarations,
|
||||
},
|
||||
)
|
||||
|
||||
@ -1598,7 +1604,9 @@ def gen_per_operator_headers(
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
declarations = get_native_function_declarations(
|
||||
grouped_native_functions=grouped_functions, backend_indices=backend_indices
|
||||
)
|
||||
ops_fm.write_with_template(
|
||||
f"{name}_native.h",
|
||||
"NativeFunction.h",
|
||||
@ -1606,23 +1614,7 @@ def gen_per_operator_headers(
|
||||
"extra_includes": (
|
||||
f"#include <ATen/ops/{name}_meta.h>" if is_structured else []
|
||||
),
|
||||
"native_function_declarations": list(
|
||||
concatMap(
|
||||
# Convert to a set first to remove duplicate kernel names.
|
||||
# Backends are allowed to repeat kernel names; only generate the declaration once!
|
||||
lambda f: list(
|
||||
OrderedDict.fromkeys(
|
||||
concatMap(
|
||||
lambda backend_idx: dest.compute_native_function_declaration(
|
||||
f, backend_idx
|
||||
),
|
||||
backend_indices.values(),
|
||||
)
|
||||
)
|
||||
),
|
||||
grouped_functions,
|
||||
)
|
||||
),
|
||||
"native_function_declarations": declarations,
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -8,7 +8,6 @@ from typing import List, Dict, Union, Sequence, Optional
|
||||
from torchgen.gen import (
|
||||
get_grouped_native_functions,
|
||||
parse_native_yaml,
|
||||
NamespaceHelper,
|
||||
)
|
||||
from torchgen.model import (
|
||||
BackendIndex,
|
||||
@ -19,7 +18,14 @@ from torchgen.model import (
|
||||
OperatorName,
|
||||
)
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
from torchgen.utils import Target, concatMap, context, YamlLoader, FileManager
|
||||
from torchgen.utils import (
|
||||
Target,
|
||||
concatMap,
|
||||
context,
|
||||
YamlLoader,
|
||||
FileManager,
|
||||
NamespaceHelper,
|
||||
)
|
||||
from torchgen.context import native_function_manager
|
||||
from torchgen.code_template import CodeTemplate
|
||||
import torchgen.dest as dest
|
||||
|
@ -22,7 +22,6 @@ from torchgen.dest.lazy_ir import GenLazyIR, GenTSLazyIR
|
||||
from torchgen.gen import (
|
||||
get_grouped_native_functions,
|
||||
parse_native_yaml,
|
||||
NamespaceHelper,
|
||||
)
|
||||
|
||||
from torchgen.api.lazy import setValueT
|
||||
@ -33,7 +32,7 @@ from torchgen.model import (
|
||||
OperatorName,
|
||||
)
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
from torchgen.utils import concatMap, YamlLoader, FileManager
|
||||
from torchgen.utils import concatMap, YamlLoader, FileManager, NamespaceHelper
|
||||
import torchgen.dest as dest
|
||||
from .gen_backend_stubs import (
|
||||
parse_backend_yaml,
|
||||
|
@ -6,7 +6,7 @@ from dataclasses import dataclass
|
||||
from enum import auto, Enum
|
||||
from typing import Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
|
||||
|
||||
from torchgen.utils import assert_never
|
||||
from torchgen.utils import assert_never, NamespaceHelper
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||
#
|
||||
@ -454,12 +454,11 @@ class NativeFunction:
|
||||
funcs = e.pop("func")
|
||||
assert isinstance(funcs, str), f"not a str: {funcs}"
|
||||
# only support one level of namespace. E.g., aten::add
|
||||
namespaced_funcs = funcs.split("::", 1)
|
||||
if len(namespaced_funcs) == 1:
|
||||
namespace = "aten"
|
||||
else:
|
||||
namespace = namespaced_funcs[0]
|
||||
func = FunctionSchema.parse(namespaced_funcs[-1])
|
||||
namespace_helper = NamespaceHelper.from_namespaced_entity(
|
||||
namespaced_entity=funcs, max_level=1
|
||||
)
|
||||
namespace = namespace_helper.get_cpp_namespace(default="aten")
|
||||
func = FunctionSchema.parse(namespace_helper.entity_name)
|
||||
|
||||
cpp_no_default_args_list = e.pop("cpp_no_default_args", [])
|
||||
assert isinstance(cpp_no_default_args_list, list)
|
||||
@ -579,19 +578,20 @@ class NativeFunction:
|
||||
f"Dispatch key {dispatch_key} of kernel {v} "
|
||||
"is not a supported dispatch key."
|
||||
)
|
||||
# We only allow one level of namespace for kernels and operator.
|
||||
# We only allow at most 2 levels of namespace for kernels.
|
||||
# We will append "native" to a custom kernel namespace.
|
||||
tokens = v.split("::", 1)
|
||||
namespace_helper = NamespaceHelper.from_namespaced_entity(
|
||||
v, max_level=2
|
||||
)
|
||||
kernel_namespace = namespace_helper.get_cpp_namespace(default="at")
|
||||
# Why is 'structured' included? External backends (e.g.
|
||||
# XLA) opt into which ops are structured independently
|
||||
# of which in-tree ops are structured
|
||||
dispatch[dispatch_key] = BackendMetadata(
|
||||
kernel=tokens[-1],
|
||||
kernel=namespace_helper.entity_name,
|
||||
structured=structured
|
||||
and is_structured_dispatch_key(dispatch_key),
|
||||
cpp_namespace=(tokens[0] + "::native")
|
||||
if len(tokens) > 1
|
||||
else DEFAULT_KERNEL_NAMESPACE,
|
||||
cpp_namespace=(kernel_namespace + "::native"),
|
||||
)
|
||||
if (
|
||||
dispatch_key is DispatchKey.CompositeImplicitAutograd
|
||||
|
@ -396,3 +396,66 @@ def _format(
|
||||
indent_str = " " * indent
|
||||
body = f", {delimiter}{curr_indent_str}".join(fields_str)
|
||||
return f"{start}{indent_str}{body}{end}"
|
||||
|
||||
|
||||
class NamespaceHelper:
|
||||
"""A helper for constructing the namespace open and close strings for a nested set of namespaces.
|
||||
|
||||
e.g. for namespace_str torch::lazy,
|
||||
|
||||
prologue:
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
epilogue:
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
"""
|
||||
|
||||
def __init__(self, namespace_str: str, entity_name: str = "", max_level: int = 2):
|
||||
# cpp_namespace can be a colon joined string such as torch::lazy
|
||||
cpp_namespaces = namespace_str.split("::")
|
||||
assert (
|
||||
len(cpp_namespaces) <= max_level
|
||||
), f"Codegen doesn't support more than {max_level} level(s) of custom namespace. Got {namespace_str}."
|
||||
self.cpp_namespace_ = namespace_str
|
||||
self.prologue_ = "\n".join([f"namespace {n} {{" for n in cpp_namespaces])
|
||||
self.epilogue_ = "\n".join(
|
||||
[f"}} // namespace {n}" for n in reversed(cpp_namespaces)]
|
||||
)
|
||||
self.namespaces_ = cpp_namespaces
|
||||
self.entity_name_ = entity_name
|
||||
|
||||
@staticmethod
|
||||
def from_namespaced_entity(
|
||||
namespaced_entity: str, max_level: int = 2
|
||||
) -> "NamespaceHelper":
|
||||
"""
|
||||
Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add"
|
||||
"""
|
||||
names = namespaced_entity.split("::")
|
||||
entity_name = names[-1]
|
||||
namespace_str = "::".join(names[:-1])
|
||||
return NamespaceHelper(
|
||||
namespace_str=namespace_str, entity_name=entity_name, max_level=max_level
|
||||
)
|
||||
|
||||
@property
|
||||
def prologue(self) -> str:
|
||||
return self.prologue_
|
||||
|
||||
@property
|
||||
def epilogue(self) -> str:
|
||||
return self.epilogue_
|
||||
|
||||
@property
|
||||
def entity_name(self) -> str:
|
||||
return self.entity_name_
|
||||
|
||||
# Only allow certain level of namespaces
|
||||
def get_cpp_namespace(self, default: str = "") -> str:
|
||||
"""
|
||||
Return the namespace string from joining all the namespaces by "::" (hence no leading "::").
|
||||
Return default if namespace string is empty.
|
||||
"""
|
||||
return self.cpp_namespace_ if self.cpp_namespace_ else default
|
||||
|
Reference in New Issue
Block a user