mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[BE] Cleanup old ExecuTorch codegen and runtime code (#154165)
Summary: These files are added to pytorch/pytorch before ExecuTorch is opensourced. Now is a good time to remove it from pytorch/pytorch, since the code is moved to pytorch/executorch already. Test Plan: Rely on CI jobs. Differential Revision: [D75985423](https://our.internmc.facebook.com/intern/diff/D75985423) Pull Request resolved: https://github.com/pytorch/pytorch/pull/154165 Approved by: https://github.com/kimishpatel, https://github.com/Skylion007, https://github.com/cyyever
This commit is contained in:
committed by
PyTorch MergeBot
parent
da1f8980df
commit
386aa72003
@ -18,13 +18,3 @@ def define_targets(rules):
|
||||
rules.requirement("typing-extensions"),
|
||||
],
|
||||
)
|
||||
|
||||
rules.py_binary(
|
||||
name = "gen_executorch",
|
||||
srcs = [":torchgen"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
rules.requirement("PyYAML"),
|
||||
rules.requirement("typing-extensions"),
|
||||
],
|
||||
)
|
||||
|
@ -1,151 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torchgen import dest
|
||||
|
||||
|
||||
# disable import sorting to avoid circular dependency.
|
||||
from torchgen.api.types import DispatcherSignature # usort: skip
|
||||
from torchgen.context import method_with_native_function
|
||||
from torchgen.model import BaseTy, BaseType, DispatchKey, NativeFunction, Variant
|
||||
from torchgen.utils import concatMap, Target
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from torchgen.executorch.model import ETKernelIndex
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
|
||||
|
||||
# Generates RegisterKernelStub.cpp, which provides placeholder kernels for custom operators. This will be used at
|
||||
# model authoring side.
|
||||
@dataclass(frozen=True)
|
||||
class ComputeNativeFunctionStub:
|
||||
@method_with_native_function
|
||||
def __call__(self, f: NativeFunction) -> str | None:
|
||||
if Variant.function not in f.variants:
|
||||
return None
|
||||
|
||||
sig = DispatcherSignature.from_schema(
|
||||
f.func, prefix=f"wrapper_CPU_{f.func.name.overload_name}_", symint=False
|
||||
)
|
||||
assert sig is not None
|
||||
if len(f.func.returns) == 0:
|
||||
ret_name = ""
|
||||
elif len(f.func.returns) == 1:
|
||||
if f.func.arguments.out:
|
||||
ret_name = f.func.arguments.out[0].name
|
||||
else:
|
||||
ret_name = next(
|
||||
(
|
||||
a.name
|
||||
for a in f.func.arguments.flat_non_out
|
||||
if a.type == f.func.returns[0].type
|
||||
),
|
||||
"",
|
||||
)
|
||||
if not ret_name:
|
||||
# if return type is tensor
|
||||
if f.func.returns[0].type == BaseType(BaseTy.Tensor):
|
||||
# Returns an empty tensor
|
||||
ret_name = "at::Tensor()"
|
||||
else:
|
||||
raise Exception( # noqa: TRY002
|
||||
f"Can't handle this return type {f.func}"
|
||||
) # noqa: TRY002
|
||||
elif len(f.func.arguments.out) == len(f.func.returns):
|
||||
# Returns a tuple of out arguments
|
||||
tensor_type = "at::Tensor &"
|
||||
comma = ", "
|
||||
ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>(
|
||||
{comma.join([r.name for r in f.func.arguments.out])}
|
||||
)"""
|
||||
else:
|
||||
assert all(a.type == BaseType(BaseTy.Tensor) for a in f.func.returns), (
|
||||
f"Only support tensor returns but got {f.func.returns}"
|
||||
)
|
||||
# Returns a tuple of empty tensors
|
||||
tensor_type = "at::Tensor"
|
||||
comma = ", "
|
||||
ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>(
|
||||
{comma.join(["at::Tensor()" for _ in f.func.returns])}
|
||||
)"""
|
||||
ret_str = f"return {ret_name};" if len(f.func.returns) > 0 else ""
|
||||
return f"""
|
||||
{sig.defn()} {{
|
||||
{ret_str}
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
def gen_custom_ops_registration(
|
||||
*,
|
||||
native_functions: Sequence[NativeFunction],
|
||||
selector: SelectiveBuilder,
|
||||
kernel_index: ETKernelIndex,
|
||||
rocm: bool,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Generate custom ops registration code for dest.RegisterDispatchKey.
|
||||
|
||||
:param native_functions: a sequence of `NativeFunction`
|
||||
:param selector: for selective build.
|
||||
:param kernel_index: kernels for all the ops.
|
||||
:param rocm: bool for dest.RegisterDispatchKey.
|
||||
:return: generated C++ code to register custom operators into PyTorch
|
||||
"""
|
||||
|
||||
# convert kernel index to BackendIndex. This is because we can't handle ETKernelIndex yet.
|
||||
# TODO larryliu: evaluate if this code is still needed. If yes let it handle ETKernelIndex.
|
||||
|
||||
dispatch_key = DispatchKey.CPU
|
||||
backend_index = kernel_index._to_backend_index()
|
||||
static_init_dispatch_registrations = ""
|
||||
ns_grouped_native_functions: dict[str, list[NativeFunction]] = defaultdict(list)
|
||||
for native_function in native_functions:
|
||||
ns_grouped_native_functions[native_function.namespace].append(native_function)
|
||||
|
||||
for namespace, functions in ns_grouped_native_functions.items():
|
||||
if len(functions) == 0:
|
||||
continue
|
||||
dispatch_registrations_body = "\n".join(
|
||||
list(
|
||||
concatMap(
|
||||
dest.RegisterDispatchKey(
|
||||
backend_index,
|
||||
Target.REGISTRATION,
|
||||
selector,
|
||||
rocm=rocm,
|
||||
symint=False,
|
||||
class_method_name=None,
|
||||
skip_dispatcher_op_registration=False,
|
||||
),
|
||||
functions,
|
||||
)
|
||||
)
|
||||
)
|
||||
static_init_dispatch_registrations += f"""
|
||||
TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
|
||||
{dispatch_registrations_body}
|
||||
}}"""
|
||||
anonymous_definition = "\n".join(
|
||||
list(
|
||||
concatMap(
|
||||
dest.RegisterDispatchKey(
|
||||
backend_index,
|
||||
Target.ANONYMOUS_DEFINITION,
|
||||
selector,
|
||||
rocm=rocm,
|
||||
symint=False,
|
||||
class_method_name=None,
|
||||
skip_dispatcher_op_registration=False,
|
||||
),
|
||||
native_functions,
|
||||
)
|
||||
)
|
||||
)
|
||||
return anonymous_definition, static_init_dispatch_registrations
|
@ -1,367 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from torchgen import local
|
||||
from torchgen.api.types import (
|
||||
ArgName,
|
||||
BaseCType,
|
||||
Binding,
|
||||
ConstRefCType,
|
||||
CType,
|
||||
MutRefCType,
|
||||
NamedCType,
|
||||
SpecialArgName,
|
||||
TupleCType,
|
||||
VectorCType,
|
||||
voidT,
|
||||
)
|
||||
from torchgen.executorch.api.types import (
|
||||
ArrayRefCType,
|
||||
BaseTypeToCppMapping,
|
||||
OptionalCType,
|
||||
scalarT,
|
||||
tensorListT,
|
||||
tensorT,
|
||||
)
|
||||
from torchgen.model import (
|
||||
Argument,
|
||||
Arguments,
|
||||
BaseTy,
|
||||
BaseType,
|
||||
ListType,
|
||||
NativeFunction,
|
||||
OptionalType,
|
||||
Return,
|
||||
SelfArgument,
|
||||
TensorOptionsArguments,
|
||||
Type,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
"""
|
||||
This file describes the translation of JIT schema to the public C++ API, which is what people use when they call
|
||||
functions like at::add. It also serves as a native function API, which is the signature of kernels,
|
||||
since in Executorch CppSignature is the same as NativeSignature.
|
||||
|
||||
Difference between this file and torchgen.api.cpp.py:
|
||||
|
||||
- Executorch doesn't support TensorOptions, however in this file we still keep the logic here to be compatible with
|
||||
torchgen.api.cpp, so that we can do stuff like ATen mode (running ATen kernels in Executorch).
|
||||
|
||||
- Executorch doesn't support Dimname.
|
||||
|
||||
- Executorch runtime doesn't support SymInt, will treat it as int.
|
||||
"""
|
||||
|
||||
|
||||
# Translation of "value types" in JIT schema to C++ API type. Value
|
||||
# types look the same no matter if they are argument types or return
|
||||
# types. Returns None if the type in question is not a value type.
|
||||
def valuetype_type(
|
||||
t: Type,
|
||||
*,
|
||||
binds: ArgName,
|
||||
) -> NamedCType | None:
|
||||
if isinstance(t, BaseType):
|
||||
if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
|
||||
return None
|
||||
# For SymInt we simply treat it as int.
|
||||
elif str(t) == "SymInt":
|
||||
return NamedCType(binds, BaseCType(BaseTypeToCppMapping[BaseTy.int]))
|
||||
# All other BaseType currently map directly to BaseCppTypes.
|
||||
return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name]))
|
||||
elif isinstance(t, OptionalType):
|
||||
elem = valuetype_type(t.elem, binds=binds)
|
||||
if elem is None:
|
||||
return None
|
||||
return NamedCType(binds, OptionalCType(elem.type))
|
||||
elif isinstance(t, ListType):
|
||||
if str(t.elem) == "bool":
|
||||
assert t.size is not None
|
||||
return NamedCType(
|
||||
binds, ArrayRefCType(BaseCType(BaseTypeToCppMapping[BaseTy.bool]))
|
||||
)
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
raise AssertionError(f"unrecognized type {repr(t)}")
|
||||
|
||||
|
||||
# Translation of types occurring in JIT arguments to a C++ argument type.
|
||||
# If remove_non_owning_ref_types is set, we'll guarantee that the outputed CType is not a non-owning reference type.
|
||||
# For example, we'll return std::vector<int> instead of IntArrayRef.
|
||||
# See Note [translation from C++ reference to value types]
|
||||
def argumenttype_type(
|
||||
t: Type,
|
||||
*,
|
||||
mutable: bool,
|
||||
binds: ArgName,
|
||||
remove_non_owning_ref_types: bool = False,
|
||||
) -> NamedCType:
|
||||
# If it's a value type, do the value type translation
|
||||
r = valuetype_type(
|
||||
t,
|
||||
binds=binds,
|
||||
)
|
||||
if r is not None:
|
||||
return r
|
||||
if isinstance(t, BaseType):
|
||||
if t.name == BaseTy.Tensor:
|
||||
if mutable and not local.use_const_ref_for_mutable_tensors():
|
||||
return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
|
||||
else:
|
||||
return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
|
||||
elif t.name == BaseTy.Scalar:
|
||||
return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
|
||||
else:
|
||||
raise AssertionError(f"base type should have been value type {t}")
|
||||
elif isinstance(t, OptionalType):
|
||||
if str(t.elem) == "Tensor":
|
||||
if mutable and not local.use_const_ref_for_mutable_tensors():
|
||||
return NamedCType(
|
||||
binds, MutRefCType(BaseCType(tensorT))
|
||||
) # TODO: fix this discrepancy
|
||||
else:
|
||||
return NamedCType(
|
||||
binds, ConstRefCType(OptionalCType(BaseCType(tensorT)))
|
||||
)
|
||||
elif str(t.elem) == "Scalar":
|
||||
return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
|
||||
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
|
||||
return NamedCType(binds, OptionalCType(elem.type))
|
||||
elif isinstance(t, ListType):
|
||||
# TODO: keeping these special cases for Tensor[] and Tensor?[] so that we can hookup with ATen kernels.
|
||||
if str(t.elem) == "Tensor":
|
||||
return NamedCType(binds, BaseCType(tensorListT))
|
||||
elif str(t.elem) == "Dimname":
|
||||
raise NotImplementedError("Executorch doesn't support Dimname")
|
||||
elif str(t.elem) == "Tensor?":
|
||||
return NamedCType(binds, ArrayRefCType(OptionalCType(BaseCType(tensorT))))
|
||||
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
|
||||
return NamedCType(binds, ArrayRefCType(elem.type))
|
||||
else:
|
||||
raise AssertionError(f"unrecognized type {repr(t)}")
|
||||
|
||||
|
||||
# Translate a JIT argument into its C++ type
|
||||
def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
|
||||
return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
|
||||
|
||||
|
||||
# Translation of a (non-multi) return type from JIT to C++
|
||||
# N.B: returntype_type returns a CType, not a NamedCType.
|
||||
# This is mostly because of the mismatch between return types and return names.
|
||||
# e.g. a function with a return type of 'void' has 0 return names,
|
||||
# and a function with a return type of 'std::tuple' has >1 return name.
|
||||
def returntype_type(t: Type, *, mutable: bool) -> CType:
|
||||
# placeholder is ignored
|
||||
r = valuetype_type(t, binds="__placeholder__")
|
||||
if r is not None:
|
||||
return r.type
|
||||
|
||||
if isinstance(t, BaseType):
|
||||
if t.name == BaseTy.Tensor:
|
||||
if mutable:
|
||||
if local.use_const_ref_for_mutable_tensors():
|
||||
return ConstRefCType(BaseCType(tensorT))
|
||||
else:
|
||||
return MutRefCType(BaseCType(tensorT))
|
||||
else:
|
||||
# Note [Tensor Copy Returns]
|
||||
# Currently, we use "Argument.is_write" to determine
|
||||
# whether or not Tensor return types should be copies or references.
|
||||
# If that ever changes, take a look at other locations of this note!
|
||||
return BaseCType(tensorT)
|
||||
elif t.name == BaseTy.Scalar:
|
||||
return BaseCType(scalarT)
|
||||
elif isinstance(t, ListType):
|
||||
assert not mutable, (
|
||||
"Native functions should never return a mutable tensor list. They should return void."
|
||||
)
|
||||
elem = returntype_type(t.elem, mutable=False)
|
||||
assert t.size is None, f"fixed size list returns not supported: {t}"
|
||||
return VectorCType(elem)
|
||||
|
||||
raise AssertionError(f"unrecognized return type {t}")
|
||||
|
||||
|
||||
# Translation of a single return to its C++ type
|
||||
def return_type(r: Return) -> CType:
|
||||
return returntype_type(r.type, mutable=r.is_write)
|
||||
|
||||
|
||||
# Translation of a full (possibly multi) return from JIT to its C++ type
|
||||
def returns_type(rs: Sequence[Return]) -> CType:
|
||||
if len(rs) == 0:
|
||||
return BaseCType(voidT)
|
||||
elif len(rs) == 1:
|
||||
return return_type(rs[0])
|
||||
else:
|
||||
return TupleCType([return_type(r) for r in rs])
|
||||
|
||||
|
||||
def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
|
||||
returns: list[str] = []
|
||||
for i, r in enumerate(f.func.returns):
|
||||
# If we have an inplace function, the return argument is
|
||||
# implicitly named self.
|
||||
# TODO: Consider incorporating this into the data model
|
||||
if f.func.name.name.inplace:
|
||||
assert i == 0, "illegal inplace function with multiple returns"
|
||||
name = "self"
|
||||
# If we are out function, the name is the name of the
|
||||
# corresponding output function (r.name will get recorded
|
||||
# in field_name later.)
|
||||
elif f.func.is_out_fn():
|
||||
name = f.func.arguments.out[i].name
|
||||
# If the return argument is explicitly named...
|
||||
elif r.name:
|
||||
name_conflict = any(
|
||||
r.name == a.name for a in f.func.schema_order_arguments()
|
||||
)
|
||||
if name_conflict and not f.func.is_out_fn():
|
||||
name = f"{r.name}_return"
|
||||
else:
|
||||
name = r.name
|
||||
# If there is no explicit name and no fallback name was passed in, we just name the output result,
|
||||
# unless it's a multi-return, in which case it's result0,
|
||||
# result1, etc (zero-indexed)
|
||||
else:
|
||||
name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}"
|
||||
returns.append(name)
|
||||
return returns
|
||||
|
||||
|
||||
JIT_TO_CPP_DEFAULT = {
|
||||
"False": "false",
|
||||
"True": "true",
|
||||
"None": "torch::execustd::nullopt", # UGH this one is type directed
|
||||
"[]": "{}",
|
||||
"contiguous_format": "torch::executorch::MemoryFormat::Contiguous",
|
||||
"long": "torch::executorch::kLong",
|
||||
}
|
||||
|
||||
|
||||
# Convert a JIT default into C++ expression representing the default
|
||||
def default_expr(d: str, t: Type) -> str:
|
||||
if d == "None" and str(t) == "Tensor?":
|
||||
return "{}"
|
||||
if isinstance(t, BaseType) and t.name is BaseTy.str:
|
||||
# Schema allows single quotes but C++ needs double
|
||||
if len(d) >= 2 and d[0] == "'" and d[-1] == "'":
|
||||
s = ""
|
||||
i = 1
|
||||
while i + 1 < len(d):
|
||||
if d[i] != "\\":
|
||||
if d[i] == '"':
|
||||
s += '\\"'
|
||||
else:
|
||||
s += d[i]
|
||||
i += 1
|
||||
else:
|
||||
if d[i + 1] == "'":
|
||||
s += "'"
|
||||
else:
|
||||
s += d[i : i + 2]
|
||||
i += 2
|
||||
|
||||
return f'"{s}"'
|
||||
|
||||
if isinstance(t, OptionalType):
|
||||
if d == "None":
|
||||
return "torch::executor::nullopt"
|
||||
|
||||
return default_expr(d, t.elem)
|
||||
|
||||
if isinstance(t, ListType):
|
||||
if d.startswith("[") and d.endswith("]"):
|
||||
return "{" + d[1:-1] + "}"
|
||||
elif t.size is None:
|
||||
# NOTE: Sized lists can have scalar defaults
|
||||
raise ValueError(f"Expected a list default '[...]' but found: '{d}'")
|
||||
|
||||
return JIT_TO_CPP_DEFAULT.get(d, d)
|
||||
|
||||
|
||||
# Convert an argument into its C++ API form
|
||||
|
||||
|
||||
def argument(
|
||||
a: Argument | TensorOptionsArguments | SelfArgument,
|
||||
*,
|
||||
cpp_no_default_args: set[str],
|
||||
method: bool,
|
||||
faithful: bool,
|
||||
has_tensor_options: bool,
|
||||
) -> list[Binding]:
|
||||
def sub_argument(
|
||||
a: Argument | TensorOptionsArguments | SelfArgument,
|
||||
) -> list[Binding]:
|
||||
return argument(
|
||||
a,
|
||||
cpp_no_default_args=cpp_no_default_args,
|
||||
method=method,
|
||||
faithful=faithful,
|
||||
has_tensor_options=has_tensor_options,
|
||||
)
|
||||
|
||||
if isinstance(a, Argument):
|
||||
binds: ArgName
|
||||
if a.name == "memory_format" and has_tensor_options:
|
||||
binds = SpecialArgName.possibly_redundant_memory_format
|
||||
else:
|
||||
binds = a.name
|
||||
default: str | None = None
|
||||
if a.name not in cpp_no_default_args and a.default is not None:
|
||||
default = default_expr(a.default, a.type)
|
||||
return [
|
||||
Binding(
|
||||
nctype=argument_type(a, binds=binds),
|
||||
name=a.name,
|
||||
default=default,
|
||||
argument=a,
|
||||
)
|
||||
]
|
||||
elif isinstance(a, TensorOptionsArguments):
|
||||
raise NotImplementedError("Need to implement type resolution for TensorOptions")
|
||||
elif isinstance(a, SelfArgument):
|
||||
if method:
|
||||
# Caller is responsible for installing implicit this in context!
|
||||
return []
|
||||
else:
|
||||
return sub_argument(a.argument)
|
||||
else:
|
||||
assert_never(a)
|
||||
|
||||
|
||||
def arguments(
|
||||
arguments: Arguments,
|
||||
*,
|
||||
faithful: bool,
|
||||
method: bool,
|
||||
cpp_no_default_args: set[str],
|
||||
) -> list[Binding]:
|
||||
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
|
||||
if faithful:
|
||||
args.extend(arguments.non_out)
|
||||
args.extend(arguments.out)
|
||||
else:
|
||||
args.extend(arguments.out)
|
||||
args.extend(arguments.non_out)
|
||||
return [
|
||||
r.no_default() if faithful else r
|
||||
for a in args
|
||||
for r in argument(
|
||||
a,
|
||||
faithful=faithful,
|
||||
method=method,
|
||||
has_tensor_options=arguments.tensor_options is not None,
|
||||
cpp_no_default_args=cpp_no_default_args,
|
||||
)
|
||||
]
|
@ -1,4 +0,0 @@
|
||||
from torchgen.executorch.api.types.types import *
|
||||
|
||||
|
||||
from torchgen.executorch.api.types.signatures import * # usort: skip
|
@ -1,76 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torchgen.api.cpp as aten_cpp
|
||||
from torchgen.executorch.api.types.types import contextArg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torchgen.api.types import Binding, CType
|
||||
from torchgen.model import FunctionSchema, NativeFunction
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExecutorchCppSignature:
|
||||
"""
|
||||
This signature is merely a CppSignature with Executorch types (optionally
|
||||
contains KernelRuntimeContext as well). The inline definition of
|
||||
CppSignature is generated in Functions.h and it's used by unboxing
|
||||
functions.
|
||||
"""
|
||||
|
||||
# The schema this signature is derived from
|
||||
func: FunctionSchema
|
||||
|
||||
# The set of C++ arguments which should not have defaults applied to them
|
||||
cpp_no_default_args: set[str]
|
||||
|
||||
# Allows you to prepend an arbitrary prefix to the signature name.
|
||||
# This is useful for parts of the codegen that generate wrappers around kernels,
|
||||
# and need to avoid naming collisions.
|
||||
prefix: str = ""
|
||||
|
||||
def arguments(self, *, include_context: bool = True) -> list[Binding]:
|
||||
return ([contextArg] if include_context else []) + et_cpp.arguments(
|
||||
self.func.arguments,
|
||||
faithful=True, # always faithful, out argument at the end
|
||||
method=False, # method not supported
|
||||
cpp_no_default_args=self.cpp_no_default_args,
|
||||
)
|
||||
|
||||
def name(self) -> str:
|
||||
return self.prefix + aten_cpp.name(
|
||||
self.func,
|
||||
faithful_name_for_out_overloads=True,
|
||||
)
|
||||
|
||||
def decl(self, name: str | None = None, *, include_context: bool = True) -> str:
|
||||
args_str = ", ".join(
|
||||
a.decl() for a in self.arguments(include_context=include_context)
|
||||
)
|
||||
if name is None:
|
||||
name = self.name()
|
||||
return f"{self.returns_type().cpp_type()} {name}({args_str})"
|
||||
|
||||
def defn(self, name: str | None = None) -> str:
|
||||
args = [a.defn() for a in self.arguments()]
|
||||
args_str = ", ".join(args)
|
||||
if name is None:
|
||||
name = self.name()
|
||||
return f"{self.returns_type().cpp_type()} {name}({args_str})"
|
||||
|
||||
def returns_type(self) -> CType:
|
||||
return et_cpp.returns_type(self.func.returns)
|
||||
|
||||
@staticmethod
|
||||
def from_native_function(
|
||||
f: NativeFunction, *, prefix: str = ""
|
||||
) -> ExecutorchCppSignature:
|
||||
return ExecutorchCppSignature(
|
||||
func=f.func, prefix=prefix, cpp_no_default_args=f.cpp_no_default_args
|
||||
)
|
||||
|
||||
|
||||
from torchgen.executorch.api import et_cpp
|
@ -1,77 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from torchgen.api.types import (
|
||||
BaseCppType,
|
||||
BaseCType,
|
||||
Binding,
|
||||
boolT,
|
||||
CType,
|
||||
doubleT,
|
||||
Expr,
|
||||
longT,
|
||||
MutRefCType,
|
||||
NamedCType,
|
||||
)
|
||||
from torchgen.model import BaseTy
|
||||
|
||||
|
||||
halfT = BaseCppType("torch::executor", "Half")
|
||||
bfloat16T = BaseCppType("torch::executor", "BFloat16")
|
||||
stringT = BaseCppType("torch::executor", "string_view")
|
||||
scalarTypeT = BaseCppType("torch::executor", "ScalarType")
|
||||
tensorT = BaseCppType("torch::executor", "Tensor")
|
||||
tensorListT = BaseCppType("torch::executor", "TensorList")
|
||||
scalarT = BaseCppType("torch::executor", "Scalar")
|
||||
memoryFormatT = BaseCppType("torch::executor", "MemoryFormat")
|
||||
intArrayRefT = BaseCppType("torch::executor", "IntArrayRef")
|
||||
optionalT = BaseCppType("torch::executor", "optional")
|
||||
contextT = BaseCppType("torch::executor", "KernelRuntimeContext")
|
||||
|
||||
contextExpr = Expr(
|
||||
expr="context",
|
||||
type=NamedCType(name="context", type=MutRefCType(BaseCType(contextT))),
|
||||
)
|
||||
|
||||
contextArg = Binding(
|
||||
name="context",
|
||||
nctype=contextExpr.type,
|
||||
argument=None, # type: ignore[arg-type]
|
||||
default=None,
|
||||
)
|
||||
|
||||
BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = {
|
||||
BaseTy.int: longT,
|
||||
BaseTy.float: doubleT,
|
||||
BaseTy.bool: boolT,
|
||||
BaseTy.str: stringT,
|
||||
BaseTy.ScalarType: scalarTypeT,
|
||||
BaseTy.Tensor: tensorT,
|
||||
BaseTy.Scalar: scalarT,
|
||||
BaseTy.MemoryFormat: memoryFormatT,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OptionalCType(CType):
|
||||
elem: CType
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
# Do not pass `strip_ref` recursively.
|
||||
return f"torch::executor::optional<{self.elem.cpp_type()}>"
|
||||
|
||||
def remove_const_ref(self) -> CType:
|
||||
return OptionalCType(self.elem.remove_const_ref())
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ArrayRefCType(CType):
|
||||
elem: CType
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
# Do not pass `strip_ref` recursively.
|
||||
return f"torch::executor::ArrayRef<{self.elem.cpp_type()}>"
|
||||
|
||||
def remove_const_ref(self) -> CType:
|
||||
return ArrayRefCType(self.elem.remove_const_ref())
|
@ -1,218 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
|
||||
from torchgen.model import (
|
||||
Argument,
|
||||
BaseTy,
|
||||
BaseType,
|
||||
ListType,
|
||||
NativeFunction,
|
||||
OptionalType,
|
||||
Type,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from torchgen.api.types import Binding, CType, NamedCType
|
||||
|
||||
|
||||
connector = "\n\t"
|
||||
|
||||
|
||||
# Return unboxing function name for a NativeFunction
|
||||
def name(f: NativeFunction) -> str:
|
||||
return f.func.name.unambiguous_name()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Unboxing:
|
||||
"""
|
||||
Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing.
|
||||
A sample generated code:
|
||||
// aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
void mul_out(EValue** stack) {
|
||||
EValue& self = *stack[0];
|
||||
EValue& other = *stack[1];
|
||||
EValue& out = *stack[2];
|
||||
const torch::executor::Tensor & self_base = self.to<torch::executor::Tensor>();
|
||||
const torch::executor::Tensor & other_base = other.to<torch::executor::Tensor>();
|
||||
torch::executor::Tensor & out_base = out.to<torch::executor::Tensor>();
|
||||
|
||||
EXECUTORCH_SCOPE_PROF("native_call_mul.out");
|
||||
torch::executor::mul_outf(self_base, other_base, out_base);
|
||||
|
||||
|
||||
}
|
||||
"""
|
||||
|
||||
# this is a callable that converts a JIT argument, into its C++ type.
|
||||
# Translates (type, mutability, binds) to NamedCType. E.g., torchgen.api.cpp.argumenttype_type.
|
||||
argument_type_gen: Callable[
|
||||
...,
|
||||
NamedCType,
|
||||
]
|
||||
|
||||
# Convert all the arguments in a NativeFunction to C++ code
|
||||
def convert_arguments(
|
||||
self, args: Sequence[Binding]
|
||||
) -> tuple[list[Binding], list[str]]:
|
||||
code_list = [f"EValue& {args[i].name} = *stack[{i}];" for i in range(len(args))]
|
||||
binding_list = []
|
||||
for arg in args:
|
||||
# expecting only Argument
|
||||
if not isinstance(arg.argument, Argument):
|
||||
raise Exception( # noqa: TRY002
|
||||
f"Unexpected argument type, expecting `Argument` but got {arg}"
|
||||
)
|
||||
argument: Argument = arg.argument
|
||||
unboxed_name, _, code, decl = self.argumenttype_evalue_convert(
|
||||
argument.type, argument.name, mutable=argument.is_write
|
||||
)
|
||||
code_list.extend(decl)
|
||||
code_list.extend(code)
|
||||
binding_list.append(arg.with_name(unboxed_name))
|
||||
return binding_list, code_list
|
||||
|
||||
def argumenttype_evalue_convert(
|
||||
self, t: Type, arg_name: str, *, mutable: bool = False
|
||||
) -> tuple[str, CType, list[str], list[str]]:
|
||||
"""
|
||||
Takes in the type, name and mutability corresponding to an argument, and generates a tuple of:
|
||||
(1) the C++ code necessary to unbox the argument
|
||||
(2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType
|
||||
:param t: a `Type` of an argument
|
||||
:param arg_name: argument name
|
||||
:param mutable: boolean for whether this argument type is mutable
|
||||
:return: unboxed result
|
||||
"""
|
||||
ctype = self.argument_type_gen(t, mutable=mutable, binds=arg_name).type
|
||||
|
||||
if isinstance(t, BaseType):
|
||||
out_name = f"{arg_name}_base"
|
||||
code, decl = self._gen_code_base_type(
|
||||
arg_name=arg_name, out_name=out_name, ctype=ctype
|
||||
)
|
||||
elif isinstance(t, OptionalType):
|
||||
out_name = f"{arg_name}_opt_out"
|
||||
code, decl = self._gen_code_optional_type(
|
||||
arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
|
||||
)
|
||||
elif isinstance(t, ListType):
|
||||
out_name = f"{arg_name}_list_out"
|
||||
code, decl = self._gen_code_list_type(
|
||||
arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
|
||||
)
|
||||
else:
|
||||
raise Exception( # noqa: TRY002
|
||||
f"Cannot handle type {t}. arg_name: {arg_name}"
|
||||
) # noqa: TRY002
|
||||
return out_name, ctype, code, decl
|
||||
|
||||
def _gen_code_base_type(
|
||||
self, arg_name: str, out_name: str, ctype: CType
|
||||
) -> tuple[list[str], list[str]]:
|
||||
return [
|
||||
f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
|
||||
], []
|
||||
|
||||
def _gen_code_optional_type(
|
||||
self, arg_name: str, out_name: str, t: OptionalType, ctype: CType
|
||||
) -> tuple[list[str], list[str]]:
|
||||
in_name = f"{arg_name}_opt_in"
|
||||
res_name, base_type, res_code, decl = self.argumenttype_evalue_convert(
|
||||
t.elem, in_name
|
||||
)
|
||||
return (
|
||||
f"""
|
||||
auto {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>();
|
||||
""".split("\n"),
|
||||
decl,
|
||||
)
|
||||
|
||||
def _gen_code_list_type(
|
||||
self, arg_name: str, out_name: str, t: ListType, ctype: CType
|
||||
) -> tuple[list[str], list[str]]:
|
||||
in_name = f"{arg_name}_list_in"
|
||||
elem_name = f"{arg_name}_elem"
|
||||
code = []
|
||||
res_name, res_ctype, res_code, decl = self.argumenttype_evalue_convert(
|
||||
t.elem, elem_name
|
||||
)
|
||||
|
||||
if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor:
|
||||
code.extend(
|
||||
f"""
|
||||
auto {out_name} = {arg_name}.toTensorList();
|
||||
""".split("\n")
|
||||
)
|
||||
elif isinstance(t.elem, BaseType) and (
|
||||
t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt
|
||||
):
|
||||
code.extend(
|
||||
f"""
|
||||
auto {out_name} = {arg_name}.toIntList();
|
||||
""".split("\n")
|
||||
)
|
||||
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float:
|
||||
code.extend(
|
||||
f"""
|
||||
auto {out_name} = {arg_name}.toDoubleList();
|
||||
""".split("\n")
|
||||
)
|
||||
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool:
|
||||
# handle list type with size, e.g., bool[4]
|
||||
code.extend(
|
||||
f"""
|
||||
#ifdef USE_ATEN_LIB
|
||||
std::array<bool, {t.size}> {out_name};
|
||||
auto {in_name} = {arg_name}.toBoolList();
|
||||
size_t _i = 0;
|
||||
for (auto {elem_name}: {in_name}) {{
|
||||
{out_name}[_i++] = {elem_name};
|
||||
}}
|
||||
#else
|
||||
auto {out_name} = {arg_name}.toBoolList();
|
||||
#endif
|
||||
""".split("\n")
|
||||
)
|
||||
# pytorch codegen:
|
||||
# we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional<at::Tensor>>
|
||||
elif (
|
||||
isinstance(t.elem, OptionalType)
|
||||
and isinstance(t.elem.elem, BaseType)
|
||||
and t.elem.elem.name == BaseTy.Tensor
|
||||
):
|
||||
code.extend(
|
||||
f"""
|
||||
#ifdef USE_ATEN_LIB
|
||||
auto {in_name} = {arg_name}.toListOptionalTensor();
|
||||
c10::List<::std::optional<at::Tensor>> {out_name};
|
||||
for (auto {elem_name}: {in_name}) {{
|
||||
{out_name}.push_back({elem_name});
|
||||
}}
|
||||
#else
|
||||
auto {out_name} = {arg_name}.toListOptionalTensor();
|
||||
#endif
|
||||
""".split("\n")
|
||||
)
|
||||
else:
|
||||
# use ArrayRef as default.
|
||||
vec_name = arg_name + "_vec"
|
||||
# need to bring vector instantiation out of scope so that ArrayRef has valid data
|
||||
decl.append(
|
||||
f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};"
|
||||
)
|
||||
code.extend(
|
||||
f"""
|
||||
for (EValue {elem_name}: {in_name}) {{
|
||||
{connector.join(res_code)}
|
||||
{vec_name}.push_back({res_name});
|
||||
}}
|
||||
{ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
|
||||
""".split("\n")
|
||||
)
|
||||
return code, decl
|
@ -1,220 +0,0 @@
|
||||
# Represents all kernels used by an Executorch model.
|
||||
# It maintains a dict[OperatorName, dict[ETKernelKey, BackendMetadata]] structure.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
from collections import defaultdict, namedtuple
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from torchgen.model import (
|
||||
BackendIndex,
|
||||
BackendMetadata,
|
||||
DispatchKey,
|
||||
NativeFunction,
|
||||
NativeFunctionsGroup,
|
||||
OperatorName,
|
||||
)
|
||||
|
||||
|
||||
KERNEL_KEY_VERSION = 1
|
||||
|
||||
|
||||
# TODO: Duplicated Subset from codegen.tool.gen_oplist, remove declaration in codegen
|
||||
class ScalarType(IntEnum):
|
||||
Byte = 0
|
||||
Char = 1
|
||||
Short = 2
|
||||
Int = 3
|
||||
Long = 4
|
||||
Float = 6
|
||||
Double = 7
|
||||
Bool = 11
|
||||
|
||||
|
||||
ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "kernel_index"])
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ETKernelKeyOpArgMeta:
|
||||
arg_name: str
|
||||
dtype: str
|
||||
# The order of the dimensions if entry is a Tensor
|
||||
dim_order: tuple[int, ...]
|
||||
|
||||
def to_native_string(self) -> str:
|
||||
dtype_str = ScalarType[self.dtype].value
|
||||
dim_str = str(self.dim_order)[1:-1].replace(" ", "")
|
||||
return f"{dtype_str};{dim_str}"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ETKernelKey:
|
||||
# Field undefined is default = True
|
||||
arg_meta: tuple[ETKernelKeyOpArgMeta, ...] = ()
|
||||
|
||||
# Indicator for this kernel being used as a catch all
|
||||
default: bool = False
|
||||
|
||||
version: int = KERNEL_KEY_VERSION
|
||||
|
||||
@staticmethod
|
||||
def gen_from_yaml(
|
||||
args: dict[str, tuple[str, str]],
|
||||
type_alias_map: dict[str, list[str]], # TODO: Support unwrapped str val
|
||||
dim_order_alias_map: dict[str, list[int]],
|
||||
) -> list[ETKernelKey]:
|
||||
"""Generate ETKernelKeys from arg kernel specs
|
||||
Multiple ETKernelKeys are returned due to dtype permutations from utilizing
|
||||
type_alias_map (actualizing each potential type permutation as a KernelKey)
|
||||
|
||||
Args:
|
||||
args: Mapping from argument name to kernel specs
|
||||
Kernel specs are a tuple of (dtype, dim_order).
|
||||
Currently tuple entries must be aliased via the alias map arguments
|
||||
type_alias_map: Mapping from type alias to potential type enums
|
||||
i.e { T0 : [Double, Int] } means T0 can be either Double or Int
|
||||
Used for lookup by args
|
||||
dim_order_alias_map: Mapping from alias to a list of dimension orders
|
||||
Used for lookup by args
|
||||
"""
|
||||
# Cast to dim order to int
|
||||
dim_order_alias_map = {
|
||||
k: [int(alias) for alias in v] for k, v in dim_order_alias_map.items()
|
||||
}
|
||||
kernel_keys = []
|
||||
|
||||
# Get all used Dtype Alias
|
||||
dtype_alias_used = set()
|
||||
for type_alias, dim_order in args.values():
|
||||
# Enforce usage of alias initially
|
||||
# TODO: Support inlined arguments
|
||||
assert type_alias in type_alias_map, "Undefined type alias: " + str(
|
||||
type_alias
|
||||
)
|
||||
assert dim_order in dim_order_alias_map, (
|
||||
f"Undefined dim_order alias: {dim_order}"
|
||||
)
|
||||
dtype_alias_used.add(type_alias)
|
||||
|
||||
# Generate all permutations of dtype alias values
|
||||
alias_dtypes = [
|
||||
[(alias, dtype) for dtype in type_alias_map[alias]]
|
||||
for alias in dtype_alias_used
|
||||
]
|
||||
alias_permutations = [
|
||||
dict(permutation) for permutation in list(itertools.product(*alias_dtypes))
|
||||
]
|
||||
|
||||
# Using each alias value permutation, generate kernel keys
|
||||
op_arg_cache = {}
|
||||
for permutation in alias_permutations:
|
||||
arg_list = []
|
||||
for arg_name, arg_spec in args.items():
|
||||
dtype = permutation[arg_spec[0]]
|
||||
dim_order = dim_order_alias_map[arg_spec[1]] # type: ignore[assignment]
|
||||
if (
|
||||
cache_key := (arg_name, dtype, tuple(dim_order))
|
||||
) not in op_arg_cache:
|
||||
op_arg_cache[cache_key] = ETKernelKeyOpArgMeta(*cache_key) # type: ignore[arg-type]
|
||||
|
||||
arg_list.append(op_arg_cache[cache_key])
|
||||
kernel_keys.append(ETKernelKey(tuple(arg_list)))
|
||||
|
||||
return kernel_keys
|
||||
|
||||
def to_native_string(self) -> str:
|
||||
if self.default:
|
||||
return "default"
|
||||
return (
|
||||
"v"
|
||||
+ str(KERNEL_KEY_VERSION)
|
||||
+ "/"
|
||||
+ "|".join([arg.to_native_string() for arg in self.arg_meta])
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ETKernelIndex:
|
||||
index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]]
|
||||
|
||||
def has_kernels(self, g: NativeFunction | NativeFunctionsGroup) -> bool:
|
||||
m = self.get_kernels(g)
|
||||
return m is not None
|
||||
|
||||
def get_kernels(
|
||||
self, g: NativeFunction | NativeFunctionsGroup
|
||||
) -> dict[ETKernelKey, BackendMetadata]:
|
||||
if isinstance(g, NativeFunction):
|
||||
f = g
|
||||
elif isinstance(g, NativeFunctionsGroup):
|
||||
f = g.functional
|
||||
else:
|
||||
assert_never(g)
|
||||
if f.func.name not in self.index:
|
||||
return {}
|
||||
return self.index[f.func.name]
|
||||
|
||||
@staticmethod
|
||||
def grow_from_backend_indices(
|
||||
kernel_index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]],
|
||||
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
|
||||
) -> None:
|
||||
for dk in backend_indices:
|
||||
index = backend_indices[dk]
|
||||
for op, backend_metadata in index.items():
|
||||
if op in kernel_index:
|
||||
kernel_index[op][ETKernelKey(default=True)] = backend_metadata
|
||||
else:
|
||||
kernel_index[op] = {ETKernelKey(default=True): backend_metadata}
|
||||
|
||||
@staticmethod
|
||||
def from_backend_indices(
|
||||
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
|
||||
) -> ETKernelIndex:
|
||||
kernel_index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]] = (
|
||||
defaultdict(dict)
|
||||
)
|
||||
ETKernelIndex.grow_from_backend_indices(kernel_index, backend_indices)
|
||||
return ETKernelIndex(kernel_index)
|
||||
|
||||
def grow(
|
||||
self, backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]]
|
||||
) -> ETKernelIndex:
|
||||
ETKernelIndex.grow_from_backend_indices(self.index, backend_indices)
|
||||
return self
|
||||
|
||||
def _to_backend_index(self) -> BackendIndex:
|
||||
"""
|
||||
WARNING: this will be deprecated once all the codegen places know how to handle ETKernelIndex.
|
||||
"""
|
||||
index: dict[OperatorName, BackendMetadata] = {}
|
||||
for op in self.index:
|
||||
kernel_dict = self.index[op]
|
||||
assert len(kernel_dict.values()) == 1, (
|
||||
f"Can't convert ETKernelIndex to BackendIndex because {op} has more than one kernels. Got {kernel_dict}"
|
||||
)
|
||||
index[op] = kernel_dict.get(
|
||||
ETKernelKey(default=True),
|
||||
BackendMetadata(kernel="", structured=False, cpp_namespace=""),
|
||||
)
|
||||
return BackendIndex(
|
||||
dispatch_key=DispatchKey.CPU,
|
||||
use_out_as_primary=False,
|
||||
device_guard=False,
|
||||
external=False,
|
||||
index=index,
|
||||
)
|
||||
|
||||
# Note duplicate ETKernelKey from index_b will clobber the metadata from index_a
|
||||
@staticmethod
|
||||
def merge_indices(index_a: ETKernelIndex, index_b: ETKernelIndex) -> ETKernelIndex:
|
||||
combined = defaultdict(dict, index_a.index.copy())
|
||||
|
||||
for op, entry in index_b.index.items():
|
||||
for key, metadata in entry.items():
|
||||
combined[op][key] = metadata
|
||||
|
||||
return ETKernelIndex(combined)
|
@ -1,153 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict, namedtuple
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
from torchgen.executorch.model import ETKernelIndex, ETKernelKey
|
||||
from torchgen.gen import LineLoader, parse_native_yaml
|
||||
from torchgen.model import (
|
||||
BackendMetadata,
|
||||
DispatchKey,
|
||||
FunctionSchema,
|
||||
NativeFunction,
|
||||
OperatorName,
|
||||
)
|
||||
from torchgen.utils import NamespaceHelper
|
||||
|
||||
|
||||
# Parse native_functions.yaml into a sequence of NativeFunctions and ET Backend Indices.
|
||||
ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "et_kernel_indices"])
|
||||
|
||||
# Fields in native_functions.yaml used to determine which kernels should be used
|
||||
ET_FIELDS = ["kernels", "type_alias", "dim_order_alias"]
|
||||
|
||||
|
||||
def parse_from_yaml(ei: dict[str, object]) -> dict[ETKernelKey, BackendMetadata]:
|
||||
"""Given a loaded yaml representing kernel assignment information, extract the
|
||||
mapping from `kernel keys` to `BackendMetadata` (the latter representing the kernel instance)
|
||||
|
||||
Args:
|
||||
ei: Dict keys {kernels, type_alias, dim_order_alias}
|
||||
See ETKernelKey for description of arguments
|
||||
"""
|
||||
e = ei.copy()
|
||||
if (kernels := e.pop("kernels", None)) is None:
|
||||
return {}
|
||||
|
||||
type_alias: dict[str, list[str]] = e.pop("type_alias", {}) # type: ignore[assignment]
|
||||
dim_order_alias: dict[str, list[str]] = e.pop("dim_order_alias", {}) # type: ignore[assignment]
|
||||
dim_order_alias.pop("__line__", None)
|
||||
|
||||
kernel_mapping: dict[ETKernelKey, BackendMetadata] = {}
|
||||
|
||||
for entry in kernels: # type: ignore[attr-defined]
|
||||
arg_meta = entry.get("arg_meta")
|
||||
if arg_meta is not None:
|
||||
arg_meta.pop("__line__")
|
||||
|
||||
kernel_name = entry.get("kernel_name")
|
||||
namespace_helper = NamespaceHelper.from_namespaced_entity(
|
||||
kernel_name, max_level=3
|
||||
)
|
||||
kernel_namespace = namespace_helper.get_cpp_namespace(default="at")
|
||||
backend_metadata = BackendMetadata(
|
||||
kernel=namespace_helper.entity_name,
|
||||
structured=False,
|
||||
cpp_namespace=(kernel_namespace + "::native"),
|
||||
)
|
||||
|
||||
kernel_keys = (
|
||||
[ETKernelKey((), default=True)]
|
||||
if arg_meta is None
|
||||
else ETKernelKey.gen_from_yaml(arg_meta, type_alias, dim_order_alias) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
for kernel_key in kernel_keys:
|
||||
assert kernel_key not in kernel_mapping, (
|
||||
"Duplicate kernel key: " + str(kernel_key) + " " + str(e)
|
||||
)
|
||||
kernel_mapping[kernel_key] = backend_metadata
|
||||
|
||||
return kernel_mapping
|
||||
|
||||
|
||||
def parse_et_yaml_struct(es: object) -> ETKernelIndex:
|
||||
"""Given a loaded yaml representing a list of operators, for each op extract the mapping
|
||||
of `kernel keys` to `BackendMetadata` (the latter representing the kernel instance
|
||||
that should be used by the kernel key).
|
||||
"""
|
||||
indices: dict[OperatorName, dict[ETKernelKey, BackendMetadata]] = {}
|
||||
for ei in es: # type: ignore[attr-defined]
|
||||
e = ei.copy()
|
||||
|
||||
funcs = e.pop("func")
|
||||
assert isinstance(funcs, str), f"not a str: {funcs}"
|
||||
namespace_helper = NamespaceHelper.from_namespaced_entity(
|
||||
namespaced_entity=funcs, max_level=1
|
||||
)
|
||||
opname = FunctionSchema.parse(namespace_helper.entity_name).name
|
||||
|
||||
assert opname not in indices, f"Duplicate func found in yaml: {opname} already"
|
||||
|
||||
if len(index := parse_from_yaml(e)) != 0:
|
||||
indices[opname] = index
|
||||
|
||||
return ETKernelIndex(indices)
|
||||
|
||||
|
||||
def extract_kernel_fields(es: object) -> dict[OperatorName, dict[str, Any]]:
|
||||
"""Given a loaded yaml representing a list of operators, extract the
|
||||
kernel key related fields indexed by the operator name.
|
||||
"""
|
||||
fields: dict[OperatorName, dict[str, Any]] = defaultdict(dict)
|
||||
for ei in es: # type: ignore[attr-defined]
|
||||
funcs = ei.get("func")
|
||||
assert isinstance(funcs, str), f"not a str: {funcs}"
|
||||
namespace_helper = NamespaceHelper.from_namespaced_entity(
|
||||
namespaced_entity=funcs, max_level=1
|
||||
)
|
||||
opname = FunctionSchema.parse(namespace_helper.entity_name).name
|
||||
|
||||
for field in ET_FIELDS:
|
||||
if (value := ei.get(field)) is not None:
|
||||
fields[opname][field] = value
|
||||
|
||||
return fields
|
||||
|
||||
|
||||
def parse_et_yaml(
|
||||
path: str,
|
||||
tags_yaml_path: str,
|
||||
ignore_keys: set[DispatchKey] | None = None,
|
||||
skip_native_fns_gen: bool = False,
|
||||
) -> tuple[list[NativeFunction], dict[OperatorName, dict[str, Any]]]:
|
||||
"""Parse native_functions.yaml into NativeFunctions and an Operator Indexed Dict
|
||||
of fields to persist from native_functions.yaml to functions.yaml
|
||||
"""
|
||||
with open(path) as f:
|
||||
es = yaml.load(f, Loader=LineLoader)
|
||||
|
||||
et_kernel = extract_kernel_fields(es)
|
||||
|
||||
# Remove ET specific fields from entries for BC compatibility
|
||||
strip_et_fields(es)
|
||||
|
||||
native_yaml = parse_native_yaml(
|
||||
path,
|
||||
tags_yaml_path,
|
||||
ignore_keys,
|
||||
skip_native_fns_gen=skip_native_fns_gen,
|
||||
loaded_yaml=es,
|
||||
)
|
||||
return native_yaml.native_functions, et_kernel
|
||||
|
||||
|
||||
def strip_et_fields(es: object) -> None:
|
||||
"""Given a loaded yaml representing a list of operators,
|
||||
remove ET specific fields from every entries for BC compatibility
|
||||
"""
|
||||
for entry in es: # type: ignore[attr-defined]
|
||||
for field in ET_FIELDS:
|
||||
entry.pop(field, None)
|
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user