[BE][Easy] enable postponed annotations in torchgen (#129376)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129376
Approved by: https://github.com/ezyang
ghstack dependencies: #129375
This commit is contained in:
Xuehai Pan
2024-06-28 16:28:16 +08:00
committed by PyTorch MergeBot
parent 59eb2897f1
commit 494057d6d4
45 changed files with 977 additions and 901 deletions

View File

@ -1,5 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
from typing import Sequence
from torchgen.api import cpp
from torchgen.api.types import Binding, CppSignature, CppSignatureGroup
@ -197,14 +199,14 @@ from torchgen.model import (
@dataclass(frozen=True)
class PythonReturns:
returns: Tuple[Return, ...]
returns: tuple[Return, ...]
@dataclass(frozen=True)
class PythonArgument:
name: str
type: Type
default: Optional[str]
default: str | None
# Used to generate the default init expr for some PythonArgParser outputs, e.g.:
#
@ -212,7 +214,7 @@ class PythonArgument:
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ^
# +--- default_init str
default_init: Optional[str]
default_init: str | None
# Compute argument formal for python argument parsing.
# Needs to be consistent with torch/csrc/utils/python_arg_parser.h.
@ -300,12 +302,10 @@ class PythonOutArgument(PythonArgument):
# 'auto out = _r.tensorlist_n<2>(2);',
# then binded to scattered C++ output arguments as 'out[0]', 'out[1]', and etc.
# TODO: maybe don't need keep scattered out fields for python signature?
outputs: Tuple[PythonArgument, ...]
outputs: tuple[PythonArgument, ...]
@staticmethod
def from_outputs(
outputs: Tuple[PythonArgument, ...]
) -> Optional["PythonOutArgument"]:
def from_outputs(outputs: tuple[PythonArgument, ...]) -> PythonOutArgument | None:
if not outputs:
return None
@ -339,13 +339,13 @@ class PythonSignature:
# Positional arguments.
# TODO: create a dedicated SelfArgument type for 'self'?
input_args: Tuple[PythonArgument, ...]
input_args: tuple[PythonArgument, ...]
# Keyword arguments excluding the 'out' argument and scattered kwargs belonging
# to TensorOptions (dtype, layout, device, pin_memory, requires_grad, etc).
input_kwargs: Tuple[PythonArgument, ...]
input_kwargs: tuple[PythonArgument, ...]
output_args: Optional[PythonOutArgument]
output_args: PythonOutArgument | None
# Return types, which are only used by pyi
returns: PythonReturns
@ -356,7 +356,7 @@ class PythonSignature:
# for out variant), in which case they will be used as scattered fields without
# being packed into 'options'.
# TODO: maybe create a PythonTensorOptionsArgument?
tensor_options_args: Tuple[PythonArgument, ...]
tensor_options_args: tuple[PythonArgument, ...]
# method or function signature?
method: bool
@ -367,8 +367,8 @@ class PythonSignature:
def arguments(
self, *, skip_outputs: bool = False, skip_tensor_options: bool = False
) -> Tuple[Union[PythonArgument, PythonOutArgument], ...]:
result: List[Union[PythonArgument, PythonOutArgument]] = []
) -> tuple[PythonArgument | PythonOutArgument, ...]:
result: list[PythonArgument | PythonOutArgument] = []
result.extend(self.input_args)
result.extend(self.input_kwargs)
if self.output_args is not None and not skip_outputs:
@ -394,7 +394,7 @@ class PythonSignature:
# signature_str_pyi().
def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
args = self.arguments(skip_outputs=skip_outputs)
schema_formals: List[str] = [
schema_formals: list[str] = [
a.argument_str(method=self.method, symint=symint) for a in args
]
positional_argc = len(self.input_args)
@ -405,7 +405,7 @@ class PythonSignature:
def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
args = self.arguments(skip_outputs=skip_outputs)
schema_formals: List[str] = [
schema_formals: list[str] = [
a.argument_str_pyi(method=self.method) for a in args
]
positional_argc = len(self.input_args)
@ -419,10 +419,10 @@ class PythonSignature:
schema_formals.insert(0, "self")
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]:
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
# only pyi uses vararg signatures
args = self.arguments(skip_outputs=skip_outputs)
schema_formals: List[str] = [
schema_formals: list[str] = [
a.argument_str_pyi(method=self.method) for a in args
]
# vararg only applies to pyi signatures. vararg variants are not generated for all signatures
@ -470,7 +470,7 @@ class PythonSignatureDeprecated(PythonSignature):
# [func schema]: aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
# [func call]: self.addmm(mat1, mat2, beta, 1)
# We store ['self', 'mat1', 'mat2', 'beta', '1'] in this case.
deprecated_args_exprs: Tuple[str, ...]
deprecated_args_exprs: tuple[str, ...]
@property
def deprecated(self) -> bool:
@ -486,7 +486,7 @@ class PythonSignatureDeprecated(PythonSignature):
def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
args = self.arguments(skip_outputs=skip_outputs)
schema_formals: List[str] = [
schema_formals: list[str] = [
a.argument_str_pyi(method=self.method, deprecated=True) for a in args
]
positional_argc = len(self.input_args)
@ -496,7 +496,7 @@ class PythonSignatureDeprecated(PythonSignature):
returns_str = returns_str_pyi(self)
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]:
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
# the codegen doesn't include vararg variants for deprecated signatures
return None
@ -530,14 +530,14 @@ class PythonSignatureGroup:
base: NativeFunction
# The out variant (e.g. conv2d_out)
outplace: Optional[NativeFunction]
outplace: NativeFunction | None
@classmethod
def from_pairs(
cls,
functional: PythonSignatureNativeFunctionPair,
out: Optional[PythonSignatureNativeFunctionPair],
) -> "PythonSignatureGroup":
out: PythonSignatureNativeFunctionPair | None,
) -> PythonSignatureGroup:
if out is None:
return PythonSignatureGroup(
signature=functional.signature,
@ -716,7 +716,7 @@ def argument_type_str(
raise RuntimeError(f"unrecognized type {repr(t)}")
def argument_type_size(t: Type) -> Optional[int]:
def argument_type_size(t: Type) -> int | None:
l = t.is_list_like()
if l is not None and str(l.elem) != "bool":
return l.size
@ -750,11 +750,11 @@ def signature(
def signature_from_schema(
func: FunctionSchema,
*,
category_override: Optional[str],
category_override: str | None,
method: bool = False,
pyi: bool = False,
) -> PythonSignature:
args: List[Argument] = []
args: list[Argument] = []
args.extend(func.arguments.pre_self_positional)
# Skip SelfArgument if this is method.
if not method and func.arguments.self_arg is not None:
@ -807,10 +807,10 @@ def signature_from_schema(
)
is_dummy_function = category_override == "dummy"
tensor_options_args: List[PythonArgument] = []
tensor_options_args: list[PythonArgument] = []
if (is_factory_function or is_like_or_new_function) and not is_dummy_function:
def topt_default_init(name: str) -> Optional[str]:
def topt_default_init(name: str) -> str | None:
topt_args = func.arguments.tensor_options
if topt_args is None:
return None
@ -891,7 +891,7 @@ def signature_from_schema(
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
def structseq_fieldnames(returns: Tuple[Return, ...]) -> List[str]:
def structseq_fieldnames(returns: tuple[Return, ...]) -> list[str]:
if len(returns) <= 1 or all(r.name is None for r in returns):
return []
else:
@ -1002,7 +1002,7 @@ def return_type_str_pyi(t: Type) -> str:
return argument_type_str_pyi(t)
def returns_structseq_pyi(signature: PythonSignature) -> Optional[Tuple[str, str]]:
def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
structseq_name = signature.name
field_names = structseq_fieldnames(signature.returns.returns)
@ -1104,7 +1104,7 @@ def returns_str_pyi(signature: PythonSignature) -> str:
def dispatch_lambda_args(
ps: PythonSignature, f: NativeFunction, symint: bool = True
) -> Tuple[DispatchLambdaArgument, ...]:
) -> tuple[DispatchLambdaArgument, ...]:
if isinstance(ps, PythonSignatureDeprecated):
schema = ps.deprecated_schema
else:
@ -1118,7 +1118,7 @@ def dispatch_lambda_args(
method=False,
cpp_no_default_args=f.cpp_no_default_args,
)
out_args: Set[str] = {a.name for a in schema.arguments.out}
out_args: set[str] = {a.name for a in schema.arguments.out}
# Convert from cpp argument to lambda argument
def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument:
@ -1224,11 +1224,11 @@ def cpp_dispatch_target(f: NativeFunction) -> str:
def cpp_dispatch_exprs(
f: NativeFunction,
*,
python_signature: Optional[PythonSignature] = None,
) -> Tuple[str, ...]:
python_signature: PythonSignature | None = None,
) -> tuple[str, ...]:
cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments()
exprs: Tuple[str, ...] = tuple()
exprs: tuple[str, ...] = tuple()
if not isinstance(python_signature, PythonSignatureDeprecated):
# By default the exprs are consistent with the C++ signature.
exprs = tuple(a.name for a in cpp_args)
@ -1262,7 +1262,7 @@ def cpp_dispatch_exprs(
# For certain cases it is intentionally more restrictive than necessary,
# e.g.: it doesn't accepts doublelist with definite size.
def arg_parser_unpack_method(
t: Type, default: Optional[str], default_init: Optional[str], *, symint: bool = True
t: Type, default: str | None, default_init: str | None, *, symint: bool = True
) -> str:
has_default_init = default_init is not None
if has_default_init and str(t) not in (
@ -1377,7 +1377,7 @@ def arg_parser_output_expr(
# Returns a map with key = arg_name and value = PythonArgParserOutputExpr.
def arg_parser_output_exprs(
ps: PythonSignature, f: NativeFunction, *, symint: bool = True
) -> Dict[str, PythonArgParserOutputExpr]:
) -> dict[str, PythonArgParserOutputExpr]:
return {
e.name: e
for i, a in enumerate(ps.arguments())
@ -1404,8 +1404,8 @@ def dispatch_lambda_exprs(
# outputs.
arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
lambda_args = dispatch_lambda_args(ps, f, symint=symint)
inits: List[str] = []
lambda_args_exprs: Dict[str, str] = {}
inits: list[str] = []
lambda_args_exprs: dict[str, str] = {}
has_toptions = has_tensor_options(f)