Revert "[BE][Easy] enable postponed annotations in torchgen (#129376)"

This reverts commit 494057d6d4e9b40daf81a6a4d7a8c839b7424b14.

Reverted https://github.com/pytorch/pytorch/pull/129376 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I need to revert to cleanly revert https://github.com/pytorch/pytorch/pull/129374, please do a rebase and reland this ([comment](https://github.com/pytorch/pytorch/pull/129375#issuecomment-2197800541))
This commit is contained in:
PyTorch MergeBot
2024-06-29 00:44:24 +00:00
parent 83caf4960f
commit 6063bb9d45
45 changed files with 900 additions and 976 deletions

View File

@ -1,5 +1,3 @@
from __future__ import annotations
import contextlib
import functools
import hashlib
@ -7,29 +5,31 @@ import os
import re
import sys
import textwrap
from argparse import Namespace
from dataclasses import fields, is_dataclass
from enum import auto, Enum
from typing import (
Any,
Callable,
Dict,
Generic,
Iterable,
Iterator,
List,
Literal,
NoReturn,
Optional,
Sequence,
TYPE_CHECKING,
Set,
Tuple,
TypeVar,
Union,
)
from typing_extensions import Self
from torchgen.code_template import CodeTemplate
if TYPE_CHECKING:
from argparse import Namespace
# Many of these functions share logic for defining both the definition
# and declaration (for example, the function signature is the same), so
# we organize them into one function that takes a Target to say which
@ -57,7 +57,7 @@ IDENT_REGEX = r"(^|\W){}($|\W)"
# TODO: Use a real parser here; this will get bamboozled
def split_name_params(schema: str) -> tuple[str, list[str]]:
def split_name_params(schema: str) -> Tuple[str, List[str]]:
m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema)
if m is None:
raise RuntimeError(f"Unsupported function schema: {schema}")
@ -73,7 +73,7 @@ S = TypeVar("S")
# Map over function that may return None; omit Nones from output sequence
def mapMaybe(func: Callable[[T], S | None], xs: Iterable[T]) -> Iterator[S]:
def mapMaybe(func: Callable[[T], Optional[S]], xs: Iterable[T]) -> Iterator[S]:
for x in xs:
r = func(x)
if r is not None:
@ -127,7 +127,7 @@ class FileManager:
install_dir: str
template_dir: str
dry_run: bool
filenames: set[str]
filenames: Set[str]
def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
self.install_dir = install_dir
@ -136,7 +136,7 @@ class FileManager:
self.dry_run = dry_run
def _write_if_changed(self, filename: str, contents: str) -> None:
old_contents: str | None
old_contents: Optional[str]
try:
with open(filename) as f:
old_contents = f.read()
@ -150,7 +150,7 @@ class FileManager:
# Read from template file and replace pattern with callable (type could be dict or str).
def substitute_with_template(
self, template_fn: str, env_callable: Callable[[], str | dict[str, Any]]
self, template_fn: str, env_callable: Callable[[], Union[str, Dict[str, Any]]]
) -> str:
template_path = os.path.join(self.template_dir, template_fn)
env = env_callable()
@ -171,7 +171,7 @@ class FileManager:
self,
filename: str,
template_fn: str,
env_callable: Callable[[], str | dict[str, Any]],
env_callable: Callable[[], Union[str, Dict[str, Any]]],
) -> None:
filename = f"{self.install_dir}/{filename}"
assert filename not in self.filenames, "duplicate file write {filename}"
@ -186,7 +186,7 @@ class FileManager:
def write(
self,
filename: str,
env_callable: Callable[[], str | dict[str, Any]],
env_callable: Callable[[], Union[str, Dict[str, Any]]],
) -> None:
self.write_with_template(filename, filename, env_callable)
@ -196,13 +196,13 @@ class FileManager:
items: Iterable[T],
*,
key_fn: Callable[[T], str],
env_callable: Callable[[T], dict[str, list[str]]],
env_callable: Callable[[T], Dict[str, List[str]]],
num_shards: int,
base_env: dict[str, Any] | None = None,
sharded_keys: set[str],
base_env: Optional[Dict[str, Any]] = None,
sharded_keys: Set[str],
) -> None:
everything: dict[str, Any] = {"shard_id": "Everything"}
shards: list[dict[str, Any]] = [
everything: Dict[str, Any] = {"shard_id": "Everything"}
shards: List[Dict[str, Any]] = [
{"shard_id": f"_{i}"} for i in range(num_shards)
]
all_shards = [everything] + shards
@ -221,7 +221,7 @@ class FileManager:
else:
shard[key] = []
def merge_env(into: dict[str, list[str]], from_: dict[str, list[str]]) -> None:
def merge_env(into: Dict[str, List[str]], from_: Dict[str, List[str]]) -> None:
for k, v in from_.items():
assert k in sharded_keys, f"undeclared sharded key {k}"
into[k] += v
@ -275,7 +275,7 @@ class FileManager:
# Helper function to generate file manager
def make_file_manager(
options: Namespace, install_dir: str | None = None
options: Namespace, install_dir: Optional[str] = None
) -> FileManager:
template_dir = os.path.join(options.source_path, "templates")
install_dir = install_dir if install_dir else options.install_dir
@ -335,7 +335,7 @@ def _pformat(
def _format_dict(
attr: dict[Any, Any],
attr: Dict[Any, Any],
indent: int,
width: int,
curr_indent: int,
@ -355,7 +355,7 @@ def _format_dict(
def _format_list(
attr: list[Any] | set[Any] | tuple[Any, ...],
attr: Union[List[Any], Set[Any], Tuple[Any, ...]],
indent: int,
width: int,
curr_indent: int,
@ -370,7 +370,7 @@ def _format_list(
def _format(
fields_str: list[str],
fields_str: List[str],
indent: int,
width: int,
curr_indent: int,
@ -402,9 +402,7 @@ class NamespaceHelper:
} // namespace torch
"""
def __init__(
self, namespace_str: str, entity_name: str = "", max_level: int = 2
) -> None:
def __init__(self, namespace_str: str, entity_name: str = "", max_level: int = 2):
# cpp_namespace can be a colon joined string such as torch::lazy
cpp_namespaces = namespace_str.split("::")
assert (
@ -421,7 +419,7 @@ class NamespaceHelper:
@staticmethod
def from_namespaced_entity(
namespaced_entity: str, max_level: int = 2
) -> NamespaceHelper:
) -> "NamespaceHelper":
"""
Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add"
"""
@ -454,9 +452,9 @@ class NamespaceHelper:
class OrderedSet(Generic[T]):
storage: dict[T, Literal[None]]
storage: Dict[T, Literal[None]]
def __init__(self, iterable: Iterable[T] | None = None) -> None:
def __init__(self, iterable: Optional[Iterable[T]] = None):
if iterable is None:
self.storage = {}
else:
@ -468,28 +466,28 @@ class OrderedSet(Generic[T]):
def __iter__(self) -> Iterator[T]:
return iter(self.storage.keys())
def update(self, items: OrderedSet[T]) -> None:
def update(self, items: "OrderedSet[T]") -> None:
self.storage.update(items.storage)
def add(self, item: T) -> None:
self.storage[item] = None
def copy(self) -> OrderedSet[T]:
def copy(self) -> "OrderedSet[T]":
ret: OrderedSet[T] = OrderedSet()
ret.storage = self.storage.copy()
return ret
@staticmethod
def union(*args: OrderedSet[T]) -> OrderedSet[T]:
def union(*args: "OrderedSet[T]") -> "OrderedSet[T]":
ret = args[0].copy()
for s in args[1:]:
ret.update(s)
return ret
def __or__(self, other: OrderedSet[T]) -> OrderedSet[T]:
def __or__(self, other: "OrderedSet[T]") -> "OrderedSet[T]":
return OrderedSet.union(self, other)
def __ior__(self, other: OrderedSet[T]) -> Self:
def __ior__(self, other: "OrderedSet[T]") -> Self:
self.update(other)
return self