mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Revert "[BE][Easy] enable postponed annotations in tools
(#129375)"
This reverts commit 59eb2897f1745f513edb6c63065ffad481c4c8d0. Reverted https://github.com/pytorch/pytorch/pull/129375 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I need to revert to cleanly revert https://github.com/pytorch/pytorch/pull/129374, please do a rebase and reland this ([comment](https://github.com/pytorch/pytorch/pull/129375#issuecomment-2197800541))
This commit is contained in:
@ -31,12 +31,11 @@
|
||||
# message, but use what's there
|
||||
#
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Iterable, Sequence
|
||||
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple
|
||||
|
||||
import yaml
|
||||
|
||||
@ -57,6 +56,7 @@ from torchgen.api.python import (
|
||||
signature_from_schema,
|
||||
structseq_fieldnames,
|
||||
)
|
||||
|
||||
from torchgen.code_template import CodeTemplate
|
||||
from torchgen.context import with_native_function
|
||||
from torchgen.gen import cpp_string, parse_native_yaml, parse_tags_yaml
|
||||
@ -75,7 +75,6 @@ from torchgen.yaml_utils import YamlLoader
|
||||
from .gen_inplace_or_view_type import is_tensor_list_type
|
||||
from .gen_trace_type import should_trace
|
||||
|
||||
|
||||
#
|
||||
# declarations blocklist
|
||||
# We skip codegen for these functions, for various reasons.
|
||||
@ -370,7 +369,7 @@ def gen(
|
||||
|
||||
valid_tags = parse_tags_yaml(tags_yaml_path)
|
||||
|
||||
def gen_tags_enum() -> dict[str, str]:
|
||||
def gen_tags_enum() -> Dict[str, str]:
|
||||
return {
|
||||
"enum_of_valid_tags": (
|
||||
"".join(
|
||||
@ -385,9 +384,9 @@ def gen(
|
||||
def group_filter_overloads(
|
||||
pairs: Sequence[PythonSignatureNativeFunctionPair],
|
||||
pred: Callable[[NativeFunction], bool],
|
||||
) -> dict[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]:
|
||||
grouped: dict[
|
||||
BaseOperatorName, list[PythonSignatureNativeFunctionPair]
|
||||
) -> Dict[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]:
|
||||
grouped: Dict[
|
||||
BaseOperatorName, List[PythonSignatureNativeFunctionPair]
|
||||
] = defaultdict(list)
|
||||
for pair in pairs:
|
||||
if pred(pair.function):
|
||||
@ -399,17 +398,17 @@ def create_python_bindings(
|
||||
fm: FileManager,
|
||||
pairs: Sequence[PythonSignatureNativeFunctionPair],
|
||||
pred: Callable[[NativeFunction], bool],
|
||||
module: str | None,
|
||||
module: Optional[str],
|
||||
filename: str,
|
||||
*,
|
||||
method: bool,
|
||||
symint: bool = True,
|
||||
) -> None:
|
||||
"""Generates Python bindings to ATen functions"""
|
||||
py_methods: list[str] = []
|
||||
ops_headers: list[str] = []
|
||||
py_method_defs: list[str] = []
|
||||
py_forwards: list[str] = []
|
||||
py_methods: List[str] = []
|
||||
ops_headers: List[str] = []
|
||||
py_method_defs: List[str] = []
|
||||
py_forwards: List[str] = []
|
||||
|
||||
grouped = group_filter_overloads(pairs, pred)
|
||||
|
||||
@ -446,8 +445,8 @@ def create_python_return_type_bindings(
|
||||
Generate function to initialize and return named tuple for native functions
|
||||
which returns named tuple and registration invocations in `python_return_types.cpp`.
|
||||
"""
|
||||
py_return_types_definition: list[str] = []
|
||||
py_return_types_registrations: list[str] = []
|
||||
py_return_types_definition: List[str] = []
|
||||
py_return_types_registrations: List[str] = []
|
||||
|
||||
grouped = group_filter_overloads(pairs, pred)
|
||||
|
||||
@ -485,7 +484,7 @@ def create_python_return_type_bindings_header(
|
||||
Generate function to initialize and return named tuple for native functions
|
||||
which returns named tuple and relevant entry for the map in `python_return_types.cpp`.
|
||||
"""
|
||||
py_return_types_declarations: list[str] = []
|
||||
py_return_types_declarations: List[str] = []
|
||||
|
||||
grouped = group_filter_overloads(pairs, pred)
|
||||
|
||||
@ -511,7 +510,7 @@ def create_python_bindings_sharded(
|
||||
fm: FileManager,
|
||||
pairs: Sequence[PythonSignatureNativeFunctionPair],
|
||||
pred: Callable[[NativeFunction], bool],
|
||||
module: str | None,
|
||||
module: Optional[str],
|
||||
filename: str,
|
||||
*,
|
||||
method: bool,
|
||||
@ -522,13 +521,13 @@ def create_python_bindings_sharded(
|
||||
grouped = group_filter_overloads(pairs, pred)
|
||||
|
||||
def key_func(
|
||||
kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]
|
||||
kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
|
||||
) -> str:
|
||||
return kv[0].base
|
||||
|
||||
def env_func(
|
||||
kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]
|
||||
) -> dict[str, list[str]]:
|
||||
kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
|
||||
) -> Dict[str, List[str]]:
|
||||
name, fn_pairs = kv
|
||||
return {
|
||||
"ops_headers": [f"#include <ATen/ops/{name.base}.h>"],
|
||||
@ -554,7 +553,7 @@ def create_python_bindings_sharded(
|
||||
|
||||
|
||||
def load_signatures(
|
||||
native_functions: list[NativeFunction],
|
||||
native_functions: List[NativeFunction],
|
||||
deprecated_yaml_path: str,
|
||||
*,
|
||||
method: bool,
|
||||
@ -581,19 +580,19 @@ def load_deprecated_signatures(
|
||||
*,
|
||||
method: bool,
|
||||
pyi: bool,
|
||||
) -> list[PythonSignatureNativeFunctionPair]:
|
||||
) -> List[PythonSignatureNativeFunctionPair]:
|
||||
# The deprecated.yaml doesn't have complete type information, we need
|
||||
# find and leverage the original ATen signature (to which it delegates
|
||||
# the call) to generate the full python signature.
|
||||
# We join the deprecated and the original signatures using type-only form.
|
||||
|
||||
# group the original ATen signatures by name
|
||||
grouped: dict[str, list[PythonSignatureNativeFunctionPair]] = defaultdict(list)
|
||||
grouped: Dict[str, List[PythonSignatureNativeFunctionPair]] = defaultdict(list)
|
||||
for pair in pairs:
|
||||
grouped[pair.signature.name].append(pair)
|
||||
|
||||
# find matching original signatures for each deprecated signature
|
||||
results: list[PythonSignatureNativeFunctionPair] = []
|
||||
results: List[PythonSignatureNativeFunctionPair] = []
|
||||
|
||||
with open(deprecated_yaml_path) as f:
|
||||
deprecated_defs = yaml.load(f, Loader=YamlLoader)
|
||||
@ -702,15 +701,15 @@ def gen_structseq_typename_key(f: NativeFunction) -> str:
|
||||
|
||||
def emit_structseq_call(
|
||||
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
||||
) -> tuple[list[str], dict[str, str]]:
|
||||
) -> Tuple[List[str], Dict[str, str]]:
|
||||
"""
|
||||
Generate block of named tuple type def inits, and add typeref snippets
|
||||
to declarations that use them
|
||||
"""
|
||||
typenames: dict[
|
||||
typenames: Dict[
|
||||
str, str
|
||||
] = {} # map from unique name + field name lists to typedef name
|
||||
typedefs: list[str] = [] # typedef declarations and init code
|
||||
typedefs: List[str] = [] # typedef declarations and init code
|
||||
|
||||
for overload in overloads:
|
||||
fieldnames = structseq_fieldnames(overload.function.func.returns)
|
||||
@ -733,17 +732,17 @@ static PyTypeObject* {typename} = generated::get_{name}_structseq();"""
|
||||
|
||||
def generate_return_type_definition_and_registrations(
|
||||
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
||||
) -> tuple[list[str], list[str]]:
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Generate block of function in `python_return_types.cpp` to initialize
|
||||
and return named tuple for a native function which returns named tuple
|
||||
and registration invocations in same file.
|
||||
"""
|
||||
typenames: dict[
|
||||
typenames: Dict[
|
||||
str, str
|
||||
] = {} # map from unique name + field name lists to typedef name
|
||||
definitions: list[str] = [] # function definition to register the typedef
|
||||
registrations: list[str] = [] # register call for the typedef
|
||||
definitions: List[str] = [] # function definition to register the typedef
|
||||
registrations: List[str] = [] # register call for the typedef
|
||||
|
||||
for overload in overloads:
|
||||
fieldnames = structseq_fieldnames(overload.function.func.returns)
|
||||
@ -784,15 +783,15 @@ PyTypeObject* get_{name}_structseq() {{
|
||||
|
||||
def generate_return_type_declarations(
|
||||
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
||||
) -> list[str]:
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate block of function declarations in `python_return_types.h` to initialize
|
||||
and return named tuple for a native function.
|
||||
"""
|
||||
typenames: dict[
|
||||
typenames: Dict[
|
||||
str, str
|
||||
] = {} # map from unique name + field name lists to typedef name
|
||||
declarations: list[str] = [] # function declaration to register the typedef
|
||||
declarations: List[str] = [] # function declaration to register the typedef
|
||||
|
||||
for overload in overloads:
|
||||
fieldnames = structseq_fieldnames(overload.function.func.returns)
|
||||
@ -892,7 +891,7 @@ static PyObject * ${pycname}(PyObject* self_, PyObject* args)
|
||||
|
||||
def method_impl(
|
||||
name: BaseOperatorName,
|
||||
module: str | None,
|
||||
module: Optional[str],
|
||||
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
||||
*,
|
||||
method: bool,
|
||||
@ -919,8 +918,8 @@ def method_impl(
|
||||
overloads, symint=symint
|
||||
)
|
||||
is_singleton = len(grouped_overloads) == 1
|
||||
signatures: list[str] = []
|
||||
dispatch: list[str] = []
|
||||
signatures: List[str] = []
|
||||
dispatch: List[str] = []
|
||||
for overload_index, overload in enumerate(grouped_overloads):
|
||||
signature = overload.signature.signature_str(symint=symint)
|
||||
signatures.append(f"{cpp_string(str(signature))},")
|
||||
@ -960,7 +959,7 @@ def method_impl(
|
||||
|
||||
|
||||
def gen_has_torch_function_check(
|
||||
name: BaseOperatorName, module: str | None, *, noarg: bool, method: bool
|
||||
name: BaseOperatorName, module: Optional[str], *, noarg: bool, method: bool
|
||||
) -> str:
|
||||
if noarg:
|
||||
if method:
|
||||
@ -1008,7 +1007,7 @@ if (_r.isNone(${out_idx})) {
|
||||
|
||||
def emit_dispatch_case(
|
||||
overload: PythonSignatureGroup,
|
||||
structseq_typenames: dict[str, str],
|
||||
structseq_typenames: Dict[str, str],
|
||||
*,
|
||||
symint: bool = True,
|
||||
) -> str:
|
||||
@ -1051,7 +1050,7 @@ def forward_decls(
|
||||
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
||||
*,
|
||||
method: bool,
|
||||
) -> tuple[str, ...]:
|
||||
) -> Tuple[str, ...]:
|
||||
if method:
|
||||
return ()
|
||||
|
||||
@ -1079,7 +1078,7 @@ static PyObject * {pycname}(PyObject* self_, PyObject* args, PyObject* kwargs);
|
||||
|
||||
def method_def(
|
||||
name: BaseOperatorName,
|
||||
module: str | None,
|
||||
module: Optional[str],
|
||||
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
||||
*,
|
||||
method: bool,
|
||||
@ -1115,8 +1114,8 @@ def method_def(
|
||||
def group_overloads(
|
||||
overloads: Sequence[PythonSignatureNativeFunctionPair], *, symint: bool = True
|
||||
) -> Sequence[PythonSignatureGroup]:
|
||||
bases: dict[str, PythonSignatureNativeFunctionPair] = {}
|
||||
outplaces: dict[str, PythonSignatureNativeFunctionPair] = {}
|
||||
bases: Dict[str, PythonSignatureNativeFunctionPair] = {}
|
||||
outplaces: Dict[str, PythonSignatureNativeFunctionPair] = {}
|
||||
|
||||
# first group by signature ignoring out arguments
|
||||
for overload in overloads:
|
||||
@ -1138,7 +1137,7 @@ def group_overloads(
|
||||
|
||||
for sig, out in outplaces.items():
|
||||
if sig not in bases:
|
||||
candidates: list[str] = []
|
||||
candidates: List[str] = []
|
||||
for overload in overloads:
|
||||
if (
|
||||
str(overload.function.func.name.name)
|
||||
@ -1269,7 +1268,7 @@ def sort_overloads(
|
||||
)
|
||||
|
||||
# Construct the relation graph
|
||||
larger_than: dict[int, set[int]] = defaultdict(set)
|
||||
larger_than: Dict[int, Set[int]] = defaultdict(set)
|
||||
for i1, overload1 in enumerate(grouped_overloads):
|
||||
for i2, overload2 in enumerate(grouped_overloads):
|
||||
if is_smaller(overload1.signature, overload2.signature):
|
||||
@ -1280,7 +1279,7 @@ def sort_overloads(
|
||||
|
||||
# Use a topological sort to sort overloads according to the partial order.
|
||||
N = len(grouped_overloads)
|
||||
sorted_ids: list[int] = list(filter(lambda x: x not in larger_than, range(N)))
|
||||
sorted_ids: List[int] = list(filter(lambda x: x not in larger_than, range(N)))
|
||||
|
||||
for idx in range(N):
|
||||
# The size of sorted_ids will grow to N eventually.
|
||||
@ -1305,7 +1304,7 @@ def sort_overloads(
|
||||
def emit_single_dispatch(
|
||||
ps: PythonSignature,
|
||||
f: NativeFunction,
|
||||
structseq_typenames: dict[str, str],
|
||||
structseq_typenames: Dict[str, str],
|
||||
*,
|
||||
symint: bool = True,
|
||||
) -> str:
|
||||
|
Reference in New Issue
Block a user