Revert "[BE][Easy] enable postponed annotations in tools (#129375)"

This reverts commit 59eb2897f1745f513edb6c63065ffad481c4c8d0.

Reverted https://github.com/pytorch/pytorch/pull/129375 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I need to revert to cleanly revert https://github.com/pytorch/pytorch/pull/129374, please do a rebase and reland this ([comment](https://github.com/pytorch/pytorch/pull/129375#issuecomment-2197800541))
This commit is contained in:
PyTorch MergeBot
2024-06-29 00:44:25 +00:00
parent 6063bb9d45
commit a32ce5ce34
123 changed files with 1052 additions and 1275 deletions

View File

@ -31,12 +31,11 @@
# message, but use what's there
#
from __future__ import annotations
import itertools
import re
from collections import defaultdict
from typing import Callable, Iterable, Sequence
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple
import yaml
@ -57,6 +56,7 @@ from torchgen.api.python import (
signature_from_schema,
structseq_fieldnames,
)
from torchgen.code_template import CodeTemplate
from torchgen.context import with_native_function
from torchgen.gen import cpp_string, parse_native_yaml, parse_tags_yaml
@ -75,7 +75,6 @@ from torchgen.yaml_utils import YamlLoader
from .gen_inplace_or_view_type import is_tensor_list_type
from .gen_trace_type import should_trace
#
# declarations blocklist
# We skip codegen for these functions, for various reasons.
@ -370,7 +369,7 @@ def gen(
valid_tags = parse_tags_yaml(tags_yaml_path)
def gen_tags_enum() -> dict[str, str]:
def gen_tags_enum() -> Dict[str, str]:
return {
"enum_of_valid_tags": (
"".join(
@ -385,9 +384,9 @@ def gen(
def group_filter_overloads(
pairs: Sequence[PythonSignatureNativeFunctionPair],
pred: Callable[[NativeFunction], bool],
) -> dict[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]:
grouped: dict[
BaseOperatorName, list[PythonSignatureNativeFunctionPair]
) -> Dict[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]:
grouped: Dict[
BaseOperatorName, List[PythonSignatureNativeFunctionPair]
] = defaultdict(list)
for pair in pairs:
if pred(pair.function):
@ -399,17 +398,17 @@ def create_python_bindings(
fm: FileManager,
pairs: Sequence[PythonSignatureNativeFunctionPair],
pred: Callable[[NativeFunction], bool],
module: str | None,
module: Optional[str],
filename: str,
*,
method: bool,
symint: bool = True,
) -> None:
"""Generates Python bindings to ATen functions"""
py_methods: list[str] = []
ops_headers: list[str] = []
py_method_defs: list[str] = []
py_forwards: list[str] = []
py_methods: List[str] = []
ops_headers: List[str] = []
py_method_defs: List[str] = []
py_forwards: List[str] = []
grouped = group_filter_overloads(pairs, pred)
@ -446,8 +445,8 @@ def create_python_return_type_bindings(
Generate function to initialize and return named tuple for native functions
which returns named tuple and registration invocations in `python_return_types.cpp`.
"""
py_return_types_definition: list[str] = []
py_return_types_registrations: list[str] = []
py_return_types_definition: List[str] = []
py_return_types_registrations: List[str] = []
grouped = group_filter_overloads(pairs, pred)
@ -485,7 +484,7 @@ def create_python_return_type_bindings_header(
Generate function to initialize and return named tuple for native functions
which returns named tuple and relevant entry for the map in `python_return_types.cpp`.
"""
py_return_types_declarations: list[str] = []
py_return_types_declarations: List[str] = []
grouped = group_filter_overloads(pairs, pred)
@ -511,7 +510,7 @@ def create_python_bindings_sharded(
fm: FileManager,
pairs: Sequence[PythonSignatureNativeFunctionPair],
pred: Callable[[NativeFunction], bool],
module: str | None,
module: Optional[str],
filename: str,
*,
method: bool,
@ -522,13 +521,13 @@ def create_python_bindings_sharded(
grouped = group_filter_overloads(pairs, pred)
def key_func(
kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]
kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
) -> str:
return kv[0].base
def env_func(
kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]
) -> dict[str, list[str]]:
kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
) -> Dict[str, List[str]]:
name, fn_pairs = kv
return {
"ops_headers": [f"#include <ATen/ops/{name.base}.h>"],
@ -554,7 +553,7 @@ def create_python_bindings_sharded(
def load_signatures(
native_functions: list[NativeFunction],
native_functions: List[NativeFunction],
deprecated_yaml_path: str,
*,
method: bool,
@ -581,19 +580,19 @@ def load_deprecated_signatures(
*,
method: bool,
pyi: bool,
) -> list[PythonSignatureNativeFunctionPair]:
) -> List[PythonSignatureNativeFunctionPair]:
# The deprecated.yaml doesn't have complete type information, we need
# find and leverage the original ATen signature (to which it delegates
# the call) to generate the full python signature.
# We join the deprecated and the original signatures using type-only form.
# group the original ATen signatures by name
grouped: dict[str, list[PythonSignatureNativeFunctionPair]] = defaultdict(list)
grouped: Dict[str, List[PythonSignatureNativeFunctionPair]] = defaultdict(list)
for pair in pairs:
grouped[pair.signature.name].append(pair)
# find matching original signatures for each deprecated signature
results: list[PythonSignatureNativeFunctionPair] = []
results: List[PythonSignatureNativeFunctionPair] = []
with open(deprecated_yaml_path) as f:
deprecated_defs = yaml.load(f, Loader=YamlLoader)
@ -702,15 +701,15 @@ def gen_structseq_typename_key(f: NativeFunction) -> str:
def emit_structseq_call(
overloads: Sequence[PythonSignatureNativeFunctionPair],
) -> tuple[list[str], dict[str, str]]:
) -> Tuple[List[str], Dict[str, str]]:
"""
Generate block of named tuple type def inits, and add typeref snippets
to declarations that use them
"""
typenames: dict[
typenames: Dict[
str, str
] = {} # map from unique name + field name lists to typedef name
typedefs: list[str] = [] # typedef declarations and init code
typedefs: List[str] = [] # typedef declarations and init code
for overload in overloads:
fieldnames = structseq_fieldnames(overload.function.func.returns)
@ -733,17 +732,17 @@ static PyTypeObject* {typename} = generated::get_{name}_structseq();"""
def generate_return_type_definition_and_registrations(
overloads: Sequence[PythonSignatureNativeFunctionPair],
) -> tuple[list[str], list[str]]:
) -> Tuple[List[str], List[str]]:
"""
Generate block of function in `python_return_types.cpp` to initialize
and return named tuple for a native function which returns named tuple
and registration invocations in same file.
"""
typenames: dict[
typenames: Dict[
str, str
] = {} # map from unique name + field name lists to typedef name
definitions: list[str] = [] # function definition to register the typedef
registrations: list[str] = [] # register call for the typedef
definitions: List[str] = [] # function definition to register the typedef
registrations: List[str] = [] # register call for the typedef
for overload in overloads:
fieldnames = structseq_fieldnames(overload.function.func.returns)
@ -784,15 +783,15 @@ PyTypeObject* get_{name}_structseq() {{
def generate_return_type_declarations(
overloads: Sequence[PythonSignatureNativeFunctionPair],
) -> list[str]:
) -> List[str]:
"""
Generate block of function declarations in `python_return_types.h` to initialize
and return named tuple for a native function.
"""
typenames: dict[
typenames: Dict[
str, str
] = {} # map from unique name + field name lists to typedef name
declarations: list[str] = [] # function declaration to register the typedef
declarations: List[str] = [] # function declaration to register the typedef
for overload in overloads:
fieldnames = structseq_fieldnames(overload.function.func.returns)
@ -892,7 +891,7 @@ static PyObject * ${pycname}(PyObject* self_, PyObject* args)
def method_impl(
name: BaseOperatorName,
module: str | None,
module: Optional[str],
overloads: Sequence[PythonSignatureNativeFunctionPair],
*,
method: bool,
@ -919,8 +918,8 @@ def method_impl(
overloads, symint=symint
)
is_singleton = len(grouped_overloads) == 1
signatures: list[str] = []
dispatch: list[str] = []
signatures: List[str] = []
dispatch: List[str] = []
for overload_index, overload in enumerate(grouped_overloads):
signature = overload.signature.signature_str(symint=symint)
signatures.append(f"{cpp_string(str(signature))},")
@ -960,7 +959,7 @@ def method_impl(
def gen_has_torch_function_check(
name: BaseOperatorName, module: str | None, *, noarg: bool, method: bool
name: BaseOperatorName, module: Optional[str], *, noarg: bool, method: bool
) -> str:
if noarg:
if method:
@ -1008,7 +1007,7 @@ if (_r.isNone(${out_idx})) {
def emit_dispatch_case(
overload: PythonSignatureGroup,
structseq_typenames: dict[str, str],
structseq_typenames: Dict[str, str],
*,
symint: bool = True,
) -> str:
@ -1051,7 +1050,7 @@ def forward_decls(
overloads: Sequence[PythonSignatureNativeFunctionPair],
*,
method: bool,
) -> tuple[str, ...]:
) -> Tuple[str, ...]:
if method:
return ()
@ -1079,7 +1078,7 @@ static PyObject * {pycname}(PyObject* self_, PyObject* args, PyObject* kwargs);
def method_def(
name: BaseOperatorName,
module: str | None,
module: Optional[str],
overloads: Sequence[PythonSignatureNativeFunctionPair],
*,
method: bool,
@ -1115,8 +1114,8 @@ def method_def(
def group_overloads(
overloads: Sequence[PythonSignatureNativeFunctionPair], *, symint: bool = True
) -> Sequence[PythonSignatureGroup]:
bases: dict[str, PythonSignatureNativeFunctionPair] = {}
outplaces: dict[str, PythonSignatureNativeFunctionPair] = {}
bases: Dict[str, PythonSignatureNativeFunctionPair] = {}
outplaces: Dict[str, PythonSignatureNativeFunctionPair] = {}
# first group by signature ignoring out arguments
for overload in overloads:
@ -1138,7 +1137,7 @@ def group_overloads(
for sig, out in outplaces.items():
if sig not in bases:
candidates: list[str] = []
candidates: List[str] = []
for overload in overloads:
if (
str(overload.function.func.name.name)
@ -1269,7 +1268,7 @@ def sort_overloads(
)
# Construct the relation graph
larger_than: dict[int, set[int]] = defaultdict(set)
larger_than: Dict[int, Set[int]] = defaultdict(set)
for i1, overload1 in enumerate(grouped_overloads):
for i2, overload2 in enumerate(grouped_overloads):
if is_smaller(overload1.signature, overload2.signature):
@ -1280,7 +1279,7 @@ def sort_overloads(
# Use a topological sort to sort overloads according to the partial order.
N = len(grouped_overloads)
sorted_ids: list[int] = list(filter(lambda x: x not in larger_than, range(N)))
sorted_ids: List[int] = list(filter(lambda x: x not in larger_than, range(N)))
for idx in range(N):
# The size of sorted_ids will grow to N eventually.
@ -1305,7 +1304,7 @@ def sort_overloads(
def emit_single_dispatch(
ps: PythonSignature,
f: NativeFunction,
structseq_typenames: dict[str, str],
structseq_typenames: Dict[str, str],
*,
symint: bool = True,
) -> str: