[BE][Easy] enable postponed annotations in tools (#129375)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129375
Approved by: https://github.com/malfet
This commit is contained in:
Xuehai Pan
2024-06-29 12:48:06 +08:00
committed by PyTorch MergeBot
parent 58f346c874
commit 8a67daf283
123 changed files with 1274 additions and 1053 deletions

View File

@ -5,7 +5,6 @@ import argparse
import json
import unittest
from collections import defaultdict
from unittest.mock import Mock, patch
from gen_operators_yaml import (
@ -43,10 +42,10 @@ def _mock_load_op_dep_graph():
class GenOperatorsYAMLTest(unittest.TestCase):
def setUp(self):
def setUp(self) -> None:
pass
def test_filter_creation(self):
def test_filter_creation(self) -> None:
filter_func = make_filter_from_options(
model_name="abc",
model_versions=["100", "101"],
@ -99,7 +98,7 @@ class GenOperatorsYAMLTest(unittest.TestCase):
len(filtered_configs) == 2
), f"Expected 2 elements in filtered_configs, but got {len(filtered_configs)}"
def test_verification_success(self):
def test_verification_success(self) -> None:
filter_func = make_filter_from_options(
model_name="abc",
model_versions=["100", "101"],
@ -142,7 +141,7 @@ class GenOperatorsYAMLTest(unittest.TestCase):
"expected verify_all_specified_present to succeed instead it raised an exception"
)
def test_verification_fail(self):
def test_verification_fail(self) -> None:
config = [
{
"model": {
@ -229,7 +228,7 @@ class GenOperatorsYAMLTest(unittest.TestCase):
)
def test_fill_output_with_arguments_not_include_all_overloads(
self, mock_parse_options: Mock, mock_load_op_dep_graph: Mock
):
) -> None:
parser = argparse.ArgumentParser(description="Generate used operators YAML")
options = get_parser_options(parser)

View File

@ -8,10 +8,10 @@ from tools.code_analyzer.gen_oplist import throw_if_any_op_includes_overloads
class GenOplistTest(unittest.TestCase):
def setUp(self):
def setUp(self) -> None:
pass
def test_throw_if_any_op_includes_overloads(self):
def test_throw_if_any_op_includes_overloads(self) -> None:
selective_builder = MagicMock()
selective_builder.operators = MagicMock()
selective_builder.operators.items.return_value = [

View File

@ -1,10 +1,12 @@
# For testing specific heuristics
from __future__ import annotations
import io
import json
import pathlib
import sys
import unittest
from typing import Any, Dict, List, Set
from typing import Any
from unittest import mock
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
@ -28,14 +30,14 @@ sys.path.remove(str(REPO_ROOT))
HEURISTIC_CLASS = "tools.testing.target_determination.heuristics.historical_class_failure_correlation."
def mocked_file(contents: Dict[Any, Any]) -> io.IOBase:
def mocked_file(contents: dict[Any, Any]) -> io.IOBase:
file_object = io.StringIO()
json.dump(contents, file_object)
file_object.seek(0)
return file_object
def gen_historical_class_failures() -> Dict[str, Dict[str, float]]:
def gen_historical_class_failures() -> dict[str, dict[str, float]]:
return {
"file1": {
"test1::classA": 0.5,
@ -80,8 +82,8 @@ class TestHistoricalClassFailureCorrelation(TestTD):
)
def test_get_prediction_confidence(
self,
historical_class_failures: Dict[str, Dict[str, float]],
changed_files: List[str],
historical_class_failures: dict[str, dict[str, float]],
changed_files: list[str],
) -> None:
tests_to_prioritize = ALL_TESTS
@ -113,7 +115,7 @@ class TestHistoricalClassFailureCorrelation(TestTD):
class TestParsePrevTests(TestTD):
@mock.patch("os.path.exists", return_value=False)
def test_cache_does_not_exist(self, mock_exists: Any) -> None:
expected_failing_test_files: Set[str] = set()
expected_failing_test_files: set[str] = set()
found_tests = get_previous_failures()
@ -122,7 +124,7 @@ class TestParsePrevTests(TestTD):
@mock.patch("os.path.exists", return_value=True)
@mock.patch("builtins.open", return_value=mocked_file({"": True}))
def test_empty_cache(self, mock_exists: Any, mock_open: Any) -> None:
expected_failing_test_files: Set[str] = set()
expected_failing_test_files: set[str] = set()
found_tests = get_previous_failures()

View File

@ -1,7 +1,9 @@
from __future__ import annotations
import pathlib
import sys
import unittest
from typing import Any, Dict, List
from typing import Any
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
sys.path.append(str(REPO_ROOT))
@ -13,7 +15,7 @@ sys.path.remove(str(REPO_ROOT))
class TestTD(unittest.TestCase):
def assert_test_scores_almost_equal(
self, d1: Dict[TestRun, float], d2: Dict[TestRun, float]
self, d1: dict[TestRun, float], d2: dict[TestRun, float]
) -> None:
# Check that dictionaries are the same, except for floating point errors
self.assertEqual(set(d1.keys()), set(d2.keys()))
@ -24,7 +26,7 @@ class TestTD(unittest.TestCase):
# Create a dummy heuristic class
class Heuristic(interface.HeuristicInterface):
def get_prediction_confidence(
self, tests: List[str]
self, tests: list[str]
) -> interface.TestPrioritizations:
# Return junk
return interface.TestPrioritizations([], {})
@ -259,9 +261,9 @@ class TestTestPrioritizations(TestTD):
class TestAggregatedHeuristics(TestTD):
def check(
self,
tests: List[str],
test_prioritizations: List[Dict[TestRun, float]],
expected: Dict[TestRun, float],
tests: list[str],
test_prioritizations: list[dict[TestRun, float]],
expected: dict[TestRun, float],
) -> None:
aggregated_heuristics = interface.AggregatedHeuristics(tests)
for i, test_prioritization in enumerate(test_prioritizations):
@ -429,7 +431,7 @@ class TestAggregatedHeuristicsTestStats(TestTD):
stats3 = aggregator.get_test_stats(TestRun("test3"))
stats5 = aggregator.get_test_stats(TestRun("test5::classA"))
def assert_valid_dict(dict_contents: Dict[str, Any]) -> None:
def assert_valid_dict(dict_contents: dict[str, Any]) -> None:
for key, value in dict_contents.items():
self.assertTrue(isinstance(key, str))
self.assertTrue(

View File

@ -1,7 +1,9 @@
from __future__ import annotations
import pathlib
import sys
import unittest
from typing import Any, Dict
from typing import Any
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
@ -14,14 +16,14 @@ sys.path.remove(str(REPO_ROOT))
class TestHeuristicsUtils(unittest.TestCase):
def assertDictAlmostEqual(
self, first: Dict[TestRun, Any], second: Dict[TestRun, Any]
self, first: dict[TestRun, Any], second: dict[TestRun, Any]
) -> None:
self.assertEqual(first.keys(), second.keys())
for key in first.keys():
self.assertAlmostEqual(first[key], second[key])
def test_normalize_ratings(self) -> None:
ratings: Dict[TestRun, float] = {
ratings: dict[TestRun, float] = {
TestRun("test1"): 1,
TestRun("test2"): 2,
TestRun("test3"): 4,

View File

@ -1,12 +1,13 @@
from __future__ import annotations
import contextlib
import os
import typing
import unittest
import unittest.mock
from typing import Iterator, Optional, Sequence
from typing import Iterator, Sequence
import tools.setup_helpers.cmake
import tools.setup_helpers.env # noqa: F401 unused but resolves circular import
@ -79,7 +80,7 @@ class TestCMake(unittest.TestCase):
@contextlib.contextmanager
def env_var(key: str, value: Optional[str]) -> Iterator[None]:
def env_var(key: str, value: str | None) -> Iterator[None]:
"""Sets/clears an environment variable within a Python context."""
# Get the previous value and then override it.
previous_value = os.environ.get(key)
@ -91,7 +92,7 @@ def env_var(key: str, value: Optional[str]) -> Iterator[None]:
set_env_var(key, previous_value)
def set_env_var(key: str, value: Optional[str]) -> None:
def set_env_var(key: str, value: str | None) -> None:
"""Sets/clears an environment variable."""
if value is None:
os.environ.pop(key, None)

View File

@ -1,14 +1,13 @@
from __future__ import annotations
import dataclasses
import typing
import unittest
from collections import defaultdict
from typing import Dict, List
import yaml
from tools.autograd import gen_autograd_functions, load_derivatives
import torchgen.model
from torchgen import dest
from torchgen.api.types import CppSignatureGroup, DispatcherSignature
from torchgen.context import native_function_manager
@ -22,6 +21,7 @@ from torchgen.model import (
BackendIndex,
BackendMetadata,
DispatchKey,
FunctionSchema,
Location,
NativeFunction,
OperatorName,
@ -32,7 +32,7 @@ from torchgen.selective_build.selector import SelectiveBuilder
class TestCreateDerivative(unittest.TestCase):
def test_named_grads(self) -> None:
schema = torchgen.model.FunctionSchema.parse(
schema = FunctionSchema.parse(
"func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
@ -47,7 +47,7 @@ class TestCreateDerivative(unittest.TestCase):
def test_non_differentiable_output(self) -> None:
specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
schema = torchgen.model.FunctionSchema.parse(specification)
schema = FunctionSchema.parse(specification)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
_, differentiability_info = load_derivatives.create_differentiability_info(
@ -69,7 +69,7 @@ class TestCreateDerivative(unittest.TestCase):
)
def test_indexed_grads(self) -> None:
schema = torchgen.model.FunctionSchema.parse(
schema = FunctionSchema.parse(
"func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
@ -84,7 +84,7 @@ class TestCreateDerivative(unittest.TestCase):
def test_named_grads_and_indexed_grads(self) -> None:
specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
schema = torchgen.model.FunctionSchema.parse(specification)
schema = FunctionSchema.parse(specification)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
with self.assertRaisesRegex(
@ -112,7 +112,7 @@ class TestCreateDerivative(unittest.TestCase):
class TestGenAutogradFunctions(unittest.TestCase):
def test_non_differentiable_output_invalid_type(self) -> None:
specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
schema = torchgen.model.FunctionSchema.parse(specification)
schema = FunctionSchema.parse(specification)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
_, differentiability_info = load_derivatives.create_differentiability_info(
@ -141,7 +141,7 @@ class TestGenAutogradFunctions(unittest.TestCase):
def test_non_differentiable_output_output_differentiability(self) -> None:
specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y, Tensor z)"
schema = torchgen.model.FunctionSchema.parse(specification)
schema = FunctionSchema.parse(specification)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
_, differentiability_info = load_derivatives.create_differentiability_info(
@ -182,7 +182,7 @@ class TestGenAutogradFunctions(unittest.TestCase):
def test_register_bogus_dispatch_key(self) -> None:
specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
schema = torchgen.model.FunctionSchema.parse(specification)
schema = FunctionSchema.parse(specification)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
with self.assertRaisesRegex(
@ -213,17 +213,17 @@ class TestGenAutogradFunctions(unittest.TestCase):
class TestGenSchemaRegistration(unittest.TestCase):
def setUp(self) -> None:
self.selector = SelectiveBuilder.get_nop_selector()
self.custom_native_function, _ = torchgen.model.NativeFunction.from_yaml(
self.custom_native_function, _ = NativeFunction.from_yaml(
{"func": "custom::func() -> bool"},
loc=torchgen.model.Location(__file__, 1),
loc=Location(__file__, 1),
valid_tags=set(),
)
(
self.fragment_custom_native_function,
_,
) = torchgen.model.NativeFunction.from_yaml(
) = NativeFunction.from_yaml(
{"func": "quantized_decomposed::func() -> bool"},
loc=torchgen.model.Location(__file__, 1),
loc=Location(__file__, 1),
valid_tags=set(),
)
@ -285,9 +285,9 @@ TORCH_LIBRARY(custom, m) {
)
def test_3_namespaces_schema_registration_code_valid(self) -> None:
custom2_native_function, _ = torchgen.model.NativeFunction.from_yaml(
custom2_native_function, _ = NativeFunction.from_yaml(
{"func": "custom2::func() -> bool"},
loc=torchgen.model.Location(__file__, 1),
loc=Location(__file__, 1),
valid_tags=set(),
)
(
@ -320,7 +320,7 @@ class TestGenNativeFunctionDeclaration(unittest.TestCase):
def setUp(self) -> None:
self.op_1_native_function, op_1_backend_index = NativeFunction.from_yaml(
{"func": "op_1() -> bool", "dispatch": {"CPU": "kernel_1"}},
loc=torchgen.model.Location(__file__, 1),
loc=Location(__file__, 1),
valid_tags=set(),
)
self.op_2_native_function, op_2_backend_index = NativeFunction.from_yaml(
@ -328,11 +328,11 @@ class TestGenNativeFunctionDeclaration(unittest.TestCase):
"func": "op_2() -> bool",
"dispatch": {"CPU": "kernel_2", "QuantizedCPU": "custom::kernel_3"},
},
loc=torchgen.model.Location(__file__, 1),
loc=Location(__file__, 1),
valid_tags=set(),
)
backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = {
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = {
DispatchKey.CPU: {},
DispatchKey.QuantizedCPU: {},
}
@ -382,9 +382,9 @@ TORCH_API bool kernel_1();
# Test for native_function_generation
class TestNativeFunctionGeneratrion(unittest.TestCase):
def setUp(self) -> None:
self.native_functions: List[NativeFunction] = []
self.backend_indices: Dict[
DispatchKey, Dict[OperatorName, BackendMetadata]
self.native_functions: list[NativeFunction] = []
self.backend_indices: dict[
DispatchKey, dict[OperatorName, BackendMetadata]
] = defaultdict(dict)
yaml_entry = """
- func: op(Tensor self) -> Tensor
@ -405,7 +405,7 @@ class TestNativeFunctionGeneratrion(unittest.TestCase):
"dispatch": {"CPU": "kernel_1"},
"autogen": "op_2.out",
},
loc=torchgen.model.Location(__file__, 1),
loc=Location(__file__, 1),
valid_tags=set(),
)
BackendIndex.grow_index(self.backend_indices, two_returns_backend_index)
@ -442,8 +442,8 @@ class TestNativeFunctionGeneratrion(unittest.TestCase):
# Test for static_dispatch
class TestStaticDispatchGeneratrion(unittest.TestCase):
def setUp(self) -> None:
self.backend_indices: Dict[
DispatchKey, Dict[OperatorName, BackendMetadata]
self.backend_indices: dict[
DispatchKey, dict[OperatorName, BackendMetadata]
] = defaultdict(dict)
yaml_entry = """
- func: op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
@ -500,9 +500,9 @@ class TestStaticDispatchGeneratrion(unittest.TestCase):
# Represents the most basic NativeFunction. Use dataclasses.replace()
# to edit for use.
DEFAULT_NATIVE_FUNCTION, _ = torchgen.model.NativeFunction.from_yaml(
DEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml(
{"func": "func() -> bool"},
loc=torchgen.model.Location(__file__, 1),
loc=Location(__file__, 1),
valid_tags=set(),
)

View File

@ -1,4 +1,6 @@
from typing import Any, List
from __future__ import annotations
from typing import Any
from unittest import main, TestCase
from tools.alerts.create_alerts import filter_job_names, JobStatus
@ -38,7 +40,7 @@ MOCK_TEST_DATA = [
class TestGitHubPR(TestCase):
# Should fail when jobs are ? ? Fail Fail
def test_alert(self) -> None:
modified_data: List[Any] = [{}]
modified_data: list[Any] = [{}]
modified_data.append({})
modified_data.extend(MOCK_TEST_DATA)
status = JobStatus(JOB_NAME, modified_data)

View File

@ -1,6 +1,8 @@
from __future__ import annotations
import tempfile
import unittest
from typing import Any, Dict
from typing import Any
from unittest.mock import ANY, Mock, patch
import expecttest
@ -13,10 +15,11 @@ from torchgen.model import Location, NativeFunction
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import FileManager
SPACES = " "
def _get_native_function_from_yaml(yaml_obj: Dict[str, object]) -> NativeFunction:
def _get_native_function_from_yaml(yaml_obj: dict[str, object]) -> NativeFunction:
native_function, _ = NativeFunction.from_yaml(
yaml_obj,
loc=Location(__file__, 1),
@ -33,7 +36,7 @@ class TestComputeNativeFunctionStub(expecttest.TestCase):
"""
def _test_function_schema_generates_correct_kernel(
self, obj: Dict[str, Any], expected: str
self, obj: dict[str, Any], expected: str
) -> None:
func = _get_native_function_from_yaml(obj)

View File

@ -1,13 +1,13 @@
from __future__ import annotations
import os
import tempfile
import unittest
from typing import Dict
import yaml
from torchgen.executorch.model import ETKernelIndex, ETKernelKey
from torchgen.gen import LineLoader
from torchgen.gen_executorch import (
ComputeCodegenUnboxedKernels,
gen_functions_declarations,
@ -24,6 +24,7 @@ from torchgen.model import (
)
from torchgen.selective_build.selector import SelectiveBuilder
TEST_YAML = """
- func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -345,7 +346,7 @@ class TestGenFunctionsDeclarations(unittest.TestCase):
valid_tags=set(),
)
backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = {
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = {
DispatchKey.CPU: {},
DispatchKey.QuantizedCPU: {},
}

View File

@ -4,6 +4,7 @@ from torchgen.executorch.api.types import ExecutorchCppSignature
from torchgen.local import parametrize
from torchgen.model import Location, NativeFunction
DEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml(
{"func": "foo.out(Tensor input, *, Tensor(a!) out) -> Tensor(a!)"},
loc=Location(__file__, 1),

View File

@ -1,9 +1,10 @@
# Owner(s): ["module: codegen"]
from __future__ import annotations
import os
import tempfile
import unittest
from typing import Optional
import expecttest
@ -29,7 +30,7 @@ class TestGenBackendStubs(expecttest.TestCase):
run(fp.name, "", True)
def get_errors_from_gen_backend_stubs(
self, yaml_str: str, *, kernels_str: Optional[str] = None
self, yaml_str: str, *, kernels_str: str | None = None
) -> str:
with tempfile.NamedTemporaryFile(mode="w") as fp:
fp.write(yaml_str)

View File

@ -1,7 +1,7 @@
import unittest
from torchgen.selective_build.operator import * # noqa: F403
from torchgen.model import Location, NativeFunction
from torchgen.selective_build.operator import * # noqa: F403
from torchgen.selective_build.selector import (
combine_selective_builders,
SelectiveBuilder,
@ -9,7 +9,7 @@ from torchgen.selective_build.selector import (
class TestSelectiveBuild(unittest.TestCase):
def test_selective_build_operator(self):
def test_selective_build_operator(self) -> None:
op = SelectiveBuildOperator(
"aten::add.int",
is_root_operator=True,
@ -21,7 +21,7 @@ class TestSelectiveBuild(unittest.TestCase):
self.assertFalse(op.is_used_for_training)
self.assertFalse(op.include_all_overloads)
def test_selector_factory(self):
def test_selector_factory(self) -> None:
yaml_config_v1 = """
debug_info:
- model1@v100
@ -132,7 +132,7 @@ operators:
selector_legacy_v1.is_operator_selected_for_training("aten::add.float")
)
def test_operator_combine(self):
def test_operator_combine(self) -> None:
op1 = SelectiveBuildOperator(
"aten::add.int",
is_root_operator=True,
@ -177,7 +177,7 @@ operators:
self.assertRaises(Exception, gen_new_op)
def test_training_op_fetch(self):
def test_training_op_fetch(self) -> None:
yaml_config = """
operators:
aten::add.int:
@ -194,7 +194,7 @@ operators:
self.assertTrue(selector.is_operator_selected_for_training("aten::add.int"))
self.assertTrue(selector.is_operator_selected_for_training("aten::add"))
def test_kernel_dtypes(self):
def test_kernel_dtypes(self) -> None:
yaml_config = """
kernel_metadata:
add_kernel:
@ -221,7 +221,7 @@ kernel_metadata:
self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16"))
self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32"))
def test_merge_kernel_dtypes(self):
def test_merge_kernel_dtypes(self) -> None:
yaml_config1 = """
kernel_metadata:
add_kernel:
@ -266,7 +266,7 @@ kernel_metadata:
self.assertTrue(selector.is_kernel_dtype_selected("mul_kernel", "int8"))
self.assertFalse(selector.is_kernel_dtype_selected("mul_kernel", "int32"))
def test_all_kernel_dtypes_selected(self):
def test_all_kernel_dtypes_selected(self) -> None:
yaml_config = """
include_all_non_op_selectives: True
"""
@ -279,7 +279,7 @@ include_all_non_op_selectives: True
self.assertTrue(selector.is_kernel_dtype_selected("add1_kernel", "int32"))
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "float"))
def test_custom_namespace_selected_correctly(self):
def test_custom_namespace_selected_correctly(self) -> None:
yaml_config = """
operators:
aten::add.int:
@ -301,7 +301,7 @@ operators:
class TestExecuTorchSelectiveBuild(unittest.TestCase):
def test_et_kernel_selected(self):
def test_et_kernel_selected(self) -> None:
yaml_config = """
et_kernel_metadata:
aten::add.out:

View File

@ -1,10 +1,11 @@
from __future__ import annotations
import functools
import pathlib
import random
import sys
import unittest
from collections import defaultdict
from typing import Dict, List, Tuple
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
try:
@ -17,12 +18,12 @@ except ModuleNotFoundError:
sys.exit(1)
def gen_class_times(test_times: Dict[str, float]) -> Dict[str, Dict[str, float]]:
def gen_class_times(test_times: dict[str, float]) -> dict[str, dict[str, float]]:
return {k: {"class1": v} for k, v in test_times.items()}
class TestCalculateShards(unittest.TestCase):
tests: List[TestRun] = [
tests: list[TestRun] = [
TestRun("super_long_test"),
TestRun("long_test1"),
TestRun("long_test2"),
@ -36,7 +37,7 @@ class TestCalculateShards(unittest.TestCase):
TestRun("short_test5"),
]
test_times: Dict[str, float] = {
test_times: dict[str, float] = {
"super_long_test": 55,
"long_test1": 22,
"long_test2": 18,
@ -50,7 +51,7 @@ class TestCalculateShards(unittest.TestCase):
"short_test5": 0.01,
}
test_class_times: Dict[str, Dict[str, float]] = {
test_class_times: dict[str, dict[str, float]] = {
"super_long_test": {"class1": 55},
"long_test1": {"class1": 1, "class2": 21},
"long_test2": {"class1": 10, "class2": 8},
@ -66,8 +67,8 @@ class TestCalculateShards(unittest.TestCase):
def assert_shards_equal(
self,
expected_shards: List[Tuple[float, List[ShardedTest]]],
actual_shards: List[Tuple[float, List[ShardedTest]]],
expected_shards: list[tuple[float, list[ShardedTest]]],
actual_shards: list[tuple[float, list[ShardedTest]]],
) -> None:
for expected, actual in zip(expected_shards, actual_shards):
self.assertAlmostEqual(expected[0], actual[0])
@ -363,7 +364,7 @@ class TestCalculateShards(unittest.TestCase):
)
def test_split_shards(self) -> None:
test_times: Dict[str, float] = {"test1": THRESHOLD, "test2": THRESHOLD}
test_times: dict[str, float] = {"test1": THRESHOLD, "test2": THRESHOLD}
expected_shards = [
(600.0, [ShardedTest(test="test1", shard=1, num_shards=1, time=THRESHOLD)]),
(600.0, [ShardedTest(test="test2", shard=1, num_shards=1, time=THRESHOLD)]),
@ -438,7 +439,7 @@ class TestCalculateShards(unittest.TestCase):
tests = [TestRun(x) for x in test_names]
serial = [x for x in test_names if random.randint(0, 1) == 0]
has_times = [x for x in test_names if random.randint(0, 1) == 0]
random_times: Dict[str, float] = {
random_times: dict[str, float] = {
i: random.randint(0, THRESHOLD * 10) for i in has_times
}
sort_by_time = random.randint(0, 1) == 0
@ -456,7 +457,7 @@ class TestCalculateShards(unittest.TestCase):
max_diff = max(times) - min(times)
self.assertTrue(max_diff <= THRESHOLD + (num_tests - len(has_times)) * 60)
all_sharded_tests: Dict[str, List[ShardedTest]] = defaultdict(list)
all_sharded_tests: dict[str, list[ShardedTest]] = defaultdict(list)
for _, sharded_tests in shards:
for sharded_test in sharded_tests:
all_sharded_tests[sharded_test.name].append(sharded_test)

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import decimal
import inspect
import pathlib
import sys
import unittest
from typing import Any, Dict
from typing import Any
from unittest import mock
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
@ -81,9 +83,9 @@ class TestUploadStats(unittest.TestCase):
}
# Preserve the metric emitted
emitted_metric: Dict[str, Any] = {}
emitted_metric: dict[str, Any] = {}
def mock_put_item(Item: Dict[str, Any]) -> None:
def mock_put_item(Item: dict[str, Any]) -> None:
nonlocal emitted_metric
emitted_metric = Item
@ -115,9 +117,9 @@ class TestUploadStats(unittest.TestCase):
}
# Preserve the metric emitted
emitted_metric: Dict[str, Any] = {}
emitted_metric: dict[str, Any] = {}
def mock_put_item(Item: Dict[str, Any]) -> None:
def mock_put_item(Item: dict[str, Any]) -> None:
nonlocal emitted_metric
emitted_metric = Item
@ -151,9 +153,9 @@ class TestUploadStats(unittest.TestCase):
}
# Preserve the metric emitted
emitted_metric: Dict[str, Any] = {}
emitted_metric: dict[str, Any] = {}
def mock_put_item(Item: Dict[str, Any]) -> None:
def mock_put_item(Item: dict[str, Any]) -> None:
nonlocal emitted_metric
emitted_metric = Item
@ -187,9 +189,9 @@ class TestUploadStats(unittest.TestCase):
).start()
# Preserve the metric emitted
emitted_metric: Dict[str, Any] = {}
emitted_metric: dict[str, Any] = {}
def mock_put_item(Item: Dict[str, Any]) -> None:
def mock_put_item(Item: dict[str, Any]) -> None:
nonlocal emitted_metric
emitted_metric = Item
@ -208,7 +210,7 @@ class TestUploadStats(unittest.TestCase):
) -> None:
metric = {"some_number": 123}
emit_should_include: Dict[str, Any] = metric.copy()
emit_should_include: dict[str, Any] = metric.copy()
# Github Actions defaults some env vars to an empty string
default_val = ""
@ -220,9 +222,9 @@ class TestUploadStats(unittest.TestCase):
).start()
# Preserve the metric emitted
emitted_metric: Dict[str, Any] = {}
emitted_metric: dict[str, Any] = {}
def mock_put_item(Item: Dict[str, Any]) -> None:
def mock_put_item(Item: dict[str, Any]) -> None:
nonlocal emitted_metric
emitted_metric = Item
@ -264,7 +266,7 @@ class TestUploadStats(unittest.TestCase):
put_item_invoked = False
def mock_put_item(Item: Dict[str, Any]) -> None:
def mock_put_item(Item: dict[str, Any]) -> None:
nonlocal put_item_invoked
put_item_invoked = True
@ -289,7 +291,7 @@ class TestUploadStats(unittest.TestCase):
put_item_invoked = False
def mock_put_item(Item: Dict[str, Any]) -> None:
def mock_put_item(Item: dict[str, Any]) -> None:
nonlocal put_item_invoked
put_item_invoked = True

View File

@ -3,6 +3,7 @@ import unittest
from tools.stats.upload_test_stats import get_tests, summarize_test_cases
IN_CI = os.environ.get("CI")

View File

@ -3,6 +3,7 @@ import unittest
from tools.gen_vulkan_spv import DEFAULT_ENV, SPVGenerator
####################
# Data for testing #
####################