mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This moves them from `torch._C._nn` to `torch._C._dist` Pull Request resolved: https://github.com/pytorch/pytorch/pull/97793 Approved by: https://github.com/albanD
1325 lines
42 KiB
Python
1325 lines
42 KiB
Python
# Generates Python bindings for ATen functions
|
|
#
|
|
# The bindings are generated as methods on python_variable or functions on the
|
|
# torch._C._nn. torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._sparse
|
|
# or torch._C._special objects.
|
|
#
|
|
|
|
# Code tries to stick to the following rules:
|
|
#
|
|
# - templates should be colocated with the functions that use them.
|
|
# no templates are currently shared between functions, but if that
|
|
# happens, maybe put the template with the first one
|
|
#
|
|
# - don't use environment dictionaries when calling template.substitute().
|
|
# pass named arguments directly for everything, otherwise it's much too
|
|
# hard to track what's actually being used and by who
|
|
#
|
|
# - colocate any new hacks/adjustments with existing ones of the same kind.
|
|
# ideally in a data structure rather than code if possible. See e.g.
|
|
# SCHEMA_DEFAULT_CONVERSION_HACKS, etc.
|
|
#
|
|
# - similarly, conversions from one format to another should ideally happen
|
|
# all at once in a single place.
|
|
#
|
|
# - no nontrivial nested functions. couple-liners are ok but please no more.
|
|
# especially avoid functions that read/write outer variables defined far away.
|
|
#
|
|
# - raise RuntimeError instead of asserting, and put as much
|
|
# information as is available into the message. I.e. no need to
|
|
# plumb in new params whose only purpose is to fill out an error
|
|
# message, but use what's there
|
|
#
|
|
|
|
import itertools
|
|
import re
|
|
from collections import defaultdict
|
|
|
|
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple
|
|
|
|
import yaml
|
|
from torchgen.api import cpp
|
|
from torchgen.api.python import (
|
|
arg_parser_output_exprs,
|
|
cpp_dispatch_exprs,
|
|
cpp_dispatch_target,
|
|
dispatch_lambda_args,
|
|
dispatch_lambda_exprs,
|
|
dispatch_lambda_return_str,
|
|
has_tensor_options,
|
|
namedtuple_fieldnames,
|
|
PythonSignature,
|
|
PythonSignatureDeprecated,
|
|
PythonSignatureGroup,
|
|
PythonSignatureNativeFunctionPair,
|
|
signature,
|
|
signature_from_schema,
|
|
)
|
|
|
|
from torchgen.code_template import CodeTemplate
|
|
from torchgen.context import with_native_function
|
|
from torchgen.gen import cpp_string, parse_native_yaml, parse_tags_yaml
|
|
from torchgen.model import (
|
|
Argument,
|
|
BaseOperatorName,
|
|
FunctionSchema,
|
|
NativeFunction,
|
|
Type,
|
|
Variant,
|
|
)
|
|
from torchgen.utils import FileManager, split_name_params, YamlLoader
|
|
|
|
from .gen_trace_type import should_trace
|
|
|
|
#
|
|
# declarations blocklist
|
|
# We skip codegen for these functions, for various reasons.
|
|
# Future PRs will categorize this list and eliminate or hoist
|
|
# them out of eager-only codegen.
|
|
# See https://github.com/pytorch/pytorch/issues/30788
|
|
#
|
|
|
|
# These functions require manual Python bindings or are not exposed to Python
|
|
_SKIP_PYTHON_BINDINGS = [
|
|
"alias",
|
|
"contiguous",
|
|
"is_cuda",
|
|
"is_sparse",
|
|
"is_sparse_csr",
|
|
"size",
|
|
"stride",
|
|
".*_backward",
|
|
".*_backward_(out|input|weight|bias)",
|
|
".*_forward",
|
|
".*_forward_out",
|
|
".*_jvp",
|
|
"_unsafe_view",
|
|
"tensor",
|
|
"_?sparse_(coo|compressed|csr|csc|bsr|bsc)_tensor.*",
|
|
"_range.*",
|
|
"_sparse_add_out",
|
|
"_sparse_div.*",
|
|
"_sparse_mul.*",
|
|
"_sparse_sub.*",
|
|
"_sparse_dense_add_out",
|
|
"index",
|
|
"index_out",
|
|
"unique_dim_consecutive",
|
|
"_cumsum.*",
|
|
"_cumprod.*",
|
|
"_sum.*",
|
|
"_prod.*",
|
|
"_th_.*",
|
|
"_thnn_.*",
|
|
"range.*",
|
|
"_solve.*",
|
|
"_inverse.*",
|
|
"_cholesky.*",
|
|
"_triangular_solve.*",
|
|
"_qr.*",
|
|
"_svd.*",
|
|
"slice",
|
|
"item",
|
|
"_local_scalar_dense",
|
|
"to",
|
|
"_to_copy",
|
|
"_to_copy_out",
|
|
"_reshape_copy",
|
|
"_reshape_copy_out",
|
|
"copy_sparse_to_sparse_",
|
|
"copy_",
|
|
"numpy_T",
|
|
"matrix_H",
|
|
"mT",
|
|
"mH", # these need to be an attributes in Python, not functions
|
|
"nonzero(_(out|numpy))?",
|
|
"set_data",
|
|
".*_overrideable", # overrideable functions for backend extension
|
|
"data",
|
|
"is_leaf",
|
|
"output_nr",
|
|
"_version",
|
|
"requires_grad_",
|
|
"retains_grad",
|
|
"set_",
|
|
"_fw_primal",
|
|
"fake_quantize_per_tensor_affine_cachemask",
|
|
"fake_quantize_per_channel_affine_cachemask",
|
|
"_new_zeros_with_same_feature_meta",
|
|
"_has_same_storage_numel", # used for forward AD internals
|
|
"_reshape_alias",
|
|
"replace_", # only used by the functionalization pass, doesn't need to be exposed to python
|
|
"copy", # only used by the functionalization pass
|
|
"fill.Tensor", # only used by the functionalization pass
|
|
"fill.Scalar", # only used by the functionalization pass
|
|
"lift.*",
|
|
"normal_functional", # only used by the functionalization pas
|
|
"_nested_view_from_buffer", # View only version of _nested_from_buffer. This will force users to only use the "safe" version.
|
|
"_nested_view_from_buffer_copy",
|
|
"_nested_view_from_buffer_copy_out",
|
|
]
|
|
|
|
SKIP_PYTHON_BINDINGS = [
|
|
re.compile(rf"^{pattern}$") for pattern in _SKIP_PYTHON_BINDINGS
|
|
]
|
|
|
|
# These function signatures are not exposed to Python. Note that this signature
|
|
# list does not support regex.
|
|
SKIP_PYTHON_BINDINGS_SIGNATURES = [
|
|
"add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor",
|
|
"add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)",
|
|
"sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor",
|
|
"sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)",
|
|
"mul.Scalar(Tensor self, Scalar other) -> Tensor",
|
|
"mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)",
|
|
"div.Scalar(Tensor self, Scalar other) -> Tensor",
|
|
"div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)",
|
|
]
|
|
|
|
|
|
@with_native_function
|
|
def should_generate_py_binding(f: NativeFunction) -> bool:
|
|
# NativeFunctions that are entirely code-generated should not get python bindings
|
|
# because these codegen implementations are often inefficient. A handful of
|
|
# view_copy style ops were exposed accidentally when they were handwritten and now
|
|
# that we are moving them to codegen for bc reasons we need to keep them exposed in
|
|
# python.
|
|
if "generated" in f.tags and "view_copy" not in f.tags:
|
|
return False
|
|
|
|
name = cpp.name(f.func)
|
|
for skip_regex in SKIP_PYTHON_BINDINGS:
|
|
if skip_regex.match(name):
|
|
return False
|
|
|
|
signature = str(f.func)
|
|
for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES:
|
|
if pattern == signature:
|
|
return False
|
|
return True
|
|
|
|
|
|
def get_pycname(name: BaseOperatorName) -> str:
|
|
return f"THPVariable_{name}"
|
|
|
|
|
|
def is_noarg(overloads: Sequence[PythonSignatureNativeFunctionPair]) -> bool:
|
|
return len(overloads) == 1 and overloads[0].signature.arguments_count() == 0
|
|
|
|
|
|
def is_py_variable_method(f: NativeFunction) -> bool:
|
|
return f.python_module is None and Variant.method in f.variants
|
|
|
|
|
|
def is_py_torch_function(f: NativeFunction) -> bool:
|
|
return f.python_module is None and Variant.function in f.variants
|
|
|
|
|
|
def is_py_nn_function(f: NativeFunction) -> bool:
|
|
return f.python_module == "nn"
|
|
|
|
|
|
def is_py_fft_function(f: NativeFunction) -> bool:
|
|
return f.python_module == "fft"
|
|
|
|
|
|
def is_py_linalg_function(f: NativeFunction) -> bool:
|
|
return f.python_module == "linalg"
|
|
|
|
|
|
def is_py_nested_function(f: NativeFunction) -> bool:
|
|
return f.python_module == "nested"
|
|
|
|
|
|
def is_py_sparse_function(f: NativeFunction) -> bool:
|
|
return f.python_module == "sparse"
|
|
|
|
|
|
def is_py_special_function(f: NativeFunction) -> bool:
|
|
return f.python_module == "special"
|
|
|
|
|
|
def is_py_dist_function(f: NativeFunction) -> bool:
|
|
return f.python_module == "dist"
|
|
|
|
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
#
|
|
# Main Function
|
|
#
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
|
|
|
|
def gen(
|
|
out: str,
|
|
native_yaml_path: str,
|
|
tags_yaml_path: str,
|
|
deprecated_yaml_path: str,
|
|
template_path: str,
|
|
*,
|
|
symint: bool = True,
|
|
) -> None:
|
|
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
|
|
native_functions = parse_native_yaml(
|
|
native_yaml_path, tags_yaml_path
|
|
).native_functions
|
|
native_functions = list(filter(should_generate_py_binding, native_functions))
|
|
|
|
methods = load_signatures(native_functions, deprecated_yaml_path, method=True)
|
|
create_python_bindings(
|
|
fm,
|
|
methods,
|
|
is_py_variable_method,
|
|
None,
|
|
"python_variable_methods.cpp",
|
|
method=True,
|
|
symint=symint,
|
|
)
|
|
|
|
# NOTE: num_shards here must be synced with gatherTorchFunctions in
|
|
# torch/csrc/autograd/python_torch_functions_manual.cpp
|
|
functions = load_signatures(native_functions, deprecated_yaml_path, method=False)
|
|
create_python_bindings_sharded(
|
|
fm,
|
|
functions,
|
|
is_py_torch_function,
|
|
"torch",
|
|
"python_torch_functions.cpp",
|
|
method=False,
|
|
num_shards=3,
|
|
symint=symint,
|
|
)
|
|
|
|
create_python_bindings(
|
|
fm,
|
|
functions,
|
|
is_py_nn_function,
|
|
"torch.nn",
|
|
"python_nn_functions.cpp",
|
|
method=False,
|
|
symint=symint,
|
|
)
|
|
|
|
create_python_bindings(
|
|
fm,
|
|
functions,
|
|
is_py_fft_function,
|
|
"torch.fft",
|
|
"python_fft_functions.cpp",
|
|
method=False,
|
|
symint=symint,
|
|
)
|
|
|
|
create_python_bindings(
|
|
fm,
|
|
functions,
|
|
is_py_linalg_function,
|
|
"torch.linalg",
|
|
"python_linalg_functions.cpp",
|
|
method=False,
|
|
symint=symint,
|
|
)
|
|
|
|
create_python_bindings(
|
|
fm,
|
|
functions,
|
|
is_py_nested_function,
|
|
"torch.nested",
|
|
"python_nested_functions.cpp",
|
|
method=False,
|
|
)
|
|
|
|
create_python_bindings(
|
|
fm,
|
|
functions,
|
|
is_py_sparse_function,
|
|
"torch.sparse",
|
|
"python_sparse_functions.cpp",
|
|
method=False,
|
|
symint=symint,
|
|
)
|
|
|
|
create_python_bindings(
|
|
fm,
|
|
functions,
|
|
is_py_special_function,
|
|
"torch.special",
|
|
"python_special_functions.cpp",
|
|
method=False,
|
|
symint=symint,
|
|
)
|
|
|
|
create_python_bindings(
|
|
fm,
|
|
functions,
|
|
is_py_dist_function,
|
|
"torch.distributed.functional",
|
|
"python_dist_functions.cpp",
|
|
method=False,
|
|
)
|
|
|
|
# Currently, we only use `functions` to generate `return_types` bindings.
|
|
# All methods which return namedtuple have function variant at this point.
|
|
# If any method only operator with namedtuple is added in the future,
|
|
# we will have to address that.
|
|
create_python_return_type_bindings(
|
|
fm, functions, lambda fn: True, "python_return_types.cpp"
|
|
)
|
|
|
|
valid_tags = parse_tags_yaml(tags_yaml_path)
|
|
|
|
def gen_tags_enum() -> Dict[str, str]:
|
|
return {
|
|
"enum_of_valid_tags": (
|
|
"".join([f'\n.value("{tag}", at::Tag::{tag})' for tag in valid_tags])
|
|
)
|
|
}
|
|
|
|
fm.write("python_enum_tag.cpp", gen_tags_enum)
|
|
|
|
|
|
def group_filter_overloads(
|
|
pairs: Sequence[PythonSignatureNativeFunctionPair],
|
|
pred: Callable[[NativeFunction], bool],
|
|
) -> Dict[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]:
|
|
grouped: Dict[
|
|
BaseOperatorName, List[PythonSignatureNativeFunctionPair]
|
|
] = defaultdict(list)
|
|
for pair in pairs:
|
|
if pred(pair.function):
|
|
grouped[pair.function.func.name.name].append(pair)
|
|
return grouped
|
|
|
|
|
|
def create_python_bindings(
|
|
fm: FileManager,
|
|
pairs: Sequence[PythonSignatureNativeFunctionPair],
|
|
pred: Callable[[NativeFunction], bool],
|
|
module: Optional[str],
|
|
filename: str,
|
|
*,
|
|
method: bool,
|
|
symint: bool = True,
|
|
) -> None:
|
|
"""Generates Python bindings to ATen functions"""
|
|
py_methods: List[str] = []
|
|
ops_headers: List[str] = []
|
|
py_method_defs: List[str] = []
|
|
py_forwards: List[str] = []
|
|
|
|
grouped = group_filter_overloads(pairs, pred)
|
|
|
|
for name in sorted(grouped.keys(), key=lambda x: str(x)):
|
|
overloads = grouped[name]
|
|
py_methods.append(
|
|
method_impl(name, module, overloads, method=method, symint=symint)
|
|
)
|
|
py_method_defs.append(method_def(name, module, overloads, method=method))
|
|
py_forwards.extend(forward_decls(name, overloads, method=method))
|
|
ops_headers.append(f"#include <ATen/ops/{name.base}.h>")
|
|
|
|
fm.write_with_template(
|
|
filename,
|
|
filename,
|
|
lambda: {
|
|
"generated_comment": "@"
|
|
+ f"generated from {fm.template_dir_for_comments()}/{filename}",
|
|
"ops_headers": ops_headers,
|
|
"py_forwards": py_forwards,
|
|
"py_methods": py_methods,
|
|
"py_method_defs": py_method_defs,
|
|
},
|
|
)
|
|
|
|
|
|
def create_python_return_type_bindings(
|
|
fm: FileManager,
|
|
pairs: Sequence[PythonSignatureNativeFunctionPair],
|
|
pred: Callable[[NativeFunction], bool],
|
|
filename: str,
|
|
) -> None:
|
|
"""
|
|
Generate function to initialize and return named tuple for native functions
|
|
which returns named tuple and relevant entry for the map in `python_return_types.cpp`.
|
|
"""
|
|
py_return_types_definition: List[str] = []
|
|
py_return_types_map: List[str] = []
|
|
|
|
grouped = group_filter_overloads(pairs, pred)
|
|
|
|
for name in sorted(grouped.keys(), key=lambda x: str(x)):
|
|
overloads = grouped[name]
|
|
definitions, map_entries = generate_return_type_definition_and_map_entry(
|
|
overloads
|
|
)
|
|
py_return_types_definition.append(
|
|
"" if not definitions else "\n".join(definitions)
|
|
)
|
|
py_return_types_map.append("" if not map_entries else "\n".join(map_entries))
|
|
|
|
fm.write_with_template(
|
|
filename,
|
|
filename,
|
|
lambda: {
|
|
"generated_comment": "@"
|
|
+ f"generated from {fm.template_dir_for_comments()}/{filename}",
|
|
"py_return_types": py_return_types_definition,
|
|
"py_return_types_map": py_return_types_map,
|
|
},
|
|
)
|
|
|
|
|
|
def create_python_bindings_sharded(
|
|
fm: FileManager,
|
|
pairs: Sequence[PythonSignatureNativeFunctionPair],
|
|
pred: Callable[[NativeFunction], bool],
|
|
module: Optional[str],
|
|
filename: str,
|
|
*,
|
|
method: bool,
|
|
num_shards: int,
|
|
symint: bool = True,
|
|
) -> None:
|
|
"""Generates Python bindings to ATen functions"""
|
|
grouped = group_filter_overloads(pairs, pred)
|
|
|
|
def key_func(
|
|
kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
|
|
) -> str:
|
|
return kv[0].base
|
|
|
|
def env_func(
|
|
kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
|
|
) -> Dict[str, List[str]]:
|
|
name, fn_pairs = kv
|
|
return {
|
|
"ops_headers": [f"#include <ATen/ops/{name.base}.h>"],
|
|
"py_forwards": list(forward_decls(name, fn_pairs, method=method)),
|
|
"py_methods": [
|
|
method_impl(name, module, fn_pairs, method=method, symint=symint)
|
|
],
|
|
"py_method_defs": [method_def(name, module, fn_pairs, method=method)],
|
|
}
|
|
|
|
fm.write_sharded(
|
|
filename,
|
|
grouped.items(),
|
|
base_env={
|
|
"generated_comment": "@"
|
|
+ f"generated from {fm.template_dir_for_comments()}/{filename}",
|
|
},
|
|
key_fn=key_func,
|
|
env_callable=env_func,
|
|
num_shards=num_shards,
|
|
sharded_keys={"ops_headers", "py_forwards", "py_methods", "py_method_defs"},
|
|
)
|
|
|
|
|
|
def load_signatures(
|
|
native_functions: List[NativeFunction],
|
|
deprecated_yaml_path: str,
|
|
*,
|
|
method: bool,
|
|
skip_deprecated: bool = False,
|
|
pyi: bool = False,
|
|
) -> Sequence[PythonSignatureNativeFunctionPair]:
|
|
@with_native_function
|
|
def gen_signature_pairs(f: NativeFunction) -> PythonSignatureNativeFunctionPair:
|
|
return PythonSignatureNativeFunctionPair(
|
|
signature=signature(f, method=method, pyi=pyi),
|
|
function=f,
|
|
)
|
|
|
|
pairs = list(map(gen_signature_pairs, native_functions))
|
|
deprecated = load_deprecated_signatures(
|
|
pairs, deprecated_yaml_path, method=method, pyi=pyi
|
|
)
|
|
return pairs if skip_deprecated else pairs + deprecated
|
|
|
|
|
|
def load_deprecated_signatures(
|
|
pairs: Sequence[PythonSignatureNativeFunctionPair],
|
|
deprecated_yaml_path: str,
|
|
*,
|
|
method: bool,
|
|
pyi: bool,
|
|
) -> List[PythonSignatureNativeFunctionPair]:
|
|
# The deprecated.yaml doesn't have complete type information, we need
|
|
# find and leverage the original ATen signature (to which it delegates
|
|
# the call) to generate the full python signature.
|
|
# We join the deprecated and the original signatures using type-only form.
|
|
|
|
# group the original ATen signatures by name
|
|
grouped: Dict[str, List[PythonSignatureNativeFunctionPair]] = defaultdict(list)
|
|
for pair in pairs:
|
|
grouped[pair.signature.name].append(pair)
|
|
|
|
# find matching original signatures for each deprecated signature
|
|
results: List[PythonSignatureNativeFunctionPair] = []
|
|
|
|
with open(deprecated_yaml_path, "r") as f:
|
|
deprecated_defs = yaml.load(f, Loader=YamlLoader)
|
|
|
|
for deprecated in deprecated_defs:
|
|
schema = FunctionSchema.parse(deprecated["name"])
|
|
aten_name, call_args = split_name_params(deprecated["aten"])
|
|
is_out = aten_name.endswith("_out")
|
|
if is_out:
|
|
aten_name = aten_name.replace("_out", "")
|
|
|
|
# HACK: these are fixed constants used to pass the the aten function.
|
|
# The type must be known ahead of time
|
|
known_constants = {
|
|
"1": Type.parse("Scalar"),
|
|
}
|
|
schema_args_by_name = {a.name: a for a in schema.arguments.flat_all}
|
|
for name in call_args:
|
|
assert (
|
|
name in schema_args_by_name or name in known_constants
|
|
), f"deprecation definiton: Unrecognized value {name}"
|
|
|
|
# Map deprecated signature arguments to their aten signature and test
|
|
# if the types and alias annotation match.
|
|
def is_schema_compatible(
|
|
aten_schema: FunctionSchema,
|
|
) -> bool:
|
|
arguments: Iterable[Argument]
|
|
if is_out:
|
|
arguments = itertools.chain(
|
|
aten_schema.arguments.out, aten_schema.arguments.flat_non_out
|
|
)
|
|
else:
|
|
arguments = aten_schema.arguments.flat_all
|
|
|
|
for i, arg in enumerate(arguments):
|
|
if i < len(call_args):
|
|
arg_name = call_args[i]
|
|
if arg_name in known_constants:
|
|
schema_type = known_constants[arg_name]
|
|
schema_annotation = None
|
|
else:
|
|
schema_arg = schema_args_by_name[arg_name]
|
|
schema_type = schema_arg.type
|
|
schema_annotation = schema_arg.annotation
|
|
|
|
if schema_type != arg.type or schema_annotation != arg.annotation:
|
|
return False
|
|
else:
|
|
if arg.default is None:
|
|
return False
|
|
|
|
return len(schema.returns) == len(aten_schema.returns) and all(
|
|
a == b for a, b in zip(schema.returns, aten_schema.returns)
|
|
)
|
|
|
|
any_schema_found = False
|
|
for pair in grouped[aten_name]:
|
|
if not is_schema_compatible(pair.function.func):
|
|
continue
|
|
any_schema_found = True
|
|
|
|
python_sig = signature_from_schema(
|
|
schema,
|
|
category_override=pair.function.category_override,
|
|
method=method,
|
|
pyi=pyi,
|
|
)
|
|
|
|
results.append(
|
|
PythonSignatureNativeFunctionPair(
|
|
signature=PythonSignatureDeprecated(
|
|
name=python_sig.name,
|
|
input_args=python_sig.input_args,
|
|
input_kwargs=python_sig.input_kwargs,
|
|
output_args=python_sig.output_args,
|
|
tensor_options_args=python_sig.tensor_options_args,
|
|
method=python_sig.method,
|
|
deprecated_schema=schema,
|
|
deprecated_args_exprs=tuple(call_args),
|
|
returns=python_sig.returns,
|
|
),
|
|
function=pair.function,
|
|
)
|
|
)
|
|
assert (
|
|
any_schema_found
|
|
), f"No native function with name {aten_name} matched signature:\n {str(schema)}"
|
|
|
|
return results
|
|
|
|
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
#
|
|
# Named Tuple Codegen
|
|
#
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
|
|
|
|
@with_native_function
|
|
def gen_namedtuple_typename_key(f: NativeFunction) -> str:
|
|
name = cpp.name(f.func)
|
|
fieldnames = namedtuple_fieldnames(f.func.returns)
|
|
return "_".join([name] + fieldnames)
|
|
|
|
|
|
def emit_namedtuple_call(
|
|
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
|
) -> Tuple[List[str], Dict[str, str]]:
|
|
"""
|
|
Generate block of named tuple type def inits, and add typeref snippets
|
|
to declarations that use them
|
|
"""
|
|
typenames: Dict[
|
|
str, str
|
|
] = {} # map from unique name + field name lists to typedef name
|
|
typedefs: List[str] = [] # typedef declarations and init code
|
|
|
|
for overload in overloads:
|
|
fieldnames = namedtuple_fieldnames(overload.function.func.returns)
|
|
if not fieldnames:
|
|
continue
|
|
|
|
name = cpp.name(overload.function.func) # use @with_native_function?
|
|
tn_key = gen_namedtuple_typename_key(overload.function)
|
|
typename = typenames.get(tn_key)
|
|
if typename is None:
|
|
typename = f'NamedTuple{"" if not typedefs else len(typedefs)}'
|
|
typenames[tn_key] = typename
|
|
typedefs.append(
|
|
f"""\
|
|
static PyTypeObject* {typename} = get_namedtuple("{name}");"""
|
|
)
|
|
|
|
return typedefs, typenames
|
|
|
|
|
|
def generate_return_type_definition_and_map_entry(
|
|
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
|
) -> Tuple[List[str], List[str]]:
|
|
"""
|
|
Generate block of function in `python_return_types.cpp` to initialize
|
|
and return named tuple for a native function which returns named tuple
|
|
and relevant entry for the map in same file.
|
|
"""
|
|
typenames: Dict[
|
|
str, str
|
|
] = {} # map from unique name + field name lists to typedef name
|
|
definitions: List[str] = [] # function defintion to register the typedef
|
|
map_entries: List[
|
|
str
|
|
] = [] # C++ map entry of <function_name, function creates it namedtuple>
|
|
|
|
for overload in overloads:
|
|
fieldnames = namedtuple_fieldnames(overload.function.func.returns)
|
|
if not fieldnames:
|
|
continue
|
|
|
|
fields = ", ".join(f'{{"{fn}", ""}}' for fn in fieldnames)
|
|
|
|
name = cpp.name(overload.function.func) # use @with_native_function?
|
|
tn_key = gen_namedtuple_typename_key(overload.function)
|
|
typename = typenames.get(tn_key)
|
|
|
|
if typename is None:
|
|
typename = f'{name}NamedTuple{"" if not definitions else len(definitions)}'
|
|
typenames[tn_key] = typename
|
|
definitions.append(
|
|
f"""\
|
|
PyTypeObject* get_{name}_namedtuple() {{
|
|
static PyStructSequence_Field NamedTuple_fields[] = {{ {fields}, {{nullptr}} }};
|
|
static PyTypeObject {typename};
|
|
static bool is_initialized = false;
|
|
static PyStructSequence_Desc desc = {{ "torch.return_types.{name}", nullptr, NamedTuple_fields, {len(fieldnames)} }};
|
|
if (!is_initialized) {{
|
|
PyStructSequence_InitType(&{typename}, &desc);
|
|
{typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr;
|
|
is_initialized = true;
|
|
}}
|
|
return &{typename};
|
|
}}
|
|
"""
|
|
)
|
|
map_entries.append(f'{{"{name}", get_{name}_namedtuple()}}, ')
|
|
|
|
return definitions, map_entries
|
|
|
|
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
#
|
|
# Method Impl Codegen
|
|
#
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
|
|
# python binding for all overloads of a particular function/method
|
|
PY_VARIABLE_METHOD_VARARGS = CodeTemplate(
|
|
r"""\
|
|
// ${name}
|
|
static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
|
|
{
|
|
${method_header}
|
|
static PythonArgParser parser({
|
|
${signatures}
|
|
}, /*traceable=*/${traceable});
|
|
|
|
ParsedArgs<${max_args}> parsed_args;
|
|
auto _r = parser.parse(${self_}, args, kwargs, parsed_args);
|
|
${check_has_torch_function}
|
|
switch (_r.idx) {
|
|
${dispatch}
|
|
}
|
|
${method_footer}
|
|
}
|
|
|
|
"""
|
|
)
|
|
|
|
# handler for a single parsed signature - may be a single overload or
|
|
# a pair of overloads that whose signatures only differ in output params
|
|
# (plugged into PY_VARIABLE_METHOD_VARARGS as an item in ${dispatch})
|
|
PY_VARIABLE_CASE = CodeTemplate(
|
|
"""\
|
|
case ${overload_index}: {
|
|
${body}
|
|
}
|
|
"""
|
|
)
|
|
|
|
# python binding for single-overload function/method
|
|
PY_VARIABLE_METHOD_VARARGS_SINGLETON = CodeTemplate(
|
|
"""\
|
|
// ${name}
|
|
static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
|
|
{
|
|
${method_header}
|
|
static PythonArgParser parser({
|
|
${signatures}
|
|
}, /*traceable=*/${traceable});
|
|
|
|
ParsedArgs<${max_args}> parsed_args;
|
|
auto _r = parser.parse(${self_}, args, kwargs, parsed_args);
|
|
${check_has_torch_function}
|
|
${dispatch}
|
|
${method_footer}
|
|
}
|
|
|
|
"""
|
|
)
|
|
|
|
# python binding for a method with no args, shortcuts parsing
|
|
PY_VARIABLE_METHOD_NOARGS = CodeTemplate(
|
|
"""\
|
|
// ${name}
|
|
static PyObject * ${pycname}(PyObject* self_, PyObject* args)
|
|
{
|
|
${method_header}
|
|
${check_has_torch_function}
|
|
${dispatch}
|
|
${method_footer}
|
|
}
|
|
|
|
"""
|
|
)
|
|
|
|
|
|
def method_impl(
|
|
name: BaseOperatorName,
|
|
module: Optional[str],
|
|
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
|
*,
|
|
method: bool,
|
|
symint: bool = True,
|
|
) -> str:
|
|
"""
|
|
Generate a python binding for all overloads of an op.
|
|
"""
|
|
pycname = get_pycname(name)
|
|
noarg = is_noarg(overloads)
|
|
namedtuple_inits, namedtuple_typenames = emit_namedtuple_call(overloads)
|
|
|
|
method_header = ["HANDLE_TH_ERRORS"]
|
|
method_header += namedtuple_inits
|
|
method_header += (
|
|
["const Tensor& self = THPVariable_Unpack(self_);"] if method else []
|
|
)
|
|
|
|
method_footer = ([] if noarg else ["Py_RETURN_NONE;"]) + ["END_HANDLE_TH_ERRORS"]
|
|
|
|
traceable = "true" if all(should_trace(o.function) for o in overloads) else "false"
|
|
|
|
grouped_overloads: Sequence[PythonSignatureGroup] = group_overloads(
|
|
overloads, symint=symint
|
|
)
|
|
is_singleton = len(grouped_overloads) == 1
|
|
signatures: List[str] = []
|
|
dispatch: List[str] = []
|
|
for overload_index, overload in enumerate(grouped_overloads):
|
|
signature = overload.signature.signature_str(symint=symint)
|
|
signatures.append(f"{cpp_string(str(signature))},")
|
|
dispatch_body = emit_dispatch_case(
|
|
overload, namedtuple_typenames, symint=symint
|
|
)
|
|
dispatch.append(
|
|
PY_VARIABLE_CASE.substitute(
|
|
overload_index=overload_index, body=dispatch_body
|
|
)
|
|
if not is_singleton
|
|
else dispatch_body
|
|
)
|
|
|
|
if noarg:
|
|
template = PY_VARIABLE_METHOD_NOARGS
|
|
elif is_singleton:
|
|
template = PY_VARIABLE_METHOD_VARARGS_SINGLETON
|
|
else:
|
|
template = PY_VARIABLE_METHOD_VARARGS
|
|
|
|
return template.substitute(
|
|
name=name,
|
|
pycname=pycname,
|
|
method_header=method_header,
|
|
max_args=max((o.signature.arguments_count() for o in overloads)),
|
|
signatures=signatures,
|
|
traceable=traceable,
|
|
check_has_torch_function=gen_has_torch_function_check(
|
|
name=name,
|
|
module=module,
|
|
noarg=noarg,
|
|
method=method,
|
|
),
|
|
dispatch=dispatch,
|
|
method_footer=method_footer,
|
|
self_="self_" if method else "nullptr",
|
|
)
|
|
|
|
|
|
def gen_has_torch_function_check(
|
|
name: BaseOperatorName, module: Optional[str], *, noarg: bool, method: bool
|
|
) -> str:
|
|
if noarg:
|
|
if method:
|
|
return f"""\
|
|
if(check_has_torch_function(self_)) {{
|
|
return handle_torch_function(self_, "{name}");
|
|
}}
|
|
"""
|
|
else:
|
|
return ""
|
|
|
|
self_ = "self_" if method else "nullptr"
|
|
namespace = (
|
|
{
|
|
"torch": "THPVariableFunctionsModule",
|
|
"torch.nn": "THPNNVariableFunctionsModule",
|
|
"torch.fft": "THPFFTVariableFunctionsModule",
|
|
"torch.linalg": "THPLinalgVariableFunctionsModule",
|
|
"torch.nested": "THPNestedVariableFunctionsModule",
|
|
"torch.sparse": "THPSparseVariableFunctionsModule",
|
|
"torch.special": "THPSpecialVariableFunctionsModule",
|
|
"torch.distributed.functional": "THPDistVariableFunctionsModule",
|
|
}[module]
|
|
if module
|
|
else "THPVariableClass"
|
|
)
|
|
|
|
return f"""\
|
|
if(_r.has_torch_function()) {{
|
|
return handle_torch_function(_r, {self_}, args, kwargs, {namespace}, "{module or "torch.Tensor"}");
|
|
}}
|
|
"""
|
|
|
|
|
|
# handler for output/no-output overload pair
|
|
PY_VARIABLE_OUT = CodeTemplate(
|
|
"""\
|
|
if (_r.isNone(${out_idx})) {
|
|
${call_dispatch}
|
|
} else {
|
|
${call_dispatch_out}
|
|
}
|
|
"""
|
|
)
|
|
|
|
|
|
def emit_dispatch_case(
|
|
overload: PythonSignatureGroup,
|
|
namedtuple_typenames: Dict[str, str],
|
|
*,
|
|
symint: bool = True,
|
|
) -> str:
|
|
"""
|
|
Emit dispatch code for a single parsed signature. This corresponds to either
|
|
a single native function, or a pair that differ only in output params. In the
|
|
latter case, a single python signature is used for both and dispatching
|
|
switches on the presence/absence of passed output args.
|
|
"""
|
|
if overload.outplace is not None:
|
|
# dispatch output and no-output variants, branch on _r.isNone(<out_idx>)
|
|
return PY_VARIABLE_OUT.substitute(
|
|
out_idx=overload.signature.output_idx(),
|
|
call_dispatch=emit_single_dispatch(
|
|
overload.signature, overload.base, namedtuple_typenames, symint=symint
|
|
),
|
|
call_dispatch_out=emit_single_dispatch(
|
|
overload.signature,
|
|
overload.outplace,
|
|
namedtuple_typenames,
|
|
symint=symint,
|
|
),
|
|
)
|
|
else:
|
|
# no-output version only
|
|
return emit_single_dispatch(
|
|
overload.signature, overload.base, namedtuple_typenames, symint=symint
|
|
)
|
|
|
|
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
#
|
|
# Forward Declarations Codegen
|
|
#
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
|
|
|
|
def forward_decls(
|
|
name: BaseOperatorName,
|
|
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
|
*,
|
|
method: bool,
|
|
) -> Tuple[str, ...]:
|
|
if method:
|
|
return ()
|
|
|
|
pycname = get_pycname(name)
|
|
if is_noarg(overloads):
|
|
return (
|
|
f"""\
|
|
static PyObject * {pycname}(PyObject* self_, PyObject* args);
|
|
""",
|
|
)
|
|
else:
|
|
return (
|
|
f"""\
|
|
static PyObject * {pycname}(PyObject* self_, PyObject* args, PyObject* kwargs);
|
|
""",
|
|
)
|
|
|
|
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
#
|
|
# Method Def (Binding Table Entry) Codegen
|
|
#
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
|
|
|
|
def method_def(
|
|
name: BaseOperatorName,
|
|
module: Optional[str],
|
|
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
|
*,
|
|
method: bool,
|
|
) -> str:
|
|
"""
|
|
Generate method def entry.
|
|
"""
|
|
pycname = get_pycname(name)
|
|
|
|
if is_noarg(overloads):
|
|
pyfunc_cast = ""
|
|
flags = "METH_NOARGS" if method else "METH_VARARGS | METH_KEYWORDS"
|
|
else:
|
|
pyfunc_cast = "castPyCFunctionWithKeywords"
|
|
flags = "METH_VARARGS | METH_KEYWORDS"
|
|
|
|
if module == "torch":
|
|
flags += " | METH_STATIC"
|
|
|
|
if name.dunder_method:
|
|
# PyMethodDef entry for binary op, throws not implemented error
|
|
return f"""\
|
|
{{"{name}", {pyfunc_cast}(TypeError_to_NotImplemented_<{pycname}>), {flags}, NULL}},"""
|
|
else:
|
|
# PyMethodDef entry
|
|
return f"""\
|
|
{{"{name}", {pyfunc_cast}({pycname}), {flags}, NULL}},"""
|
|
|
|
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
#
|
|
# Overload Sorting and Grouping
|
|
#
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
|
|
|
|
def group_overloads(
|
|
overloads: Sequence[PythonSignatureNativeFunctionPair], *, symint: bool = True
|
|
) -> Sequence[PythonSignatureGroup]:
|
|
bases: Dict[str, PythonSignatureNativeFunctionPair] = {}
|
|
outplaces: Dict[str, PythonSignatureNativeFunctionPair] = {}
|
|
|
|
# first group by signature ignoring out arguments
|
|
for overload in overloads:
|
|
sig = overload.signature.signature_str(skip_outputs=True, symint=symint)
|
|
if overload.function.func.is_out_fn():
|
|
if sig in outplaces:
|
|
raise RuntimeError(
|
|
f"Found duplicated function definition:\n- {overload.function.func}.\n"
|
|
f"Existing definition:\n- {outplaces[sig].function.func}."
|
|
)
|
|
outplaces[sig] = overload
|
|
else:
|
|
if sig in bases:
|
|
raise RuntimeError(
|
|
f"Found duplicated function definition:\n- {overload.function.func}.\n"
|
|
f"Existing definition:\n- {bases[sig].function.func}."
|
|
)
|
|
bases[sig] = overload
|
|
|
|
for sig, out in outplaces.items():
|
|
if sig not in bases:
|
|
candidates: List[str] = []
|
|
for overload in overloads:
|
|
if (
|
|
str(overload.function.func.name.name)
|
|
== str(out.function.func.name.name)
|
|
and not overload.function.func.is_out_fn()
|
|
and not overload.signature.deprecated
|
|
):
|
|
candidates.append(
|
|
overload.signature.signature_str(
|
|
skip_outputs=True, symint=symint
|
|
)
|
|
)
|
|
out_sig = out.signature.signature_str(symint=symint)
|
|
raise RuntimeError(
|
|
f"While identifying overloads, we found an out schema {out_sig} without a corresponding non-out variant. "
|
|
f"We expected the non-out variant to have schema: \n- {sig}\nPlease check that you spelled the schema "
|
|
"correctly in native_functions.yaml. We discovered the following candidate(s): \n"
|
|
+ "\n".join(f"- {candidate}" for candidate in candidates)
|
|
)
|
|
|
|
grouped = [
|
|
PythonSignatureGroup.from_pairs(
|
|
functional=base,
|
|
out=outplaces.get(sig),
|
|
)
|
|
for sig, base in bases.items()
|
|
]
|
|
return sort_overloads(grouped, symint=symint)
|
|
|
|
|
|
# This function declares a partial order on declarations, and sorts them according
|
|
# to its linear extension. This is necessary, because there's some ambiguity in the
|
|
# choice of overload, and we want a different order.
|
|
#
|
|
# See Note[Order of overloads matters]
|
|
#
|
|
# A few examples of ambiguous python signature pairs.
|
|
#
|
|
# All parameters have the same type, except one taking Tensor the other taking
|
|
# Scalar. A numeric PyObject can be casted into Tensor, and a zero-dim Tensor
|
|
# object can be accepted as Scalar type parameter (see python_arg_parser.cpp).
|
|
# Therefore, same input arguments might be accepted by either python signature.
|
|
# We want to always parse the one taking Tensor first.
|
|
#
|
|
# bitwise_and(Tensor input, Tensor other, *, Tensor out=None)
|
|
# bitwise_and(Tensor input, Scalar other, *, Tensor out=None)
|
|
#
|
|
# If they have different number of parameters then they are not ambiguous - but
|
|
# the difference on output param can be ignored as it's optional.
|
|
#
|
|
# multiply(Tensor input, Tensor other, *, Tensor out=None)
|
|
# multiply(Tensor input, Scalar other)
|
|
#
|
|
# Both positional args and keyword-only args are considered together.
|
|
#
|
|
# subtract(Tensor other, *, Scalar alpha=1)
|
|
# subtract(Scalar other, Scalar alpha=1)
|
|
#
|
|
# A few ambiguous cases which it does NOT handle yet.
|
|
#
|
|
# If there is any difference in other parameters besides the Tensor/Scalar
|
|
# difference, then they are not considered ambiguous by this method anymore.
|
|
# However, the difference could be too trivial to disambiguate.
|
|
#
|
|
# foo(Tensor input, Scalar other, Scalar bar)
|
|
# foo(Tensor input, Tensor other, double bar)
|
|
#
|
|
# If they are taking different number of parameters then they are not considered
|
|
# ambiguous anymore, even if the difference is only on optional kwargs.
|
|
#
|
|
# foo(Scalar other, Scalar alpha=1)
|
|
# foo(Tensor other, *, Scalar alpha=1, Scalar beta=1)
|
|
#
|
|
|
|
|
|
def sort_overloads(
|
|
grouped_overloads: Sequence[PythonSignatureGroup], *, symint: bool = True
|
|
) -> Sequence[PythonSignatureGroup]:
|
|
# NB: Smaller here means lower priority
|
|
|
|
def is_arg_smaller(t1: Type, t2: Type) -> bool:
|
|
return (
|
|
str(t1) == "Scalar"
|
|
and str(t2) == "Tensor"
|
|
or str(t1) == "Scalar?"
|
|
and str(t2) == "Tensor?"
|
|
or "Dimname" in str(t1)
|
|
and "Dimname" not in str(t2)
|
|
or
|
|
# In the discussion https://github.com/pytorch/pytorch/issues/54555 it has been
|
|
# discussed why it is important to prioritize int/int? over int[]
|
|
str(t1) == "int[]"
|
|
and (str(t2) == "int" or str(t2) == "int?")
|
|
or
|
|
# TensorList currently throws an error during argument parsing, that's why it needs to be
|
|
# last in signature ordering. See discussion: https://github.com/pytorch/pytorch/issues/58087
|
|
str(t1) == "Tensor[]"
|
|
and str(t2).find("[]") != -1
|
|
or
|
|
# Prioritize IntArrayRef overload over SymIntArrayRef
|
|
str(t1) == "SymInt[]"
|
|
and str(t2) == "int[]"
|
|
or
|
|
# Make sure both in, SymInt are sorted consistently w.r.t. Tensor since Tensor can be implicitly
|
|
# converted to either int or SymInt. Prioritize the Tensor overload since it otherwise gets shadowed.
|
|
(str(t1) == "SymInt" or str(t1) == "int")
|
|
and str(t2) == "Tensor"
|
|
)
|
|
|
|
def is_smaller(s1: PythonSignature, s2: PythonSignature) -> bool:
|
|
"""Returns True if s1 < s2 in the partial order."""
|
|
args1, args2 = s1.arguments(skip_outputs=True), s2.arguments(skip_outputs=True)
|
|
if len(args1) != len(args2):
|
|
return False
|
|
# TODO: should use some canonical form instead of 'str(arg.type)' - see comments
|
|
# above. The old codegen used the deprecated 'dynamic_type(arg.type)', which
|
|
# ignores the optional annotation, i.e. 'Scalar' and 'Scalar?'.
|
|
equal = all(arg1.type == arg2.type for arg1, arg2 in zip(args1, args2))
|
|
smaller_or_equal = all(
|
|
str(arg1.type) == str(arg2.type) or is_arg_smaller(arg1.type, arg2.type)
|
|
for arg1, arg2 in zip(args1, args2)
|
|
)
|
|
return smaller_or_equal and not equal
|
|
|
|
# First sort by signature
|
|
grouped_overloads = sorted(
|
|
grouped_overloads, key=lambda x: x.signature.signature_str(symint=symint)
|
|
)
|
|
|
|
# Construct the relation graph
|
|
larger_than: Dict[int, Set[int]] = defaultdict(set)
|
|
for i1, overload1 in enumerate(grouped_overloads):
|
|
for i2, overload2 in enumerate(grouped_overloads):
|
|
if is_smaller(overload1.signature, overload2.signature):
|
|
larger_than[i1].add(i2)
|
|
|
|
if not larger_than:
|
|
return list(grouped_overloads)
|
|
|
|
# Use a topological sort to sort overloads according to the partial order.
|
|
N = len(grouped_overloads)
|
|
sorted_ids: List[int] = list(filter(lambda x: x not in larger_than, range(N)))
|
|
|
|
for idx in range(N):
|
|
# The size of sorted_ids will grow to N eventually.
|
|
i = sorted_ids[idx]
|
|
for j in sorted(larger_than.keys()):
|
|
larger = larger_than[j]
|
|
larger.discard(i)
|
|
if not larger:
|
|
del larger_than[j]
|
|
sorted_ids.append(j)
|
|
|
|
return [grouped_overloads[x] for x in sorted_ids]
|
|
|
|
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
#
|
|
# Codegen API Integration
|
|
#
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
|
|
|
|
def emit_single_dispatch(
|
|
ps: PythonSignature,
|
|
f: NativeFunction,
|
|
namedtuple_typenames: Dict[str, str],
|
|
*,
|
|
symint: bool = True,
|
|
) -> str:
|
|
"""
|
|
Emit dispatch code for a single native function.
|
|
"""
|
|
|
|
@with_native_function
|
|
def go(f: NativeFunction) -> str:
|
|
# header comments
|
|
if isinstance(ps, PythonSignatureDeprecated):
|
|
schema_comment = f"// [deprecated] aten::{ps.deprecated_schema}"
|
|
else:
|
|
schema_comment = f"// aten::{f.func}"
|
|
|
|
deprecated = "[deprecated] " if ps.deprecated else ""
|
|
|
|
# dispatch lambda signature
|
|
name = cpp.name(f.func)
|
|
lambda_formals = ", ".join(
|
|
(
|
|
f"{a.type_str} {a.name}"
|
|
for a in dispatch_lambda_args(ps, f, symint=symint)
|
|
)
|
|
)
|
|
lambda_return = dispatch_lambda_return_str(f)
|
|
|
|
# dispatch lambda body
|
|
dispatch_callee = cpp_dispatch_target(f)
|
|
dispatch_args = ", ".join(cpp_dispatch_exprs(f, python_signature=ps))
|
|
|
|
# from arg parser outputs to dispatch lambda arguments
|
|
parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
|
|
lambda_arg_exprs = dispatch_lambda_exprs(ps, f, symint=symint)
|
|
inits = "\n".join(lambda_arg_exprs.inits)
|
|
lambda_args = ", ".join(lambda_arg_exprs.exprs)
|
|
|
|
# scatter fields
|
|
# TODO: Checking `ps.method and ('requires_grad' in parser_outputs)` is a hacky
|
|
# solution for enabling the 'requires_grad' argument for tensor methods
|
|
# new_full, new_empty, and new_zeros. A much better but more difficult to
|
|
# implement solution involves refactoring according to Ed's description here:
|
|
# https://github.com/pytorch/pytorch/issues/36455#issuecomment-614767589
|
|
need_set_requires_grad = ps.tensor_options_args and (
|
|
not has_tensor_options(f)
|
|
or (ps.method and ("requires_grad" in parser_outputs))
|
|
)
|
|
set_requires_grad = (
|
|
f'.set_requires_grad({parser_outputs["requires_grad"].expr})'
|
|
if need_set_requires_grad
|
|
else ""
|
|
)
|
|
|
|
if lambda_return == "void":
|
|
return f"""\
|
|
{schema_comment}
|
|
{inits}
|
|
auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{
|
|
pybind11::gil_scoped_release no_gil;
|
|
{dispatch_callee}({dispatch_args});
|
|
}};
|
|
dispatch_{name}({lambda_args}){set_requires_grad};
|
|
Py_RETURN_NONE;
|
|
"""
|
|
else:
|
|
typename = namedtuple_typenames.get(gen_namedtuple_typename_key(f))
|
|
namedtuple_typeref = f"{typename}, " if typename is not None else ""
|
|
return f"""\
|
|
{schema_comment}
|
|
{inits}
|
|
auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{
|
|
pybind11::gil_scoped_release no_gil;
|
|
return {dispatch_callee}({dispatch_args});
|
|
}};
|
|
return wrap({namedtuple_typeref}dispatch_{name}({lambda_args}){set_requires_grad});
|
|
"""
|
|
|
|
return go(f)
|