mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[BE][Easy] enable postponed annotations in torchgen
(#129376)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129376 Approved by: https://github.com/ezyang ghstack dependencies: #129375
This commit is contained in:
committed by
PyTorch MergeBot
parent
59eb2897f1
commit
494057d6d4
@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
|
||||
from typing import Sequence
|
||||
|
||||
from torchgen.api import cpp
|
||||
from torchgen.api.types import Binding, CppSignature, CppSignatureGroup
|
||||
@ -197,14 +199,14 @@ from torchgen.model import (
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PythonReturns:
|
||||
returns: Tuple[Return, ...]
|
||||
returns: tuple[Return, ...]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PythonArgument:
|
||||
name: str
|
||||
type: Type
|
||||
default: Optional[str]
|
||||
default: str | None
|
||||
|
||||
# Used to generate the default init expr for some PythonArgParser outputs, e.g.:
|
||||
#
|
||||
@ -212,7 +214,7 @@ class PythonArgument:
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# ^
|
||||
# +--- default_init str
|
||||
default_init: Optional[str]
|
||||
default_init: str | None
|
||||
|
||||
# Compute argument formal for python argument parsing.
|
||||
# Needs to be consistent with torch/csrc/utils/python_arg_parser.h.
|
||||
@ -300,12 +302,10 @@ class PythonOutArgument(PythonArgument):
|
||||
# 'auto out = _r.tensorlist_n<2>(2);',
|
||||
# then binded to scattered C++ output arguments as 'out[0]', 'out[1]', and etc.
|
||||
# TODO: maybe don't need keep scattered out fields for python signature?
|
||||
outputs: Tuple[PythonArgument, ...]
|
||||
outputs: tuple[PythonArgument, ...]
|
||||
|
||||
@staticmethod
|
||||
def from_outputs(
|
||||
outputs: Tuple[PythonArgument, ...]
|
||||
) -> Optional["PythonOutArgument"]:
|
||||
def from_outputs(outputs: tuple[PythonArgument, ...]) -> PythonOutArgument | None:
|
||||
if not outputs:
|
||||
return None
|
||||
|
||||
@ -339,13 +339,13 @@ class PythonSignature:
|
||||
|
||||
# Positional arguments.
|
||||
# TODO: create a dedicated SelfArgument type for 'self'?
|
||||
input_args: Tuple[PythonArgument, ...]
|
||||
input_args: tuple[PythonArgument, ...]
|
||||
|
||||
# Keyword arguments excluding the 'out' argument and scattered kwargs belonging
|
||||
# to TensorOptions (dtype, layout, device, pin_memory, requires_grad, etc).
|
||||
input_kwargs: Tuple[PythonArgument, ...]
|
||||
input_kwargs: tuple[PythonArgument, ...]
|
||||
|
||||
output_args: Optional[PythonOutArgument]
|
||||
output_args: PythonOutArgument | None
|
||||
|
||||
# Return types, which are only used by pyi
|
||||
returns: PythonReturns
|
||||
@ -356,7 +356,7 @@ class PythonSignature:
|
||||
# for out variant), in which case they will be used as scattered fields without
|
||||
# being packed into 'options'.
|
||||
# TODO: maybe create a PythonTensorOptionsArgument?
|
||||
tensor_options_args: Tuple[PythonArgument, ...]
|
||||
tensor_options_args: tuple[PythonArgument, ...]
|
||||
|
||||
# method or function signature?
|
||||
method: bool
|
||||
@ -367,8 +367,8 @@ class PythonSignature:
|
||||
|
||||
def arguments(
|
||||
self, *, skip_outputs: bool = False, skip_tensor_options: bool = False
|
||||
) -> Tuple[Union[PythonArgument, PythonOutArgument], ...]:
|
||||
result: List[Union[PythonArgument, PythonOutArgument]] = []
|
||||
) -> tuple[PythonArgument | PythonOutArgument, ...]:
|
||||
result: list[PythonArgument | PythonOutArgument] = []
|
||||
result.extend(self.input_args)
|
||||
result.extend(self.input_kwargs)
|
||||
if self.output_args is not None and not skip_outputs:
|
||||
@ -394,7 +394,7 @@ class PythonSignature:
|
||||
# signature_str_pyi().
|
||||
def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
|
||||
args = self.arguments(skip_outputs=skip_outputs)
|
||||
schema_formals: List[str] = [
|
||||
schema_formals: list[str] = [
|
||||
a.argument_str(method=self.method, symint=symint) for a in args
|
||||
]
|
||||
positional_argc = len(self.input_args)
|
||||
@ -405,7 +405,7 @@ class PythonSignature:
|
||||
|
||||
def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
|
||||
args = self.arguments(skip_outputs=skip_outputs)
|
||||
schema_formals: List[str] = [
|
||||
schema_formals: list[str] = [
|
||||
a.argument_str_pyi(method=self.method) for a in args
|
||||
]
|
||||
positional_argc = len(self.input_args)
|
||||
@ -419,10 +419,10 @@ class PythonSignature:
|
||||
schema_formals.insert(0, "self")
|
||||
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
|
||||
|
||||
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]:
|
||||
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
|
||||
# only pyi uses vararg signatures
|
||||
args = self.arguments(skip_outputs=skip_outputs)
|
||||
schema_formals: List[str] = [
|
||||
schema_formals: list[str] = [
|
||||
a.argument_str_pyi(method=self.method) for a in args
|
||||
]
|
||||
# vararg only applies to pyi signatures. vararg variants are not generated for all signatures
|
||||
@ -470,7 +470,7 @@ class PythonSignatureDeprecated(PythonSignature):
|
||||
# [func schema]: aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
|
||||
# [func call]: self.addmm(mat1, mat2, beta, 1)
|
||||
# We store ['self', 'mat1', 'mat2', 'beta', '1'] in this case.
|
||||
deprecated_args_exprs: Tuple[str, ...]
|
||||
deprecated_args_exprs: tuple[str, ...]
|
||||
|
||||
@property
|
||||
def deprecated(self) -> bool:
|
||||
@ -486,7 +486,7 @@ class PythonSignatureDeprecated(PythonSignature):
|
||||
|
||||
def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
|
||||
args = self.arguments(skip_outputs=skip_outputs)
|
||||
schema_formals: List[str] = [
|
||||
schema_formals: list[str] = [
|
||||
a.argument_str_pyi(method=self.method, deprecated=True) for a in args
|
||||
]
|
||||
positional_argc = len(self.input_args)
|
||||
@ -496,7 +496,7 @@ class PythonSignatureDeprecated(PythonSignature):
|
||||
returns_str = returns_str_pyi(self)
|
||||
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
|
||||
|
||||
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]:
|
||||
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
|
||||
# the codegen doesn't include vararg variants for deprecated signatures
|
||||
return None
|
||||
|
||||
@ -530,14 +530,14 @@ class PythonSignatureGroup:
|
||||
base: NativeFunction
|
||||
|
||||
# The out variant (e.g. conv2d_out)
|
||||
outplace: Optional[NativeFunction]
|
||||
outplace: NativeFunction | None
|
||||
|
||||
@classmethod
|
||||
def from_pairs(
|
||||
cls,
|
||||
functional: PythonSignatureNativeFunctionPair,
|
||||
out: Optional[PythonSignatureNativeFunctionPair],
|
||||
) -> "PythonSignatureGroup":
|
||||
out: PythonSignatureNativeFunctionPair | None,
|
||||
) -> PythonSignatureGroup:
|
||||
if out is None:
|
||||
return PythonSignatureGroup(
|
||||
signature=functional.signature,
|
||||
@ -716,7 +716,7 @@ def argument_type_str(
|
||||
raise RuntimeError(f"unrecognized type {repr(t)}")
|
||||
|
||||
|
||||
def argument_type_size(t: Type) -> Optional[int]:
|
||||
def argument_type_size(t: Type) -> int | None:
|
||||
l = t.is_list_like()
|
||||
if l is not None and str(l.elem) != "bool":
|
||||
return l.size
|
||||
@ -750,11 +750,11 @@ def signature(
|
||||
def signature_from_schema(
|
||||
func: FunctionSchema,
|
||||
*,
|
||||
category_override: Optional[str],
|
||||
category_override: str | None,
|
||||
method: bool = False,
|
||||
pyi: bool = False,
|
||||
) -> PythonSignature:
|
||||
args: List[Argument] = []
|
||||
args: list[Argument] = []
|
||||
args.extend(func.arguments.pre_self_positional)
|
||||
# Skip SelfArgument if this is method.
|
||||
if not method and func.arguments.self_arg is not None:
|
||||
@ -807,10 +807,10 @@ def signature_from_schema(
|
||||
)
|
||||
is_dummy_function = category_override == "dummy"
|
||||
|
||||
tensor_options_args: List[PythonArgument] = []
|
||||
tensor_options_args: list[PythonArgument] = []
|
||||
if (is_factory_function or is_like_or_new_function) and not is_dummy_function:
|
||||
|
||||
def topt_default_init(name: str) -> Optional[str]:
|
||||
def topt_default_init(name: str) -> str | None:
|
||||
topt_args = func.arguments.tensor_options
|
||||
if topt_args is None:
|
||||
return None
|
||||
@ -891,7 +891,7 @@ def signature_from_schema(
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||
|
||||
|
||||
def structseq_fieldnames(returns: Tuple[Return, ...]) -> List[str]:
|
||||
def structseq_fieldnames(returns: tuple[Return, ...]) -> list[str]:
|
||||
if len(returns) <= 1 or all(r.name is None for r in returns):
|
||||
return []
|
||||
else:
|
||||
@ -1002,7 +1002,7 @@ def return_type_str_pyi(t: Type) -> str:
|
||||
return argument_type_str_pyi(t)
|
||||
|
||||
|
||||
def returns_structseq_pyi(signature: PythonSignature) -> Optional[Tuple[str, str]]:
|
||||
def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
|
||||
python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
|
||||
structseq_name = signature.name
|
||||
field_names = structseq_fieldnames(signature.returns.returns)
|
||||
@ -1104,7 +1104,7 @@ def returns_str_pyi(signature: PythonSignature) -> str:
|
||||
|
||||
def dispatch_lambda_args(
|
||||
ps: PythonSignature, f: NativeFunction, symint: bool = True
|
||||
) -> Tuple[DispatchLambdaArgument, ...]:
|
||||
) -> tuple[DispatchLambdaArgument, ...]:
|
||||
if isinstance(ps, PythonSignatureDeprecated):
|
||||
schema = ps.deprecated_schema
|
||||
else:
|
||||
@ -1118,7 +1118,7 @@ def dispatch_lambda_args(
|
||||
method=False,
|
||||
cpp_no_default_args=f.cpp_no_default_args,
|
||||
)
|
||||
out_args: Set[str] = {a.name for a in schema.arguments.out}
|
||||
out_args: set[str] = {a.name for a in schema.arguments.out}
|
||||
|
||||
# Convert from cpp argument to lambda argument
|
||||
def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument:
|
||||
@ -1224,11 +1224,11 @@ def cpp_dispatch_target(f: NativeFunction) -> str:
|
||||
def cpp_dispatch_exprs(
|
||||
f: NativeFunction,
|
||||
*,
|
||||
python_signature: Optional[PythonSignature] = None,
|
||||
) -> Tuple[str, ...]:
|
||||
python_signature: PythonSignature | None = None,
|
||||
) -> tuple[str, ...]:
|
||||
cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments()
|
||||
|
||||
exprs: Tuple[str, ...] = tuple()
|
||||
exprs: tuple[str, ...] = tuple()
|
||||
if not isinstance(python_signature, PythonSignatureDeprecated):
|
||||
# By default the exprs are consistent with the C++ signature.
|
||||
exprs = tuple(a.name for a in cpp_args)
|
||||
@ -1262,7 +1262,7 @@ def cpp_dispatch_exprs(
|
||||
# For certain cases it is intentionally more restrictive than necessary,
|
||||
# e.g.: it doesn't accepts doublelist with definite size.
|
||||
def arg_parser_unpack_method(
|
||||
t: Type, default: Optional[str], default_init: Optional[str], *, symint: bool = True
|
||||
t: Type, default: str | None, default_init: str | None, *, symint: bool = True
|
||||
) -> str:
|
||||
has_default_init = default_init is not None
|
||||
if has_default_init and str(t) not in (
|
||||
@ -1377,7 +1377,7 @@ def arg_parser_output_expr(
|
||||
# Returns a map with key = arg_name and value = PythonArgParserOutputExpr.
|
||||
def arg_parser_output_exprs(
|
||||
ps: PythonSignature, f: NativeFunction, *, symint: bool = True
|
||||
) -> Dict[str, PythonArgParserOutputExpr]:
|
||||
) -> dict[str, PythonArgParserOutputExpr]:
|
||||
return {
|
||||
e.name: e
|
||||
for i, a in enumerate(ps.arguments())
|
||||
@ -1404,8 +1404,8 @@ def dispatch_lambda_exprs(
|
||||
# outputs.
|
||||
arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
|
||||
lambda_args = dispatch_lambda_args(ps, f, symint=symint)
|
||||
inits: List[str] = []
|
||||
lambda_args_exprs: Dict[str, str] = {}
|
||||
inits: list[str] = []
|
||||
lambda_args_exprs: dict[str, str] = {}
|
||||
|
||||
has_toptions = has_tensor_options(f)
|
||||
|
||||
|
Reference in New Issue
Block a user