mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 12:15:03 +08:00 
			
		
		
		
	Tracking issue: #55070 Pull Request resolved: https://github.com/pytorch/pytorch/pull/69607 Approved by: https://github.com/bdhirsh
		
			
				
	
	
		
			1250 lines
		
	
	
		
			40 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1250 lines
		
	
	
		
			40 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Generates Python bindings for ATen functions
 | |
| #
 | |
| # The bindings are generated as methods on python_variable or functions on the
 | |
| # torch._C._nn. torch._C._fft, torch._C._linalg, torch._C._sparse or torch._C._special objects.
 | |
| #
 | |
| 
 | |
| # Code tries to stick to the following rules:
 | |
| #
 | |
| # - templates should be colocated with the functions that use them.
 | |
| #   no templates are currently shared between functions, but if that
 | |
| #   happens, maybe put the template with the first one
 | |
| #
 | |
| # - don't use environment dictionaries when calling template.substitute().
 | |
| #   pass named arguments directly for everything, otherwise it's much too
 | |
| #   hard to track what's actually being used and by who
 | |
| #
 | |
| # - colocate any new hacks/adjustments with existing ones of the same kind.
 | |
| #   ideally in a data structure rather than code if possible. See e.g.
 | |
| #   SCHEMA_DEFAULT_CONVERSION_HACKS, etc.
 | |
| #
 | |
| # - similarly, conversions from one format to another should ideally happen
 | |
| #   all at once in a single place.
 | |
| #
 | |
| # - no nontrivial nested functions. couple-liners are ok but please no more.
 | |
| #   especially avoid functions that read/write outer variables defined far away.
 | |
| #
 | |
| # - raise RuntimeError instead of asserting, and put as much
 | |
| #   information as is available into the message. I.e. no need to
 | |
| #   plumb in new params whose only purpose is to fill out an error
 | |
| #   message, but use what's there
 | |
| #
 | |
| 
 | |
| from collections import defaultdict
 | |
| import itertools
 | |
| import re
 | |
| import yaml
 | |
| 
 | |
| from .gen_trace_type import should_trace
 | |
| 
 | |
| from torchgen.code_template import CodeTemplate
 | |
| from torchgen.api import cpp
 | |
| from torchgen.api.types import CppSignatureGroup
 | |
| from torchgen.api.python import (
 | |
|     PythonArgument,
 | |
|     PythonSignature,
 | |
|     PythonSignatureDeprecated,
 | |
|     PythonSignatureGroup,
 | |
|     PythonSignatureNativeFunctionPair,
 | |
|     arg_parser_output_exprs,
 | |
|     argument_type_str,
 | |
|     cpp_dispatch_exprs,
 | |
|     cpp_dispatch_target,
 | |
|     dispatch_lambda_args,
 | |
|     dispatch_lambda_exprs,
 | |
|     dispatch_lambda_return_str,
 | |
|     has_tensor_options,
 | |
|     namedtuple_fieldnames,
 | |
|     signature,
 | |
| )
 | |
| from torchgen.gen import cpp_string, parse_native_yaml
 | |
| from torchgen.context import with_native_function
 | |
| from torchgen.model import (
 | |
|     Argument,
 | |
|     BaseOperatorName,
 | |
|     NativeFunction,
 | |
|     Type,
 | |
|     Variant,
 | |
| )
 | |
| from torchgen.utils import split_name_params, YamlLoader, FileManager
 | |
| 
 | |
| from typing import Dict, Optional, List, Tuple, Set, Sequence, Callable
 | |
| 
 | |
| #
 | |
| # declarations blocklist
 | |
| # We skip codegen for these functions, for various reasons.
 | |
| # Future PRs will categorize this list and eliminate or hoist
 | |
| # them out of eager-only codegen.
 | |
| # See https://github.com/pytorch/pytorch/issues/30788
 | |
| #
 | |
| 
 | |
| # These functions require manual Python bindings or are not exposed to Python
 | |
| _SKIP_PYTHON_BINDINGS = [
 | |
|     "alias",
 | |
|     "contiguous",
 | |
|     "is_cuda",
 | |
|     "is_sparse",
 | |
|     "is_sparse_csr",
 | |
|     "size",
 | |
|     "stride",
 | |
|     ".*_backward",
 | |
|     ".*_backward_(out|input|weight|bias)",
 | |
|     ".*_forward",
 | |
|     ".*_forward_out",
 | |
|     ".*_jvp",
 | |
|     "_unsafe_view",
 | |
|     "tensor",
 | |
|     "_?sparse_(coo|compressed|csr|csc|bsr|bsc)_tensor.*",
 | |
|     "_arange.*",
 | |
|     "_range.*",
 | |
|     "linspace.*",
 | |
|     "logspace.*",
 | |
|     "_sparse_add_out",
 | |
|     "_sparse_div.*",
 | |
|     "_sparse_mul.*",
 | |
|     "_sparse_sub.*",
 | |
|     "_sparse_dense_add_out",
 | |
|     "index",
 | |
|     "index_out",
 | |
|     "unique_dim_consecutive",
 | |
|     "_cumsum.*",
 | |
|     "_cumprod.*",
 | |
|     "_sum.*",
 | |
|     "_prod.*",
 | |
|     "_th_.*",
 | |
|     "_thnn_.*",
 | |
|     "arange.*",
 | |
|     "range.*",
 | |
|     "_solve.*",
 | |
|     "_inverse.*",
 | |
|     "full(_out)?",
 | |
|     "_cholesky.*",
 | |
|     "_triangular_solve.*",
 | |
|     "_qr.*",
 | |
|     "_symeig.*",
 | |
|     "_svd.*",
 | |
|     "slice",
 | |
|     "randint(_out)?",
 | |
|     "item",
 | |
|     "_local_scalar_dense",
 | |
|     "to",
 | |
|     "_to_copy",
 | |
|     "copy_sparse_to_sparse_",
 | |
|     "copy_",
 | |
|     "numpy_T",
 | |
|     "matrix_H",
 | |
|     "mT",
 | |
|     "mH",  # these need to be an attributes in Python, not functions
 | |
|     "nonzero(_(out|numpy))?",
 | |
|     "set_data",
 | |
|     ".*_overrideable",  # overrideable functions for backend extension
 | |
|     "data",
 | |
|     "is_leaf",
 | |
|     "output_nr",
 | |
|     "_version",
 | |
|     "requires_grad_",
 | |
|     "retains_grad",
 | |
|     "set_",
 | |
|     "_fw_primal",
 | |
|     "fake_quantize_per_tensor_affine_cachemask",
 | |
|     "fake_quantize_per_channel_affine_cachemask",
 | |
|     "_new_zeros_with_same_feature_meta",
 | |
|     "_has_same_storage_numel",  # used for forward AD internals
 | |
|     "_reshape_alias",
 | |
|     "replace_",  # only used by the functionalization pass, doesn't need to be exposed to python
 | |
|     "copy",  # only used by the functionalization pass
 | |
|     "fill.Tensor",  # only used by the functionalization pass
 | |
|     "fill.Scalar",  # only used by the functionalization pass
 | |
|     "lift",
 | |
| ]
 | |
| 
 | |
| SKIP_PYTHON_BINDINGS = list(
 | |
|     map(lambda pattern: re.compile(rf"^{pattern}$"), _SKIP_PYTHON_BINDINGS)
 | |
| )
 | |
| 
 | |
| # These function signatures are not exposed to Python. Note that this signature
 | |
| # list does not support regex.
 | |
| SKIP_PYTHON_BINDINGS_SIGNATURES = [
 | |
|     "add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor",
 | |
|     "add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)",
 | |
|     "sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor",
 | |
|     "sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)",
 | |
|     "mul.Scalar(Tensor self, Scalar other) -> Tensor",
 | |
|     "mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)",
 | |
|     "div.Scalar(Tensor self, Scalar other) -> Tensor",
 | |
|     "div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)",
 | |
| ]
 | |
| 
 | |
| 
 | |
| @with_native_function
 | |
| def should_generate_py_binding(f: NativeFunction) -> bool:
 | |
|     # So far, all NativeFunctions that are entirely code-generated do not get python bindings.
 | |
|     if "generated" in f.tags:
 | |
|         return False
 | |
|     name = cpp.name(f.func)
 | |
|     for skip_regex in SKIP_PYTHON_BINDINGS:
 | |
|         if skip_regex.match(name):
 | |
|             return False
 | |
| 
 | |
|     signature = str(f.func)
 | |
|     for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES:
 | |
|         if pattern == signature:
 | |
|             return False
 | |
| 
 | |
|     return True
 | |
| 
 | |
| 
 | |
| def get_pycname(name: BaseOperatorName) -> str:
 | |
|     return f"THPVariable_{name}"
 | |
| 
 | |
| 
 | |
| def is_noarg(overloads: Sequence[PythonSignatureNativeFunctionPair]) -> bool:
 | |
|     return len(overloads) == 1 and overloads[0].signature.arguments_count() == 0
 | |
| 
 | |
| 
 | |
| def is_py_variable_method(f: NativeFunction) -> bool:
 | |
|     return f.python_module is None and Variant.method in f.variants
 | |
| 
 | |
| 
 | |
| def is_py_torch_function(f: NativeFunction) -> bool:
 | |
|     return f.python_module is None and Variant.function in f.variants
 | |
| 
 | |
| 
 | |
| def is_py_nn_function(f: NativeFunction) -> bool:
 | |
|     return f.python_module == "nn"
 | |
| 
 | |
| 
 | |
| def is_py_fft_function(f: NativeFunction) -> bool:
 | |
|     return f.python_module == "fft"
 | |
| 
 | |
| 
 | |
| def is_py_linalg_function(f: NativeFunction) -> bool:
 | |
|     return f.python_module == "linalg"
 | |
| 
 | |
| 
 | |
| def is_py_sparse_function(f: NativeFunction) -> bool:
 | |
|     return f.python_module == "sparse"
 | |
| 
 | |
| 
 | |
| def is_py_special_function(f: NativeFunction) -> bool:
 | |
|     return f.python_module == "special"
 | |
| 
 | |
| 
 | |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 | |
| #
 | |
| #                            Main Function
 | |
| #
 | |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 | |
| 
 | |
| 
 | |
| def gen(
 | |
|     out: str,
 | |
|     native_yaml_path: str,
 | |
|     tags_yaml_path: str,
 | |
|     deprecated_yaml_path: str,
 | |
|     template_path: str,
 | |
| ) -> None:
 | |
|     fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
 | |
|     native_functions = parse_native_yaml(
 | |
|         native_yaml_path, tags_yaml_path
 | |
|     ).native_functions
 | |
|     native_functions = list(filter(should_generate_py_binding, native_functions))
 | |
| 
 | |
|     methods = load_signatures(native_functions, deprecated_yaml_path, method=True)
 | |
|     create_python_bindings(
 | |
|         fm,
 | |
|         methods,
 | |
|         is_py_variable_method,
 | |
|         None,
 | |
|         "python_variable_methods.cpp",
 | |
|         method=True,
 | |
|     )
 | |
| 
 | |
|     # NOTE: num_shards here must be synced with gatherTorchFunctions in
 | |
|     #       torch/csrc/autograd/python_torch_functions_manual.cpp
 | |
|     functions = load_signatures(native_functions, deprecated_yaml_path, method=False)
 | |
|     create_python_bindings_sharded(
 | |
|         fm,
 | |
|         functions,
 | |
|         is_py_torch_function,
 | |
|         "torch",
 | |
|         "python_torch_functions.cpp",
 | |
|         method=False,
 | |
|         num_shards=3,
 | |
|     )
 | |
| 
 | |
|     create_python_bindings(
 | |
|         fm,
 | |
|         functions,
 | |
|         is_py_nn_function,
 | |
|         "torch.nn",
 | |
|         "python_nn_functions.cpp",
 | |
|         method=False,
 | |
|     )
 | |
| 
 | |
|     create_python_bindings(
 | |
|         fm,
 | |
|         functions,
 | |
|         is_py_fft_function,
 | |
|         "torch.fft",
 | |
|         "python_fft_functions.cpp",
 | |
|         method=False,
 | |
|     )
 | |
| 
 | |
|     create_python_bindings(
 | |
|         fm,
 | |
|         functions,
 | |
|         is_py_linalg_function,
 | |
|         "torch.linalg",
 | |
|         "python_linalg_functions.cpp",
 | |
|         method=False,
 | |
|     )
 | |
| 
 | |
|     create_python_bindings(
 | |
|         fm,
 | |
|         functions,
 | |
|         is_py_sparse_function,
 | |
|         "torch.sparse",
 | |
|         "python_sparse_functions.cpp",
 | |
|         method=False,
 | |
|     )
 | |
| 
 | |
|     create_python_bindings(
 | |
|         fm,
 | |
|         functions,
 | |
|         is_py_special_function,
 | |
|         "torch.special",
 | |
|         "python_special_functions.cpp",
 | |
|         method=False,
 | |
|     )
 | |
| 
 | |
|     # Currently, we only use `functions` to generate `return_types` bindings.
 | |
|     # All methods which return namedtuple have function variant at this point.
 | |
|     # If any method only operator with namedtuple is added in the future,
 | |
|     # we will have to address that.
 | |
|     create_python_return_type_bindings(
 | |
|         fm, functions, lambda fn: True, "python_return_types.cpp"
 | |
|     )
 | |
| 
 | |
| 
 | |
| def group_filter_overloads(
 | |
|     pairs: Sequence[PythonSignatureNativeFunctionPair],
 | |
|     pred: Callable[[NativeFunction], bool],
 | |
| ) -> Dict[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]:
 | |
|     grouped: Dict[
 | |
|         BaseOperatorName, List[PythonSignatureNativeFunctionPair]
 | |
|     ] = defaultdict(list)
 | |
|     for pair in pairs:
 | |
|         if pred(pair.function):
 | |
|             grouped[pair.function.func.name.name].append(pair)
 | |
|     return grouped
 | |
| 
 | |
| 
 | |
| def create_python_bindings(
 | |
|     fm: FileManager,
 | |
|     pairs: Sequence[PythonSignatureNativeFunctionPair],
 | |
|     pred: Callable[[NativeFunction], bool],
 | |
|     module: Optional[str],
 | |
|     filename: str,
 | |
|     *,
 | |
|     method: bool,
 | |
| ) -> None:
 | |
|     """Generates Python bindings to ATen functions"""
 | |
|     py_methods: List[str] = []
 | |
|     ops_headers: List[str] = []
 | |
|     py_method_defs: List[str] = []
 | |
|     py_forwards: List[str] = []
 | |
| 
 | |
|     grouped = group_filter_overloads(pairs, pred)
 | |
| 
 | |
|     for name in sorted(grouped.keys(), key=lambda x: str(x)):
 | |
|         overloads = grouped[name]
 | |
|         py_methods.append(method_impl(name, module, overloads, method=method))
 | |
|         py_method_defs.append(method_def(name, module, overloads, method=method))
 | |
|         py_forwards.extend(forward_decls(name, overloads, method=method))
 | |
|         ops_headers.append(f"#include <ATen/ops/{name.base}.h>")
 | |
| 
 | |
|     fm.write_with_template(
 | |
|         filename,
 | |
|         filename,
 | |
|         lambda: {
 | |
|             "generated_comment": "@" + f"generated from {fm.template_dir}/{filename}",
 | |
|             "ops_headers": ops_headers,
 | |
|             "py_forwards": py_forwards,
 | |
|             "py_methods": py_methods,
 | |
|             "py_method_defs": py_method_defs,
 | |
|         },
 | |
|     )
 | |
| 
 | |
| 
 | |
| def create_python_return_type_bindings(
 | |
|     fm: FileManager,
 | |
|     pairs: Sequence[PythonSignatureNativeFunctionPair],
 | |
|     pred: Callable[[NativeFunction], bool],
 | |
|     filename: str,
 | |
| ) -> None:
 | |
|     """
 | |
|     Generate function to initialize and return named tuple for native functions
 | |
|     which returns named tuple and relevant entry for the map in `python_return_types.cpp`.
 | |
|     """
 | |
|     py_return_types_definition: List[str] = []
 | |
|     py_return_types_map: List[str] = []
 | |
| 
 | |
|     grouped = group_filter_overloads(pairs, pred)
 | |
| 
 | |
|     for name in sorted(grouped.keys(), key=lambda x: str(x)):
 | |
|         overloads = grouped[name]
 | |
|         definitions, map_entries = generate_return_type_definition_and_map_entry(
 | |
|             overloads
 | |
|         )
 | |
|         py_return_types_definition.append(
 | |
|             "" if not definitions else "\n".join(definitions)
 | |
|         )
 | |
|         py_return_types_map.append("" if not map_entries else "\n".join(map_entries))
 | |
| 
 | |
|     fm.write_with_template(
 | |
|         filename,
 | |
|         filename,
 | |
|         lambda: {
 | |
|             "generated_comment": "@" + f"generated from {fm.template_dir}/{filename}",
 | |
|             "py_return_types": py_return_types_definition,
 | |
|             "py_return_types_map": py_return_types_map,
 | |
|         },
 | |
|     )
 | |
| 
 | |
| 
 | |
| def create_python_bindings_sharded(
 | |
|     fm: FileManager,
 | |
|     pairs: Sequence[PythonSignatureNativeFunctionPair],
 | |
|     pred: Callable[[NativeFunction], bool],
 | |
|     module: Optional[str],
 | |
|     filename: str,
 | |
|     *,
 | |
|     method: bool,
 | |
|     num_shards: int,
 | |
| ) -> None:
 | |
|     """Generates Python bindings to ATen functions"""
 | |
|     grouped = group_filter_overloads(pairs, pred)
 | |
| 
 | |
|     def key_func(
 | |
|         kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
 | |
|     ) -> str:
 | |
|         return kv[0].base
 | |
| 
 | |
|     def env_func(
 | |
|         kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
 | |
|     ) -> Dict[str, List[str]]:
 | |
|         name, fn_pairs = kv
 | |
|         return {
 | |
|             "ops_headers": [f"#include <ATen/ops/{name.base}.h>"],
 | |
|             "py_forwards": list(forward_decls(name, fn_pairs, method=method)),
 | |
|             "py_methods": [method_impl(name, module, fn_pairs, method=method)],
 | |
|             "py_method_defs": [method_def(name, module, fn_pairs, method=method)],
 | |
|         }
 | |
| 
 | |
|     fm.write_sharded(
 | |
|         filename,
 | |
|         grouped.items(),
 | |
|         base_env={
 | |
|             "generated_comment": "@" + f"generated from {fm.template_dir}/{filename}",
 | |
|         },
 | |
|         key_fn=key_func,
 | |
|         env_callable=env_func,
 | |
|         num_shards=num_shards,
 | |
|         sharded_keys={"ops_headers", "py_forwards", "py_methods", "py_method_defs"},
 | |
|     )
 | |
| 
 | |
| 
 | |
| def load_signatures(
 | |
|     native_functions: List[NativeFunction],
 | |
|     deprecated_yaml_path: str,
 | |
|     *,
 | |
|     method: bool,
 | |
|     skip_deprecated: bool = False,
 | |
|     pyi: bool = False,
 | |
| ) -> Sequence[PythonSignatureNativeFunctionPair]:
 | |
|     @with_native_function
 | |
|     def gen_signature_pairs(f: NativeFunction) -> PythonSignatureNativeFunctionPair:
 | |
|         return PythonSignatureNativeFunctionPair(
 | |
|             signature=signature(f, method=method, pyi=pyi),
 | |
|             function=f,
 | |
|         )
 | |
| 
 | |
|     pairs = list(map(gen_signature_pairs, native_functions))
 | |
|     deprecated = load_deprecated_signatures(
 | |
|         pairs, deprecated_yaml_path, method=method, pyi=pyi
 | |
|     )
 | |
|     return pairs if skip_deprecated else pairs + deprecated
 | |
| 
 | |
| 
 | |
| def load_deprecated_signatures(
 | |
|     pairs: Sequence[PythonSignatureNativeFunctionPair],
 | |
|     deprecated_yaml_path: str,
 | |
|     *,
 | |
|     method: bool,
 | |
|     pyi: bool,
 | |
| ) -> List[PythonSignatureNativeFunctionPair]:
 | |
|     # The deprecated.yaml doesn't have complete type information, we need
 | |
|     # find and leverage the original ATen signature (to which it delegates
 | |
|     # the call) to generate the full python signature.
 | |
|     # We join the deprecated and the original signatures using type-only form.
 | |
| 
 | |
|     # native function -> type-only signature
 | |
|     @with_native_function
 | |
|     def signature_original(f: NativeFunction) -> str:
 | |
|         # remove inplace suffix but keep outplace suffix
 | |
|         opname = str(f.func.name.name.base)
 | |
|         if f.func.is_out_fn():
 | |
|             opname += "_out"
 | |
|         if f.func.name.name.inplace and pyi:
 | |
|             opname += "_"
 | |
|         args = CppSignatureGroup.from_native_function(
 | |
|             f, method=False
 | |
|         ).signature.arguments()
 | |
|         # Simply ignore TensorOptionsArguments as it does not exist in deprecated.yaml.
 | |
|         types = ", ".join(
 | |
|             argument_type_str(a.argument.type)
 | |
|             for a in args
 | |
|             if isinstance(a.argument, Argument)
 | |
|         )
 | |
|         return f"{opname}({types})"
 | |
| 
 | |
|     # deprecated -> type-only native signature (according to the call order)
 | |
|     def signature_deprecated(
 | |
|         opname: str, params: List[str], call_args: List[str]
 | |
|     ) -> str:
 | |
|         # create a mapping of parameter name to parameter type
 | |
|         types: Dict[str, str] = {}
 | |
|         for param in params:
 | |
|             if param == "*":
 | |
|                 continue
 | |
|             type, name = param.split(" ")
 | |
|             types[name] = type
 | |
|         # if the name in the call is not in the parameter list, assume it's
 | |
|         # a literal Scalar
 | |
|         rearranged_types = ", ".join(types.get(arg, "Scalar") for arg in call_args)
 | |
|         return f"{opname}({rearranged_types})"
 | |
| 
 | |
|     # group the original ATen signatures by type-only signature
 | |
|     grouped: Dict[str, List[PythonSignatureNativeFunctionPair]] = defaultdict(list)
 | |
|     for pair in pairs:
 | |
|         grouped[signature_original(pair.function)].append(pair)
 | |
| 
 | |
|     # find matching original signatures for each deprecated signature
 | |
|     results: List[PythonSignatureNativeFunctionPair] = []
 | |
| 
 | |
|     with open(deprecated_yaml_path, "r") as f:
 | |
|         deprecated_defs = yaml.load(f, Loader=YamlLoader)
 | |
| 
 | |
|     for deprecated in deprecated_defs:
 | |
|         _, params = split_name_params(deprecated["name"])
 | |
|         aten_name, call_args = split_name_params(deprecated["aten"])
 | |
| 
 | |
|         for pair in grouped[signature_deprecated(aten_name, params, call_args)]:
 | |
|             # It uses the types from the original ATen declaration, but the
 | |
|             # ordering and parameter names from the deprecated overload. Any
 | |
|             # default parameter values from the original ATen declaration are
 | |
|             # ignored.
 | |
|             # Deprecated signature might reorder input_args and input_kwargs,
 | |
|             # but never changes output_args nor TensorOptions (if any?),
 | |
|             # so here we only look into these two types of args.
 | |
|             python_sig = pair.signature
 | |
|             src_args: Dict[str, PythonArgument] = {
 | |
|                 a.name: PythonArgument(
 | |
|                     name=a.name,
 | |
|                     type=a.type,
 | |
|                     default=None,
 | |
|                     default_init=None,
 | |
|                 )
 | |
|                 for a in itertools.chain(python_sig.input_args, python_sig.input_kwargs)
 | |
|             }
 | |
| 
 | |
|             args: List[str] = []
 | |
|             input_args: List[PythonArgument] = []
 | |
|             input_kwargs: List[PythonArgument] = []
 | |
| 
 | |
|             kwarg_only = False
 | |
|             for param in params:
 | |
|                 if param == "*":
 | |
|                     kwarg_only = True
 | |
|                     continue
 | |
|                 _, param_name = param.split(" ")
 | |
|                 args.append(param_name)
 | |
| 
 | |
|                 if param_name not in src_args:
 | |
|                     # output argument
 | |
|                     continue
 | |
| 
 | |
|                 if not kwarg_only:
 | |
|                     if not method or param_name != "self":
 | |
|                         input_args.append(src_args[param_name])
 | |
|                 else:
 | |
|                     input_kwargs.append(src_args[param_name])
 | |
| 
 | |
|             results.append(
 | |
|                 PythonSignatureNativeFunctionPair(
 | |
|                     signature=PythonSignatureDeprecated(
 | |
|                         name=python_sig.name,
 | |
|                         input_args=tuple(input_args),
 | |
|                         input_kwargs=tuple(input_kwargs),
 | |
|                         output_args=python_sig.output_args,
 | |
|                         tensor_options_args=python_sig.tensor_options_args,
 | |
|                         method=python_sig.method,
 | |
|                         deprecated_args_names=tuple(args),
 | |
|                         deprecated_args_exprs=tuple(call_args),
 | |
|                         returns=python_sig.returns,
 | |
|                     ),
 | |
|                     function=pair.function,
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|     return results
 | |
| 
 | |
| 
 | |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 | |
| #
 | |
| #                         Named Tuple Codegen
 | |
| #
 | |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 | |
| 
 | |
| 
 | |
| @with_native_function
 | |
| def gen_namedtuple_typename_key(f: NativeFunction) -> str:
 | |
|     name = cpp.name(f.func)
 | |
|     fieldnames = namedtuple_fieldnames(f.func.returns)
 | |
|     return "_".join([name] + fieldnames)
 | |
| 
 | |
| 
 | |
| def emit_namedtuple_call(
 | |
|     overloads: Sequence[PythonSignatureNativeFunctionPair],
 | |
| ) -> Tuple[List[str], Dict[str, str]]:
 | |
|     """
 | |
|     Generate block of named tuple type def inits, and add typeref snippets
 | |
|     to declarations that use them
 | |
|     """
 | |
|     typenames: Dict[
 | |
|         str, str
 | |
|     ] = {}  # map from unique name + field name lists to typedef name
 | |
|     typedefs: List[str] = []  # typedef declarations and init code
 | |
| 
 | |
|     for overload in overloads:
 | |
|         fieldnames = namedtuple_fieldnames(overload.function.func.returns)
 | |
|         if not fieldnames:
 | |
|             continue
 | |
| 
 | |
|         name = cpp.name(overload.function.func)  # use @with_native_function?
 | |
|         tn_key = gen_namedtuple_typename_key(overload.function)
 | |
|         typename = typenames.get(tn_key)
 | |
|         if typename is None:
 | |
|             typename = f'NamedTuple{"" if not typedefs else len(typedefs)}'
 | |
|             typenames[tn_key] = typename
 | |
|             typedefs.append(
 | |
|                 f"""\
 | |
| static PyTypeObject* {typename} = get_namedtuple("{name}");"""
 | |
|             )
 | |
| 
 | |
|     return typedefs, typenames
 | |
| 
 | |
| 
 | |
| def generate_return_type_definition_and_map_entry(
 | |
|     overloads: Sequence[PythonSignatureNativeFunctionPair],
 | |
| ) -> Tuple[List[str], List[str]]:
 | |
|     """
 | |
|     Generate block of function in `python_return_types.cpp` to initialize
 | |
|     and return named tuple for a native function which returns named tuple
 | |
|     and relevant entry for the map in same file.
 | |
|     """
 | |
|     typenames: Dict[
 | |
|         str, str
 | |
|     ] = {}  # map from unique name + field name lists to typedef name
 | |
|     definitions: List[str] = []  # function defintion to register the typedef
 | |
|     map_entries: List[
 | |
|         str
 | |
|     ] = []  # C++ map entry of <function_name, function creates it namedtuple>
 | |
| 
 | |
|     for overload in overloads:
 | |
|         fieldnames = namedtuple_fieldnames(overload.function.func.returns)
 | |
|         if not fieldnames:
 | |
|             continue
 | |
| 
 | |
|         fields = ", ".join(f'{{"{fn}", ""}}' for fn in fieldnames)
 | |
| 
 | |
|         name = cpp.name(overload.function.func)  # use @with_native_function?
 | |
|         tn_key = gen_namedtuple_typename_key(overload.function)
 | |
|         typename = typenames.get(tn_key)
 | |
| 
 | |
|         if typename is None:
 | |
|             typename = f'{name}NamedTuple{"" if not definitions else len(definitions)}'
 | |
|             typenames[tn_key] = typename
 | |
|             definitions.append(
 | |
|                 f"""\
 | |
| PyTypeObject* get_{name}_namedtuple() {{
 | |
|     static PyStructSequence_Field NamedTuple_fields[] = {{ {fields},  {{nullptr}} }};
 | |
|     static PyTypeObject {typename};
 | |
|     static bool is_initialized = false;
 | |
|     static PyStructSequence_Desc desc = {{ "torch.return_types.{name}", nullptr, NamedTuple_fields, {len(fieldnames)} }};
 | |
|     if (!is_initialized) {{
 | |
|         PyStructSequence_InitType(&{typename}, &desc);
 | |
|         {typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr;
 | |
|         is_initialized = true;
 | |
|     }}
 | |
|     return &{typename};
 | |
| }}
 | |
| """
 | |
|             )
 | |
|             map_entries.append(f'{{"{name}", get_{name}_namedtuple()}}, ')
 | |
| 
 | |
|     return definitions, map_entries
 | |
| 
 | |
| 
 | |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 | |
| #
 | |
| #                         Method Impl Codegen
 | |
| #
 | |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 | |
| 
 | |
| # python binding for all overloads of a particular function/method
 | |
| PY_VARIABLE_METHOD_VARARGS = CodeTemplate(
 | |
|     r"""\
 | |
| // ${name}
 | |
| static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
 | |
| {
 | |
|   ${method_header}
 | |
|   static PythonArgParser parser({
 | |
|     ${signatures}
 | |
|   }, /*traceable=*/${traceable});
 | |
| 
 | |
|   ParsedArgs<${max_args}> parsed_args;
 | |
|   auto _r = parser.parse(${self_}, args, kwargs, parsed_args);
 | |
|   ${check_has_torch_function}
 | |
|   switch (_r.idx) {
 | |
|     ${dispatch}
 | |
|   }
 | |
|   ${method_footer}
 | |
| }
 | |
| 
 | |
| """
 | |
| )
 | |
| 
 | |
| # handler for a single parsed signature - may be a single overload or
 | |
| # a pair of overloads that whose signatures only differ in output params
 | |
| # (plugged into PY_VARIABLE_METHOD_VARARGS as an item in ${dispatch})
 | |
| PY_VARIABLE_CASE = CodeTemplate(
 | |
|     """\
 | |
| case ${overload_index}: {
 | |
|   ${body}
 | |
| }
 | |
| """
 | |
| )
 | |
| 
 | |
| # python binding for single-overload function/method
 | |
| PY_VARIABLE_METHOD_VARARGS_SINGLETON = CodeTemplate(
 | |
|     """\
 | |
| // ${name}
 | |
| static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
 | |
| {
 | |
|   ${method_header}
 | |
|   static PythonArgParser parser({
 | |
|     ${signatures}
 | |
|   }, /*traceable=*/${traceable});
 | |
| 
 | |
|   ParsedArgs<${max_args}> parsed_args;
 | |
|   auto _r = parser.parse(${self_}, args, kwargs, parsed_args);
 | |
|   ${check_has_torch_function}
 | |
|   ${dispatch}
 | |
|   ${method_footer}
 | |
| }
 | |
| 
 | |
| """
 | |
| )
 | |
| 
 | |
| # python binding for a method with no args, shortcuts parsing
 | |
| PY_VARIABLE_METHOD_NOARGS = CodeTemplate(
 | |
|     """\
 | |
| // ${name}
 | |
| static PyObject * ${pycname}(PyObject* self_, PyObject* args)
 | |
| {
 | |
|   ${method_header}
 | |
|   ${check_has_torch_function}
 | |
|   ${dispatch}
 | |
|   ${method_footer}
 | |
| }
 | |
| 
 | |
| """
 | |
| )
 | |
| 
 | |
| 
 | |
| def method_impl(
 | |
|     name: BaseOperatorName,
 | |
|     module: Optional[str],
 | |
|     overloads: Sequence[PythonSignatureNativeFunctionPair],
 | |
|     *,
 | |
|     method: bool,
 | |
| ) -> str:
 | |
|     """
 | |
|     Generate a python binding for all overloads of an op.
 | |
|     """
 | |
|     pycname = get_pycname(name)
 | |
|     noarg = is_noarg(overloads)
 | |
|     namedtuple_inits, namedtuple_typenames = emit_namedtuple_call(overloads)
 | |
| 
 | |
|     method_header = ["HANDLE_TH_ERRORS"]
 | |
|     method_header += namedtuple_inits
 | |
|     method_header += (
 | |
|         ["const Tensor& self = THPVariable_Unpack(self_);"] if method else []
 | |
|     )
 | |
| 
 | |
|     method_footer = ([] if noarg else ["Py_RETURN_NONE;"]) + ["END_HANDLE_TH_ERRORS"]
 | |
| 
 | |
|     traceable = "true" if all(should_trace(o.function) for o in overloads) else "false"
 | |
| 
 | |
|     grouped_overloads: Sequence[PythonSignatureGroup] = group_overloads(overloads)
 | |
|     is_singleton = len(grouped_overloads) == 1
 | |
|     signatures: List[str] = []
 | |
|     dispatch: List[str] = []
 | |
|     for overload_index, overload in enumerate(grouped_overloads):
 | |
|         signature = overload.signature.signature_str()
 | |
|         signatures.append(f"{cpp_string(str(signature))},")
 | |
|         dispatch_body = emit_dispatch_case(overload, namedtuple_typenames)
 | |
|         dispatch.append(
 | |
|             PY_VARIABLE_CASE.substitute(
 | |
|                 overload_index=overload_index, body=dispatch_body
 | |
|             )
 | |
|             if not is_singleton
 | |
|             else dispatch_body
 | |
|         )
 | |
| 
 | |
|     if noarg:
 | |
|         template = PY_VARIABLE_METHOD_NOARGS
 | |
|     elif is_singleton:
 | |
|         template = PY_VARIABLE_METHOD_VARARGS_SINGLETON
 | |
|     else:
 | |
|         template = PY_VARIABLE_METHOD_VARARGS
 | |
| 
 | |
|     return template.substitute(
 | |
|         name=name,
 | |
|         pycname=pycname,
 | |
|         method_header=method_header,
 | |
|         max_args=max(map(lambda o: o.signature.arguments_count(), overloads)),
 | |
|         signatures=signatures,
 | |
|         traceable=traceable,
 | |
|         check_has_torch_function=gen_has_torch_function_check(
 | |
|             name=name,
 | |
|             module=module,
 | |
|             noarg=noarg,
 | |
|             method=method,
 | |
|         ),
 | |
|         dispatch=dispatch,
 | |
|         method_footer=method_footer,
 | |
|         self_="self_" if method else "nullptr",
 | |
|     )
 | |
| 
 | |
| 
 | |
| def gen_has_torch_function_check(
 | |
|     name: BaseOperatorName, module: Optional[str], *, noarg: bool, method: bool
 | |
| ) -> str:
 | |
|     if noarg:
 | |
|         if method:
 | |
|             return f"""\
 | |
| if(check_has_torch_function(self_)) {{
 | |
|   return handle_torch_function(self_, "{name}");
 | |
| }}
 | |
| """
 | |
|         else:
 | |
|             return ""
 | |
| 
 | |
|     self_ = "self_" if method else "nullptr"
 | |
|     namespace = (
 | |
|         {
 | |
|             "torch": "THPVariableFunctionsModule",
 | |
|             "torch.nn": "THPNNVariableFunctionsModule",
 | |
|             "torch.fft": "THPFFTVariableFunctionsModule",
 | |
|             "torch.linalg": "THPLinalgVariableFunctionsModule",
 | |
|             "torch.sparse": "THPSparseVariableFunctionsModule",
 | |
|             "torch.special": "THPSpecialVariableFunctionsModule",
 | |
|         }[module]
 | |
|         if module
 | |
|         else "THPVariableClass"
 | |
|     )
 | |
| 
 | |
|     return f"""\
 | |
| if(_r.has_torch_function()) {{
 | |
|   return handle_torch_function(_r, {self_}, args, kwargs, {namespace}, "{module or "torch.Tensor"}");
 | |
| }}
 | |
| """
 | |
| 
 | |
| 
 | |
| # handler for output/no-output overload pair
 | |
| PY_VARIABLE_OUT = CodeTemplate(
 | |
|     """\
 | |
| if (_r.isNone(${out_idx})) {
 | |
|   ${call_dispatch}
 | |
| } else {
 | |
|   ${call_dispatch_out}
 | |
| }
 | |
| """
 | |
| )
 | |
| 
 | |
| 
 | |
| def emit_dispatch_case(
 | |
|     overload: PythonSignatureGroup,
 | |
|     namedtuple_typenames: Dict[str, str],
 | |
| ) -> str:
 | |
|     """
 | |
|     Emit dispatch code for a single parsed signature. This corresponds to either
 | |
|     a single native function, or a pair that differ only in output params. In the
 | |
|     latter case, a single python signature is used for both and dispatching
 | |
|     switches on the presence/absence of passed output args.
 | |
|     """
 | |
|     if overload.outplace is not None:
 | |
|         # dispatch output and no-output variants, branch on _r.isNone(<out_idx>)
 | |
|         return PY_VARIABLE_OUT.substitute(
 | |
|             out_idx=overload.signature.output_idx(),
 | |
|             call_dispatch=emit_single_dispatch(
 | |
|                 overload.signature, overload.base, namedtuple_typenames
 | |
|             ),
 | |
|             call_dispatch_out=emit_single_dispatch(
 | |
|                 overload.signature, overload.outplace, namedtuple_typenames
 | |
|             ),
 | |
|         )
 | |
|     else:
 | |
|         # no-output version only
 | |
|         return emit_single_dispatch(
 | |
|             overload.signature, overload.base, namedtuple_typenames
 | |
|         )
 | |
| 
 | |
| 
 | |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 | |
| #
 | |
| #                    Forward Declarations Codegen
 | |
| #
 | |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 | |
| 
 | |
| 
 | |
| def forward_decls(
 | |
|     name: BaseOperatorName,
 | |
|     overloads: Sequence[PythonSignatureNativeFunctionPair],
 | |
|     *,
 | |
|     method: bool,
 | |
| ) -> Tuple[str, ...]:
 | |
|     if method:
 | |
|         return ()
 | |
| 
 | |
|     pycname = get_pycname(name)
 | |
|     if is_noarg(overloads):
 | |
|         return (
 | |
|             f"""\
 | |
| static PyObject * {pycname}(PyObject* self_, PyObject* args);
 | |
| """,
 | |
|         )
 | |
|     else:
 | |
|         return (
 | |
|             f"""\
 | |
| static PyObject * {pycname}(PyObject* self_, PyObject* args, PyObject* kwargs);
 | |
| """,
 | |
|         )
 | |
| 
 | |
| 
 | |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 | |
| #
 | |
| #              Method Def (Binding Table Entry) Codegen
 | |
| #
 | |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 | |
| 
 | |
| 
 | |
| def method_def(
 | |
|     name: BaseOperatorName,
 | |
|     module: Optional[str],
 | |
|     overloads: Sequence[PythonSignatureNativeFunctionPair],
 | |
|     *,
 | |
|     method: bool,
 | |
| ) -> str:
 | |
|     """
 | |
|     Generate method def entry.
 | |
|     """
 | |
|     pycname = get_pycname(name)
 | |
| 
 | |
|     if is_noarg(overloads):
 | |
|         pyfunc_cast = ""
 | |
|         flags = "METH_NOARGS" if method else "METH_VARARGS | METH_KEYWORDS"
 | |
|     else:
 | |
|         pyfunc_cast = "castPyCFunctionWithKeywords"
 | |
|         flags = "METH_VARARGS | METH_KEYWORDS"
 | |
| 
 | |
|     if module == "torch":
 | |
|         flags += " | METH_STATIC"
 | |
| 
 | |
|     if name.dunder_method:
 | |
|         # PyMethodDef entry for binary op, throws not implemented error
 | |
|         return f"""\
 | |
| {{"{name}", {pyfunc_cast}(TypeError_to_NotImplemented_<{pycname}>), {flags}, NULL}},"""
 | |
|     else:
 | |
|         # PyMethodDef entry
 | |
|         return f"""\
 | |
| {{"{name}", {pyfunc_cast}({pycname}), {flags}, NULL}},"""
 | |
| 
 | |
| 
 | |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 | |
| #
 | |
| #                   Overload Sorting and Grouping
 | |
| #
 | |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 | |
| 
 | |
| 
 | |
| def group_overloads(
 | |
|     overloads: Sequence[PythonSignatureNativeFunctionPair],
 | |
| ) -> Sequence[PythonSignatureGroup]:
 | |
|     bases: Dict[str, PythonSignatureNativeFunctionPair] = {}
 | |
|     outplaces: Dict[str, PythonSignatureNativeFunctionPair] = {}
 | |
| 
 | |
|     # first group by signature ignoring out arguments
 | |
|     for overload in overloads:
 | |
|         sig = overload.signature.signature_str(skip_outputs=True)
 | |
|         if overload.function.func.is_out_fn():
 | |
|             if sig in outplaces:
 | |
|                 raise RuntimeError(
 | |
|                     f"Found duplicated function definition:\n- {overload.function.func}.\n"
 | |
|                     f"Existing definition:\n- {outplaces[sig].function.func}."
 | |
|                 )
 | |
|             outplaces[sig] = overload
 | |
|         else:
 | |
|             if sig in bases:
 | |
|                 raise RuntimeError(
 | |
|                     f"Found duplicated function definition:\n- {overload.function.func}.\n"
 | |
|                     f"Existing definition:\n- {bases[sig].function.func}."
 | |
|                 )
 | |
|             bases[sig] = overload
 | |
| 
 | |
|     for sig, out in outplaces.items():
 | |
|         if sig not in bases:
 | |
|             candidates: List[str] = []
 | |
|             for overload in overloads:
 | |
|                 if (
 | |
|                     str(overload.function.func.name.name)
 | |
|                     == str(out.function.func.name.name)
 | |
|                     and not overload.function.func.is_out_fn()
 | |
|                     and not overload.signature.deprecated
 | |
|                 ):
 | |
|                     candidates.append(
 | |
|                         overload.signature.signature_str(skip_outputs=True)
 | |
|                     )
 | |
|             out_sig = out.signature.signature_str()
 | |
|             raise RuntimeError(
 | |
|                 f"While identifying overloads, we found an out schema {out_sig} without a corresponding non-out variant. "
 | |
|                 f"We expected the non-out variant to have schema: \n- {sig}\nPlease check that you spelled the schema "
 | |
|                 "correctly in native_functions.yaml. We discovered the following candidate(s): \n"
 | |
|                 + "\n".join(f"- {candidate}" for candidate in candidates)
 | |
|             )
 | |
| 
 | |
|     grouped: List[PythonSignatureGroup] = []
 | |
|     for sig, base in bases.items():
 | |
|         outplace = outplaces.get(sig)
 | |
|         grouped.append(
 | |
|             PythonSignatureGroup(
 | |
|                 # prefer the signature with optional out=... arguments because it's the
 | |
|                 # superset that can be used to parse input for both base and outplace.
 | |
|                 signature=outplace.signature
 | |
|                 if outplace is not None
 | |
|                 else base.signature,
 | |
|                 base=base.function,
 | |
|                 outplace=outplace.function if outplace is not None else None,
 | |
|             )
 | |
|         )
 | |
| 
 | |
|     return sort_overloads(grouped)
 | |
| 
 | |
| 
 | |
| # This function declares a partial order on declarations, and sorts them according
 | |
| # to its linear extension. This is necessary, because there's some ambiguity in the
 | |
| # choice of overload, and we want a different order.
 | |
| #
 | |
| # See Note[Order of overloads matters]
 | |
| #
 | |
| # A few examples of ambiguous python signature pairs.
 | |
| #
 | |
| #   All parameters have the same type, except one taking Tensor the other taking
 | |
| #   Scalar. A numeric PyObject can be casted into Tensor, and a zero-dim Tensor
 | |
| #   object can be accepted as Scalar type parameter (see python_arg_parser.cpp).
 | |
| #   Therefore, same input arguments might be accepted by either python signature.
 | |
| #   We want to always parse the one taking Tensor first.
 | |
| #
 | |
| #     bitwise_and(Tensor input, Tensor other, *, Tensor out=None)
 | |
| #     bitwise_and(Tensor input, Scalar other, *, Tensor out=None)
 | |
| #
 | |
| #   If they have different number of parameters then they are not ambiguous - but
 | |
| #   the difference on output param can be ignored as it's optional.
 | |
| #
 | |
| #     multiply(Tensor input, Tensor other, *, Tensor out=None)
 | |
| #     multiply(Tensor input, Scalar other)
 | |
| #
 | |
| #   Both positional args and keyword-only args are considered together.
 | |
| #
 | |
| #     subtract(Tensor other, *, Scalar alpha=1)
 | |
| #     subtract(Scalar other, Scalar alpha=1)
 | |
| #
 | |
| # A few ambiguous cases which it does NOT handle yet.
 | |
| #
 | |
| #   If there is any difference in other parameters besides the Tensor/Scalar
 | |
| #   difference, then they are not considered ambiguous by this method anymore.
 | |
| #   However, the difference could be too trivial to disambiguate.
 | |
| #
 | |
| #     foo(Tensor input, Scalar other, Scalar bar)
 | |
| #     foo(Tensor input, Tensor other, double bar)
 | |
| #
 | |
| #   If they are taking different number of parameters then they are not considered
 | |
| #   ambiguous anymore, even if the difference is only on optional kwargs.
 | |
| #
 | |
| #     foo(Scalar other, Scalar alpha=1)
 | |
| #     foo(Tensor other, *, Scalar alpha=1, Scalar beta=1)
 | |
| #
 | |
| 
 | |
| 
 | |
| def sort_overloads(
 | |
|     grouped_overloads: Sequence[PythonSignatureGroup],
 | |
| ) -> Sequence[PythonSignatureGroup]:
 | |
|     def is_arg_smaller(t1: Type, t2: Type) -> bool:
 | |
|         return (
 | |
|             str(t1) == "Scalar"
 | |
|             and str(t2) == "Tensor"
 | |
|             or str(t1) == "Scalar?"
 | |
|             and str(t2) == "Tensor?"
 | |
|             or "Dimname" in str(t1)
 | |
|             and "Dimname" not in str(t2)
 | |
|             or
 | |
|             # In the discussion https://github.com/pytorch/pytorch/issues/54555 it has been
 | |
|             # discussed why it is important to prioritize int/int? over int[]
 | |
|             str(t1) == "int[]"
 | |
|             and (str(t2) == "int" or str(t2) == "int?")
 | |
|             or
 | |
|             # TensorList currently throws an error during argument parsing, that's why it needs to be
 | |
|             # last in signature ordering. See discussion: https://github.com/pytorch/pytorch/issues/58087
 | |
|             str(t1) == "Tensor[]"
 | |
|             and str(t2).find("[]") != -1
 | |
|         )
 | |
| 
 | |
|     def is_smaller(s1: PythonSignature, s2: PythonSignature) -> bool:
 | |
|         """Returns True if s1 < s2 in the partial order."""
 | |
|         args1, args2 = s1.arguments(skip_outputs=True), s2.arguments(skip_outputs=True)
 | |
|         if len(args1) != len(args2):
 | |
|             return False
 | |
|         # TODO: should use some canonical form instead of 'str(arg.type)' - see comments
 | |
|         # above. The old codegen used the deprecated 'dynamic_type(arg.type)', which
 | |
|         # ignores the optional annotation, i.e. 'Scalar' and 'Scalar?'.
 | |
|         equal = all(arg1.type == arg2.type for arg1, arg2 in zip(args1, args2))
 | |
|         smaller_or_equal = all(
 | |
|             str(arg1.type) == str(arg2.type) or is_arg_smaller(arg1.type, arg2.type)
 | |
|             for arg1, arg2 in zip(args1, args2)
 | |
|         )
 | |
|         return smaller_or_equal and not equal
 | |
| 
 | |
|     # First sort by signature
 | |
|     grouped_overloads = sorted(
 | |
|         grouped_overloads, key=lambda x: x.signature.signature_str()
 | |
|     )
 | |
| 
 | |
|     # Construct the relation graph
 | |
|     larger_than: Dict[int, Set[int]] = defaultdict(set)
 | |
|     for i1, overload1 in enumerate(grouped_overloads):
 | |
|         for i2, overload2 in enumerate(grouped_overloads):
 | |
|             if is_smaller(overload1.signature, overload2.signature):
 | |
|                 larger_than[i1].add(i2)
 | |
| 
 | |
|     if not larger_than:
 | |
|         return list(grouped_overloads)
 | |
| 
 | |
|     # Use a topological sort to sort overloads according to the partial order.
 | |
|     N = len(grouped_overloads)
 | |
|     sorted_ids: List[int] = list(filter(lambda x: x not in larger_than, range(N)))
 | |
| 
 | |
|     for idx in range(N):
 | |
|         # The size of sorted_ids will grow to N eventually.
 | |
|         i = sorted_ids[idx]
 | |
|         for j in sorted(larger_than.keys()):
 | |
|             larger = larger_than[j]
 | |
|             larger.discard(i)
 | |
|             if not larger:
 | |
|                 del larger_than[j]
 | |
|                 sorted_ids.append(j)
 | |
| 
 | |
|     return list(map(lambda x: grouped_overloads[x], sorted_ids))
 | |
| 
 | |
| 
 | |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 | |
| #
 | |
| #                       Codegen API Integration
 | |
| #
 | |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 | |
| 
 | |
| 
 | |
| def emit_single_dispatch(
 | |
|     ps: PythonSignature, f: NativeFunction, namedtuple_typenames: Dict[str, str]
 | |
| ) -> str:
 | |
|     """
 | |
|     Emit dispatch code for a single native function.
 | |
|     """
 | |
| 
 | |
|     @with_native_function
 | |
|     def go(f: NativeFunction) -> str:
 | |
|         # header comments
 | |
|         deprecated = "[deprecated] " if ps.deprecated else ""
 | |
|         schema_comment = f"// {deprecated}aten::{f.func}"
 | |
| 
 | |
|         # dispatch lambda signature
 | |
|         name = cpp.name(f.func)
 | |
|         lambda_formals = ", ".join(
 | |
|             map(lambda a: f"{a.type_str} {a.name}", dispatch_lambda_args(ps, f))
 | |
|         )
 | |
|         lambda_return = dispatch_lambda_return_str(f)
 | |
| 
 | |
|         # dispatch lambda body
 | |
|         dispatch_callee = cpp_dispatch_target(f)
 | |
|         dispatch_args = ", ".join(cpp_dispatch_exprs(f, python_signature=ps))
 | |
| 
 | |
|         # from arg parser outputs to dispatch lambda arguments
 | |
|         parser_outputs = arg_parser_output_exprs(ps, f)
 | |
|         lambda_arg_exprs = dispatch_lambda_exprs(ps, f)
 | |
|         inits = "\n".join(lambda_arg_exprs.inits)
 | |
|         lambda_args = ", ".join(lambda_arg_exprs.exprs)
 | |
| 
 | |
|         # scatter fields
 | |
|         # TODO: Checking `ps.method and ('requires_grad' in parser_outputs)` is a hacky
 | |
|         #       solution for enabling the 'requires_grad' argument for tensor methods
 | |
|         #       new_full, new_empty, and new_zeros. A much better but more difficult to
 | |
|         #       implement solution involves refactoring according to Ed's description here:
 | |
|         #       https://github.com/pytorch/pytorch/issues/36455#issuecomment-614767589
 | |
|         need_set_requires_grad = ps.tensor_options_args and (
 | |
|             not has_tensor_options(f)
 | |
|             or (ps.method and ("requires_grad" in parser_outputs))
 | |
|         )
 | |
|         set_requires_grad = (
 | |
|             f'.set_requires_grad({parser_outputs["requires_grad"].expr})'
 | |
|             if need_set_requires_grad
 | |
|             else ""
 | |
|         )
 | |
| 
 | |
|         if lambda_return == "void":
 | |
|             return f"""\
 | |
| {schema_comment}
 | |
| {inits}
 | |
| auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{
 | |
|   pybind11::gil_scoped_release no_gil;
 | |
|   {dispatch_callee}({dispatch_args});
 | |
| }};
 | |
| dispatch_{name}({lambda_args}){set_requires_grad};
 | |
| Py_RETURN_NONE;
 | |
| """
 | |
|         else:
 | |
|             typename = namedtuple_typenames.get(gen_namedtuple_typename_key(f))
 | |
|             namedtuple_typeref = f"{typename}, " if typename is not None else ""
 | |
|             return f"""\
 | |
| {schema_comment}
 | |
| {inits}
 | |
| auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{
 | |
|   pybind11::gil_scoped_release no_gil;
 | |
|   return {dispatch_callee}({dispatch_args});
 | |
| }};
 | |
| return wrap({namedtuple_typeref}dispatch_{name}({lambda_args}){set_requires_grad});
 | |
| """
 | |
| 
 | |
|     return go(f)
 |