[BE][Easy] enable postponed annotations in tools (#129375)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129375
Approved by: https://github.com/malfet
This commit is contained in:
Xuehai Pan
2024-06-28 16:28:16 +08:00
committed by PyTorch MergeBot
parent 2e3ff394bf
commit 59eb2897f1
123 changed files with 1274 additions and 1051 deletions

View File

@ -1,6 +1,8 @@
from __future__ import annotations
from copy import copy
from functools import total_ordering
from typing import Any, Dict, FrozenSet, Iterable, List, Optional, Union
from typing import Any, Iterable
class TestRun:
@ -15,16 +17,16 @@ class TestRun:
"""
test_file: str
_excluded: FrozenSet[str] # Tests that should be excluded from this test run
_included: FrozenSet[
_excluded: frozenset[str] # Tests that should be excluded from this test run
_included: frozenset[
str
] # If non-empy, only these tests should be run in this test run
def __init__(
self,
name: str,
excluded: Optional[Iterable[str]] = None,
included: Optional[Iterable[str]] = None,
excluded: Iterable[str] | None = None,
included: Iterable[str] | None = None,
) -> None:
if excluded and included:
raise ValueError("Can't specify both included and excluded")
@ -45,7 +47,7 @@ class TestRun:
self._included = frozenset(ins)
@staticmethod
def empty() -> "TestRun":
def empty() -> TestRun:
return TestRun("")
def is_empty(self) -> bool:
@ -56,10 +58,10 @@ class TestRun:
def is_full_file(self) -> bool:
return not self._included and not self._excluded
def included(self) -> FrozenSet[str]:
def included(self) -> frozenset[str]:
return self._included
def excluded(self) -> FrozenSet[str]:
def excluded(self) -> frozenset[str]:
return self._excluded
def get_pytest_filter(self) -> str:
@ -70,7 +72,7 @@ class TestRun:
else:
return ""
def contains(self, test: "TestRun") -> bool:
def contains(self, test: TestRun) -> bool:
if self.test_file != test.test_file:
return False
@ -92,7 +94,7 @@ class TestRun:
# Does self exclude anything test includes? If not, we're good
return not self._excluded.intersection(test._included)
def __copy__(self) -> "TestRun":
def __copy__(self) -> TestRun:
return TestRun(self.test_file, excluded=self._excluded, included=self._included)
def __bool__(self) -> bool:
@ -126,7 +128,7 @@ class TestRun:
def __hash__(self) -> int:
return hash((self.test_file, self._included, self._excluded))
def __or__(self, other: "TestRun") -> "TestRun":
def __or__(self, other: TestRun) -> TestRun:
"""
To OR/Union test runs means to run all the tests that either of the two runs specify.
"""
@ -167,7 +169,7 @@ class TestRun:
excluded = self._excluded | other._excluded
return TestRun(self.test_file, excluded=excluded - included)
def __sub__(self, other: "TestRun") -> "TestRun":
def __sub__(self, other: TestRun) -> TestRun:
"""
To subtract test runs means to run all the tests in the first run except for what the second run specifies.
"""
@ -186,7 +188,7 @@ class TestRun:
if other.is_full_file():
return TestRun.empty()
def return_inclusions_or_empty(inclusions: FrozenSet[str]) -> TestRun:
def return_inclusions_or_empty(inclusions: frozenset[str]) -> TestRun:
if inclusions:
return TestRun(self.test_file, included=inclusions)
return TestRun.empty()
@ -204,14 +206,14 @@ class TestRun:
else:
return return_inclusions_or_empty(other._excluded - self._excluded)
def __and__(self, other: "TestRun") -> "TestRun":
def __and__(self, other: TestRun) -> TestRun:
if self.test_file != other.test_file:
return TestRun.empty()
return (self | other) - (self - other) - (other - self)
def to_json(self) -> Dict[str, Any]:
r: Dict[str, Any] = {
def to_json(self) -> dict[str, Any]:
r: dict[str, Any] = {
"test_file": self.test_file,
}
if self._included:
@ -221,7 +223,7 @@ class TestRun:
return r
@staticmethod
def from_json(json: Dict[str, Any]) -> "TestRun":
def from_json(json: dict[str, Any]) -> TestRun:
return TestRun(
json["test_file"],
included=json.get("included", []),
@ -234,14 +236,14 @@ class ShardedTest:
test: TestRun
shard: int
num_shards: int
time: Optional[float] # In seconds
time: float | None # In seconds
def __init__(
self,
test: Union[TestRun, str],
test: TestRun | str,
shard: int,
num_shards: int,
time: Optional[float] = None,
time: float | None = None,
) -> None:
if isinstance(test, str):
test = TestRun(test)
@ -296,7 +298,7 @@ class ShardedTest:
def get_time(self, default: float = 0) -> float:
return self.time if self.time is not None else default
def get_pytest_args(self) -> List[str]:
def get_pytest_args(self) -> list[str]:
filter = self.test.get_pytest_filter()
if filter:
return ["-k", self.test.get_pytest_filter()]