mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][Easy] enable postponed annotations in torchgen
(#129376)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129376 Approved by: https://github.com/ezyang ghstack dependencies: #129375
This commit is contained in:
committed by
PyTorch MergeBot
parent
59eb2897f1
commit
494057d6d4
@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import textwrap
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
from typing import Sequence
|
||||
|
||||
from torchgen.api.translate import translate
|
||||
from torchgen.api.types import DispatcherSignature
|
||||
@ -32,7 +34,7 @@ def is_tensor_list(typ: Type) -> bool:
|
||||
return isinstance(typ, ListType) and is_tensor(typ.elem)
|
||||
|
||||
|
||||
def unwrap_tensor(name: str, cur_level_var: str) -> List[str]:
|
||||
def unwrap_tensor(name: str, cur_level_var: str) -> list[str]:
|
||||
result = f"""\
|
||||
Tensor {name}_value;
|
||||
optional<int64_t> {name}_bdim;
|
||||
@ -40,7 +42,7 @@ def unwrap_tensor(name: str, cur_level_var: str) -> List[str]:
|
||||
return textwrap.dedent(result).split("\n")
|
||||
|
||||
|
||||
def unwrap_optional_tensor(name: str, cur_level_var: str) -> List[str]:
|
||||
def unwrap_optional_tensor(name: str, cur_level_var: str) -> list[str]:
|
||||
result = f"""\
|
||||
optional<Tensor> {name}_value;
|
||||
optional<int64_t> {name}_bdim;
|
||||
@ -52,7 +54,7 @@ def unwrap_optional_tensor(name: str, cur_level_var: str) -> List[str]:
|
||||
|
||||
def gen_unwraps(
|
||||
flat_arguments: Sequence[Argument], cur_level_var: str
|
||||
) -> Tuple[str, List[str]]:
|
||||
) -> tuple[str, list[str]]:
|
||||
arg_names = [a.name for a in flat_arguments]
|
||||
arg_types = [a.type for a in flat_arguments]
|
||||
|
||||
@ -99,7 +101,7 @@ if ({' && '.join(conditions)}) {{
|
||||
|
||||
|
||||
def gen_returns(
|
||||
returns: Tuple[Return, ...], cur_level_var: str, results_var: str
|
||||
returns: tuple[Return, ...], cur_level_var: str, results_var: str
|
||||
) -> str:
|
||||
idx = 0
|
||||
wrapped_returns = []
|
||||
@ -132,7 +134,7 @@ def is_mutated_arg(argument: Argument) -> bool:
|
||||
return argument.annotation is not None and argument.annotation.is_write
|
||||
|
||||
|
||||
def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> Optional[str]:
|
||||
def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> str | None:
|
||||
# Assumptions:
|
||||
# - only one argument is being modified in-place
|
||||
# - the argument that is being modified in-place is the first argument
|
||||
@ -197,7 +199,7 @@ template <typename batch_rule_t, batch_rule_t batch_rule>
|
||||
}}"""
|
||||
|
||||
|
||||
def gen_vmap_plumbing(native_function: NativeFunction) -> Optional[str]:
|
||||
def gen_vmap_plumbing(native_function: NativeFunction) -> str | None:
|
||||
schema = native_function.func
|
||||
sig = DispatcherSignature.from_schema(schema)
|
||||
returns = schema.returns
|
||||
@ -244,7 +246,7 @@ template <typename batch_rule_t, batch_rule_t batch_rule>
|
||||
@dataclass(frozen=True)
|
||||
class ComputeBatchRulePlumbing:
|
||||
@method_with_native_function
|
||||
def __call__(self, f: NativeFunction) -> Optional[str]:
|
||||
def __call__(self, f: NativeFunction) -> str | None:
|
||||
result = gen_vmap_plumbing(f)
|
||||
return result
|
||||
|
||||
|
Reference in New Issue
Block a user