mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit ffb979032dc149b4c895526fe5b92d713ed7b1e1. Reverted https://github.com/pytorch/pytorch/pull/140225 on behalf of https://github.com/kit1980 due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/140225#issuecomment-2469312229))
150 lines
5.3 KiB
Python
150 lines
5.3 KiB
Python
from __future__ import annotations
|
|
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from typing import Sequence, TYPE_CHECKING
|
|
|
|
from torchgen import dest
|
|
|
|
|
|
# disable import sorting to avoid circular dependency.
|
|
from torchgen.api.types import DispatcherSignature # usort: skip
|
|
from torchgen.context import method_with_native_function
|
|
from torchgen.model import BaseTy, BaseType, DispatchKey, NativeFunction, Variant
|
|
from torchgen.utils import concatMap, Target
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from torchgen.executorch.model import ETKernelIndex
|
|
from torchgen.selective_build.selector import SelectiveBuilder
|
|
|
|
|
|
# Generates RegisterKernelStub.cpp, which provides placeholder kernels for custom operators. This will be used at
|
|
# model authoring side.
|
|
@dataclass(frozen=True)
|
|
class ComputeNativeFunctionStub:
|
|
@method_with_native_function
|
|
def __call__(self, f: NativeFunction) -> str | None:
|
|
if Variant.function not in f.variants:
|
|
return None
|
|
|
|
sig = DispatcherSignature.from_schema(
|
|
f.func, prefix=f"wrapper_CPU_{f.func.name.overload_name}_", symint=False
|
|
)
|
|
assert sig is not None
|
|
if len(f.func.returns) == 0:
|
|
ret_name = ""
|
|
elif len(f.func.returns) == 1:
|
|
if f.func.arguments.out:
|
|
ret_name = f.func.arguments.out[0].name
|
|
else:
|
|
ret_name = next(
|
|
(
|
|
a.name
|
|
for a in f.func.arguments.flat_non_out
|
|
if a.type == f.func.returns[0].type
|
|
),
|
|
"",
|
|
)
|
|
if not ret_name:
|
|
# if return type is tensor
|
|
if f.func.returns[0].type == BaseType(BaseTy.Tensor):
|
|
# Returns an empty tensor
|
|
ret_name = "at::Tensor()"
|
|
else:
|
|
raise Exception( # noqa: TRY002
|
|
f"Can't handle this return type {f.func}"
|
|
) # noqa: TRY002
|
|
elif len(f.func.arguments.out) == len(f.func.returns):
|
|
# Returns a tuple of out arguments
|
|
tensor_type = "at::Tensor &"
|
|
comma = ", "
|
|
ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>(
|
|
{comma.join([r.name for r in f.func.arguments.out])}
|
|
)"""
|
|
else:
|
|
assert all(
|
|
a.type == BaseType(BaseTy.Tensor) for a in f.func.returns
|
|
), f"Only support tensor returns but got {f.func.returns}"
|
|
# Returns a tuple of empty tensors
|
|
tensor_type = "at::Tensor"
|
|
comma = ", "
|
|
ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>(
|
|
{comma.join(["at::Tensor()" for _ in f.func.returns])}
|
|
)"""
|
|
ret_str = f"return {ret_name};" if len(f.func.returns) > 0 else ""
|
|
return f"""
|
|
{sig.defn()} {{
|
|
{ret_str}
|
|
}}
|
|
"""
|
|
|
|
|
|
def gen_custom_ops_registration(
|
|
*,
|
|
native_functions: Sequence[NativeFunction],
|
|
selector: SelectiveBuilder,
|
|
kernel_index: ETKernelIndex,
|
|
rocm: bool,
|
|
) -> tuple[str, str]:
|
|
"""
|
|
Generate custom ops registration code for dest.RegisterDispatchKey.
|
|
|
|
:param native_functions: a sequence of `NativeFunction`
|
|
:param selector: for selective build.
|
|
:param kernel_index: kernels for all the ops.
|
|
:param rocm: bool for dest.RegisterDispatchKey.
|
|
:return: generated C++ code to register custom operators into PyTorch
|
|
"""
|
|
|
|
# convert kernel index to BackendIndex. This is because we can't handle ETKernelIndex yet.
|
|
# TODO larryliu: evaluate if this code is still needed. If yes let it handle ETKernelIndex.
|
|
|
|
dispatch_key = DispatchKey.CPU
|
|
backend_index = kernel_index._to_backend_index()
|
|
static_init_dispatch_registrations = ""
|
|
ns_grouped_native_functions: dict[str, list[NativeFunction]] = defaultdict(list)
|
|
for native_function in native_functions:
|
|
ns_grouped_native_functions[native_function.namespace].append(native_function)
|
|
|
|
for namespace, functions in ns_grouped_native_functions.items():
|
|
if len(functions) == 0:
|
|
continue
|
|
dispatch_registrations_body = "\n".join(
|
|
list(
|
|
concatMap(
|
|
dest.RegisterDispatchKey(
|
|
backend_index,
|
|
Target.REGISTRATION,
|
|
selector,
|
|
rocm=rocm,
|
|
symint=False,
|
|
class_method_name=None,
|
|
skip_dispatcher_op_registration=False,
|
|
),
|
|
functions,
|
|
)
|
|
)
|
|
)
|
|
static_init_dispatch_registrations += f"""
|
|
TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
|
|
{dispatch_registrations_body}
|
|
}};"""
|
|
anonymous_definition = "\n".join(
|
|
list(
|
|
concatMap(
|
|
dest.RegisterDispatchKey(
|
|
backend_index,
|
|
Target.ANONYMOUS_DEFINITION,
|
|
selector,
|
|
rocm=rocm,
|
|
symint=False,
|
|
class_method_name=None,
|
|
skip_dispatcher_op_registration=False,
|
|
),
|
|
native_functions,
|
|
)
|
|
)
|
|
)
|
|
return anonymous_definition, static_init_dispatch_registrations
|