mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][Easy] enable postponed annotations in tools
(#129375)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129375 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
58f346c874
commit
8a67daf283
@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import pathlib
|
||||
import random
|
||||
import sys
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
|
||||
try:
|
||||
@ -17,12 +18,12 @@ except ModuleNotFoundError:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def gen_class_times(test_times: Dict[str, float]) -> Dict[str, Dict[str, float]]:
|
||||
def gen_class_times(test_times: dict[str, float]) -> dict[str, dict[str, float]]:
|
||||
return {k: {"class1": v} for k, v in test_times.items()}
|
||||
|
||||
|
||||
class TestCalculateShards(unittest.TestCase):
|
||||
tests: List[TestRun] = [
|
||||
tests: list[TestRun] = [
|
||||
TestRun("super_long_test"),
|
||||
TestRun("long_test1"),
|
||||
TestRun("long_test2"),
|
||||
@ -36,7 +37,7 @@ class TestCalculateShards(unittest.TestCase):
|
||||
TestRun("short_test5"),
|
||||
]
|
||||
|
||||
test_times: Dict[str, float] = {
|
||||
test_times: dict[str, float] = {
|
||||
"super_long_test": 55,
|
||||
"long_test1": 22,
|
||||
"long_test2": 18,
|
||||
@ -50,7 +51,7 @@ class TestCalculateShards(unittest.TestCase):
|
||||
"short_test5": 0.01,
|
||||
}
|
||||
|
||||
test_class_times: Dict[str, Dict[str, float]] = {
|
||||
test_class_times: dict[str, dict[str, float]] = {
|
||||
"super_long_test": {"class1": 55},
|
||||
"long_test1": {"class1": 1, "class2": 21},
|
||||
"long_test2": {"class1": 10, "class2": 8},
|
||||
@ -66,8 +67,8 @@ class TestCalculateShards(unittest.TestCase):
|
||||
|
||||
def assert_shards_equal(
|
||||
self,
|
||||
expected_shards: List[Tuple[float, List[ShardedTest]]],
|
||||
actual_shards: List[Tuple[float, List[ShardedTest]]],
|
||||
expected_shards: list[tuple[float, list[ShardedTest]]],
|
||||
actual_shards: list[tuple[float, list[ShardedTest]]],
|
||||
) -> None:
|
||||
for expected, actual in zip(expected_shards, actual_shards):
|
||||
self.assertAlmostEqual(expected[0], actual[0])
|
||||
@ -363,7 +364,7 @@ class TestCalculateShards(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_split_shards(self) -> None:
|
||||
test_times: Dict[str, float] = {"test1": THRESHOLD, "test2": THRESHOLD}
|
||||
test_times: dict[str, float] = {"test1": THRESHOLD, "test2": THRESHOLD}
|
||||
expected_shards = [
|
||||
(600.0, [ShardedTest(test="test1", shard=1, num_shards=1, time=THRESHOLD)]),
|
||||
(600.0, [ShardedTest(test="test2", shard=1, num_shards=1, time=THRESHOLD)]),
|
||||
@ -438,7 +439,7 @@ class TestCalculateShards(unittest.TestCase):
|
||||
tests = [TestRun(x) for x in test_names]
|
||||
serial = [x for x in test_names if random.randint(0, 1) == 0]
|
||||
has_times = [x for x in test_names if random.randint(0, 1) == 0]
|
||||
random_times: Dict[str, float] = {
|
||||
random_times: dict[str, float] = {
|
||||
i: random.randint(0, THRESHOLD * 10) for i in has_times
|
||||
}
|
||||
sort_by_time = random.randint(0, 1) == 0
|
||||
@ -456,7 +457,7 @@ class TestCalculateShards(unittest.TestCase):
|
||||
max_diff = max(times) - min(times)
|
||||
self.assertTrue(max_diff <= THRESHOLD + (num_tests - len(has_times)) * 60)
|
||||
|
||||
all_sharded_tests: Dict[str, List[ShardedTest]] = defaultdict(list)
|
||||
all_sharded_tests: dict[str, list[ShardedTest]] = defaultdict(list)
|
||||
for _, sharded_tests in shards:
|
||||
for sharded_test in sharded_tests:
|
||||
all_sharded_tests[sharded_test.name].append(sharded_test)
|
||||
|
Reference in New Issue
Block a user