mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][Easy] enable postponed annotations in tools (#129375)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129375 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
2e3ff394bf
commit
59eb2897f1
@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import copy
|
||||
from functools import total_ordering
|
||||
from typing import Any, Dict, FrozenSet, Iterable, List, Optional, Union
|
||||
from typing import Any, Iterable
|
||||
|
||||
|
||||
class TestRun:
|
||||
@ -15,16 +17,16 @@ class TestRun:
|
||||
"""
|
||||
|
||||
test_file: str
|
||||
_excluded: FrozenSet[str] # Tests that should be excluded from this test run
|
||||
_included: FrozenSet[
|
||||
_excluded: frozenset[str] # Tests that should be excluded from this test run
|
||||
_included: frozenset[
|
||||
str
|
||||
] # If non-empy, only these tests should be run in this test run
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
excluded: Optional[Iterable[str]] = None,
|
||||
included: Optional[Iterable[str]] = None,
|
||||
excluded: Iterable[str] | None = None,
|
||||
included: Iterable[str] | None = None,
|
||||
) -> None:
|
||||
if excluded and included:
|
||||
raise ValueError("Can't specify both included and excluded")
|
||||
@ -45,7 +47,7 @@ class TestRun:
|
||||
self._included = frozenset(ins)
|
||||
|
||||
@staticmethod
|
||||
def empty() -> "TestRun":
|
||||
def empty() -> TestRun:
|
||||
return TestRun("")
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
@ -56,10 +58,10 @@ class TestRun:
|
||||
def is_full_file(self) -> bool:
|
||||
return not self._included and not self._excluded
|
||||
|
||||
def included(self) -> FrozenSet[str]:
|
||||
def included(self) -> frozenset[str]:
|
||||
return self._included
|
||||
|
||||
def excluded(self) -> FrozenSet[str]:
|
||||
def excluded(self) -> frozenset[str]:
|
||||
return self._excluded
|
||||
|
||||
def get_pytest_filter(self) -> str:
|
||||
@ -70,7 +72,7 @@ class TestRun:
|
||||
else:
|
||||
return ""
|
||||
|
||||
def contains(self, test: "TestRun") -> bool:
|
||||
def contains(self, test: TestRun) -> bool:
|
||||
if self.test_file != test.test_file:
|
||||
return False
|
||||
|
||||
@ -92,7 +94,7 @@ class TestRun:
|
||||
# Does self exclude anything test includes? If not, we're good
|
||||
return not self._excluded.intersection(test._included)
|
||||
|
||||
def __copy__(self) -> "TestRun":
|
||||
def __copy__(self) -> TestRun:
|
||||
return TestRun(self.test_file, excluded=self._excluded, included=self._included)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
@ -126,7 +128,7 @@ class TestRun:
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.test_file, self._included, self._excluded))
|
||||
|
||||
def __or__(self, other: "TestRun") -> "TestRun":
|
||||
def __or__(self, other: TestRun) -> TestRun:
|
||||
"""
|
||||
To OR/Union test runs means to run all the tests that either of the two runs specify.
|
||||
"""
|
||||
@ -167,7 +169,7 @@ class TestRun:
|
||||
excluded = self._excluded | other._excluded
|
||||
return TestRun(self.test_file, excluded=excluded - included)
|
||||
|
||||
def __sub__(self, other: "TestRun") -> "TestRun":
|
||||
def __sub__(self, other: TestRun) -> TestRun:
|
||||
"""
|
||||
To subtract test runs means to run all the tests in the first run except for what the second run specifies.
|
||||
"""
|
||||
@ -186,7 +188,7 @@ class TestRun:
|
||||
if other.is_full_file():
|
||||
return TestRun.empty()
|
||||
|
||||
def return_inclusions_or_empty(inclusions: FrozenSet[str]) -> TestRun:
|
||||
def return_inclusions_or_empty(inclusions: frozenset[str]) -> TestRun:
|
||||
if inclusions:
|
||||
return TestRun(self.test_file, included=inclusions)
|
||||
return TestRun.empty()
|
||||
@ -204,14 +206,14 @@ class TestRun:
|
||||
else:
|
||||
return return_inclusions_or_empty(other._excluded - self._excluded)
|
||||
|
||||
def __and__(self, other: "TestRun") -> "TestRun":
|
||||
def __and__(self, other: TestRun) -> TestRun:
|
||||
if self.test_file != other.test_file:
|
||||
return TestRun.empty()
|
||||
|
||||
return (self | other) - (self - other) - (other - self)
|
||||
|
||||
def to_json(self) -> Dict[str, Any]:
|
||||
r: Dict[str, Any] = {
|
||||
def to_json(self) -> dict[str, Any]:
|
||||
r: dict[str, Any] = {
|
||||
"test_file": self.test_file,
|
||||
}
|
||||
if self._included:
|
||||
@ -221,7 +223,7 @@ class TestRun:
|
||||
return r
|
||||
|
||||
@staticmethod
|
||||
def from_json(json: Dict[str, Any]) -> "TestRun":
|
||||
def from_json(json: dict[str, Any]) -> TestRun:
|
||||
return TestRun(
|
||||
json["test_file"],
|
||||
included=json.get("included", []),
|
||||
@ -234,14 +236,14 @@ class ShardedTest:
|
||||
test: TestRun
|
||||
shard: int
|
||||
num_shards: int
|
||||
time: Optional[float] # In seconds
|
||||
time: float | None # In seconds
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
test: Union[TestRun, str],
|
||||
test: TestRun | str,
|
||||
shard: int,
|
||||
num_shards: int,
|
||||
time: Optional[float] = None,
|
||||
time: float | None = None,
|
||||
) -> None:
|
||||
if isinstance(test, str):
|
||||
test = TestRun(test)
|
||||
@ -296,7 +298,7 @@ class ShardedTest:
|
||||
def get_time(self, default: float = 0) -> float:
|
||||
return self.time if self.time is not None else default
|
||||
|
||||
def get_pytest_args(self) -> List[str]:
|
||||
def get_pytest_args(self) -> list[str]:
|
||||
filter = self.test.get_pytest_filter()
|
||||
if filter:
|
||||
return ["-k", self.test.get_pytest_filter()]
|
||||
|
||||
Reference in New Issue
Block a user