[BE][Easy] enable postponed annotations in torchgen (#129376)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129376
Approved by: https://github.com/ezyang
ghstack dependencies: #129375
This commit is contained in:
Xuehai Pan
2024-06-28 16:28:16 +08:00
committed by PyTorch MergeBot
parent 59eb2897f1
commit 494057d6d4
45 changed files with 977 additions and 901 deletions

View File

@ -1,6 +1,8 @@
from __future__ import annotations
import textwrap
from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple
from typing import Sequence
from torchgen.api.translate import translate
from torchgen.api.types import DispatcherSignature
@ -32,7 +34,7 @@ def is_tensor_list(typ: Type) -> bool:
return isinstance(typ, ListType) and is_tensor(typ.elem)
def unwrap_tensor(name: str, cur_level_var: str) -> List[str]:
def unwrap_tensor(name: str, cur_level_var: str) -> list[str]:
result = f"""\
Tensor {name}_value;
optional<int64_t> {name}_bdim;
@ -40,7 +42,7 @@ def unwrap_tensor(name: str, cur_level_var: str) -> List[str]:
return textwrap.dedent(result).split("\n")
def unwrap_optional_tensor(name: str, cur_level_var: str) -> List[str]:
def unwrap_optional_tensor(name: str, cur_level_var: str) -> list[str]:
result = f"""\
optional<Tensor> {name}_value;
optional<int64_t> {name}_bdim;
@ -52,7 +54,7 @@ def unwrap_optional_tensor(name: str, cur_level_var: str) -> List[str]:
def gen_unwraps(
flat_arguments: Sequence[Argument], cur_level_var: str
) -> Tuple[str, List[str]]:
) -> tuple[str, list[str]]:
arg_names = [a.name for a in flat_arguments]
arg_types = [a.type for a in flat_arguments]
@ -99,7 +101,7 @@ if ({' && '.join(conditions)}) {{
def gen_returns(
returns: Tuple[Return, ...], cur_level_var: str, results_var: str
returns: tuple[Return, ...], cur_level_var: str, results_var: str
) -> str:
idx = 0
wrapped_returns = []
@ -132,7 +134,7 @@ def is_mutated_arg(argument: Argument) -> bool:
return argument.annotation is not None and argument.annotation.is_write
def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> Optional[str]:
def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> str | None:
# Assumptions:
# - only one argument is being modified in-place
# - the argument that is being modified in-place is the first argument
@ -197,7 +199,7 @@ template <typename batch_rule_t, batch_rule_t batch_rule>
}}"""
def gen_vmap_plumbing(native_function: NativeFunction) -> Optional[str]:
def gen_vmap_plumbing(native_function: NativeFunction) -> str | None:
schema = native_function.func
sig = DispatcherSignature.from_schema(schema)
returns = schema.returns
@ -244,7 +246,7 @@ template <typename batch_rule_t, batch_rule_t batch_rule>
@dataclass(frozen=True)
class ComputeBatchRulePlumbing:
@method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]:
def __call__(self, f: NativeFunction) -> str | None:
result = gen_vmap_plumbing(f)
return result