diff --git a/tools/alerts/create_alerts.py b/tools/alerts/create_alerts.py index 33e595f98dad..97607e07fa0a 100644 --- a/tools/alerts/create_alerts.py +++ b/tools/alerts/create_alerts.py @@ -1,16 +1,19 @@ #!/usr/bin/env python3 +from __future__ import annotations + import argparse import json import os import re from collections import defaultdict from difflib import SequenceMatcher -from typing import Any, Dict, List, Set, Tuple +from typing import Any import requests from setuptools import distutils # type: ignore[import] + ALL_SKIPPED_THRESHOLD = 100 SIMILARITY_THRESHOLD = 0.75 FAILURE_CHAIN_THRESHOLD = 2 @@ -65,14 +68,14 @@ DISABLED_ALERTS = [ class JobStatus: job_name: str = "" - jobs: List[Any] = [] + jobs: list[Any] = [] current_status: Any = None - job_statuses: List[Any] = [] - filtered_statuses: List[Any] = [] - failure_chain: List[Any] = [] - flaky_jobs: List[Any] = [] + job_statuses: list[Any] = [] + filtered_statuses: list[Any] = [] + failure_chain: list[Any] = [] + flaky_jobs: list[Any] = [] - def __init__(self, job_name: str, job_statuses: List[Any]): + def __init__(self, job_name: str, job_statuses: list[Any]) -> None: self.job_name = job_name self.job_statuses = job_statuses @@ -93,7 +96,7 @@ class JobStatus: return status return None - def get_unique_failures(self, jobs: List[Any]) -> Dict[str, List[Any]]: + def get_unique_failures(self, jobs: list[Any]) -> dict[str, list[Any]]: """ Returns list of jobs grouped by failureCaptures from the input list """ @@ -120,7 +123,7 @@ class JobStatus: return failures # A flaky job is if it's the only job that has that failureCapture and is not the most recent job - def get_flaky_jobs(self) -> List[Any]: + def get_flaky_jobs(self) -> list[Any]: unique_failures = self.get_unique_failures(self.filtered_statuses) flaky_jobs = [] for failure in unique_failures: @@ -134,7 +137,7 @@ class JobStatus: # The most recent failure chain is an array of jobs that have the same-ish failures. # A success in the middle of the chain will terminate the chain. - def get_most_recent_failure_chain(self) -> List[Any]: + def get_most_recent_failure_chain(self) -> list[Any]: failures = [] found_most_recent_failure = False @@ -178,7 +181,7 @@ def fetch_hud_data(repo: str, branch: str) -> Any: # Creates a Dict of Job Name -> [JobData]. Essentially a Column in HUD -def map_job_data(jobNames: Any, shaGrid: Any) -> Dict[str, Any]: +def map_job_data(jobNames: Any, shaGrid: Any) -> dict[str, Any]: jobData = defaultdict(list) for sha in shaGrid: for ind, job in enumerate(sha["jobs"]): @@ -196,13 +199,13 @@ def is_job_skipped(job: Any) -> bool: return conclusion in (NEUTRAL, SKIPPED) or conclusion is None -def get_failed_jobs(job_data: List[Any]) -> List[Any]: +def get_failed_jobs(job_data: list[Any]) -> list[Any]: return [job for job in job_data if job["conclusion"] == "failure"] def classify_jobs( - all_job_names: List[str], sha_grid: Any, filtered_jobs_names: Set[str] -) -> Tuple[List[JobStatus], List[Any]]: + all_job_names: list[str], sha_grid: Any, filtered_jobs_names: set[str] +) -> tuple[list[JobStatus], list[Any]]: """ Creates Job Statuses which has the logic for if need to alert or if there's flaky jobs. Classifies jobs into jobs to alert on and flaky jobs. @@ -212,7 +215,7 @@ def classify_jobs( :return: """ job_data = map_job_data(all_job_names, sha_grid) - job_statuses: List[JobStatus] = [] + job_statuses: list[JobStatus] = [] for job in job_data: job_statuses.append(JobStatus(job, job_data[job])) @@ -230,7 +233,7 @@ def classify_jobs( # filter job names that don't match the regex -def filter_job_names(job_names: List[str], job_name_regex: str) -> List[str]: +def filter_job_names(job_names: list[str], job_name_regex: str) -> list[str]: if job_name_regex: return [ job_name for job_name in job_names if re.match(job_name_regex, job_name) @@ -240,7 +243,7 @@ def filter_job_names(job_names: List[str], job_name_regex: str) -> List[str]: def get_recurrently_failing_jobs_alerts( repo: str, branch: str, job_name_regex: str -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: job_names, sha_grid = fetch_hud_data(repo=repo, branch=branch) filtered_job_names = set(filter_job_names(job_names, job_name_regex)) diff --git a/tools/autograd/gen_annotated_fn_args.py b/tools/autograd/gen_annotated_fn_args.py index f935a9adf4c6..c32779b3a282 100644 --- a/tools/autograd/gen_annotated_fn_args.py +++ b/tools/autograd/gen_annotated_fn_args.py @@ -14,18 +14,17 @@ generated. In the full build system, OUTPUT_DIR is torch/testing/_internal/generated """ +from __future__ import annotations + import argparse import os import textwrap from collections import defaultdict - -from typing import Any, Dict, List, Sequence +from typing import Any, Sequence, TYPE_CHECKING import torchgen.api.python as python from torchgen.context import with_native_function - from torchgen.gen import parse_native_yaml -from torchgen.model import Argument, BaseOperatorName, NativeFunction from torchgen.utils import FileManager from .gen_python_functions import ( @@ -39,6 +38,10 @@ from .gen_python_functions import ( ) +if TYPE_CHECKING: + from torchgen.model import Argument, BaseOperatorName, NativeFunction + + def gen_annotated( native_yaml_path: str, tags_yaml_path: str, out: str, autograd_dir: str ) -> None: @@ -53,9 +56,9 @@ def gen_annotated( (is_py_fft_function, "torch._C._fft"), (is_py_variable_method, "torch.Tensor"), ) - annotated_args: List[str] = [] + annotated_args: list[str] = [] for pred, namespace in mappings: - groups: Dict[BaseOperatorName, List[NativeFunction]] = defaultdict(list) + groups: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list) for f in native_functions: if not should_generate_py_binding(f) or not pred(f): continue @@ -77,7 +80,7 @@ def gen_annotated( @with_native_function def gen_annotated_args(f: NativeFunction) -> str: - def _get_kwargs_func_exclusion_list() -> List[str]: + def _get_kwargs_func_exclusion_list() -> list[str]: # functions that currently don't work with kwargs in test_overrides.py return [ "diagonal", @@ -87,12 +90,12 @@ def gen_annotated_args(f: NativeFunction) -> str: ] def _add_out_arg( - out_args: List[Dict[str, Any]], args: Sequence[Argument], *, is_kwarg_only: bool + out_args: list[dict[str, Any]], args: Sequence[Argument], *, is_kwarg_only: bool ) -> None: for arg in args: if arg.default is not None: continue - out_arg: Dict[str, Any] = {} + out_arg: dict[str, Any] = {} out_arg["is_kwarg_only"] = str(is_kwarg_only) out_arg["name"] = arg.name out_arg["simple_type"] = python.argument_type_str( @@ -103,7 +106,7 @@ def gen_annotated_args(f: NativeFunction) -> str: out_arg["size"] = size_t out_args.append(out_arg) - out_args: List[Dict[str, Any]] = [] + out_args: list[dict[str, Any]] = [] _add_out_arg(out_args, f.func.arguments.flat_positional, is_kwarg_only=False) if f"{f.func.name.name}" not in _get_kwargs_func_exclusion_list(): _add_out_arg(out_args, f.func.arguments.flat_kwarg_only, is_kwarg_only=True) diff --git a/tools/autograd/gen_autograd.py b/tools/autograd/gen_autograd.py index 0d4aa91d3fad..f6e7be149ad6 100644 --- a/tools/autograd/gen_autograd.py +++ b/tools/autograd/gen_autograd.py @@ -22,9 +22,10 @@ torch/csrc/autograd/generated/ # gen_python_functions.py: generates Python bindings to THPVariable # +from __future__ import annotations + import argparse import os -from typing import List from torchgen.api import cpp from torchgen.api.autograd import ( @@ -69,7 +70,7 @@ def gen_autograd( ), key=lambda f: cpp.name(f.func), ) - fns_with_diff_infos: List[ + fns_with_diff_infos: list[ NativeFunctionWithDifferentiabilityInfo ] = match_differentiability_info(fns, differentiability_infos) diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py index a1a432e2b3d8..89974fd1ddb6 100644 --- a/tools/autograd/gen_autograd_functions.py +++ b/tools/autograd/gen_autograd_functions.py @@ -4,7 +4,10 @@ # Functions.h/cpp: subclasses of autograd::Node # python_functions.h/cpp: Python bindings for the above classes # -from typing import Dict, List, Sequence, Tuple + +from __future__ import annotations + +from typing import Sequence from torchgen.api.autograd import ( Derivative, @@ -43,6 +46,7 @@ from torchgen.utils import FileManager from .gen_inplace_or_view_type import VIEW_FUNCTIONS + FUNCTION_DECLARATION = CodeTemplate( """\ #ifdef _WIN32 @@ -443,8 +447,8 @@ UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS def get_infos_with_derivatives_list( - differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]] -) -> List[DifferentiabilityInfo]: + differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]] +) -> list[DifferentiabilityInfo]: diff_info_list = [ info for diffinfo_dict in differentiability_infos.values() @@ -456,7 +460,7 @@ def get_infos_with_derivatives_list( def gen_autograd_functions_lib( out: str, - differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], + differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]], template_path: str, ) -> None: """Functions.h and Functions.cpp body @@ -490,7 +494,7 @@ def gen_autograd_functions_lib( def gen_autograd_functions_python( out: str, - differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], + differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]], template_path: str, ) -> None: fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) @@ -536,17 +540,17 @@ def gen_autograd_functions_python( def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str: - saved_variables: List[str] = [] - release_variables: List[str] = [] - saved_list_sizes: List[str] = [] - unpack: List[str] = [] - asserts: List[str] = [] - compute_index_ranges: List[str] = [] - getter_definitions: List[str] = [] - py_getsetdef_structs: List[str] = [] - compiled_args: List[str] = [] - apply_with_saved_before: List[str] = [] - apply_with_saved_after: List[str] = [] + saved_variables: list[str] = [] + release_variables: list[str] = [] + saved_list_sizes: list[str] = [] + unpack: list[str] = [] + asserts: list[str] = [] + compute_index_ranges: list[str] = [] + getter_definitions: list[str] = [] + py_getsetdef_structs: list[str] = [] + compiled_args: list[str] = [] + apply_with_saved_before: list[str] = [] + apply_with_saved_after: list[str] = [] for arg in info.args_with_derivatives: if arg.type in TENSOR_LIST_LIKE_CTYPES: @@ -807,7 +811,7 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { else: will_release_variables = "" - body: List[str] = [] + body: list[str] = [] if uses_single_grad(info): body.append("const auto& grad = grads[0];") @@ -821,7 +825,7 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { def emit_derivative( derivative: Derivative, args_with_derivatives: Sequence[Binding], - ) -> Tuple[bool, str]: + ) -> tuple[bool, str]: formula = derivative.formula var_names = derivative.var_names if len(var_names) == 1: @@ -857,7 +861,7 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { else: grad_input_mask = "" idx_ranges = ", ".join(f"{n}_ix" for n in var_names) - copy_ranges: List[str] = [] + copy_ranges: list[str] = [] for i, n in enumerate(var_names): copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i)) return False, DERIVATIVE_MULTI.substitute( diff --git a/tools/autograd/gen_inplace_or_view_type.py b/tools/autograd/gen_inplace_or_view_type.py index d1392f5407c0..7ef97fd0fb7d 100644 --- a/tools/autograd/gen_inplace_or_view_type.py +++ b/tools/autograd/gen_inplace_or_view_type.py @@ -4,7 +4,7 @@ # if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp # The fallback is expected to mimick this codegen, so we should keep the two in sync. -from typing import Dict, List, Optional, Tuple +from __future__ import annotations from torchgen.api import cpp from torchgen.api.autograd import ( @@ -24,8 +24,7 @@ from torchgen.api.types import ( OptionalCType, symIntArrayRefT, SymIntT, - # See Note [Nested Arg Types] - tensorT, + tensorT, # See Note [Nested Arg Types] ) from torchgen.code_template import CodeTemplate from torchgen.context import with_native_function @@ -46,6 +45,7 @@ from .gen_trace_type import ( type_wrapper_name, ) + # See NOTE [ Autograd View Variables ] in variable.h for details. # If you update list VIEW_FUNCTIONS or RETURNS_VIEWS_OF_INPUT, # you **MUST** also update the public list of view ops accordingly in @@ -281,7 +281,7 @@ def inverse_view_name(f: NativeFunction) -> str: return f"{copy_variant}{overload}_inverse" -def extract_bindings(f: NativeFunction) -> List[Binding]: +def extract_bindings(f: NativeFunction) -> list[Binding]: return [ r for a in f.func.schema_order_arguments() @@ -297,9 +297,9 @@ def extract_bindings(f: NativeFunction) -> List[Binding]: @with_native_function -def unpack_args(f: NativeFunction) -> Tuple[List[str], List[Binding]]: - body: List[str] = [] - unpacked_bindings: List[Binding] = [] +def unpack_args(f: NativeFunction) -> tuple[list[str], list[Binding]]: + body: list[str] = [] + unpacked_bindings: list[Binding] = [] for i, binding in enumerate(extract_bindings(f)): assert not isinstance(binding.argument, SelfArgument) @@ -338,7 +338,7 @@ def get_base_name(f: NativeFunction) -> str: return f.func.name.name.base # TODO: should be str(f.func.name.name)? -def get_view_info(f: NativeFunction) -> Optional[str]: +def get_view_info(f: NativeFunction) -> str | None: base_name = get_base_name(f) view_info = VIEW_FUNCTIONS.get(base_name, None) if view_info is None and base_name in RETURNS_VIEWS_OF_INPUT: @@ -347,7 +347,7 @@ def get_view_info(f: NativeFunction) -> Optional[str]: def emit_view_func( - f: NativeFunction, bindings: List[Binding], view_idx: Optional[str] = None + f: NativeFunction, bindings: list[Binding], view_idx: str | None = None ) -> str: """Generate an additional lambda function to recover views in backward when as_strided is not supported. See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details. @@ -355,8 +355,8 @@ def emit_view_func( # TODO: Clean this logic up if we get rid of reverse view funcs or reify them. input_base = "input_base" replay_view_func = "" - updated_args: List[str] = [] - known_view_arg_simple_types: List[CType] = [ + updated_args: list[str] = [] + known_view_arg_simple_types: list[CType] = [ BaseCType(longT), OptionalCType(BaseCType(longT)), BaseCType(SymIntT), @@ -448,7 +448,7 @@ def emit_view_func( def emit_view_body( fn: NativeFunctionWithDifferentiabilityInfo, var: str -) -> Tuple[str, str]: +) -> tuple[str, str]: # See NOTE [ Autograd View Variables ] in variable.h for details. f = fn.func base_name = get_base_name(f) @@ -523,9 +523,9 @@ def modifies_arguments(f: NativeFunction) -> bool: @with_native_function_with_differentiability_info -def emit_inplace_or_view_body(fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]: +def emit_inplace_or_view_body(fn: NativeFunctionWithDifferentiabilityInfo) -> list[str]: f = fn.func - inplace_view_body: List[str] = [] + inplace_view_body: list[str] = [] dispatcher_sig = DispatcherSignature.from_schema(f.func) dispatcher_exprs = dispatcher_sig.exprs() @@ -584,7 +584,7 @@ def gen_formals(f: NativeFunction) -> str: @with_native_function_with_differentiability_info def inplace_or_view_method_definition( fn: NativeFunctionWithDifferentiabilityInfo, -) -> Optional[str]: +) -> str | None: f = fn.func if get_view_info(f) is None and ( # For functions that modify their inputs but don't return them, @@ -605,7 +605,7 @@ def inplace_or_view_method_definition( @with_native_function_with_differentiability_info def inplace_or_view_method_registration( fn: NativeFunctionWithDifferentiabilityInfo, -) -> Optional[str]: +) -> str | None: f = fn.func if get_view_info(f) is None and ( not modifies_arguments(f) or len(f.func.returns) == 0 @@ -626,7 +626,7 @@ def use_derived(fn: NativeFunctionWithDifferentiabilityInfo) -> bool: def gen_inplace_or_view_type_env( fn: NativeFunctionWithDifferentiabilityInfo, -) -> Dict[str, List[str]]: +) -> dict[str, list[str]]: definition = inplace_or_view_method_definition(fn) registration = inplace_or_view_method_registration(fn) @@ -649,7 +649,7 @@ def gen_inplace_or_view_type( out: str, native_yaml_path: str, tags_yaml_path: str, - fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo], + fns_with_infos: list[NativeFunctionWithDifferentiabilityInfo], template_path: str, ) -> None: # NOTE: see Note [Sharded File] at the top of the VariableType.cpp diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index fb37354b2fa8..44453306a0ec 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -31,11 +31,12 @@ # message, but use what's there # +from __future__ import annotations + import itertools import re from collections import defaultdict - -from typing import Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple +from typing import Callable, Iterable, Sequence import yaml @@ -56,7 +57,6 @@ from torchgen.api.python import ( signature_from_schema, structseq_fieldnames, ) - from torchgen.code_template import CodeTemplate from torchgen.context import with_native_function from torchgen.gen import cpp_string, parse_native_yaml, parse_tags_yaml @@ -75,6 +75,7 @@ from torchgen.yaml_utils import YamlLoader from .gen_inplace_or_view_type import is_tensor_list_type from .gen_trace_type import should_trace + # # declarations blocklist # We skip codegen for these functions, for various reasons. @@ -369,7 +370,7 @@ def gen( valid_tags = parse_tags_yaml(tags_yaml_path) - def gen_tags_enum() -> Dict[str, str]: + def gen_tags_enum() -> dict[str, str]: return { "enum_of_valid_tags": ( "".join( @@ -384,9 +385,9 @@ def gen( def group_filter_overloads( pairs: Sequence[PythonSignatureNativeFunctionPair], pred: Callable[[NativeFunction], bool], -) -> Dict[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]: - grouped: Dict[ - BaseOperatorName, List[PythonSignatureNativeFunctionPair] +) -> dict[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]: + grouped: dict[ + BaseOperatorName, list[PythonSignatureNativeFunctionPair] ] = defaultdict(list) for pair in pairs: if pred(pair.function): @@ -398,17 +399,17 @@ def create_python_bindings( fm: FileManager, pairs: Sequence[PythonSignatureNativeFunctionPair], pred: Callable[[NativeFunction], bool], - module: Optional[str], + module: str | None, filename: str, *, method: bool, symint: bool = True, ) -> None: """Generates Python bindings to ATen functions""" - py_methods: List[str] = [] - ops_headers: List[str] = [] - py_method_defs: List[str] = [] - py_forwards: List[str] = [] + py_methods: list[str] = [] + ops_headers: list[str] = [] + py_method_defs: list[str] = [] + py_forwards: list[str] = [] grouped = group_filter_overloads(pairs, pred) @@ -445,8 +446,8 @@ def create_python_return_type_bindings( Generate function to initialize and return named tuple for native functions which returns named tuple and registration invocations in `python_return_types.cpp`. """ - py_return_types_definition: List[str] = [] - py_return_types_registrations: List[str] = [] + py_return_types_definition: list[str] = [] + py_return_types_registrations: list[str] = [] grouped = group_filter_overloads(pairs, pred) @@ -484,7 +485,7 @@ def create_python_return_type_bindings_header( Generate function to initialize and return named tuple for native functions which returns named tuple and relevant entry for the map in `python_return_types.cpp`. """ - py_return_types_declarations: List[str] = [] + py_return_types_declarations: list[str] = [] grouped = group_filter_overloads(pairs, pred) @@ -510,7 +511,7 @@ def create_python_bindings_sharded( fm: FileManager, pairs: Sequence[PythonSignatureNativeFunctionPair], pred: Callable[[NativeFunction], bool], - module: Optional[str], + module: str | None, filename: str, *, method: bool, @@ -521,13 +522,13 @@ def create_python_bindings_sharded( grouped = group_filter_overloads(pairs, pred) def key_func( - kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]] + kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]] ) -> str: return kv[0].base def env_func( - kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]] - ) -> Dict[str, List[str]]: + kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]] + ) -> dict[str, list[str]]: name, fn_pairs = kv return { "ops_headers": [f"#include "], @@ -553,7 +554,7 @@ def create_python_bindings_sharded( def load_signatures( - native_functions: List[NativeFunction], + native_functions: list[NativeFunction], deprecated_yaml_path: str, *, method: bool, @@ -580,19 +581,19 @@ def load_deprecated_signatures( *, method: bool, pyi: bool, -) -> List[PythonSignatureNativeFunctionPair]: +) -> list[PythonSignatureNativeFunctionPair]: # The deprecated.yaml doesn't have complete type information, we need # find and leverage the original ATen signature (to which it delegates # the call) to generate the full python signature. # We join the deprecated and the original signatures using type-only form. # group the original ATen signatures by name - grouped: Dict[str, List[PythonSignatureNativeFunctionPair]] = defaultdict(list) + grouped: dict[str, list[PythonSignatureNativeFunctionPair]] = defaultdict(list) for pair in pairs: grouped[pair.signature.name].append(pair) # find matching original signatures for each deprecated signature - results: List[PythonSignatureNativeFunctionPair] = [] + results: list[PythonSignatureNativeFunctionPair] = [] with open(deprecated_yaml_path) as f: deprecated_defs = yaml.load(f, Loader=YamlLoader) @@ -701,15 +702,15 @@ def gen_structseq_typename_key(f: NativeFunction) -> str: def emit_structseq_call( overloads: Sequence[PythonSignatureNativeFunctionPair], -) -> Tuple[List[str], Dict[str, str]]: +) -> tuple[list[str], dict[str, str]]: """ Generate block of named tuple type def inits, and add typeref snippets to declarations that use them """ - typenames: Dict[ + typenames: dict[ str, str ] = {} # map from unique name + field name lists to typedef name - typedefs: List[str] = [] # typedef declarations and init code + typedefs: list[str] = [] # typedef declarations and init code for overload in overloads: fieldnames = structseq_fieldnames(overload.function.func.returns) @@ -732,17 +733,17 @@ static PyTypeObject* {typename} = generated::get_{name}_structseq();""" def generate_return_type_definition_and_registrations( overloads: Sequence[PythonSignatureNativeFunctionPair], -) -> Tuple[List[str], List[str]]: +) -> tuple[list[str], list[str]]: """ Generate block of function in `python_return_types.cpp` to initialize and return named tuple for a native function which returns named tuple and registration invocations in same file. """ - typenames: Dict[ + typenames: dict[ str, str ] = {} # map from unique name + field name lists to typedef name - definitions: List[str] = [] # function definition to register the typedef - registrations: List[str] = [] # register call for the typedef + definitions: list[str] = [] # function definition to register the typedef + registrations: list[str] = [] # register call for the typedef for overload in overloads: fieldnames = structseq_fieldnames(overload.function.func.returns) @@ -783,15 +784,15 @@ PyTypeObject* get_{name}_structseq() {{ def generate_return_type_declarations( overloads: Sequence[PythonSignatureNativeFunctionPair], -) -> List[str]: +) -> list[str]: """ Generate block of function declarations in `python_return_types.h` to initialize and return named tuple for a native function. """ - typenames: Dict[ + typenames: dict[ str, str ] = {} # map from unique name + field name lists to typedef name - declarations: List[str] = [] # function declaration to register the typedef + declarations: list[str] = [] # function declaration to register the typedef for overload in overloads: fieldnames = structseq_fieldnames(overload.function.func.returns) @@ -891,7 +892,7 @@ static PyObject * ${pycname}(PyObject* self_, PyObject* args) def method_impl( name: BaseOperatorName, - module: Optional[str], + module: str | None, overloads: Sequence[PythonSignatureNativeFunctionPair], *, method: bool, @@ -918,8 +919,8 @@ def method_impl( overloads, symint=symint ) is_singleton = len(grouped_overloads) == 1 - signatures: List[str] = [] - dispatch: List[str] = [] + signatures: list[str] = [] + dispatch: list[str] = [] for overload_index, overload in enumerate(grouped_overloads): signature = overload.signature.signature_str(symint=symint) signatures.append(f"{cpp_string(str(signature))},") @@ -959,7 +960,7 @@ def method_impl( def gen_has_torch_function_check( - name: BaseOperatorName, module: Optional[str], *, noarg: bool, method: bool + name: BaseOperatorName, module: str | None, *, noarg: bool, method: bool ) -> str: if noarg: if method: @@ -1007,7 +1008,7 @@ if (_r.isNone(${out_idx})) { def emit_dispatch_case( overload: PythonSignatureGroup, - structseq_typenames: Dict[str, str], + structseq_typenames: dict[str, str], *, symint: bool = True, ) -> str: @@ -1050,7 +1051,7 @@ def forward_decls( overloads: Sequence[PythonSignatureNativeFunctionPair], *, method: bool, -) -> Tuple[str, ...]: +) -> tuple[str, ...]: if method: return () @@ -1078,7 +1079,7 @@ static PyObject * {pycname}(PyObject* self_, PyObject* args, PyObject* kwargs); def method_def( name: BaseOperatorName, - module: Optional[str], + module: str | None, overloads: Sequence[PythonSignatureNativeFunctionPair], *, method: bool, @@ -1114,8 +1115,8 @@ def method_def( def group_overloads( overloads: Sequence[PythonSignatureNativeFunctionPair], *, symint: bool = True ) -> Sequence[PythonSignatureGroup]: - bases: Dict[str, PythonSignatureNativeFunctionPair] = {} - outplaces: Dict[str, PythonSignatureNativeFunctionPair] = {} + bases: dict[str, PythonSignatureNativeFunctionPair] = {} + outplaces: dict[str, PythonSignatureNativeFunctionPair] = {} # first group by signature ignoring out arguments for overload in overloads: @@ -1137,7 +1138,7 @@ def group_overloads( for sig, out in outplaces.items(): if sig not in bases: - candidates: List[str] = [] + candidates: list[str] = [] for overload in overloads: if ( str(overload.function.func.name.name) @@ -1268,7 +1269,7 @@ def sort_overloads( ) # Construct the relation graph - larger_than: Dict[int, Set[int]] = defaultdict(set) + larger_than: dict[int, set[int]] = defaultdict(set) for i1, overload1 in enumerate(grouped_overloads): for i2, overload2 in enumerate(grouped_overloads): if is_smaller(overload1.signature, overload2.signature): @@ -1279,7 +1280,7 @@ def sort_overloads( # Use a topological sort to sort overloads according to the partial order. N = len(grouped_overloads) - sorted_ids: List[int] = list(filter(lambda x: x not in larger_than, range(N))) + sorted_ids: list[int] = list(filter(lambda x: x not in larger_than, range(N))) for idx in range(N): # The size of sorted_ids will grow to N eventually. @@ -1304,7 +1305,7 @@ def sort_overloads( def emit_single_dispatch( ps: PythonSignature, f: NativeFunction, - structseq_typenames: Dict[str, str], + structseq_typenames: dict[str, str], *, symint: bool = True, ) -> str: diff --git a/tools/autograd/gen_trace_type.py b/tools/autograd/gen_trace_type.py index 9d9144bce04c..75ca13852a21 100644 --- a/tools/autograd/gen_trace_type.py +++ b/tools/autograd/gen_trace_type.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import itertools -from typing import Dict, List, Sequence, Union +from typing import Sequence from torchgen.api import cpp from torchgen.api.types import DispatcherSignature @@ -8,6 +10,7 @@ from torchgen.context import with_native_function from torchgen.model import Argument, NativeFunction, SchemaKind, TensorOptionsArguments from torchgen.utils import FileManager + # Note [Manual Backend kernels] # For these ops, we want to manually register to dispatch key Backend and # skip codegen-ed registeration to all keys before Backend. @@ -136,9 +139,7 @@ ADD_TRACE_INPUT = CodeTemplate("""jit::tracer::addInputs(node, "${name}", ${inpu def format_trace_inputs(f: NativeFunction) -> str: - def dispatch_trace_input( - arg: Union[Argument, TensorOptionsArguments] - ) -> Sequence[str]: + def dispatch_trace_input(arg: Argument | TensorOptionsArguments) -> Sequence[str]: if isinstance(arg, TensorOptionsArguments): name = "options" return [ @@ -156,7 +157,7 @@ def format_trace_inputs(f: NativeFunction) -> str: else: return [ADD_TRACE_INPUT.substitute(name=name, input=name)] - args: List[Union[Argument, TensorOptionsArguments]] = list( + args: list[Argument | TensorOptionsArguments] = list( f.func.schema_order_arguments() ) @@ -399,8 +400,8 @@ ${assign_return_values}at::_ops::${unambiguous_name}::redispatch(${unpacked_args ) -def emit_trace_body(f: NativeFunction) -> List[str]: - trace_body: List[str] = [] +def emit_trace_body(f: NativeFunction) -> list[str]: + trace_body: list[str] = [] trace_body.append(format_prerecord_trace(f)) @@ -503,7 +504,7 @@ def method_registration(f: NativeFunction) -> str: ) -def gen_trace_type_func(fn: NativeFunction) -> Dict[str, List[str]]: +def gen_trace_type_func(fn: NativeFunction) -> dict[str, list[str]]: return { "ops_headers": [f"#include "], "trace_method_definitions": [method_definition(fn)], @@ -512,7 +513,7 @@ def gen_trace_type_func(fn: NativeFunction) -> Dict[str, List[str]]: def gen_trace_type( - out: str, native_functions: List[NativeFunction], template_path: str + out: str, native_functions: list[NativeFunction], template_path: str ) -> None: # NOTE: see Note [Sharded File] at the top of the VariableType.cpp # template regarding sharding of the generated files. diff --git a/tools/autograd/gen_variable_factories.py b/tools/autograd/gen_variable_factories.py index d7cf4a342328..9e1bb7e8220a 100644 --- a/tools/autograd/gen_variable_factories.py +++ b/tools/autograd/gen_variable_factories.py @@ -2,18 +2,19 @@ # # This writes one file: variable_factories.h +from __future__ import annotations + import re -from typing import List, Optional import torchgen.api.python as python from torchgen.api import cpp - from torchgen.api.types import CppSignatureGroup from torchgen.context import with_native_function from torchgen.gen import parse_native_yaml from torchgen.model import NativeFunction, TensorOptionsArguments, Variant from torchgen.utils import FileManager, mapMaybe + OPTIONAL_TYPE_PATTERN = re.compile(r"c10::optional<(.+)>") TYPE_PATTERN = re.compile(r"(?:const\s+)?([A-Z]\w+)") @@ -69,7 +70,7 @@ def is_factory_function(f: NativeFunction) -> bool: @with_native_function -def process_function(f: NativeFunction) -> Optional[str]: +def process_function(f: NativeFunction) -> str | None: name = cpp.name(f.func) has_tensor_options = python.has_tensor_options(f) is_factory = has_tensor_options or name.endswith("_like") @@ -83,8 +84,8 @@ def process_function(f: NativeFunction) -> Optional[str]: sigs.append(cpp_sigs.symint_signature) r = "" for sig in sigs: - formals: List[str] = [] - exprs: List[str] = [] + formals: list[str] = [] + exprs: list[str] = [] requires_grad = "false" for arg in sig.arguments(): qualified_type = fully_qualified_type(arg.type) diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 93cadc2e4636..b28d5a1073fe 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -25,8 +25,11 @@ # which will in turn dispatch back to VariableType for its # differentiable subcomponents. # + +from __future__ import annotations + import re -from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import Callable, Sequence from torchgen.api import cpp from torchgen.api.autograd import ( @@ -38,7 +41,6 @@ from torchgen.api.autograd import ( NativeFunctionWithDifferentiabilityInfo, SavedAttribute, ) - from torchgen.api.types import ( ArrayRefCType, BaseCppType, @@ -103,6 +105,7 @@ from .gen_trace_type import ( type_wrapper_name, ) + # We don't set or modify grad_fn on these methods. Generally, they return # tensors that have requires_grad=False. In-place functions listed here will # not examine or modify requires_grad or grad_fn. @@ -837,9 +840,9 @@ def gen_variable_type( out: str, native_yaml_path: str, tags_yaml_path: str, - fns_with_diff_infos: List[NativeFunctionWithDifferentiabilityInfo], + fns_with_diff_infos: list[NativeFunctionWithDifferentiabilityInfo], template_path: str, - used_keys: Set[str], + used_keys: set[str], ) -> None: """VariableType.h and VariableType.cpp body @@ -858,8 +861,8 @@ def gen_variable_type( # helper that generates a TORCH_LIBRARY_IMPL macro for each # dispatch key that appears in derivatives.yaml - def wrapper_registrations(used_keys: Set[str]) -> str: - library_impl_macro_list: List[str] = [] + def wrapper_registrations(used_keys: set[str]) -> str: + library_impl_macro_list: list[str] = [] for key in sorted(used_keys): dispatch_key = key if key == "Default": @@ -926,7 +929,7 @@ def gen_wrapper_registration(f: NativeFunction, key: str = "Default") -> str: def gen_variable_type_func( fn: NativeFunctionWithDifferentiabilityInfo, -) -> Dict[str, List[str]]: +) -> dict[str, list[str]]: f = fn.func result = {} with native_function_manager(f): @@ -1034,7 +1037,7 @@ _foreach_ops_with_different_arity = { @with_native_function_with_differentiability_info_and_key def emit_body( fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default" -) -> List[str]: +) -> list[str]: assert dispatch_strategy(fn) == "use_derived" f = fn.func info = fn.info[key] if fn.info else None @@ -1050,8 +1053,8 @@ def emit_body( is_foreach = name.startswith("_foreach") is_inplace_foreach = is_foreach and inplace if is_inplace_foreach: - inplace_foreacharg2refarg: Dict[Argument, Argument] = {} - refargname2inplace_foreacharg: Dict[str, Argument] = {} + inplace_foreacharg2refarg: dict[Argument, Argument] = {} + refargname2inplace_foreacharg: dict[str, Argument] = {} base_name_and_overload_name = (f.func.name.name.base, f.func.name.overload_name) if info is None: assert ( @@ -1077,8 +1080,8 @@ def emit_body( refargname2inplace_foreacharg[ref_arg.name] = foreach_arg def gen_differentiable_input( - arg: Union[Argument, SelfArgument, TensorOptionsArguments] - ) -> Optional[DifferentiableInput]: + arg: Argument | SelfArgument | TensorOptionsArguments, + ) -> DifferentiableInput | None: if isinstance(arg, TensorOptionsArguments): return None a: Argument = arg.argument if isinstance(arg, SelfArgument) else arg @@ -1097,7 +1100,7 @@ def emit_body( ) @with_native_function - def gen_differentiable_inputs(f: NativeFunction) -> List[DifferentiableInput]: + def gen_differentiable_inputs(f: NativeFunction) -> list[DifferentiableInput]: arguments = list(f.func.arguments.non_out) if is_inplace_foreach and info is not None: for i, arg in enumerate(f.func.arguments.flat_non_out): @@ -1115,8 +1118,8 @@ def emit_body( return list(mapMaybe(gen_differentiable_input, arguments)) def find_args_with_derivatives( - differentiable_inputs: List[DifferentiableInput], - ) -> List[DifferentiableInput]: + differentiable_inputs: list[DifferentiableInput], + ) -> list[DifferentiableInput]: """Find arguments that have derivative definitions""" if info is None or not info.has_derivatives: return differentiable_inputs @@ -1178,8 +1181,8 @@ def emit_body( and (not returns_void) ) - def emit_save_inputs() -> List[str]: - setup: List[str] = [] + def emit_save_inputs() -> list[str]: + setup: list[str] = [] if info is None or not info.has_derivatives: return setup @@ -1189,7 +1192,7 @@ def emit_body( # We don't want to save tensors if we know that they will never be used # when computing the derivative, so we add guards to those statements - def guard_for(arg: SavedAttribute) -> Optional[str]: + def guard_for(arg: SavedAttribute) -> str | None: assert info is not None # It's hard to determine the edge offset if we have TensorLists @@ -1276,8 +1279,8 @@ def emit_body( setup.append(f"grad_fn->{arg.name}_size_ = {arg.name}.size();") return setup - def setup_derivative(differentiable_inputs: List[DifferentiableInput]) -> List[str]: - body: List[str] = [] + def setup_derivative(differentiable_inputs: list[DifferentiableInput]) -> list[str]: + body: list[str] = [] if is_out_fn: # For out functions, ensure that no input or output requires grad body.append(DECLARE_GRAD_FN.substitute(op="Node")) @@ -1343,8 +1346,8 @@ def emit_body( body.append(SETUP_DERIVATIVE.substitute(setup=setup)) return body - def emit_check_if_in_complex_autograd_allowlist() -> List[str]: - body: List[str] = [] + def emit_check_if_in_complex_autograd_allowlist() -> list[str]: + body: list[str] = [] if base_name in GRADIENT_IMPLEMENTED_FOR_COMPLEX: return body for arg in differentiable_outputs: @@ -1355,11 +1358,11 @@ def emit_body( return body def emit_check_no_requires_grad( - tensor_args: List[DifferentiableInput], - args_with_derivatives: List[DifferentiableInput], - ) -> List[str]: + tensor_args: list[DifferentiableInput], + args_with_derivatives: list[DifferentiableInput], + ) -> list[str]: """Checks that arguments without derivatives don't require grad""" - body: List[str] = [] + body: list[str] = [] for arg in tensor_args: if arg in args_with_derivatives: continue @@ -1373,8 +1376,8 @@ def emit_body( body.append(f'check_no_requires_grad({arg_name}, "{arg_name}", "{name}");') return body - def emit_original_self_definition() -> List[str]: - body: List[str] = [] + def emit_original_self_definition() -> list[str]: + body: list[str] = [] if inplace: if is_inplace_foreach: body.append( @@ -1412,17 +1415,17 @@ def emit_body( def save_variables( saved_variables: Sequence[SavedAttribute], is_output: bool, - guard_for: Callable[[SavedAttribute], Optional[str]] = lambda name: None, + guard_for: Callable[[SavedAttribute], str | None] = lambda name: None, ) -> Sequence[str]: # assign the saved variables to the generated grad_fn - stmts: List[str] = [] + stmts: list[str] = [] for arg in sorted(saved_variables, key=lambda sa: str(sa.nctype.name)): name = ( arg.nctype.name.name if isinstance(arg.nctype.name, SpecialArgName) else arg.nctype.name ) - foreacharg: Optional[Argument] = None + foreacharg: Argument | None = None is_foreacharg_list_type: bool = False type = arg.nctype.type expr = arg.expr @@ -1539,10 +1542,10 @@ def emit_body( return call def wrap_output( - f: NativeFunction, unpacked_bindings: List[Binding], var: str + f: NativeFunction, unpacked_bindings: list[Binding], var: str ) -> str: call = "" - rhs_value: Optional[str] = None + rhs_value: str | None = None if not any(r.type.is_tensor_like() for r in f.func.returns): rhs_value = var else: @@ -1554,11 +1557,11 @@ def emit_body( return call def check_tensorimpl_and_storage( - call: str, unpacked_bindings: List[Binding] + call: str, unpacked_bindings: list[Binding] ) -> str: # See NOTE [ TensorImpl and Storage Pointer Sanity Checks ] - stmts_before_call: List[str] = [] - stmts_after_call: List[str] = [] + stmts_before_call: list[str] = [] + stmts_after_call: list[str] = [] if cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE: return call @@ -1665,7 +1668,7 @@ def emit_body( return call def emit_call( - f: NativeFunction, unpacked_bindings: List[Binding], try_jit_decomposition: bool + f: NativeFunction, unpacked_bindings: list[Binding], try_jit_decomposition: bool ) -> str: # We only care about adding `at::AutoDispatchBelowAutograd` guard for non-variable dispatch # (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure @@ -1764,7 +1767,7 @@ def emit_body( ) return "" - def emit_any_requires_grad() -> List[str]: + def emit_any_requires_grad() -> list[str]: extra_condition = "" if info and info.output_differentiability_conditions: assert len(info.output_differentiability_conditions) == 1 @@ -1782,14 +1785,14 @@ def emit_body( ) ] - def get_any_has_forward_grad_name(var_names: Tuple[str, ...]) -> str: + def get_any_has_forward_grad_name(var_names: tuple[str, ...]) -> str: if len(var_names) == 1: return f"_any_has_forward_grad_{var_names[0]}" else: return f'_any_has_forward_grad_{"_".join(var_names)}' - def emit_any_has_forward_grad() -> List[str]: - content: List[str] = [] + def emit_any_has_forward_grad() -> list[str]: + content: list[str] = [] if not is_foreach: for derivative in fw_derivatives: requires_fw_grad = get_any_has_fw_grad_cond(derivative=derivative) @@ -1844,7 +1847,7 @@ def emit_body( content.append("}") return content - def emit_check_inplace() -> List[str]: + def emit_check_inplace() -> list[str]: if not inplace: return [] return [ @@ -1852,9 +1855,9 @@ def emit_body( for arg in differentiable_outputs ] - def emit_fw_derivatives() -> List[str]: - content: List[str] = [] - fw_grad_setters: List[str] = [] + def emit_fw_derivatives() -> list[str]: + content: list[str] = [] + fw_grad_setters: list[str] = [] for derivative in fw_derivatives: res = derivative.var_names if f.func.name.name.inplace: @@ -2002,7 +2005,7 @@ def emit_body( "(self.size(), c10::nullopt);" ) foreach_forward_grad_formula = derivative.formula - _foreach_arg: Union[Argument, DifferentiableInput] + _foreach_arg: Argument | DifferentiableInput if inplace: for _foreach_arg, _ref_arg in inplace_foreacharg2refarg.items(): # note(crcrpar): Massage only Scalar and ArrayRef here. @@ -2044,7 +2047,7 @@ def emit_body( content.append("\n".join(fw_grad_setters)) return content - def get_any_has_fw_grad_cond(derivative: Optional[ForwardDerivative]) -> str: + def get_any_has_fw_grad_cond(derivative: ForwardDerivative | None) -> str: # # Produces a condition string (e.g, "isFwGradDefined(grad_output) || isFwGradDefined(output)") # @@ -2053,7 +2056,7 @@ def emit_body( # - Used in the out_fn case when we want to forbid fw derivatives # - Used in the case where the fw_derivative is not defined, but we want # To check if there is a decomposition registered for jvp - to_check: List[str] = [] + to_check: list[str] = [] for inp in list( mapMaybe( gen_differentiable_input, @@ -2126,7 +2129,7 @@ def emit_body( else "" ) - body: List[str] = [] + body: list[str] = [] unpack_args_stats, unpacked_bindings = unpack_args(f) body.extend(unpack_args_stats) diff --git a/tools/autograd/gen_view_funcs.py b/tools/autograd/gen_view_funcs.py index c9f7561dca17..7838f255c8c3 100644 --- a/tools/autograd/gen_view_funcs.py +++ b/tools/autograd/gen_view_funcs.py @@ -4,10 +4,11 @@ # if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp # The fallback is expected to mimic this codegen, so we should keep the two in sync. -from typing import List, Tuple +from __future__ import annotations + +from typing import TYPE_CHECKING import torchgen.api.dispatcher as dispatcher -from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo from torchgen.api.translate import translate from torchgen.api.types import ( BaseCType, @@ -29,6 +30,11 @@ from .gen_inplace_or_view_type import ( use_derived, ) + +if TYPE_CHECKING: + from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo + + FUNCTION_DECLARATION = CodeTemplate( """\ #define ${uppercase_op}_AVAILABLE @@ -155,9 +161,9 @@ def returns_multi_tensor(fn: NativeFunction) -> bool: # tuple: (list of getter logic strings, list of setter logic strings, string # with num items expression) def generate_state_getter_setter( - bindings: List[Binding], + bindings: list[Binding], state_vec_type: NamedCType, -) -> Tuple[List[str], List[str], str]: +) -> tuple[list[str], list[str], str]: getter_logic = [] setter_logic = [] @@ -302,7 +308,7 @@ def process_function(fn: NativeFunction, template: CodeTemplate) -> str: def gen_view_funcs( out: str, - fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo], + fns_with_infos: list[NativeFunctionWithDifferentiabilityInfo], template_path: str, ) -> None: # don't need the info parts, just the function diff --git a/tools/autograd/load_derivatives.py b/tools/autograd/load_derivatives.py index 47c975ce1aa0..96a37eb7c946 100644 --- a/tools/autograd/load_derivatives.py +++ b/tools/autograd/load_derivatives.py @@ -2,14 +2,16 @@ # # Each autograd function is represented by `DifferentiabilityInfo` containing # a list of `Derivative`. See `torchgen.api.autograd` for the data models. + +from __future__ import annotations + import re from collections import defaultdict -from typing import Any, Counter, Dict, List, Match, Optional, Sequence, Set, Tuple +from typing import Any, Counter, Dict, Sequence, Set, Tuple import yaml from torchgen.api import cpp - from torchgen.api.autograd import ( Derivative, DifferentiabilityInfo, @@ -50,9 +52,10 @@ from torchgen.model import ( from torchgen.utils import concatMap, IDENT_REGEX, split_name_params from torchgen.yaml_utils import YamlLoader + DerivativeRet = Tuple[Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], Set[str]] -_GLOBAL_LOAD_DERIVATIVE_CACHE: Dict[Tuple[str, str], DerivativeRet] = {} +_GLOBAL_LOAD_DERIVATIVE_CACHE: dict[tuple[str, str], DerivativeRet] = {} _VALID_AUTOGRAD_KEYS = set(AUTOGRAD_KEYS) @@ -62,11 +65,11 @@ _VALID_AUTOGRAD_KEYS = set(AUTOGRAD_KEYS) # we generate them here instead of duplicating them in the yaml. # See Note [Codegen'd {view}_copy Operators] def add_view_copy_derivatives( - infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], - view_groups: List[NativeFunctionsViewGroup], + infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]], + view_groups: list[NativeFunctionsViewGroup], ) -> None: # Get the map from each view op's name to its corresponding view group - view_name_to_group: Dict[OperatorName, NativeFunctionsViewGroup] = { + view_name_to_group: dict[OperatorName, NativeFunctionsViewGroup] = { g.view.func.name: g for g in view_groups } @@ -125,10 +128,10 @@ def load_derivatives( # function schema is the complete declaration including mutability annotation / default value and etc. # signature is the canonical schema for a group of functions (in-place/out/functional variants) # that are semantically related. - functions_by_signature: Dict[ - FunctionSchema, List[NativeFunction] + functions_by_signature: dict[ + FunctionSchema, list[NativeFunction] ] = defaultdict(list) - functions_by_schema: Dict[str, NativeFunction] = {} + functions_by_schema: dict[str, NativeFunction] = {} for function in native_functions: functions_by_signature[function.func.signature()].append(function) assert str(function.func) not in functions_by_schema @@ -141,8 +144,8 @@ def load_derivatives( # infos is a dict that maps FunctionSchema -> a dict of per dispatch key DifferentiabilityInfos # this is useful because in tools/autograd/gen_autograd.py:match_differentiability_info # we ultimately need to categorize the DifferentiabilityInfos by FunctionSchema - infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]] = {} - used_dispatch_keys: Set[str] = set() + infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]] = {} + used_dispatch_keys: set[str] = set() for defn_dict in definitions: # Ensure that the old derivatives.yaml schema with no dispatch key can be loaded. if "dispatch" not in defn_dict: @@ -185,11 +188,11 @@ def cpp_arguments(f: NativeFunction) -> Sequence[Binding]: def create_derivative( f: NativeFunction, formula: str, - var_names: Tuple[str, ...], + var_names: tuple[str, ...], available_named_gradients: Sequence[str], ) -> Derivative: original_formula = formula - arguments: List[NamedCType] = [ + arguments: list[NamedCType] = [ a.nctype.remove_const_ref() for a in cpp_arguments(f) ] @@ -230,10 +233,10 @@ def create_derivative( def create_forward_derivative( - f: NativeFunction, formula: str, names: Tuple[str, ...] + f: NativeFunction, formula: str, names: tuple[str, ...] ) -> ForwardDerivative: var_names = names - var_types: Optional[Tuple[Type, ...]] = None + var_types: tuple[Type, ...] | None = None for r in f.func.returns: if r.name in var_names: if var_types is None: @@ -269,12 +272,12 @@ def create_forward_derivative( def postprocess_forward_derivatives( f: NativeFunction, defn_name: str, - all_arg_names: List[str], - derivatives: List[Derivative], - forward_derivatives: List[ForwardDerivative], + all_arg_names: list[str], + derivatives: list[Derivative], + forward_derivatives: list[ForwardDerivative], args_with_derivatives: Sequence[Binding], -) -> List[ForwardDerivative]: - def find_required_inputs(formula: str, postfix: str) -> Tuple[str, ...]: +) -> list[ForwardDerivative]: + def find_required_inputs(formula: str, postfix: str) -> tuple[str, ...]: is_foreach = f.func.name.name.base.startswith("_foreach_") required_inputs = set() for arg in args_with_derivatives: @@ -300,7 +303,7 @@ def postprocess_forward_derivatives( return tuple(required_inputs) - updated_derivatives: List[ForwardDerivative] = [] + updated_derivatives: list[ForwardDerivative] = [] for defn in forward_derivatives: formula = defn.formula @@ -430,7 +433,7 @@ def postprocess_forward_derivatives( def is_forward_derivative_definition( - all_arg_names: List[str], names: Tuple[str, ...] + all_arg_names: list[str], names: tuple[str, ...] ) -> bool: for name in names: if name not in all_arg_names: @@ -441,12 +444,12 @@ def is_forward_derivative_definition( def create_differentiability_info( - defn_dict: Dict[Any, Any], - functions_by_signature: Dict[FunctionSchema, List[NativeFunction]], - functions_by_schema: Dict[str, NativeFunction], + defn_dict: dict[Any, Any], + functions_by_signature: dict[FunctionSchema, list[NativeFunction]], + functions_by_schema: dict[str, NativeFunction], op_counter: Counter[str], - used_dispatch_keys: Set[str], -) -> Tuple[FunctionSchema, Dict[str, DifferentiabilityInfo]]: + used_dispatch_keys: set[str], +) -> tuple[FunctionSchema, dict[str, DifferentiabilityInfo]]: """Processes a single entry `defn` in derivatives.yaml""" def canonical_function( @@ -463,7 +466,7 @@ def create_differentiability_info( assert name + "_" == cpp.name(functions[0].func) return functions[0] - def split_names(raw_names: str) -> Tuple[str, ...]: + def split_names(raw_names: str) -> tuple[str, ...]: """Given "foo, bar", return ["foo", "bar"].""" return tuple(x.strip() for x in raw_names.split(",")) @@ -477,7 +480,7 @@ def create_differentiability_info( uses_grad = False # true if any derivative uses "grad" num_grads_uses = 0 # count of uses of "grads" or "grads[INDEX]" uses_named_grads = False # true if any derivative uses "grad_{name}" - used_grads_indices: List[int] = [] # which indices of grads are used + used_grads_indices: list[int] = [] # which indices of grads are used for d in derivatives: formula = d.formula uses_grad = uses_grad or bool( @@ -521,7 +524,7 @@ def create_differentiability_info( @with_native_function def set_up_derivatives( f: NativeFunction, - ) -> Tuple[ + ) -> tuple[ Sequence[Derivative], Sequence[ForwardDerivative], Sequence[Binding], @@ -529,10 +532,10 @@ def create_differentiability_info( Sequence[str], ]: # Set up the derivative information - derivatives: List[Derivative] = [] - forward_derivatives: List[ForwardDerivative] = [] - non_differentiable_arg_names: List[str] = [] - args_with_derivatives_set: Set[str] = set() + derivatives: list[Derivative] = [] + forward_derivatives: list[ForwardDerivative] = [] + non_differentiable_arg_names: list[str] = [] + args_with_derivatives_set: set[str] = set() all_arg_names = [a.name for a in cpp_arguments(f)] all_ret_names = [ @@ -699,7 +702,7 @@ def create_differentiability_info( available_named_gradients, ) = set_up_derivatives(canonical) - used_named_gradients: Set[str] = set() + used_named_gradients: set[str] = set() for d in derivatives: used_named_gradients |= d.named_gradients @@ -738,7 +741,7 @@ def create_differentiability_info( GRAD_INDEX_REGEX = r"(?:^|\W)grads\[(\d+)\]" -def used_gradient_indices(formula: str) -> List[int]: +def used_gradient_indices(formula: str) -> list[int]: """Determine a list of gradient indices (the i in grads[i]) that are used by the formula. @@ -750,9 +753,9 @@ def used_gradient_indices(formula: str) -> List[int]: def saved_variables( formula: str, - nctypes: List[NamedCType], - var_names: Tuple[str, ...], -) -> Tuple[str, Tuple[SavedAttribute, ...]]: + nctypes: list[NamedCType], + var_names: tuple[str, ...], +) -> tuple[str, tuple[SavedAttribute, ...]]: def stride_expr(name: str) -> str: assert var_names == (name,), ( 'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor ' @@ -760,7 +763,7 @@ def saved_variables( ) return f'strides_or_error({name}, "{name}")' - REPLACEMENTS: List[Tuple[str, Dict[str, Any]]] = [ + REPLACEMENTS: list[tuple[str, dict[str, Any]]] = [ # replace self.sym_sizes() with self_sym_sizes ( r"{}.sym_sizes\(\)", @@ -914,7 +917,7 @@ def saved_variables( ] # find which arguments need to be saved - saved: List[SavedAttribute] = [] + saved: list[SavedAttribute] = [] if ".sizes()" in formula or "->sizes()" in formula: raise RuntimeError( @@ -941,7 +944,7 @@ def saved_variables( # when the autograd Function is created to avoid saving variables for regex, info in REPLACEMENTS: - def repl(m: Match[str]) -> str: + def repl(m: re.Match[str]) -> str: suffix: str = ( info["suffix"](m) if callable(info["suffix"]) else info["suffix"] ) @@ -999,8 +1002,8 @@ def _create_op_prefix(name: str) -> str: def dedup_vars(vars: Sequence[SavedAttribute]) -> Sequence[SavedAttribute]: - seen: Set[str] = set() - saved: List[SavedAttribute] = [] + seen: set[str] = set() + saved: list[SavedAttribute] = [] for var in vars: name = ( var.nctype.name.name diff --git a/tools/build_pytorch_libs.py b/tools/build_pytorch_libs.py index dbdf34eeda6e..64e7132a6910 100644 --- a/tools/build_pytorch_libs.py +++ b/tools/build_pytorch_libs.py @@ -1,17 +1,17 @@ +from __future__ import annotations + import os import platform import shutil from glob import glob -from typing import Dict, Optional from setuptools import distutils # type: ignore[import] from .setup_helpers.cmake import CMake, USE_NINJA - from .setup_helpers.env import check_negative_env_flag, IS_64BIT, IS_WINDOWS -def _overlay_windows_vcvars(env: Dict[str, str]) -> Dict[str, str]: +def _overlay_windows_vcvars(env: dict[str, str]) -> dict[str, str]: vc_arch = "x64" if IS_64BIT else "x86" if platform.machine() == "ARM64": @@ -34,7 +34,7 @@ def _overlay_windows_vcvars(env: Dict[str, str]) -> Dict[str, str]: "emulation is enabled!" ) - vc_env: Dict[str, str] = distutils._msvccompiler._get_vc_env(vc_arch) + vc_env: dict[str, str] = distutils._msvccompiler._get_vc_env(vc_arch) # Keys in `_get_vc_env` are always lowercase. # We turn them into uppercase before overlaying vcvars # because OS environ keys are always uppercase on Windows. @@ -47,7 +47,7 @@ def _overlay_windows_vcvars(env: Dict[str, str]) -> Dict[str, str]: return vc_env -def _create_build_env() -> Dict[str, str]: +def _create_build_env() -> dict[str, str]: # XXX - our cmake file sometimes looks at the system environment # and not cmake flags! # you should NEVER add something to this list. It is bad practice to @@ -72,8 +72,8 @@ def _create_build_env() -> Dict[str, str]: def build_caffe2( - version: Optional[str], - cmake_python_library: Optional[str], + version: str | None, + cmake_python_library: str | None, build_python: bool, rerun_cmake: bool, cmake_only: bool, diff --git a/tools/build_with_debinfo.py b/tools/build_with_debinfo.py index b8d4a634849a..066d6ce414d6 100755 --- a/tools/build_with_debinfo.py +++ b/tools/build_with_debinfo.py @@ -5,10 +5,13 @@ # - ninja -j1 -v -n torch_python | sed -e 's/-O[23]/-g/g' -e 's#\[[0-9]\+\/[0-9]\+\] \+##' |sh # - Copy libs from build/lib to torch/lib folder +from __future__ import annotations + import subprocess import sys from pathlib import Path -from typing import Any, List, Optional, Tuple +from typing import Any + PYTORCH_ROOTDIR = Path(__file__).resolve().parent.parent TORCH_DIR = PYTORCH_ROOTDIR / "torch" @@ -17,7 +20,7 @@ BUILD_DIR = PYTORCH_ROOTDIR / "build" BUILD_LIB_DIR = BUILD_DIR / "lib" -def check_output(args: List[str], cwd: Optional[str] = None) -> str: +def check_output(args: list[str], cwd: str | None = None) -> str: return subprocess.check_output(args, cwd=cwd).decode("utf-8") @@ -63,7 +66,7 @@ def is_devel_setup() -> bool: return output.strip() == str(TORCH_DIR / "__init__.py") -def create_build_plan() -> List[Tuple[str, str]]: +def create_build_plan() -> list[tuple[str, str]]: output = check_output( ["ninja", "-j1", "-v", "-n", "torch_python"], cwd=str(BUILD_DIR) ) diff --git a/tools/code_analyzer/gen_op_registration_allowlist.py b/tools/code_analyzer/gen_op_registration_allowlist.py index bc73b61ee7e1..073ea3f3d67f 100644 --- a/tools/code_analyzer/gen_op_registration_allowlist.py +++ b/tools/code_analyzer/gen_op_registration_allowlist.py @@ -8,13 +8,15 @@ For custom build with static dispatch, the op dependency graph will be omitted, and it will directly output root ops as the allowlist. """ -import argparse +from __future__ import annotations +import argparse from collections import defaultdict -from typing import Dict, List, Set +from typing import Dict, Set import yaml + DepGraph = Dict[str, Set[str]] @@ -34,7 +36,7 @@ def load_op_dep_graph(fname: str) -> DepGraph: return dict(result) -def load_root_ops(fname: str) -> List[str]: +def load_root_ops(fname: str) -> list[str]: result = [] with open(fname) as stream: for op in yaml.safe_load(stream): @@ -44,9 +46,9 @@ def load_root_ops(fname: str) -> List[str]: def gen_transitive_closure( dep_graph: DepGraph, - root_ops: List[str], + root_ops: list[str], train: bool = False, -) -> List[str]: +) -> list[str]: result = set(root_ops) queue = root_ops.copy() @@ -73,7 +75,7 @@ def gen_transitive_closure( return sorted(result) -def gen_transitive_closure_str(dep_graph: DepGraph, root_ops: List[str]) -> str: +def gen_transitive_closure_str(dep_graph: DepGraph, root_ops: list[str]) -> str: return " ".join(gen_transitive_closure(dep_graph, root_ops)) diff --git a/tools/code_analyzer/gen_operators_yaml.py b/tools/code_analyzer/gen_operators_yaml.py index 8d92aa3adb70..ede651679847 100644 --- a/tools/code_analyzer/gen_operators_yaml.py +++ b/tools/code_analyzer/gen_operators_yaml.py @@ -1,8 +1,11 @@ #!/usr/bin/env python3 + +from __future__ import annotations + import argparse import json import sys -from typing import Any, Dict, List, Optional +from typing import Any import yaml from gen_op_registration_allowlist import ( @@ -17,6 +20,7 @@ from torchgen.selective_build.operator import ( ) from torchgen.selective_build.selector import merge_kernel_metadata + # Generate YAML file containing the operators used for a specific PyTorch model. # ------------------------------------------------------------------------------ # @@ -84,17 +88,17 @@ from torchgen.selective_build.selector import merge_kernel_metadata # -def canonical_opnames(opnames: List[str]) -> List[str]: +def canonical_opnames(opnames: list[str]) -> list[str]: return [canonical_name(opname) for opname in opnames] def make_filter_from_options( model_name: str, - model_versions: List[str], - model_assets: Optional[List[str]], - model_backends: Optional[List[str]], + model_versions: list[str], + model_assets: list[str] | None, + model_backends: list[str] | None, ): - def is_model_included(model_info): + def is_model_included(model_info) -> bool: model = model_info["model"] if model["name"] != model_name: return False @@ -109,7 +113,7 @@ def make_filter_from_options( # Returns if a the specified rule is a new or old style pt_operator_library -def is_new_style_rule(model_name: str, model_versions: Optional[List[str]]): +def is_new_style_rule(model_name: str, model_versions: list[str] | None): return model_name is not None and model_versions is not None @@ -117,13 +121,13 @@ def is_new_style_rule(model_name: str, model_versions: Optional[List[str]]): # appear in at least one model yaml. Throws if verification is failed, # returns None on success def verify_all_specified_present( - model_assets: Optional[List[str]], - model_versions: List[str], - selected_models_yaml: List[Dict[str, Any]], + model_assets: list[str] | None, + model_versions: list[str], + selected_models_yaml: list[dict[str, Any]], rule_name: str, model_name: str, new_style_rule: bool, -): +) -> None: def find_missing_items(model_items, key, selected_models_yaml): missing_items = [] if not new_style_rule or not model_items: @@ -179,10 +183,10 @@ def verify_all_specified_present( # Uses the selected models configs and then combines them into one dictionary, # formats them as a string, and places the string into output as a top level debug_info def create_debug_info_from_selected_models( - output: Dict[str, object], - selected_models: List[dict], + output: dict[str, object], + selected_models: list[dict], new_style_rule: bool, -): +) -> None: model_dict = { "asset_info": {}, # maps asset name -> dict of asset metadata like hashes "is_new_style_rule": new_style_rule, @@ -201,7 +205,7 @@ def create_debug_info_from_selected_models( output["debug_info"] = [json.dumps(model_dict)] -def fill_output(output: Dict[str, object], options: object): +def fill_output(output: dict[str, object], options: object) -> None: """Populate the output dict with the information required to serialize the YAML file used for selective build. """ @@ -458,7 +462,7 @@ def fill_output(output: Dict[str, object], options: object): # END TRACING BASED BUILD OPS # Merge dictionaries together to remove op duplication - operators: Dict[str, SelectiveBuildOperator] = {} + operators: dict[str, SelectiveBuildOperator] = {} for ops_dict in bucketed_ops: operators = merge_operator_dicts(operators, ops_dict) diff --git a/tools/code_analyzer/gen_oplist.py b/tools/code_analyzer/gen_oplist.py index 4579abf28fa1..0d735cdb3d44 100644 --- a/tools/code_analyzer/gen_oplist.py +++ b/tools/code_analyzer/gen_oplist.py @@ -1,10 +1,13 @@ #!/usr/bin/env python3 + +from __future__ import annotations + import argparse import json import os import sys from functools import reduce -from typing import Any, List, Set +from typing import Any import yaml from tools.lite_interpreter.gen_selected_mobile_ops_header import ( @@ -17,11 +20,11 @@ from torchgen.selective_build.selector import ( ) -def extract_all_operators(selective_builder: SelectiveBuilder) -> Set[str]: +def extract_all_operators(selective_builder: SelectiveBuilder) -> set[str]: return set(selective_builder.operators.keys()) -def extract_training_operators(selective_builder: SelectiveBuilder) -> Set[str]: +def extract_training_operators(selective_builder: SelectiveBuilder) -> set[str]: ops = [] for op_name, op in selective_builder.operators.items(): if op.is_used_for_training: @@ -44,7 +47,7 @@ def throw_if_any_op_includes_overloads(selective_builder: SelectiveBuilder) -> N ) -def gen_supported_mobile_models(model_dicts: List[Any], output_dir: str) -> None: +def gen_supported_mobile_models(model_dicts: list[Any], output_dir: str) -> None: supported_mobile_models_source = """/* * Generated by gen_oplist.py */ @@ -87,7 +90,7 @@ SupportedMobileModelCheckerRegistry register_model_versions; out_file.write(source.encode("utf-8")) -def main(argv: List[Any]) -> None: +def main(argv: list[Any]) -> None: """This binary generates 3 files: 1. selected_mobile_ops.h: Primary operators used by templated selective build and Kernel Function diff --git a/tools/code_coverage/package/oss/init.py b/tools/code_coverage/package/oss/init.py index ece3041231eb..1b1c320e3758 100644 --- a/tools/code_coverage/package/oss/init.py +++ b/tools/code_coverage/package/oss/init.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import argparse import os -from typing import cast, List, Optional, Tuple +from typing import cast from ..util.setting import ( CompilerType, @@ -38,7 +40,7 @@ BLOCKED_PYTHON_TESTS = { } -def initialization() -> Tuple[Option, TestList, List[str]]: +def initialization() -> tuple[Option, TestList, list[str]]: # create folder if not exists create_folders() # add arguments @@ -77,7 +79,7 @@ def add_arguments_oss(parser: argparse.ArgumentParser) -> argparse.ArgumentParse def parse_arguments( parser: argparse.ArgumentParser, -) -> Tuple[Option, Optional[List[str]], Optional[List[str]], Optional[bool]]: +) -> tuple[Option, list[str] | None, list[str] | None, bool | None]: # parse args args = parser.parse_args() # get option @@ -85,9 +87,7 @@ def parse_arguments( return (options, args.interest_only, args.run_only, args.clean) -def get_test_list_by_type( - run_only: Optional[List[str]], test_type: TestType -) -> TestList: +def get_test_list_by_type(run_only: list[str] | None, test_type: TestType) -> TestList: test_list: TestList = [] binary_folder = get_oss_binary_folder(test_type) g = os.walk(binary_folder) @@ -106,7 +106,7 @@ def get_test_list_by_type( return test_list -def get_test_list(run_only: Optional[List[str]]) -> TestList: +def get_test_list(run_only: list[str] | None) -> TestList: test_list: TestList = [] # add c++ test list test_list.extend(get_test_list_by_type(run_only, TestType.CPP)) @@ -122,7 +122,7 @@ def get_test_list(run_only: Optional[List[str]]) -> TestList: return test_list -def empty_list_if_none(arg_interested_folder: Optional[List[str]]) -> List[str]: +def empty_list_if_none(arg_interested_folder: list[str] | None) -> list[str]: if arg_interested_folder is None: return [] # if this argument is specified, just return itself @@ -134,7 +134,7 @@ def gcc_export_init() -> None: create_folder(JSON_FOLDER_BASE_DIR) -def get_python_run_only(args_run_only: Optional[List[str]]) -> List[str]: +def get_python_run_only(args_run_only: list[str] | None) -> list[str]: # if user specifies run-only option if args_run_only: return args_run_only @@ -144,7 +144,7 @@ def get_python_run_only(args_run_only: Optional[List[str]]) -> List[str]: return ["run_test.py"] else: # for clang, some tests will result in too large intermediate files that can't be merged by llvm, we need to skip them - run_only: List[str] = [] + run_only: list[str] = [] binary_folder = get_oss_binary_folder(TestType.PY) g = os.walk(binary_folder) for _, _, file_list in g: diff --git a/tools/code_coverage/package/oss/utils.py b/tools/code_coverage/package/oss/utils.py index 1cb67dc71439..c4019d762893 100644 --- a/tools/code_coverage/package/oss/utils.py +++ b/tools/code_coverage/package/oss/utils.py @@ -1,6 +1,7 @@ +from __future__ import annotations + import os import subprocess -from typing import List, Optional from ..util.setting import CompilerType, TestType, TOOLS_FOLDER from ..util.utils import print_error, remove_file @@ -14,7 +15,7 @@ def get_oss_binary_folder(test_type: TestType) -> str: ) -def get_oss_shared_library() -> List[str]: +def get_oss_shared_library() -> list[str]: lib_dir = os.path.join(get_pytorch_folder(), "build", "lib") return [ os.path.join(lib_dir, lib) @@ -48,7 +49,7 @@ def get_pytorch_folder() -> str: ) -def detect_compiler_type() -> Optional[CompilerType]: +def detect_compiler_type() -> CompilerType | None: # check if user specifies the compiler type user_specify = os.environ.get("CXX", None) if user_specify: @@ -76,7 +77,7 @@ def clean_up_gcda() -> None: remove_file(item) -def get_gcda_files() -> List[str]: +def get_gcda_files() -> list[str]: folder_has_gcda = os.path.join(get_pytorch_folder(), "build") if os.path.isdir(folder_has_gcda): # TODO use glob diff --git a/tools/code_coverage/package/tool/clang_coverage.py b/tools/code_coverage/package/tool/clang_coverage.py index a6b1fa0c0812..36c353558927 100644 --- a/tools/code_coverage/package/tool/clang_coverage.py +++ b/tools/code_coverage/package/tool/clang_coverage.py @@ -1,7 +1,8 @@ +from __future__ import annotations + import os import subprocess import time -from typing import List from ..util.setting import ( JSON_FOLDER_BASE_DIR, @@ -25,7 +26,7 @@ from .utils import get_tool_path_by_platform, run_cpp_test def create_corresponding_folder( - cur_path: str, prefix_cur_path: str, dir_list: List[str], new_base_folder: str + cur_path: str, prefix_cur_path: str, dir_list: list[str], new_base_folder: str ) -> None: for dir_name in dir_list: relative_path = convert_to_relative_path( @@ -70,7 +71,7 @@ def export_target( merged_file: str, json_file: str, binary_file: str, - shared_library_list: List[str], + shared_library_list: list[str], platform_type: TestPlatform, ) -> None: if binary_file is None: diff --git a/tools/code_coverage/package/tool/gcc_coverage.py b/tools/code_coverage/package/tool/gcc_coverage.py index d8de71aab6c1..e344b2f425aa 100644 --- a/tools/code_coverage/package/tool/gcc_coverage.py +++ b/tools/code_coverage/package/tool/gcc_coverage.py @@ -1,7 +1,8 @@ +from __future__ import annotations + import os import subprocess import time -from typing import Dict # gcc is only used in oss from ..oss.utils import get_gcda_files, run_oss_python_test @@ -10,7 +11,7 @@ from ..util.utils import print_log, print_time from .utils import run_cpp_test -def update_gzip_dict(gzip_dict: Dict[str, int], file_name: str) -> str: +def update_gzip_dict(gzip_dict: dict[str, int], file_name: str) -> str: file_name = file_name.lower() gzip_dict[file_name] = gzip_dict.get(file_name, 0) + 1 num = gzip_dict[file_name] @@ -34,7 +35,7 @@ def export() -> None: # collect .gcda files gcda_files = get_gcda_files() # file name like utils.cpp may have same name in different folder - gzip_dict: Dict[str, int] = {} + gzip_dict: dict[str, int] = {} for gcda_item in gcda_files: # generate json.gz subprocess.check_call(["gcov", "-i", gcda_item]) diff --git a/tools/code_coverage/package/tool/parser/coverage_record.py b/tools/code_coverage/package/tool/parser/coverage_record.py index 1d6698aa861c..7693abf26496 100644 --- a/tools/code_coverage/package/tool/parser/coverage_record.py +++ b/tools/code_coverage/package/tool/parser/coverage_record.py @@ -1,12 +1,14 @@ -import typing as t +from __future__ import annotations + +from typing import Any, NamedTuple -class CoverageRecord(t.NamedTuple): +class CoverageRecord(NamedTuple): filepath: str - covered_lines: t.List[int] - uncovered_lines: t.Optional[t.List[int]] = None + covered_lines: list[int] + uncovered_lines: list[int] | None = None - def to_dict(self) -> t.Dict[str, t.Any]: + def to_dict(self) -> dict[str, Any]: return { "filepath": self.filepath, "covered_lines": self.covered_lines, diff --git a/tools/code_coverage/package/tool/parser/gcov_coverage_parser.py b/tools/code_coverage/package/tool/parser/gcov_coverage_parser.py index 000b72cbc2d8..9529cf9f1520 100644 --- a/tools/code_coverage/package/tool/parser/gcov_coverage_parser.py +++ b/tools/code_coverage/package/tool/parser/gcov_coverage_parser.py @@ -1,4 +1,6 @@ -from typing import Any, Dict, List, Set +from __future__ import annotations + +from typing import Any from .coverage_record import CoverageRecord @@ -10,7 +12,7 @@ class GcovCoverageParser: of CoverageRecord(s). """ - def __init__(self, llvm_coverage: Dict[str, Any]) -> None: + def __init__(self, llvm_coverage: dict[str, Any]) -> None: self._llvm_coverage = llvm_coverage @staticmethod @@ -24,17 +26,17 @@ class GcovCoverageParser: return True return False - def parse(self) -> List[CoverageRecord]: + def parse(self) -> list[CoverageRecord]: # The JSON format is described in the gcov source code # https://gcc.gnu.org/onlinedocs/gcc/Invoking-Gcov.html - records: List[CoverageRecord] = [] + records: list[CoverageRecord] = [] for file_info in self._llvm_coverage["files"]: filepath = file_info["file"] if self._skip_coverage(filepath): continue # parse json file - covered_lines: Set[int] = set() - uncovered_lines: Set[int] = set() + covered_lines: set[int] = set() + uncovered_lines: set[int] = set() for line in file_info["lines"]: line_number = line["line_number"] count = line["count"] diff --git a/tools/code_coverage/package/tool/parser/llvm_coverage_parser.py b/tools/code_coverage/package/tool/parser/llvm_coverage_parser.py index 691bb40a2a2b..0ea85df2e2e6 100644 --- a/tools/code_coverage/package/tool/parser/llvm_coverage_parser.py +++ b/tools/code_coverage/package/tool/parser/llvm_coverage_parser.py @@ -1,4 +1,6 @@ -from typing import Any, Dict, List, Set, Tuple +from __future__ import annotations + +from typing import Any from .coverage_record import CoverageRecord from .llvm_coverage_segment import LlvmCoverageSegment, parse_segments @@ -12,7 +14,7 @@ class LlvmCoverageParser: """ - def __init__(self, llvm_coverage: Dict[str, Any]) -> None: + def __init__(self, llvm_coverage: dict[str, Any]) -> None: self._llvm_coverage = llvm_coverage @staticmethod @@ -28,13 +30,13 @@ class LlvmCoverageParser: @staticmethod def _collect_coverage( - segments: List[LlvmCoverageSegment], - ) -> Tuple[List[int], List[int]]: + segments: list[LlvmCoverageSegment], + ) -> tuple[list[int], list[int]]: """ Stateful parsing of coverage segments. """ - covered_lines: Set[int] = set() - uncovered_lines: Set[int] = set() + covered_lines: set[int] = set() + uncovered_lines: set[int] = set() prev_segment = LlvmCoverageSegment(1, 0, 0, 0, 0, None) for segment in segments: covered_range, uncovered_range = segment.get_coverage(prev_segment) @@ -45,10 +47,10 @@ class LlvmCoverageParser: uncovered_lines.difference_update(covered_lines) return sorted(covered_lines), sorted(uncovered_lines) - def parse(self, repo_name: str) -> List[CoverageRecord]: + def parse(self, repo_name: str) -> list[CoverageRecord]: # The JSON format is described in the LLVM source code # https://github.com/llvm-mirror/llvm/blob/master/tools/llvm-cov/CoverageExporterJson.cpp - records: List[CoverageRecord] = [] + records: list[CoverageRecord] = [] for export_unit in self._llvm_coverage["data"]: for file_info in export_unit["files"]: filepath = file_info["filename"] diff --git a/tools/code_coverage/package/tool/parser/llvm_coverage_segment.py b/tools/code_coverage/package/tool/parser/llvm_coverage_segment.py index 6f780c1e0f82..63b1e4baf51f 100644 --- a/tools/code_coverage/package/tool/parser/llvm_coverage_segment.py +++ b/tools/code_coverage/package/tool/parser/llvm_coverage_segment.py @@ -1,4 +1,6 @@ -from typing import List, NamedTuple, Optional, Tuple +from __future__ import annotations + +from typing import NamedTuple class LlvmCoverageSegment(NamedTuple): @@ -7,7 +9,7 @@ class LlvmCoverageSegment(NamedTuple): segment_count: int has_count: int is_region_entry: int - is_gap_entry: Optional[int] + is_gap_entry: int | None @property def has_coverage(self) -> bool: @@ -18,8 +20,8 @@ class LlvmCoverageSegment(NamedTuple): return self.has_count > 0 def get_coverage( - self, prev_segment: "LlvmCoverageSegment" - ) -> Tuple[List[int], List[int]]: + self, prev_segment: LlvmCoverageSegment + ) -> tuple[list[int], list[int]]: # Code adapted from testpilot.testinfra.runners.gtestcoveragerunner.py if not prev_segment.is_executable: return [], [] @@ -32,12 +34,12 @@ class LlvmCoverageSegment(NamedTuple): return (lines_range, []) if prev_segment.has_coverage else ([], lines_range) -def parse_segments(raw_segments: List[List[int]]) -> List[LlvmCoverageSegment]: +def parse_segments(raw_segments: list[list[int]]) -> list[LlvmCoverageSegment]: """ Creates LlvmCoverageSegment from a list of lists in llvm export json. each segment is represented by 5-element array. """ - ret: List[LlvmCoverageSegment] = [] + ret: list[LlvmCoverageSegment] = [] for raw_segment in raw_segments: assert ( len(raw_segment) == 5 or len(raw_segment) == 6 diff --git a/tools/code_coverage/package/tool/print_report.py b/tools/code_coverage/package/tool/print_report.py index 130d82fe3136..a597e580f8c7 100644 --- a/tools/code_coverage/package/tool/print_report.py +++ b/tools/code_coverage/package/tool/print_report.py @@ -1,10 +1,13 @@ +from __future__ import annotations + import os import subprocess -from typing import Dict, IO, List, Set, Tuple +from typing import IO, Tuple from ..oss.utils import get_pytorch_folder from ..util.setting import SUMMARY_FOLDER_DIR, TestList, TestStatusType + CoverageItem = Tuple[str, float, int, int] @@ -16,7 +19,7 @@ def key_by_name(x: CoverageItem) -> str: return x[0] -def is_intrested_file(file_path: str, interested_folders: List[str]) -> bool: +def is_intrested_file(file_path: str, interested_folders: list[str]) -> bool: if "cuda" in file_path: return False if "aten/gen_aten" in file_path or "aten/aten_" in file_path: @@ -27,7 +30,7 @@ def is_intrested_file(file_path: str, interested_folders: List[str]) -> bool: return False -def is_this_type_of_tests(target_name: str, test_set_by_type: Set[str]) -> bool: +def is_this_type_of_tests(target_name: str, test_set_by_type: set[str]) -> bool: # tests are divided into three types: success / partial success / fail to collect coverage for test in test_set_by_type: if target_name in test: @@ -36,7 +39,7 @@ def is_this_type_of_tests(target_name: str, test_set_by_type: Set[str]) -> bool: def print_test_by_type( - tests: TestList, test_set_by_type: Set[str], type_name: str, summary_file: IO[str] + tests: TestList, test_set_by_type: set[str], type_name: str, summary_file: IO[str] ) -> None: print("Tests " + type_name + " to collect coverage:", file=summary_file) for test in tests: @@ -48,8 +51,8 @@ def print_test_by_type( def print_test_condition( tests: TestList, tests_type: TestStatusType, - interested_folders: List[str], - coverage_only: List[str], + interested_folders: list[str], + coverage_only: list[str], summary_file: IO[str], summary_type: str, ) -> None: @@ -77,10 +80,10 @@ def print_test_condition( def line_oriented_report( tests: TestList, tests_type: TestStatusType, - interested_folders: List[str], - coverage_only: List[str], - covered_lines: Dict[str, Set[int]], - uncovered_lines: Dict[str, Set[int]], + interested_folders: list[str], + coverage_only: list[str], + covered_lines: dict[str, set[int]], + uncovered_lines: dict[str, set[int]], ) -> None: with open(os.path.join(SUMMARY_FOLDER_DIR, "line_summary"), "w+") as report_file: print_test_condition( @@ -119,13 +122,13 @@ def print_file_summary( def print_file_oriented_report( tests_type: TestStatusType, - coverage: List[CoverageItem], + coverage: list[CoverageItem], covered_summary: int, total_summary: int, summary_file: IO[str], tests: TestList, - interested_folders: List[str], - coverage_only: List[str], + interested_folders: list[str], + coverage_only: list[str], ) -> None: coverage_percentage = print_file_summary( covered_summary, total_summary, summary_file @@ -155,10 +158,10 @@ def print_file_oriented_report( def file_oriented_report( tests: TestList, tests_type: TestStatusType, - interested_folders: List[str], - coverage_only: List[str], - covered_lines: Dict[str, Set[int]], - uncovered_lines: Dict[str, Set[int]], + interested_folders: list[str], + coverage_only: list[str], + covered_lines: dict[str, set[int]], + uncovered_lines: dict[str, set[int]], ) -> None: with open(os.path.join(SUMMARY_FOLDER_DIR, "file_summary"), "w+") as summary_file: covered_summary = 0 @@ -193,7 +196,7 @@ def file_oriented_report( ) -def get_html_ignored_pattern() -> List[str]: +def get_html_ignored_pattern() -> list[str]: return ["/usr/*", "*anaconda3/*", "*third_party/*"] diff --git a/tools/code_coverage/package/tool/summarize_jsons.py b/tools/code_coverage/package/tool/summarize_jsons.py index 7c5d8891ea83..f97cadde888f 100644 --- a/tools/code_coverage/package/tool/summarize_jsons.py +++ b/tools/code_coverage/package/tool/summarize_jsons.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import json import os import time -from typing import Any, Dict, List, Set, Tuple +from typing import Any, TYPE_CHECKING from ..util.setting import ( CompilerType, @@ -16,7 +18,6 @@ from ..util.utils import ( print_time, related_to_test_list, ) -from .parser.coverage_record import CoverageRecord from .parser.gcov_coverage_parser import GcovCoverageParser from .parser.llvm_coverage_parser import LlvmCoverageParser from .print_report import ( @@ -26,16 +27,20 @@ from .print_report import ( ) +if TYPE_CHECKING: + from .parser.coverage_record import CoverageRecord + + # coverage_records: Dict[str, LineInfo] = {} -covered_lines: Dict[str, Set[int]] = {} -uncovered_lines: Dict[str, Set[int]] = {} +covered_lines: dict[str, set[int]] = {} +uncovered_lines: dict[str, set[int]] = {} tests_type: TestStatusType = {"success": set(), "partial": set(), "fail": set()} def transform_file_name( - file_path: str, interested_folders: List[str], platform: TestPlatform + file_path: str, interested_folders: list[str], platform: TestPlatform ) -> str: - remove_patterns: Set[str] = {".DEFAULT.cpp", ".AVX.cpp", ".AVX2.cpp"} + remove_patterns: set[str] = {".DEFAULT.cpp", ".AVX.cpp", ".AVX2.cpp"} for pattern in remove_patterns: file_path = file_path.replace(pattern, "") # if user has specified interested folder @@ -54,7 +59,7 @@ def transform_file_name( def is_intrested_file( - file_path: str, interested_folders: List[str], platform: TestPlatform + file_path: str, interested_folders: list[str], platform: TestPlatform ) -> bool: ignored_patterns = ["cuda", "aten/gen_aten", "aten/aten_", "build/"] if any(pattern in file_path for pattern in ignored_patterns): @@ -77,7 +82,7 @@ def is_intrested_file( return True -def get_json_obj(json_file: str) -> Tuple[Any, int]: +def get_json_obj(json_file: str) -> tuple[Any, int]: """ Sometimes at the start of file llvm/gcov will complains "fail to find coverage data", then we need to skip these lines @@ -102,7 +107,7 @@ def get_json_obj(json_file: str) -> Tuple[Any, int]: return None, 2 -def parse_json(json_file: str, platform: TestPlatform) -> List[CoverageRecord]: +def parse_json(json_file: str, platform: TestPlatform) -> list[CoverageRecord]: print("start parse:", json_file) json_obj, read_status = get_json_obj(json_file) if read_status == 0: @@ -117,7 +122,7 @@ def parse_json(json_file: str, platform: TestPlatform) -> List[CoverageRecord]: cov_type = detect_compiler_type(platform) - coverage_records: List[CoverageRecord] = [] + coverage_records: list[CoverageRecord] = [] if cov_type == CompilerType.CLANG: coverage_records = LlvmCoverageParser(json_obj).parse("fbcode") # print(coverage_records) @@ -128,7 +133,7 @@ def parse_json(json_file: str, platform: TestPlatform) -> List[CoverageRecord]: def parse_jsons( - test_list: TestList, interested_folders: List[str], platform: TestPlatform + test_list: TestList, interested_folders: list[str], platform: TestPlatform ) -> None: g = os.walk(JSON_FOLDER_BASE_DIR) @@ -152,8 +157,8 @@ def parse_jsons( def update_coverage( - coverage_records: List[CoverageRecord], - interested_folders: List[str], + coverage_records: list[CoverageRecord], + interested_folders: list[str], platform: TestPlatform, ) -> None: for item in coverage_records: @@ -187,8 +192,8 @@ def update_set() -> None: def summarize_jsons( test_list: TestList, - interested_folders: List[str], - coverage_only: List[str], + interested_folders: list[str], + coverage_only: list[str], platform: TestPlatform, ) -> None: start_time = time.time() diff --git a/tools/code_coverage/package/util/setting.py b/tools/code_coverage/package/util/setting.py index ed5efc3a751c..a1d7683a8a9a 100644 --- a/tools/code_coverage/package/util/setting.py +++ b/tools/code_coverage/package/util/setting.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os from enum import Enum from typing import Dict, List, Set diff --git a/tools/code_coverage/package/util/utils.py b/tools/code_coverage/package/util/utils.py index ddeef943986e..2b2a4200463e 100644 --- a/tools/code_coverage/package/util/utils.py +++ b/tools/code_coverage/package/util/utils.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import os import shutil import sys import time -from typing import Any, NoReturn, Optional +from typing import Any, NoReturn from .setting import ( CompilerType, @@ -113,7 +115,7 @@ def get_test_name_from_whole_path(path: str) -> str: return path[start + 1 : end] -def check_compiler_type(cov_type: Optional[CompilerType]) -> None: +def check_compiler_type(cov_type: CompilerType | None) -> None: if cov_type is not None and cov_type in [CompilerType.GCC, CompilerType.CLANG]: return raise Exception( # noqa: TRY002 diff --git a/tools/coverage_plugins_package/setup.py b/tools/coverage_plugins_package/setup.py index e3e88067cb08..10214ec72f9f 100644 --- a/tools/coverage_plugins_package/setup.py +++ b/tools/coverage_plugins_package/setup.py @@ -1,5 +1,6 @@ import setuptools # type: ignore[import] + with open("README.md", encoding="utf-8") as fh: long_description = fh.read() diff --git a/tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py b/tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py index e6d0786a32b0..72594abefd0a 100644 --- a/tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py +++ b/tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py @@ -22,6 +22,7 @@ from typing import Any from coverage import CoverageData, CoveragePlugin # type: ignore[import] + # All coverage stats resulting from this plug-in will be in a separate .coverage file that should be merged later with # `coverage combine`. The convention seems to be .coverage.dotted.suffix based on the following link: # https://coverage.readthedocs.io/en/coverage-5.5/cmd.html#combining-data-files-coverage-combine diff --git a/tools/download_mnist.py b/tools/download_mnist.py index 021904959a9e..4fe6068fed9b 100644 --- a/tools/download_mnist.py +++ b/tools/download_mnist.py @@ -5,6 +5,7 @@ import sys from urllib.error import URLError from urllib.request import urlretrieve + MIRRORS = [ "http://yann.lecun.com/exdb/mnist/", "https://ossci-datasets.s3.amazonaws.com/mnist/", diff --git a/tools/dynamo/verify_dynamo.py b/tools/dynamo/verify_dynamo.py index e62d74043c7c..ec09fbd2b489 100644 --- a/tools/dynamo/verify_dynamo.py +++ b/tools/dynamo/verify_dynamo.py @@ -5,6 +5,7 @@ import sys import traceback import warnings + MIN_CUDA_VERSION = "11.6" MIN_ROCM_VERSION = "5.4" MIN_PYTHON_VERSION = (3, 8) @@ -141,7 +142,7 @@ def check_rocm(): return rocm_ver if torch.version.hip else "None" -def check_dynamo(backend, device, err_msg): +def check_dynamo(backend, device, err_msg) -> None: import torch if device == "cuda" and not torch.cuda.is_available(): @@ -203,7 +204,7 @@ _SANITY_CHECK_ARGS = ( ) -def main(): +def main() -> None: python_ver = check_python() torch_ver = check_torch() cuda_ver = check_cuda() diff --git a/tools/extract_scripts.py b/tools/extract_scripts.py index 75120c0ece9d..9fdf2a2d6d27 100755 --- a/tools/extract_scripts.py +++ b/tools/extract_scripts.py @@ -1,14 +1,17 @@ #!/usr/bin/env python3 +from __future__ import annotations + import argparse import re import sys from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Dict from typing_extensions import TypedDict # Python 3.11+ import yaml + Step = Dict[str, Any] @@ -17,7 +20,7 @@ class Script(TypedDict): script: str -def extract(step: Step) -> Optional[Script]: +def extract(step: Step) -> Script | None: run = step.get("run") # https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#using-a-specific-shell diff --git a/tools/gen_vulkan_spv.py b/tools/gen_vulkan_spv.py index 653a18a1ac18..a64fb45591f2 100644 --- a/tools/gen_vulkan_spv.py +++ b/tools/gen_vulkan_spv.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +from __future__ import annotations + import argparse import array import codecs @@ -15,7 +17,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) import subprocess import textwrap from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any import yaml from yaml.constructor import ConstructorError @@ -29,14 +31,14 @@ except ImportError: CPP_H_NAME = "spv.h" CPP_SRC_NAME = "spv.cpp" -DEFAULT_ENV: Dict[str, Any] = { +DEFAULT_ENV: dict[str, Any] = { "PRECISION": "highp", "FLOAT_IMAGE_FORMAT": "rgba16f", "INT_IMAGE_FORMAT": "rgba32i", "UINT_IMAGE_FORMAT": "rgba32ui", } -TYPES_ENV: Dict[str, Any] = { +TYPES_ENV: dict[str, Any] = { "IMAGE_FORMAT": { "float": "rgba32f", "half": "rgba16f", @@ -91,7 +93,7 @@ TYPES_ENV: Dict[str, Any] = { }, } -FUNCS_ENV: Dict[str, Any] = { +FUNCS_ENV: dict[str, Any] = { "GET_POS": { 3: lambda pos: pos, 2: lambda pos: f"{pos}.xy", @@ -169,7 +171,7 @@ def escape(line: str) -> str: # https://github.com/google/XNNPACK/blob/master/tools/xngen.py def preprocess( - input_text: str, variables: Dict[str, Any], input_path: str = "codegen" + input_text: str, variables: dict[str, Any], input_path: str = "codegen" ) -> str: input_lines = input_text.splitlines() python_lines = [] @@ -243,9 +245,9 @@ def preprocess( class SPVGenerator: def __init__( self, - src_dir_paths: Union[str, List[str]], - env: Dict[Any, Any], - glslc_path: Optional[str], + src_dir_paths: str | list[str], + env: dict[Any, Any], + glslc_path: str | None, ) -> None: if isinstance(src_dir_paths, str): self.src_dir_paths = [src_dir_paths] @@ -255,18 +257,18 @@ class SPVGenerator: self.env = env self.glslc_path = glslc_path - self.glsl_src_files: Dict[str, str] = {} - self.template_yaml_files: List[str] = [] + self.glsl_src_files: dict[str, str] = {} + self.template_yaml_files: list[str] = [] self.addSrcAndYamlFiles(self.src_dir_paths) - self.shader_template_params: Dict[Any, Any] = {} + self.shader_template_params: dict[Any, Any] = {} for yaml_file in self.template_yaml_files: self.parseTemplateYaml(yaml_file) - self.output_shader_map: Dict[str, Tuple[str, Dict[str, str]]] = {} + self.output_shader_map: dict[str, tuple[str, dict[str, str]]] = {} self.constructOutputMap() - def addSrcAndYamlFiles(self, src_dir_paths: List[str]) -> None: + def addSrcAndYamlFiles(self, src_dir_paths: list[str]) -> None: for src_path in src_dir_paths: # Collect glsl source files glsl_files = glob.glob( @@ -285,9 +287,9 @@ class SPVGenerator: def generateVariantCombinations( self, - iterated_params: Dict[str, Any], - exclude_params: Optional[Set[str]] = None, - ) -> List[Any]: + iterated_params: dict[str, Any], + exclude_params: set[str] | None = None, + ) -> list[Any]: if exclude_params is None: exclude_params = set() all_iterated_params = [] @@ -362,8 +364,8 @@ class SPVGenerator: ) def create_shader_params( - self, variant_params: Optional[Dict[str, Any]] = None - ) -> Dict[str, str]: + self, variant_params: dict[str, Any] | None = None + ) -> dict[str, str]: if variant_params is None: variant_params = {} shader_params = copy.deepcopy(self.env) @@ -409,7 +411,7 @@ class SPVGenerator: self.create_shader_params(), ) - def generateSPV(self, output_dir: str) -> Dict[str, str]: + def generateSPV(self, output_dir: str) -> dict[str, str]: output_file_map = {} for shader_name in self.output_shader_map: source_glsl = self.output_shader_map[shader_name][0] @@ -457,11 +459,11 @@ class SPVGenerator: @dataclass class ShaderInfo: - tile_size: List[int] - layouts: List[str] + tile_size: list[int] + layouts: list[str] weight_storage_type: str = "" bias_storage_type: str = "" - register_for: Optional[Tuple[str, List[str]]] = None + register_for: tuple[str, list[str]] | None = None def getName(filePath: str) -> str: @@ -478,7 +480,7 @@ def isTileSizeLine(lineStr: str) -> bool: return re.search(tile_size_id, lineStr) is not None -def findTileSizes(lineStr: str) -> List[int]: +def findTileSizes(lineStr: str) -> list[int]: tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)" matches = re.search(tile_size_id, lineStr) if matches is None: @@ -520,7 +522,7 @@ def isRegisterForLine(lineStr: str) -> bool: return re.search(register_for_id, lineStr) is not None -def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]: +def findRegisterFor(lineStr: str) -> tuple[str, list[str]]: register_for_pattern = r"'([A-Za-z0-9_]+)'" matches = re.findall(register_for_pattern, lineStr) if matches is None: @@ -609,7 +611,7 @@ static const api::ShaderRegisterInit register_shaders(®ister_fn); """ -def generateSpvBinStr(spvPath: str, name: str) -> Tuple[int, str]: +def generateSpvBinStr(spvPath: str, name: str) -> tuple[int, str]: with open(spvPath, "rb") as fr: next_bin = array.array("I", fr.read()) sizeBytes = 4 * len(next_bin) @@ -665,7 +667,7 @@ def generateShaderDispatchStr(shader_info: ShaderInfo, name: str) -> str: def genCppFiles( - spv_files: Dict[str, str], cpp_header_path: str, cpp_src_file_path: str + spv_files: dict[str, str], cpp_header_path: str, cpp_src_file_path: str ) -> None: spv_bin_strs = [] register_shader_info_strs = [] @@ -705,7 +707,7 @@ def genCppFiles( ########## -def parse_arg_env(items: Dict[Any, Any]) -> Dict[Any, Any]: +def parse_arg_env(items: dict[Any, Any]) -> dict[Any, Any]: d = {} if items: for item in items: @@ -716,7 +718,7 @@ def parse_arg_env(items: Dict[Any, Any]) -> Dict[Any, Any]: return d -def main(argv: List[str]) -> int: +def main(argv: list[str]) -> int: parser = argparse.ArgumentParser(description="") parser.add_argument( "-i", diff --git a/tools/generate_torch_version.py b/tools/generate_torch_version.py index 75ab4530e26f..8c5f950141e0 100644 --- a/tools/generate_torch_version.py +++ b/tools/generate_torch_version.py @@ -1,9 +1,10 @@ +from __future__ import annotations + import argparse import os import re import subprocess from pathlib import Path -from typing import Optional, Union from setuptools import distutils # type: ignore[import] @@ -12,7 +13,7 @@ UNKNOWN = "Unknown" RELEASE_PATTERN = re.compile(r"/v[0-9]+(\.[0-9]+)*(-rc[0-9]+)?/") -def get_sha(pytorch_root: Union[str, Path]) -> str: +def get_sha(pytorch_root: str | Path) -> str: try: rev = None if os.path.exists(os.path.join(pytorch_root, ".git")): @@ -30,7 +31,7 @@ def get_sha(pytorch_root: Union[str, Path]) -> str: return UNKNOWN -def get_tag(pytorch_root: Union[str, Path]) -> str: +def get_tag(pytorch_root: str | Path) -> str: try: tag = subprocess.run( ["git", "describe", "--tags", "--exact"], @@ -46,8 +47,8 @@ def get_tag(pytorch_root: Union[str, Path]) -> str: return UNKNOWN -def get_torch_version(sha: Optional[str] = None) -> str: - pytorch_root = Path(__file__).parent.parent +def get_torch_version(sha: str | None = None) -> str: + pytorch_root = Path(__file__).absolute().parent.parent version = open(pytorch_root / "version.txt").read().strip() if os.getenv("PYTORCH_BUILD_VERSION"): diff --git a/tools/github/github_utils.py b/tools/github/github_utils.py index 7424fa181abc..67a7a2e60cbf 100644 --- a/tools/github/github_utils.py +++ b/tools/github/github_utils.py @@ -1,10 +1,10 @@ """GitHub Utilities""" +from __future__ import annotations + import json import os - -from typing import Any, Callable, cast, Dict, Optional, Tuple - +from typing import Any, Callable, cast, Dict from urllib.error import HTTPError from urllib.parse import quote from urllib.request import Request, urlopen @@ -13,11 +13,11 @@ from urllib.request import Request, urlopen def gh_fetch_url_and_headers( url: str, *, - headers: Optional[Dict[str, str]] = None, - data: Optional[Dict[str, Any]] = None, - method: Optional[str] = None, + headers: dict[str, str] | None = None, + data: dict[str, Any] | None = None, + method: str | None = None, reader: Callable[[Any], Any] = lambda x: x.read(), -) -> Tuple[Any, Any]: +) -> tuple[Any, Any]: if headers is None: headers = {} token = os.environ.get("GITHUB_TOKEN") @@ -44,9 +44,9 @@ def gh_fetch_url_and_headers( def gh_fetch_url( url: str, *, - headers: Optional[Dict[str, str]] = None, - data: Optional[Dict[str, Any]] = None, - method: Optional[str] = None, + headers: dict[str, str] | None = None, + data: dict[str, Any] | None = None, + method: str | None = None, reader: Callable[[Any], Any] = lambda x: x.read(), ) -> Any: return gh_fetch_url_and_headers( @@ -56,8 +56,8 @@ def gh_fetch_url( def _gh_fetch_json_any( url: str, - params: Optional[Dict[str, Any]] = None, - data: Optional[Dict[str, Any]] = None, + params: dict[str, Any] | None = None, + data: dict[str, Any] | None = None, ) -> Any: headers = {"Accept": "application/vnd.github.v3+json"} if params is not None and len(params) > 0: @@ -69,13 +69,13 @@ def _gh_fetch_json_any( def gh_fetch_json_dict( url: str, - params: Optional[Dict[str, Any]] = None, - data: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: + params: dict[str, Any] | None = None, + data: dict[str, Any] | None = None, +) -> dict[str, Any]: return cast(Dict[str, Any], _gh_fetch_json_any(url, params, data)) -def gh_fetch_commit(org: str, repo: str, sha: str) -> Dict[str, Any]: +def gh_fetch_commit(org: str, repo: str, sha: str) -> dict[str, Any]: return gh_fetch_json_dict( f"https://api.github.com/repos/{org}/{repo}/commits/{sha}" ) diff --git a/tools/iwyu/fixup.py b/tools/iwyu/fixup.py index 2a585762273b..50d2cf1103c8 100644 --- a/tools/iwyu/fixup.py +++ b/tools/iwyu/fixup.py @@ -1,6 +1,7 @@ import re import sys + QUOTE_INCLUDE_RE = re.compile(r'^#include "(.*)"') ANGLE_INCLUDE_RE = re.compile(r"^#include <(.*)>") diff --git a/tools/jit/gen_unboxing.py b/tools/jit/gen_unboxing.py index ee4e2fc2ddb1..a33eb0c98617 100644 --- a/tools/jit/gen_unboxing.py +++ b/tools/jit/gen_unboxing.py @@ -1,10 +1,13 @@ # Generates RegisterCodegenUnboxedKernels.cpp, UnboxingFunctions.h and UnboxingFunctions.cpp. + +from __future__ import annotations + import argparse import os -import pathlib import sys from dataclasses import dataclass -from typing import List, Literal, Sequence, Union +from pathlib import Path +from typing import Literal, Sequence, TYPE_CHECKING import yaml @@ -15,10 +18,13 @@ from torchgen.api.unboxing import convert_arguments from torchgen.context import method_with_native_function from torchgen.gen import cpp_string, get_custom_build_selector, parse_native_yaml from torchgen.model import Argument, NativeFunction, NativeFunctionsGroup, Variant -from torchgen.selective_build.selector import SelectiveBuilder from torchgen.utils import FileManager, make_file_manager, mapMaybe, Target +if TYPE_CHECKING: + from torchgen.selective_build.selector import SelectiveBuilder + + # Generates UnboxingFunctions.h & UnboxingFunctions.cpp. @dataclass(frozen=True) class ComputeUnboxingFunctions: @@ -156,7 +162,7 @@ def gen_unboxing( cpu_fm: FileManager, selector: SelectiveBuilder, ) -> None: - def key_func(fn: Union[NativeFunction, NativeFunctionsGroup]) -> str: + def key_func(fn: NativeFunction | NativeFunctionsGroup) -> str: return fn.root_name selected_op_num: int = len(selector.operators) @@ -195,7 +201,7 @@ def gen_unboxing( ) -def main(args: List[str]) -> None: +def main(args: list[str]) -> None: parser = argparse.ArgumentParser(description="Generate unboxing source files") parser.add_argument( "-s", @@ -272,7 +278,7 @@ def main(args: List[str]) -> None: gen_unboxing(native_functions=native_functions, cpu_fm=cpu_fm, selector=selector) if options.output_dependencies: - depfile_path = pathlib.Path(options.output_dependencies).resolve() + depfile_path = Path(options.output_dependencies).resolve() depfile_name = depfile_path.name depfile_stem = depfile_path.stem diff --git a/tools/linter/adapters/actionlint_linter.py b/tools/linter/adapters/actionlint_linter.py index 685fae541573..bebb95c499f8 100644 --- a/tools/linter/adapters/actionlint_linter.py +++ b/tools/linter/adapters/actionlint_linter.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import concurrent.futures import json @@ -8,7 +10,7 @@ import subprocess import sys import time from enum import Enum -from typing import List, NamedTuple, Optional, Pattern +from typing import NamedTuple LINTER_CODE = "ACTIONLINT" @@ -22,18 +24,18 @@ class LintSeverity(str, Enum): class LintMessage(NamedTuple): - path: Optional[str] - line: Optional[int] - char: Optional[int] + path: str | None + line: int | None + char: int | None code: str severity: LintSeverity name: str - original: Optional[str] - replacement: Optional[str] - description: Optional[str] + original: str | None + replacement: str | None + description: str | None -RESULTS_RE: Pattern[str] = re.compile( +RESULTS_RE: re.Pattern[str] = re.compile( r"""(?mx) ^ (?P.*?): @@ -47,8 +49,8 @@ RESULTS_RE: Pattern[str] = re.compile( def run_command( - args: List[str], -) -> "subprocess.CompletedProcess[bytes]": + args: list[str], +) -> subprocess.CompletedProcess[bytes]: logging.debug("$ %s", " ".join(args)) start_time = time.monotonic() try: @@ -64,7 +66,7 @@ def run_command( def check_file( binary: str, file: str, -) -> List[LintMessage]: +) -> list[LintMessage]: try: proc = run_command( [ diff --git a/tools/linter/adapters/bazel_linter.py b/tools/linter/adapters/bazel_linter.py index f3c4a95c6e45..926628d3d76a 100644 --- a/tools/linter/adapters/bazel_linter.py +++ b/tools/linter/adapters/bazel_linter.py @@ -5,6 +5,9 @@ archive is downloaded from some sites like GitHub because it can change. Specifi GitHub gives no guarantee to keep the same value forever. Check for more details at https://github.com/community/community/discussions/46034. """ + +from __future__ import annotations + import argparse import json import re @@ -13,7 +16,7 @@ import subprocess import sys import xml.etree.ElementTree as ET from enum import Enum -from typing import List, NamedTuple, Optional, Set +from typing import NamedTuple from urllib.parse import urlparse @@ -30,18 +33,18 @@ class LintSeverity(str, Enum): class LintMessage(NamedTuple): - path: Optional[str] - line: Optional[int] - char: Optional[int] + path: str | None + line: int | None + char: int | None code: str severity: LintSeverity name: str - original: Optional[str] - replacement: Optional[str] - description: Optional[str] + original: str | None + replacement: str | None + description: str | None -def is_required_checksum(urls: List[Optional[str]]) -> bool: +def is_required_checksum(urls: list[str | None]) -> bool: if not urls: return False @@ -58,7 +61,7 @@ def is_required_checksum(urls: List[Optional[str]]) -> bool: def get_disallowed_checksums( binary: str, -) -> Set[str]: +) -> set[str]: """ Return the set of disallowed checksums from all http_archive rules """ @@ -96,8 +99,8 @@ def get_disallowed_checksums( def check_bazel( filename: str, - disallowed_checksums: Set[str], -) -> List[LintMessage]: + disallowed_checksums: set[str], +) -> list[LintMessage]: original = "" replacement = "" diff --git a/tools/linter/adapters/black_linter.py b/tools/linter/adapters/black_linter.py index 617bfb1d39cc..c5229273178c 100644 --- a/tools/linter/adapters/black_linter.py +++ b/tools/linter/adapters/black_linter.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import concurrent.futures import json @@ -7,7 +9,7 @@ import subprocess import sys import time from enum import Enum -from typing import Any, BinaryIO, List, NamedTuple, Optional +from typing import Any, BinaryIO, NamedTuple IS_WINDOWS: bool = os.name == "nt" @@ -25,15 +27,15 @@ class LintSeverity(str, Enum): class LintMessage(NamedTuple): - path: Optional[str] - line: Optional[int] - char: Optional[int] + path: str | None + line: int | None + char: int | None code: str severity: LintSeverity name: str - original: Optional[str] - replacement: Optional[str] - description: Optional[str] + original: str | None + replacement: str | None + description: str | None def as_posix(name: str) -> str: @@ -41,11 +43,11 @@ def as_posix(name: str) -> str: def _run_command( - args: List[str], + args: list[str], *, stdin: BinaryIO, timeout: int, -) -> "subprocess.CompletedProcess[bytes]": +) -> subprocess.CompletedProcess[bytes]: logging.debug("$ %s", " ".join(args)) start_time = time.monotonic() try: @@ -63,12 +65,12 @@ def _run_command( def run_command( - args: List[str], + args: list[str], *, stdin: BinaryIO, retries: int, timeout: int, -) -> "subprocess.CompletedProcess[bytes]": +) -> subprocess.CompletedProcess[bytes]: remaining_retries = retries while True: try: @@ -90,7 +92,7 @@ def check_file( filename: str, retries: int, timeout: int, -) -> List[LintMessage]: +) -> list[LintMessage]: try: with open(filename, "rb") as f: original = f.read() diff --git a/tools/linter/adapters/clangformat_linter.py b/tools/linter/adapters/clangformat_linter.py index f30275684406..c775e2f7e7e8 100644 --- a/tools/linter/adapters/clangformat_linter.py +++ b/tools/linter/adapters/clangformat_linter.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import concurrent.futures import json @@ -8,7 +10,7 @@ import sys import time from enum import Enum from pathlib import Path -from typing import Any, List, NamedTuple, Optional +from typing import Any, NamedTuple IS_WINDOWS: bool = os.name == "nt" @@ -26,15 +28,15 @@ class LintSeverity(str, Enum): class LintMessage(NamedTuple): - path: Optional[str] - line: Optional[int] - char: Optional[int] + path: str | None + line: int | None + char: int | None code: str severity: LintSeverity name: str - original: Optional[str] - replacement: Optional[str] - description: Optional[str] + original: str | None + replacement: str | None + description: str | None def as_posix(name: str) -> str: @@ -42,10 +44,10 @@ def as_posix(name: str) -> str: def _run_command( - args: List[str], + args: list[str], *, timeout: int, -) -> "subprocess.CompletedProcess[bytes]": +) -> subprocess.CompletedProcess[bytes]: logging.debug("$ %s", " ".join(args)) start_time = time.monotonic() try: @@ -62,11 +64,11 @@ def _run_command( def run_command( - args: List[str], + args: list[str], *, retries: int, timeout: int, -) -> "subprocess.CompletedProcess[bytes]": +) -> subprocess.CompletedProcess[bytes]: remaining_retries = retries while True: try: @@ -89,7 +91,7 @@ def check_file( binary: str, retries: int, timeout: int, -) -> List[LintMessage]: +) -> list[LintMessage]: try: with open(filename, "rb") as f: original = f.read() diff --git a/tools/linter/adapters/clangtidy_linter.py b/tools/linter/adapters/clangtidy_linter.py index 7fe7a87e98da..0859f6e59d47 100644 --- a/tools/linter/adapters/clangtidy_linter.py +++ b/tools/linter/adapters/clangtidy_linter.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import concurrent.futures import json @@ -11,7 +13,7 @@ import time from enum import Enum from pathlib import Path from sysconfig import get_paths as gp -from typing import Any, List, NamedTuple, Optional, Pattern +from typing import Any, NamedTuple # PyTorch directory root @@ -49,15 +51,15 @@ class LintSeverity(str, Enum): class LintMessage(NamedTuple): - path: Optional[str] - line: Optional[int] - char: Optional[int] + path: str | None + line: int | None + char: int | None code: str severity: LintSeverity name: str - original: Optional[str] - replacement: Optional[str] - description: Optional[str] + original: str | None + replacement: str | None + description: str | None def as_posix(name: str) -> str: @@ -65,7 +67,7 @@ def as_posix(name: str) -> str: # c10/core/DispatchKey.cpp:281:26: error: 'k' used after it was moved [bugprone-use-after-move] -RESULTS_RE: Pattern[str] = re.compile( +RESULTS_RE: re.Pattern[str] = re.compile( r"""(?mx) ^ (?P.*?): @@ -80,8 +82,8 @@ RESULTS_RE: Pattern[str] = re.compile( def run_command( - args: List[str], -) -> "subprocess.CompletedProcess[bytes]": + args: list[str], +) -> subprocess.CompletedProcess[bytes]: logging.debug("$ %s", " ".join(args)) start_time = time.monotonic() try: @@ -103,7 +105,7 @@ severities = { } -def clang_search_dirs() -> List[str]: +def clang_search_dirs() -> list[str]: # Compilers are ordered based on fallback preference # We pick the first one that is available on the system compilers = ["clang", "gcc", "cpp", "cc"] @@ -152,7 +154,7 @@ def check_file( filename: str, binary: str, build_dir: Path, -) -> List[LintMessage]: +) -> list[LintMessage]: try: proc = run_command( [binary, f"-p={build_dir}", *include_args, filename], diff --git a/tools/linter/adapters/cmake_linter.py b/tools/linter/adapters/cmake_linter.py index c5de15352c27..407f62403115 100644 --- a/tools/linter/adapters/cmake_linter.py +++ b/tools/linter/adapters/cmake_linter.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import concurrent.futures import json @@ -7,7 +9,7 @@ import re import subprocess import time from enum import Enum -from typing import List, NamedTuple, Optional, Pattern +from typing import NamedTuple LINTER_CODE = "CMAKE" @@ -21,19 +23,19 @@ class LintSeverity(str, Enum): class LintMessage(NamedTuple): - path: Optional[str] - line: Optional[int] - char: Optional[int] + path: str | None + line: int | None + char: int | None code: str severity: LintSeverity name: str - original: Optional[str] - replacement: Optional[str] - description: Optional[str] + original: str | None + replacement: str | None + description: str | None # CMakeLists.txt:901: Lines should be <= 80 characters long [linelength] -RESULTS_RE: Pattern[str] = re.compile( +RESULTS_RE: re.Pattern[str] = re.compile( r"""(?mx) ^ (?P.*?): @@ -46,8 +48,8 @@ RESULTS_RE: Pattern[str] = re.compile( def run_command( - args: List[str], -) -> "subprocess.CompletedProcess[bytes]": + args: list[str], +) -> subprocess.CompletedProcess[bytes]: logging.debug("$ %s", " ".join(args)) start_time = time.monotonic() try: @@ -63,7 +65,7 @@ def run_command( def check_file( filename: str, config: str, -) -> List[LintMessage]: +) -> list[LintMessage]: try: proc = run_command( ["cmakelint", f"--config={config}", filename], diff --git a/tools/linter/adapters/constexpr_linter.py b/tools/linter/adapters/constexpr_linter.py index 24ecc83b238e..adb7fe001749 100644 --- a/tools/linter/adapters/constexpr_linter.py +++ b/tools/linter/adapters/constexpr_linter.py @@ -2,13 +2,15 @@ CONSTEXPR: Ensures users don't use vanilla constexpr since it causes issues """ +from __future__ import annotations + import argparse import json import logging import sys - from enum import Enum -from typing import NamedTuple, Optional +from typing import NamedTuple + CONSTEXPR = "constexpr char" CONSTEXPR_MACRO = "CONSTEXPR_EXCEPT_WIN_CUDA char" @@ -21,18 +23,18 @@ class LintSeverity(str, Enum): class LintMessage(NamedTuple): - path: Optional[str] - line: Optional[int] - char: Optional[int] + path: str | None + line: int | None + char: int | None code: str severity: LintSeverity name: str - original: Optional[str] - replacement: Optional[str] - description: Optional[str] + original: str | None + replacement: str | None + description: str | None -def check_file(filename: str) -> Optional[LintMessage]: +def check_file(filename: str) -> LintMessage | None: logging.debug("Checking file %s", filename) with open(filename) as f: diff --git a/tools/linter/adapters/exec_linter.py b/tools/linter/adapters/exec_linter.py index f00dc60afbb2..5a8017bc2584 100644 --- a/tools/linter/adapters/exec_linter.py +++ b/tools/linter/adapters/exec_linter.py @@ -1,14 +1,17 @@ """ EXEC: Ensure that source files are not executable. """ + +from __future__ import annotations + import argparse import json import logging import os import sys - from enum import Enum -from typing import NamedTuple, Optional +from typing import NamedTuple + LINTER_CODE = "EXEC" @@ -21,18 +24,18 @@ class LintSeverity(str, Enum): class LintMessage(NamedTuple): - path: Optional[str] - line: Optional[int] - char: Optional[int] + path: str | None + line: int | None + char: int | None code: str severity: LintSeverity name: str - original: Optional[str] - replacement: Optional[str] - description: Optional[str] + original: str | None + replacement: str | None + description: str | None -def check_file(filename: str) -> Optional[LintMessage]: +def check_file(filename: str) -> LintMessage | None: is_executable = os.access(filename, os.X_OK) if is_executable: return LintMessage( diff --git a/tools/linter/adapters/flake8_linter.py b/tools/linter/adapters/flake8_linter.py index 20c9c7cea316..0a401e69fc8f 100644 --- a/tools/linter/adapters/flake8_linter.py +++ b/tools/linter/adapters/flake8_linter.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import json import logging @@ -7,7 +9,7 @@ import subprocess import sys import time from enum import Enum -from typing import Any, Dict, List, NamedTuple, Optional, Pattern, Set +from typing import Any, NamedTuple IS_WINDOWS: bool = os.name == "nt" @@ -25,15 +27,15 @@ class LintSeverity(str, Enum): class LintMessage(NamedTuple): - path: Optional[str] - line: Optional[int] - char: Optional[int] + path: str | None + line: int | None + char: int | None code: str severity: LintSeverity name: str - original: Optional[str] - replacement: Optional[str] - description: Optional[str] + original: str | None + replacement: str | None + description: str | None def as_posix(name: str) -> str: @@ -42,7 +44,7 @@ def as_posix(name: str) -> str: # fmt: off # https://www.flake8rules.com/ -DOCUMENTED_IN_FLAKE8RULES: Set[str] = { +DOCUMENTED_IN_FLAKE8RULES: set[str] = { "E101", "E111", "E112", "E113", "E114", "E115", "E116", "E117", "E121", "E122", "E123", "E124", "E125", "E126", "E127", "E128", "E129", "E131", "E133", @@ -78,14 +80,14 @@ DOCUMENTED_IN_FLAKE8RULES: Set[str] = { } # https://pypi.org/project/flake8-comprehensions/#rules -DOCUMENTED_IN_FLAKE8COMPREHENSIONS: Set[str] = { +DOCUMENTED_IN_FLAKE8COMPREHENSIONS: set[str] = { "C400", "C401", "C402", "C403", "C404", "C405", "C406", "C407", "C408", "C409", "C410", "C411", "C412", "C413", "C414", "C415", "C416", } # https://github.com/PyCQA/flake8-bugbear#list-of-warnings -DOCUMENTED_IN_BUGBEAR: Set[str] = { +DOCUMENTED_IN_BUGBEAR: set[str] = { "B001", "B002", "B003", "B004", "B005", "B006", "B007", "B008", "B009", "B010", "B011", "B012", "B013", "B014", "B015", "B301", "B302", "B303", "B304", "B305", "B306", @@ -98,7 +100,7 @@ DOCUMENTED_IN_BUGBEAR: Set[str] = { # stdin:3:6: T484 Name 'foo' is not defined # stdin:3:-100: W605 invalid escape sequence '\/' # stdin:3:1: E302 expected 2 blank lines, found 1 -RESULTS_RE: Pattern[str] = re.compile( +RESULTS_RE: re.Pattern[str] = re.compile( r"""(?mx) ^ (?P.*?): @@ -134,10 +136,10 @@ def _test_results_re() -> None: def _run_command( - args: List[str], + args: list[str], *, - extra_env: Optional[Dict[str, str]], -) -> "subprocess.CompletedProcess[str]": + extra_env: dict[str, str] | None, +) -> subprocess.CompletedProcess[str]: logging.debug( "$ %s", " ".join( @@ -158,11 +160,11 @@ def _run_command( def run_command( - args: List[str], + args: list[str], *, - extra_env: Optional[Dict[str, str]], + extra_env: dict[str, str] | None, retries: int, -) -> "subprocess.CompletedProcess[str]": +) -> subprocess.CompletedProcess[str]: remaining_retries = retries while True: try: @@ -243,11 +245,11 @@ def get_issue_documentation_url(code: str) -> str: def check_files( - filenames: List[str], - flake8_plugins_path: Optional[str], - severities: Dict[str, LintSeverity], + filenames: list[str], + flake8_plugins_path: str | None, + severities: dict[str, LintSeverity], retries: int, -) -> List[LintMessage]: +) -> list[LintMessage]: try: proc = run_command( [sys.executable, "-mflake8", "--exit-zero"] + filenames, @@ -351,7 +353,7 @@ def main() -> None: else os.path.realpath(args.flake8_plugins_path) ) - severities: Dict[str, LintSeverity] = {} + severities: dict[str, LintSeverity] = {} if args.severity: for severity in args.severity: parts = severity.split(":", 1) diff --git a/tools/linter/adapters/grep_linter.py b/tools/linter/adapters/grep_linter.py index 168800eb447e..e20e7f5f8a2d 100644 --- a/tools/linter/adapters/grep_linter.py +++ b/tools/linter/adapters/grep_linter.py @@ -2,6 +2,8 @@ Generic linter that greps for a pattern and optionally suggests replacements. """ +from __future__ import annotations + import argparse import json import logging @@ -10,7 +12,7 @@ import subprocess import sys import time from enum import Enum -from typing import Any, List, NamedTuple, Optional +from typing import Any, NamedTuple IS_WINDOWS: bool = os.name == "nt" @@ -28,15 +30,15 @@ class LintSeverity(str, Enum): class LintMessage(NamedTuple): - path: Optional[str] - line: Optional[int] - char: Optional[int] + path: str | None + line: int | None + char: int | None code: str severity: LintSeverity name: str - original: Optional[str] - replacement: Optional[str] - description: Optional[str] + original: str | None + replacement: str | None + description: str | None def as_posix(name: str) -> str: @@ -44,8 +46,8 @@ def as_posix(name: str) -> str: def run_command( - args: List[str], -) -> "subprocess.CompletedProcess[bytes]": + args: list[str], +) -> subprocess.CompletedProcess[bytes]: logging.debug("$ %s", " ".join(args)) start_time = time.monotonic() try: @@ -65,7 +67,7 @@ def lint_file( linter_name: str, error_name: str, error_description: str, -) -> Optional[LintMessage]: +) -> LintMessage | None: # matching_line looks like: # tools/linter/clangtidy_linter.py:13:import foo.bar.baz split = matching_line.split(":") diff --git a/tools/linter/adapters/lintrunner_version_linter.py b/tools/linter/adapters/lintrunner_version_linter.py index 48eab1a39a8c..366adc3a0496 100644 --- a/tools/linter/adapters/lintrunner_version_linter.py +++ b/tools/linter/adapters/lintrunner_version_linter.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import json import subprocess import sys from enum import Enum -from typing import NamedTuple, Optional, Tuple +from typing import NamedTuple LINTER_CODE = "LINTRUNNER_VERSION" @@ -16,18 +18,18 @@ class LintSeverity(str, Enum): class LintMessage(NamedTuple): - path: Optional[str] - line: Optional[int] - char: Optional[int] + path: str | None + line: int | None + char: int | None code: str severity: LintSeverity name: str - original: Optional[str] - replacement: Optional[str] - description: Optional[str] + original: str | None + replacement: str | None + description: str | None -def toVersionString(version_tuple: Tuple[int, int, int]) -> str: +def toVersionString(version_tuple: tuple[int, int, int]) -> str: return ".".join(str(x) for x in version_tuple) diff --git a/tools/linter/adapters/mypy_linter.py b/tools/linter/adapters/mypy_linter.py index 9c59563f2050..cc5367b19211 100644 --- a/tools/linter/adapters/mypy_linter.py +++ b/tools/linter/adapters/mypy_linter.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import json import logging @@ -8,7 +10,7 @@ import sys import time from enum import Enum from pathlib import Path -from typing import Any, Dict, List, NamedTuple, Optional, Pattern +from typing import Any, NamedTuple IS_WINDOWS: bool = os.name == "nt" @@ -26,15 +28,15 @@ class LintSeverity(str, Enum): class LintMessage(NamedTuple): - path: Optional[str] - line: Optional[int] - char: Optional[int] + path: str | None + line: int | None + char: int | None code: str severity: LintSeverity name: str - original: Optional[str] - replacement: Optional[str] - description: Optional[str] + original: str | None + replacement: str | None + description: str | None def as_posix(name: str) -> str: @@ -42,7 +44,7 @@ def as_posix(name: str) -> str: # tools/linter/flake8_linter.py:15:13: error: Incompatibl...int") [assignment] -RESULTS_RE: Pattern[str] = re.compile( +RESULTS_RE: re.Pattern[str] = re.compile( r"""(?mx) ^ (?P.*?): @@ -56,7 +58,7 @@ RESULTS_RE: Pattern[str] = re.compile( ) # torch/_dynamo/variables/tensor.py:363: error: INTERNAL ERROR -INTERNAL_ERROR_RE: Pattern[str] = re.compile( +INTERNAL_ERROR_RE: re.Pattern[str] = re.compile( r"""(?mx) ^ (?P.*?): @@ -69,11 +71,11 @@ INTERNAL_ERROR_RE: Pattern[str] = re.compile( def run_command( - args: List[str], + args: list[str], *, - extra_env: Optional[Dict[str, str]], + extra_env: dict[str, str] | None, retries: int, -) -> "subprocess.CompletedProcess[bytes]": +) -> subprocess.CompletedProcess[bytes]: logging.debug("$ %s", " ".join(args)) start_time = time.monotonic() try: @@ -94,7 +96,7 @@ severities = { } -def check_mypy_installed(code: str) -> List[LintMessage]: +def check_mypy_installed(code: str) -> list[LintMessage]: cmd = [sys.executable, "-mmypy", "-V"] try: subprocess.run(cmd, check=True, capture_output=True) @@ -117,11 +119,11 @@ def check_mypy_installed(code: str) -> List[LintMessage]: def check_files( - filenames: List[str], + filenames: list[str], config: str, retries: int, code: str, -) -> List[LintMessage]: +) -> list[LintMessage]: # dmypy has a bug where it won't pick up changes if you pass it absolute # file names, see https://github.com/python/mypy/issues/16768 filenames = [os.path.relpath(f) for f in filenames] @@ -224,7 +226,7 @@ def main() -> None: # Use a dictionary here to preserve order. mypy cares about order, # tragically, e.g. https://github.com/python/mypy/issues/2015 - filenames: Dict[str, bool] = {} + filenames: dict[str, bool] = {} # If a stub file exists, have mypy check it instead of the original file, in # accordance with PEP-484 (see https://www.python.org/dev/peps/pep-0484/#stub-files) diff --git a/tools/linter/adapters/nativefunctions_linter.py b/tools/linter/adapters/nativefunctions_linter.py index 12a6c7e0062d..68e81f730543 100644 --- a/tools/linter/adapters/nativefunctions_linter.py +++ b/tools/linter/adapters/nativefunctions_linter.py @@ -14,12 +14,14 @@ is simply to make sure that there is *some* configuration of ruamel that can rou the YAML, not to be prescriptive about it. """ +from __future__ import annotations + import argparse import json import sys from enum import Enum from io import StringIO -from typing import NamedTuple, Optional +from typing import NamedTuple import ruamel.yaml # type: ignore[import] @@ -32,15 +34,15 @@ class LintSeverity(str, Enum): class LintMessage(NamedTuple): - path: Optional[str] - line: Optional[int] - char: Optional[int] + path: str | None + line: int | None + char: int | None code: str severity: LintSeverity name: str - original: Optional[str] - replacement: Optional[str] - description: Optional[str] + original: str | None + replacement: str | None + description: str | None if __name__ == "__main__": diff --git a/tools/linter/adapters/newlines_linter.py b/tools/linter/adapters/newlines_linter.py index a2cb1c5ccdc9..9af1d895699b 100644 --- a/tools/linter/adapters/newlines_linter.py +++ b/tools/linter/adapters/newlines_linter.py @@ -1,13 +1,16 @@ """ NEWLINE: Checks files to make sure there are no trailing newlines. """ + +from __future__ import annotations + import argparse import json import logging import sys - from enum import Enum -from typing import List, NamedTuple, Optional +from typing import NamedTuple + NEWLINE = 10 # ASCII "\n" CARRIAGE_RETURN = 13 # ASCII "\r" @@ -22,18 +25,18 @@ class LintSeverity(str, Enum): class LintMessage(NamedTuple): - path: Optional[str] - line: Optional[int] - char: Optional[int] + path: str | None + line: int | None + char: int | None code: str severity: LintSeverity name: str - original: Optional[str] - replacement: Optional[str] - description: Optional[str] + original: str | None + replacement: str | None + description: str | None -def check_file(filename: str) -> Optional[LintMessage]: +def check_file(filename: str) -> LintMessage | None: logging.debug("Checking file %s", filename) with open(filename, "rb") as f: @@ -85,7 +88,7 @@ def check_file(filename: str) -> Optional[LintMessage]: description="Trailing newline found. Run `lintrunner --take NEWLINE -a` to apply changes.", ) has_changes = False - original_lines: Optional[List[bytes]] = None + original_lines: list[bytes] | None = None for idx, line in enumerate(lines): if len(line) >= 2 and line[-1] == NEWLINE and line[-2] == CARRIAGE_RETURN: if not has_changes: diff --git a/tools/linter/adapters/no_merge_conflict_csv_linter.py b/tools/linter/adapters/no_merge_conflict_csv_linter.py index 689762a481a9..4b14e03c0496 100644 --- a/tools/linter/adapters/no_merge_conflict_csv_linter.py +++ b/tools/linter/adapters/no_merge_conflict_csv_linter.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import concurrent.futures import json @@ -5,7 +7,7 @@ import logging import os import sys from enum import Enum -from typing import Any, List, NamedTuple, Optional +from typing import Any, NamedTuple IS_WINDOWS: bool = os.name == "nt" @@ -23,18 +25,18 @@ class LintSeverity(str, Enum): class LintMessage(NamedTuple): - path: Optional[str] - line: Optional[int] - char: Optional[int] + path: str | None + line: int | None + char: int | None code: str severity: LintSeverity name: str - original: Optional[str] - replacement: Optional[str] - description: Optional[str] + original: str | None + replacement: str | None + description: str | None -def check_file(filename: str) -> List[LintMessage]: +def check_file(filename: str) -> list[LintMessage]: with open(filename, "rb") as f: original = f.read().decode("utf-8") replacement = "" diff --git a/tools/linter/adapters/pip_init.py b/tools/linter/adapters/pip_init.py index 0c1551f5e301..85bcd200a4bf 100644 --- a/tools/linter/adapters/pip_init.py +++ b/tools/linter/adapters/pip_init.py @@ -1,6 +1,9 @@ """ Initializer script that installs stuff to pip. """ + +from __future__ import annotations + import argparse import logging import os @@ -9,10 +12,8 @@ import subprocess import sys import time -from typing import List - -def run_command(args: List[str]) -> "subprocess.CompletedProcess[bytes]": +def run_command(args: list[str]) -> subprocess.CompletedProcess[bytes]: logging.debug("$ %s", " ".join(args)) start_time = time.monotonic() try: diff --git a/tools/linter/adapters/ruff_linter.py b/tools/linter/adapters/ruff_linter.py index 908947e09fc8..6dcb9fc3fa17 100644 --- a/tools/linter/adapters/ruff_linter.py +++ b/tools/linter/adapters/ruff_linter.py @@ -14,6 +14,7 @@ import sys import time from typing import Any, BinaryIO + LINTER_CODE = "RUFF" IS_WINDOWS: bool = os.name == "nt" diff --git a/tools/linter/adapters/shellcheck_linter.py b/tools/linter/adapters/shellcheck_linter.py index 5d8d7a8052e8..ada235ba553e 100644 --- a/tools/linter/adapters/shellcheck_linter.py +++ b/tools/linter/adapters/shellcheck_linter.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import json import logging @@ -6,7 +8,7 @@ import subprocess import sys import time from enum import Enum -from typing import List, NamedTuple, Optional +from typing import NamedTuple LINTER_CODE = "SHELLCHECK" @@ -20,20 +22,20 @@ class LintSeverity(str, Enum): class LintMessage(NamedTuple): - path: Optional[str] - line: Optional[int] - char: Optional[int] + path: str | None + line: int | None + char: int | None code: str severity: LintSeverity name: str - original: Optional[str] - replacement: Optional[str] - description: Optional[str] + original: str | None + replacement: str | None + description: str | None def run_command( - args: List[str], -) -> "subprocess.CompletedProcess[bytes]": + args: list[str], +) -> subprocess.CompletedProcess[bytes]: logging.debug("$ %s", " ".join(args)) start_time = time.monotonic() try: @@ -47,8 +49,8 @@ def run_command( def check_files( - files: List[str], -) -> List[LintMessage]: + files: list[str], +) -> list[LintMessage]: try: proc = run_command( ["shellcheck", "--external-sources", "--format=json1"] + files diff --git a/tools/linter/adapters/test_has_main_linter.py b/tools/linter/adapters/test_has_main_linter.py index 1cd5573726b5..e648a96e0df5 100644 --- a/tools/linter/adapters/test_has_main_linter.py +++ b/tools/linter/adapters/test_has_main_linter.py @@ -6,15 +6,19 @@ calls run_tests to ensure that the test will be run in OSS CI. Takes ~2 minuters to run without the multiprocessing, probably overkill. """ + +from __future__ import annotations + import argparse import json import multiprocessing as mp from enum import Enum -from typing import List, NamedTuple, Optional +from typing import NamedTuple import libcst as cst import libcst.matchers as m + LINTER_CODE = "TEST_HAS_MAIN" @@ -62,18 +66,18 @@ class LintSeverity(str, Enum): class LintMessage(NamedTuple): - path: Optional[str] - line: Optional[int] - char: Optional[int] + path: str | None + line: int | None + char: int | None code: str severity: LintSeverity name: str - original: Optional[str] - replacement: Optional[str] - description: Optional[str] + original: str | None + replacement: str | None + description: str | None -def check_file(filename: str) -> List[LintMessage]: +def check_file(filename: str) -> list[LintMessage]: lint_messages = [] with open(filename) as f: diff --git a/tools/linter/adapters/testowners_linter.py b/tools/linter/adapters/testowners_linter.py index b9e626743e58..b4c35b8ad91b 100755 --- a/tools/linter/adapters/testowners_linter.py +++ b/tools/linter/adapters/testowners_linter.py @@ -8,10 +8,13 @@ has valid ownership information in a comment header. Valid means: - Each owner label actually exists in PyTorch - Each owner label starts with "module: " or "oncall: " or is in ACCEPTABLE_OWNER_LABELS """ + +from __future__ import annotations + import argparse import json from enum import Enum -from typing import Any, List, NamedTuple, Optional +from typing import Any, NamedTuple from urllib.request import urlopen @@ -26,15 +29,15 @@ class LintSeverity(str, Enum): class LintMessage(NamedTuple): - path: Optional[str] - line: Optional[int] - char: Optional[int] + path: str | None + line: int | None + char: int | None code: str severity: LintSeverity name: str - original: Optional[str] - replacement: Optional[str] - description: Optional[str] + original: str | None + replacement: str | None + description: str | None # Team/owner labels usually start with "module: " or "oncall: ", but the following are acceptable exceptions @@ -58,8 +61,8 @@ GLOB_EXCEPTIONS = ["**/test/run_test.py"] def check_labels( - labels: List[str], filename: str, line_number: int -) -> List[LintMessage]: + labels: list[str], filename: str, line_number: int +) -> list[LintMessage]: lint_messages = [] for label in labels: if label not in PYTORCH_LABELS: @@ -104,7 +107,7 @@ def check_labels( return lint_messages -def check_file(filename: str) -> List[LintMessage]: +def check_file(filename: str) -> list[LintMessage]: lint_messages = [] has_ownership_info = False diff --git a/tools/linter/adapters/ufmt_linter.py b/tools/linter/adapters/ufmt_linter.py index c31dc109f5e9..7adbba7bea5d 100644 --- a/tools/linter/adapters/ufmt_linter.py +++ b/tools/linter/adapters/ufmt_linter.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import concurrent.futures import json @@ -6,7 +8,7 @@ import os import sys from enum import Enum from pathlib import Path -from typing import Any, List, NamedTuple, Optional +from typing import Any, NamedTuple from ufmt.core import ufmt_string from ufmt.util import make_black_config @@ -28,15 +30,15 @@ class LintSeverity(str, Enum): class LintMessage(NamedTuple): - path: Optional[str] - line: Optional[int] - char: Optional[int] + path: str | None + line: int | None + char: int | None code: str severity: LintSeverity name: str - original: Optional[str] - replacement: Optional[str] - description: Optional[str] + original: str | None + replacement: str | None + description: str | None def as_posix(name: str) -> str: @@ -59,7 +61,7 @@ def format_error_message(filename: str, err: Exception) -> LintMessage: def check_file( filename: str, -) -> List[LintMessage]: +) -> list[LintMessage]: with open(filename, "rb") as f: original = f.read().decode("utf-8") diff --git a/tools/linter/adapters/workflow_consistency_linter.py b/tools/linter/adapters/workflow_consistency_linter.py index d18a19a9f825..562b85a93360 100644 --- a/tools/linter/adapters/workflow_consistency_linter.py +++ b/tools/linter/adapters/workflow_consistency_linter.py @@ -2,16 +2,20 @@ Any job with a specific `sync-tag` must match all other jobs with the same `sync-tag`. """ + +from __future__ import annotations + import argparse import itertools import json from collections import defaultdict from enum import Enum from pathlib import Path -from typing import Any, Dict, Iterable, NamedTuple, Optional +from typing import Any, Iterable, NamedTuple from yaml import dump, load + # Safely load fast C Yaml loader/dumper if they are available try: from yaml import CSafeLoader as Loader @@ -27,15 +31,15 @@ class LintSeverity(str, Enum): class LintMessage(NamedTuple): - path: Optional[str] - line: Optional[int] - char: Optional[int] + path: str | None + line: int | None + char: int | None code: str severity: LintSeverity name: str - original: Optional[str] - replacement: Optional[str] - description: Optional[str] + original: str | None + replacement: str | None + description: str | None def glob_yamls(path: Path) -> Iterable[Path]: @@ -51,7 +55,7 @@ def is_workflow(yaml: Any) -> bool: return yaml.get("jobs") is not None -def print_lint_message(path: Path, job: Dict[str, Any], sync_tag: str) -> None: +def print_lint_message(path: Path, job: dict[str, Any], sync_tag: str) -> None: job_id = next(iter(job.keys())) with open(path) as f: lines = f.readlines() diff --git a/tools/linter/clang_tidy/generate_build_files.py b/tools/linter/clang_tidy/generate_build_files.py index 3d0473abda48..af322e754b87 100644 --- a/tools/linter/clang_tidy/generate_build_files.py +++ b/tools/linter/clang_tidy/generate_build_files.py @@ -1,10 +1,11 @@ +from __future__ import annotations + import os import subprocess import sys -from typing import List -def run_cmd(cmd: List[str]) -> None: +def run_cmd(cmd: list[str]) -> None: print(f"Running: {cmd}") result = subprocess.run( cmd, diff --git a/tools/lite_interpreter/gen_selected_mobile_ops_header.py b/tools/lite_interpreter/gen_selected_mobile_ops_header.py index 1aaccad3b2b4..09f0f4e80bba 100644 --- a/tools/lite_interpreter/gen_selected_mobile_ops_header.py +++ b/tools/lite_interpreter/gen_selected_mobile_ops_header.py @@ -1,13 +1,16 @@ #!/usr/bin/env python3 + +from __future__ import annotations + import argparse import os -from typing import Set import yaml from torchgen.code_template import CodeTemplate from torchgen.selective_build.selector import SelectiveBuilder + # Safely load fast C Yaml loader/dumper if they are available try: from yaml import CSafeLoader as Loader @@ -46,7 +49,7 @@ selected_mobile_ops_preamble = """#pragma once """ -def extract_root_operators(selective_builder: SelectiveBuilder) -> Set[str]: +def extract_root_operators(selective_builder: SelectiveBuilder) -> set[str]: ops = [] for op_name, op in selective_builder.operators.items(): if op.is_root_operator: @@ -125,7 +128,7 @@ def write_selected_mobile_ops( # 2. All kernel dtypes def write_selected_mobile_ops_with_all_dtypes( output_file_path: str, - root_ops: Set[str], + root_ops: set[str], ) -> None: with open(output_file_path, "wb") as out_file: body_parts = [selected_mobile_ops_preamble] diff --git a/tools/lldb/deploy_debugger.py b/tools/lldb/deploy_debugger.py index 5a1395898b5c..135a6167e3a4 100644 --- a/tools/lldb/deploy_debugger.py +++ b/tools/lldb/deploy_debugger.py @@ -1,5 +1,6 @@ import lldb # type: ignore[import] + # load into lldb instance with: # command script import tools/lldb/deploy_debugger.py diff --git a/tools/nightly.py b/tools/nightly.py index 983e69150b56..265a286c5d94 100755 --- a/tools/nightly.py +++ b/tools/nightly.py @@ -24,6 +24,9 @@ well. This can be done with Pulling will reinstalle the conda dependencies as well as the nightly binaries into the repo directory. """ + +from __future__ import annotations + import contextlib import datetime import functools @@ -40,23 +43,10 @@ import time import uuid from argparse import ArgumentParser from ast import literal_eval -from typing import ( - Any, - Callable, - cast, - Dict, - Generator, - Iterable, - Iterator, - List, - Optional, - Sequence, - Set, - Tuple, - TypeVar, -) +from typing import Any, Callable, cast, Generator, Iterable, Iterator, Sequence, TypeVar -LOGGER: Optional[logging.Logger] = None + +LOGGER: logging.Logger | None = None URL_FORMAT = "{base_url}/{platform}/{dist_name}.tar.bz2" DATETIME_FORMAT = "%Y-%m-%d_%Hh%Mm%Ss" SHA1_RE = re.compile("([0-9a-fA-F]{40})") @@ -68,9 +58,9 @@ SPECS_TO_INSTALL = ("pytorch", "mypy", "pytest", "hypothesis", "ipython", "sphin class Formatter(logging.Formatter): - redactions: Dict[str, str] + redactions: dict[str, str] - def __init__(self, fmt: Optional[str] = None, datefmt: Optional[str] = None): + def __init__(self, fmt: str | None = None, datefmt: str | None = None) -> None: super().__init__(fmt, datefmt) self.redactions = {} @@ -192,7 +182,7 @@ def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, N sys.exit(1) -def check_in_repo() -> Optional[str]: +def check_in_repo() -> str | None: """Ensures that we are in the PyTorch repo.""" if not os.path.isfile("setup.py"): return "Not in root-level PyTorch repo, no setup.py found" @@ -203,7 +193,7 @@ def check_in_repo() -> Optional[str]: return None -def check_branch(subcommand: str, branch: Optional[str]) -> Optional[str]: +def check_branch(subcommand: str, branch: str | None) -> str | None: """Checks that the branch name can be checked out.""" if subcommand != "checkout": return None @@ -259,7 +249,7 @@ def timed(prefix: str) -> Callable[[F], F]: def _make_channel_args( channels: Iterable[str] = ("pytorch-nightly",), override_channels: bool = False, -) -> List[str]: +) -> list[str]: args = [] for channel in channels: args.append("--channel") @@ -271,11 +261,11 @@ def _make_channel_args( @timed("Solving conda environment") def conda_solve( - name: Optional[str] = None, - prefix: Optional[str] = None, + name: str | None = None, + prefix: str | None = None, channels: Iterable[str] = ("pytorch-nightly",), override_channels: bool = False, -) -> Tuple[List[str], str, str, bool, List[str]]: +) -> tuple[list[str], str, str, bool, list[str]]: """Performs the conda solve and splits the deps from the package.""" # compute what environment to use if prefix is not None: @@ -329,7 +319,7 @@ def conda_solve( @timed("Installing dependencies") -def deps_install(deps: List[str], existing_env: bool, env_opts: List[str]) -> None: +def deps_install(deps: list[str], existing_env: bool, env_opts: list[str]) -> None: """Install dependencies to deps environment""" if not existing_env: # first remove previous pytorch-deps env @@ -342,7 +332,7 @@ def deps_install(deps: List[str], existing_env: bool, env_opts: List[str]) -> No @timed("Installing pytorch nightly binaries") -def pytorch_install(url: str) -> "tempfile.TemporaryDirectory[str]": +def pytorch_install(url: str) -> tempfile.TemporaryDirectory[str]: """Install pytorch into a temporary directory""" pytdir = tempfile.TemporaryDirectory() cmd = ["conda", "create", "--yes", "--no-deps", "--prefix", pytdir.name, url] @@ -421,33 +411,33 @@ def pull_nightly_version(spdir: str) -> None: p = subprocess.run(cmd, check=True) -def _get_listing_linux(source_dir: str) -> List[str]: +def _get_listing_linux(source_dir: str) -> list[str]: listing = glob.glob(os.path.join(source_dir, "*.so")) listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.so"))) return listing -def _get_listing_osx(source_dir: str) -> List[str]: +def _get_listing_osx(source_dir: str) -> list[str]: # oddly, these are .so files even on Mac listing = glob.glob(os.path.join(source_dir, "*.so")) listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dylib"))) return listing -def _get_listing_win(source_dir: str) -> List[str]: +def _get_listing_win(source_dir: str) -> list[str]: listing = glob.glob(os.path.join(source_dir, "*.pyd")) listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.lib"))) listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dll"))) return listing -def _glob_pyis(d: str) -> Set[str]: +def _glob_pyis(d: str) -> set[str]: search = os.path.join(d, "**", "*.pyi") pyis = {os.path.relpath(p, d) for p in glob.iglob(search)} return pyis -def _find_missing_pyi(source_dir: str, target_dir: str) -> List[str]: +def _find_missing_pyi(source_dir: str, target_dir: str) -> list[str]: source_pyis = _glob_pyis(source_dir) target_pyis = _glob_pyis(target_dir) missing_pyis = [os.path.join(source_dir, p) for p in (source_pyis - target_pyis)] @@ -455,7 +445,7 @@ def _find_missing_pyi(source_dir: str, target_dir: str) -> List[str]: return missing_pyis -def _get_listing(source_dir: str, target_dir: str, platform: str) -> List[str]: +def _get_listing(source_dir: str, target_dir: str, platform: str) -> list[str]: if platform.startswith("linux"): listing = _get_listing_linux(source_dir) elif platform.startswith("osx"): @@ -510,12 +500,12 @@ def _move_single( mover(src, trg) -def _copy_files(listing: List[str], source_dir: str, target_dir: str) -> None: +def _copy_files(listing: list[str], source_dir: str, target_dir: str) -> None: for src in listing: _move_single(src, source_dir, target_dir, shutil.copy2, "Copying") -def _link_files(listing: List[str], source_dir: str, target_dir: str) -> None: +def _link_files(listing: list[str], source_dir: str, target_dir: str) -> None: for src in listing: _move_single(src, source_dir, target_dir, os.link, "Linking") @@ -537,7 +527,7 @@ def move_nightly_files(spdir: str, platform: str) -> None: _copy_files(listing, source_dir, target_dir) -def _available_envs() -> Dict[str, str]: +def _available_envs() -> dict[str, str]: cmd = ["conda", "env", "list"] p = subprocess.run( cmd, @@ -559,7 +549,7 @@ def _available_envs() -> Dict[str, str]: @timed("Writing pytorch-nightly.pth") -def write_pth(env_opts: List[str], platform: str) -> None: +def write_pth(env_opts: list[str], platform: str) -> None: """Writes Python path file for this dir.""" env_type, env_dir = env_opts if env_type == "--name": @@ -582,9 +572,9 @@ def install( *, logger: logging.Logger, subcommand: str = "checkout", - branch: Optional[str] = None, - name: Optional[str] = None, - prefix: Optional[str] = None, + branch: str | None = None, + name: str | None = None, + prefix: str | None = None, channels: Iterable[str] = ("pytorch-nightly",), override_channels: bool = False, ) -> None: @@ -673,7 +663,7 @@ def make_parser() -> ArgumentParser: return p -def main(args: Optional[Sequence[str]] = None) -> None: +def main(args: Sequence[str] | None = None) -> None: """Main entry point""" global LOGGER p = make_parser() diff --git a/tools/nvcc_fix_deps.py b/tools/nvcc_fix_deps.py index 9101e5527626..0c0c9db66693 100644 --- a/tools/nvcc_fix_deps.py +++ b/tools/nvcc_fix_deps.py @@ -13,13 +13,15 @@ CMAKE_CUDA_COMPILER_LAUNCHER="python;tools/nvcc_fix_deps.py;ccache" """ +from __future__ import annotations + import subprocess import sys from pathlib import Path -from typing import List, Optional, TextIO +from typing import TextIO -def resolve_include(path: Path, include_dirs: List[Path]) -> Path: +def resolve_include(path: Path, include_dirs: list[Path]) -> Path: for include_path in include_dirs: abs_path = include_path / path if abs_path.exists(): @@ -36,7 +38,7 @@ Tried the following paths, but none existed: ) -def repair_depfile(depfile: TextIO, include_dirs: List[Path]) -> None: +def repair_depfile(depfile: TextIO, include_dirs: list[Path]) -> None: changes_made = False out = "" for line in depfile: @@ -70,8 +72,8 @@ PRE_INCLUDE_ARGS = ["-include", "--pre-include"] POST_INCLUDE_ARGS = ["-I", "--include-path", "-isystem", "--system-include"] -def extract_include_arg(include_dirs: List[Path], i: int, args: List[str]) -> None: - def extract_one(name: str, i: int, args: List[str]) -> Optional[str]: +def extract_include_arg(include_dirs: list[Path], i: int, args: list[str]) -> None: + def extract_one(name: str, i: int, args: list[str]) -> str | None: arg = args[i] if arg == name: return args[i + 1] diff --git a/tools/onnx/gen_diagnostics.py b/tools/onnx/gen_diagnostics.py index 4cf702892960..0d02f1ec4e8f 100644 --- a/tools/onnx/gen_diagnostics.py +++ b/tools/onnx/gen_diagnostics.py @@ -24,6 +24,7 @@ import yaml from torchgen import utils as torchgen_utils from torchgen.yaml_utils import YamlLoader + _RULES_GENERATED_COMMENT = """\ GENERATED CODE - DO NOT EDIT DIRECTLY This file is generated by gen_diagnostics.py. diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index 0604fced8489..8bf65047ae0c 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import argparse import collections import importlib import sys from pprint import pformat -from typing import Dict, List, Sequence +from typing import Sequence from unittest.mock import Mock, patch from warnings import warn @@ -220,7 +222,7 @@ to_py_type_ops = ("bool", "float", "complex", "long", "index", "int", "nonzero") all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops -def sig_for_ops(opname: str) -> List[str]: +def sig_for_ops(opname: str) -> list[str]: """sig_for_ops(opname : str) -> List[str] Returns signatures for operator special functions (__add__ etc.)""" @@ -254,8 +256,8 @@ def sig_for_ops(opname: str) -> List[str]: raise Exception("unknown op", opname) # noqa: TRY002 -def generate_type_hints(sig_group: PythonSignatureGroup) -> List[str]: - type_hints: List[str] = [] +def generate_type_hints(sig_group: PythonSignatureGroup) -> list[str]: + type_hints: list[str] = [] # Some deprecated ops that are on the blocklist are still included in pyi if sig_group.signature.name in blocklist and not sig_group.signature.deprecated: @@ -285,7 +287,7 @@ def generate_type_hints(sig_group: PythonSignatureGroup) -> List[str]: return type_hints -def get_max_pool_dispatch(name: str, arg_list: List[str]) -> Dict[str, List[str]]: +def get_max_pool_dispatch(name: str, arg_list: list[str]) -> dict[str, list[str]]: flag_pos = arg_list.index("{return_indices}") # If return_indices is positional arg, everything before should have no default arg_list_positional = ( @@ -329,7 +331,7 @@ def gen_nn_functional(fm: FileManager) -> None: ) # TODO the list for `torch._C._nn` is nonexhaustive - unsorted_c_nn_function_hints: Dict[str, List[str]] = {} + unsorted_c_nn_function_hints: dict[str, list[str]] = {} for d in (2, 3): unsorted_c_nn_function_hints.update( @@ -471,7 +473,7 @@ def gen_nn_functional(fm: FileManager) -> None: } ) - c_nn_function_hints: List[str] = [] + c_nn_function_hints: list[str] = [] for _, hints in sorted(unsorted_c_nn_function_hints.items()): if len(hints) > 1: hints = ["@overload\n" + h for h in hints] @@ -528,7 +530,7 @@ def gen_nn_functional(fm: FileManager) -> None: ) # Functions generated by `torch._jit_internal.boolean_dispatch` in `nn.functional` - unsorted_dispatched_hints: Dict[str, List[str]] = {} + unsorted_dispatched_hints: dict[str, list[str]] = {} for d in (1, 2, 3): unsorted_dispatched_hints.update( @@ -563,7 +565,7 @@ def gen_nn_functional(fm: FileManager) -> None: # There's no fractional_max_pool1d del unsorted_dispatched_hints["fractional_max_pool1d"] - dispatched_hints: List[str] = [] + dispatched_hints: list[str] = [] for _, hints in sorted(unsorted_dispatched_hints.items()): if len(hints) > 1: hints = ["@overload\n" + h for h in hints] @@ -594,7 +596,7 @@ We gather the docstrings for torch with the following steps: """ -def gather_docstrs() -> Dict[str, str]: +def gather_docstrs() -> dict[str, str]: docstrs = {} def mock_add_docstr(func: Mock, docstr: str) -> None: @@ -648,12 +650,12 @@ def gen_pyi( # also needs to update the other file. # Dictionary for NamedTuple definitions - structseqs: Dict[str, str] = {} + structseqs: dict[str, str] = {} # Generate type signatures for top-level functions # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - unsorted_function_hints: Dict[str, List[str]] = collections.defaultdict(list) + unsorted_function_hints: dict[str, list[str]] = collections.defaultdict(list) for n, n1, n2 in [ ("csr", "crow", "col"), @@ -1054,7 +1056,7 @@ def gen_pyi( # Generate type signatures for Tensor methods # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - unsorted_tensor_method_hints: Dict[str, List[str]] = collections.defaultdict(list) + unsorted_tensor_method_hints: dict[str, list[str]] = collections.defaultdict(list) unsorted_tensor_method_hints.update( { "size": [ diff --git a/tools/render_junit.py b/tools/render_junit.py index 0d6effbd0906..22e9fdc60606 100644 --- a/tools/render_junit.py +++ b/tools/render_junit.py @@ -1,8 +1,11 @@ #!/usr/bin/env python3 +from __future__ import annotations + import argparse import os -from typing import Any, List, Union +from typing import Any + try: from junitparser import ( # type: ignore[import] @@ -23,8 +26,8 @@ except ImportError: print("rich not found, for color output use 'pip install rich'") -def parse_junit_reports(path_to_reports: str) -> List[TestCase]: # type: ignore[no-any-unimported] - def parse_file(path: str) -> List[TestCase]: # type: ignore[no-any-unimported] +def parse_junit_reports(path_to_reports: str) -> list[TestCase]: # type: ignore[no-any-unimported] + def parse_file(path: str) -> list[TestCase]: # type: ignore[no-any-unimported] try: return convert_junit_to_testcases(JUnitXml.fromfile(path)) except Exception as err: @@ -46,7 +49,7 @@ def parse_junit_reports(path_to_reports: str) -> List[TestCase]: # type: ignore return ret_xml -def convert_junit_to_testcases(xml: Union[JUnitXml, TestSuite]) -> List[TestCase]: # type: ignore[no-any-unimported] +def convert_junit_to_testcases(xml: JUnitXml | TestSuite) -> list[TestCase]: # type: ignore[no-any-unimported] testcases = [] for item in xml: if isinstance(item, TestSuite): @@ -56,7 +59,7 @@ def convert_junit_to_testcases(xml: Union[JUnitXml, TestSuite]) -> List[TestCase return testcases -def render_tests(testcases: List[TestCase]) -> None: # type: ignore[no-any-unimported] +def render_tests(testcases: list[TestCase]) -> None: # type: ignore[no-any-unimported] num_passed = 0 num_skipped = 0 num_failed = 0 diff --git a/tools/setup_helpers/__init__.py b/tools/setup_helpers/__init__.py index 4bf1747e80c6..32731175f180 100644 --- a/tools/setup_helpers/__init__.py +++ b/tools/setup_helpers/__init__.py @@ -1,9 +1,10 @@ +from __future__ import annotations + import os import sys -from typing import Optional -def which(thefile: str) -> Optional[str]: +def which(thefile: str) -> str | None: path = os.environ.get("PATH", os.defpath).split(os.pathsep) for d in path: fname = os.path.join(d, thefile) diff --git a/tools/setup_helpers/cmake.py b/tools/setup_helpers/cmake.py index 4d10b3db1aa3..5481ce46031c 100644 --- a/tools/setup_helpers/cmake.py +++ b/tools/setup_helpers/cmake.py @@ -1,5 +1,6 @@ "Manages CMake." +from __future__ import annotations import multiprocessing import os @@ -8,7 +9,7 @@ import sys import sysconfig from distutils.version import LooseVersion from subprocess import CalledProcessError, check_call, check_output -from typing import Any, cast, Dict, List, Optional +from typing import Any, cast from . import which from .cmake_utils import CMakeValue, get_cmake_cache_variables_from_file @@ -77,7 +78,7 @@ class CMake: return cmake_command @staticmethod - def _get_version(cmd: Optional[str]) -> Any: + def _get_version(cmd: str | None) -> Any: "Returns cmake version." if cmd is None: @@ -87,7 +88,7 @@ class CMake: return LooseVersion(line.strip().split(" ")[2]) raise RuntimeError("no version found") - def run(self, args: List[str], env: Dict[str, str]) -> None: + def run(self, args: list[str], env: dict[str, str]) -> None: "Executes cmake with arguments and an environment." command = [self._cmake_command] + args @@ -101,13 +102,13 @@ class CMake: sys.exit(1) @staticmethod - def defines(args: List[str], **kwargs: CMakeValue) -> None: + def defines(args: list[str], **kwargs: CMakeValue) -> None: "Adds definitions to a cmake argument list." for key, value in sorted(kwargs.items()): if value is not None: args.append(f"-D{key}={value}") - def get_cmake_cache_variables(self) -> Dict[str, CMakeValue]: + def get_cmake_cache_variables(self) -> dict[str, CMakeValue]: r"""Gets values in CMakeCache.txt into a dictionary. Returns: dict: A ``dict`` containing the value of cached CMake variables. @@ -117,11 +118,11 @@ class CMake: def generate( self, - version: Optional[str], - cmake_python_library: Optional[str], + version: str | None, + cmake_python_library: str | None, build_python: bool, build_test: bool, - my_env: Dict[str, str], + my_env: dict[str, str], rerun: bool, ) -> None: "Runs cmake to generate native build files." @@ -181,7 +182,7 @@ class CMake: _mkdir_p(self.build_dir) # Store build options that are directly stored in environment variables - build_options: Dict[str, CMakeValue] = {} + build_options: dict[str, CMakeValue] = {} # Build options that do not start with "BUILD_", "USE_", or "CMAKE_" and are directly controlled by env vars. # This is a dict that maps environment variables to the corresponding variable name in CMake. @@ -340,7 +341,7 @@ class CMake: args.append(base_dir) self.run(args, env=my_env) - def build(self, my_env: Dict[str, str]) -> None: + def build(self, my_env: dict[str, str]) -> None: "Runs cmake to build binaries." from .env import build_type diff --git a/tools/setup_helpers/cmake_utils.py b/tools/setup_helpers/cmake_utils.py index bbef56de175b..591dea5b2a35 100644 --- a/tools/setup_helpers/cmake_utils.py +++ b/tools/setup_helpers/cmake_utils.py @@ -3,8 +3,10 @@ This is refactored from cmake.py to avoid circular imports issue with env.py, which calls get_cmake_cache_variables_from_file """ +from __future__ import annotations + import re -from typing import Dict, IO, Optional, Union +from typing import IO, Optional, Union CMakeValue = Optional[Union[bool, str]] @@ -42,7 +44,7 @@ def convert_cmake_value_to_python_value( def get_cmake_cache_variables_from_file( cmake_cache_file: IO[str], -) -> Dict[str, CMakeValue]: +) -> dict[str, CMakeValue]: r"""Gets values in CMakeCache.txt into a dictionary. Args: diff --git a/tools/setup_helpers/env.py b/tools/setup_helpers/env.py index d87e97a2bb5a..24db2f5a8917 100644 --- a/tools/setup_helpers/env.py +++ b/tools/setup_helpers/env.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import os import platform import struct import sys from itertools import chain -from typing import cast, Iterable, List, Optional +from typing import cast, Iterable IS_WINDOWS = platform.system() == "Windows" @@ -30,11 +32,11 @@ def check_negative_env_flag(name: str, default: str = "") -> bool: return os.getenv(name, default).upper() in ["OFF", "0", "NO", "FALSE", "N"] -def gather_paths(env_vars: Iterable[str]) -> List[str]: +def gather_paths(env_vars: Iterable[str]) -> list[str]: return list(chain(*(os.getenv(v, "").split(os.pathsep) for v in env_vars))) -def lib_paths_from_base(base_path: str) -> List[str]: +def lib_paths_from_base(base_path: str) -> list[str]: return [os.path.join(base_path, s) for s in ["lib/x64", "lib", "lib64"]] @@ -54,7 +56,7 @@ class BuildType: """ - def __init__(self, cmake_build_type_env: Optional[str] = None) -> None: + def __init__(self, cmake_build_type_env: str | None = None) -> None: if cmake_build_type_env is not None: self.build_type_string = cmake_build_type_env return diff --git a/tools/setup_helpers/gen_version_header.py b/tools/setup_helpers/gen_version_header.py index 1812957193d0..6a8b1f05a4ef 100644 --- a/tools/setup_helpers/gen_version_header.py +++ b/tools/setup_helpers/gen_version_header.py @@ -2,9 +2,12 @@ # and use the version numbers from there as substitutions for # an expand_template action. Since there isn't, this silly script exists. +from __future__ import annotations + import argparse import os -from typing import cast, Dict, Tuple +from typing import cast, Tuple + Version = Tuple[int, int, int] @@ -30,7 +33,7 @@ def parse_version(version: str) -> Version: return cast(Version, tuple([int(n) for n in version_number_str.split(".")])) -def apply_replacements(replacements: Dict[str, str], text: str) -> str: +def apply_replacements(replacements: dict[str, str], text: str) -> str: """ Applies the given replacements within the text. diff --git a/tools/setup_helpers/generate_code.py b/tools/setup_helpers/generate_code.py index 6c939fe1e52b..6ef951d8f021 100644 --- a/tools/setup_helpers/generate_code.py +++ b/tools/setup_helpers/generate_code.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import argparse import os import pathlib import sys -from typing import Any, cast, Optional +from typing import Any, cast import yaml @@ -18,10 +20,10 @@ TAGS_PATH = "aten/src/ATen/native/tags.yaml" def generate_code( gen_dir: pathlib.Path, - native_functions_path: Optional[str] = None, - tags_path: Optional[str] = None, - install_dir: Optional[str] = None, - subset: Optional[str] = None, + native_functions_path: str | None = None, + tags_path: str | None = None, + install_dir: str | None = None, + subset: str | None = None, disable_autograd: bool = False, force_schema_registration: bool = False, operator_selector: Any = None, @@ -102,8 +104,8 @@ def get_selector_from_legacy_operator_selection_list( def get_selector( - selected_op_list_path: Optional[str], - operators_yaml_path: Optional[str], + selected_op_list_path: str | None, + operators_yaml_path: str | None, ) -> Any: # cwrap depends on pyyaml, so we can't import it earlier root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) diff --git a/tools/stats/check_disabled_tests.py b/tools/stats/check_disabled_tests.py index e9713a0951e2..b0204bbf8b26 100644 --- a/tools/stats/check_disabled_tests.py +++ b/tools/stats/check_disabled_tests.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import argparse import json import os import xml.etree.ElementTree as ET from pathlib import Path from tempfile import TemporaryDirectory -from typing import Any, Dict, Generator, Tuple +from typing import Any, Generator from tools.stats.upload_stats_lib import ( download_s3_artifacts, @@ -14,13 +16,14 @@ from tools.stats.upload_stats_lib import ( ) from tools.stats.upload_test_stats import process_xml_element + TESTCASE_TAG = "testcase" SEPARATOR = ";" def process_report( report: Path, -) -> Dict[str, Dict[str, int]]: +) -> dict[str, dict[str, int]]: """ Return a list of disabled tests that should be re-enabled and those that are still flaky (failed or skipped) @@ -36,7 +39,7 @@ def process_report( # * Skipped tests from unittest # # We want to keep track of how many times the test fails (num_red) or passes (num_green) - all_tests: Dict[str, Dict[str, int]] = {} + all_tests: dict[str, dict[str, int]] = {} for test_case in root.iter(TESTCASE_TAG): parsed_test_case = process_xml_element(test_case) @@ -116,7 +119,7 @@ def get_test_reports( yield from Path(".").glob("**/*.xml") -def get_disabled_test_name(test_id: str) -> Tuple[str, str, str, str]: +def get_disabled_test_name(test_id: str) -> tuple[str, str, str, str]: """ Follow flaky bot convention here, if that changes, this will also need to be updated """ @@ -133,7 +136,7 @@ def prepare_record( flaky: bool, num_red: int = 0, num_green: int = 0, -) -> Tuple[Any, Dict[str, Any]]: +) -> tuple[Any, dict[str, Any]]: """ Prepare the record to save onto S3 """ @@ -162,7 +165,7 @@ def prepare_record( def save_results( workflow_id: int, workflow_run_attempt: int, - all_tests: Dict[str, Dict[str, int]], + all_tests: dict[str, dict[str, int]], ) -> None: """ Save the result to S3, so it can go to Rockset @@ -228,7 +231,7 @@ def main(repo: str, workflow_run_id: int, workflow_run_attempt: int) -> None: Find the list of all disabled tests that should be re-enabled """ # Aggregated across all jobs - all_tests: Dict[str, Dict[str, int]] = {} + all_tests: dict[str, dict[str, int]] = {} for report in get_test_reports( args.repo, args.workflow_run_id, args.workflow_run_attempt diff --git a/tools/stats/import_test_stats.py b/tools/stats/import_test_stats.py index 513edb12fcfe..5f1d4ced59be 100644 --- a/tools/stats/import_test_stats.py +++ b/tools/stats/import_test_stats.py @@ -1,17 +1,19 @@ #!/usr/bin/env python3 +from __future__ import annotations + import datetime import json import os import pathlib import shutil -from typing import Any, Callable, cast, Dict, List, Optional, Union +from typing import Any, Callable, cast, Dict from urllib.request import urlopen REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent -def get_disabled_issues() -> List[str]: +def get_disabled_issues() -> list[str]: reenabled_issues = os.getenv("REENABLED_ISSUES", "") issue_numbers = reenabled_issues.split(",") print("Ignoring disabled issues: ", issue_numbers) @@ -34,11 +36,11 @@ FILE_CACHE_LIFESPAN_SECONDS = datetime.timedelta(hours=3).seconds def fetch_and_cache( - dirpath: Union[str, pathlib.Path], + dirpath: str | pathlib.Path, name: str, url: str, - process_fn: Callable[[Dict[str, Any]], Dict[str, Any]], -) -> Dict[str, Any]: + process_fn: Callable[[dict[str, Any]], dict[str, Any]], +) -> dict[str, Any]: """ This fetch and cache utils allows sharing between different process. """ @@ -76,7 +78,7 @@ def fetch_and_cache( def get_slow_tests( dirpath: str, filename: str = SLOW_TESTS_FILE -) -> Optional[Dict[str, float]]: +) -> dict[str, float] | None: url = "https://ossci-metrics.s3.amazonaws.com/slow-tests.json" try: return fetch_and_cache(dirpath, filename, url, lambda x: x) @@ -85,7 +87,7 @@ def get_slow_tests( return {} -def get_test_times() -> Dict[str, Dict[str, float]]: +def get_test_times() -> dict[str, dict[str, float]]: return get_from_test_infra_generated_stats( "test-times.json", TEST_TIMES_FILE, @@ -93,7 +95,7 @@ def get_test_times() -> Dict[str, Dict[str, float]]: ) -def get_test_class_times() -> Dict[str, Dict[str, float]]: +def get_test_class_times() -> dict[str, dict[str, float]]: return get_from_test_infra_generated_stats( "test-class-times.json", TEST_CLASS_TIMES_FILE, @@ -103,8 +105,8 @@ def get_test_class_times() -> Dict[str, Dict[str, float]]: def get_disabled_tests( dirpath: str, filename: str = DISABLED_TESTS_FILE -) -> Optional[Dict[str, Any]]: - def process_disabled_test(the_response: Dict[str, Any]) -> Dict[str, Any]: +) -> dict[str, Any] | None: + def process_disabled_test(the_response: dict[str, Any]) -> dict[str, Any]: # remove re-enabled tests and condense even further by getting rid of pr_num disabled_issues = get_disabled_issues() disabled_test_from_issues = dict() @@ -124,7 +126,7 @@ def get_disabled_tests( return {} -def get_test_file_ratings() -> Dict[str, Any]: +def get_test_file_ratings() -> dict[str, Any]: return get_from_test_infra_generated_stats( "file_test_rating.json", TEST_FILE_RATINGS_FILE, @@ -132,7 +134,7 @@ def get_test_file_ratings() -> Dict[str, Any]: ) -def get_test_class_ratings() -> Dict[str, Any]: +def get_test_class_ratings() -> dict[str, Any]: return get_from_test_infra_generated_stats( "file_test_class_rating.json", TEST_CLASS_RATINGS_FILE, @@ -140,7 +142,7 @@ def get_test_class_ratings() -> Dict[str, Any]: ) -def get_td_heuristic_historial_edited_files_json() -> Dict[str, Any]: +def get_td_heuristic_historial_edited_files_json() -> dict[str, Any]: return get_from_test_infra_generated_stats( "td_heuristic_historical_edited_files.json", TD_HEURISTIC_HISTORICAL_EDITED_FILES, @@ -148,7 +150,7 @@ def get_td_heuristic_historial_edited_files_json() -> Dict[str, Any]: ) -def get_td_heuristic_profiling_json() -> Dict[str, Any]: +def get_td_heuristic_profiling_json() -> dict[str, Any]: return get_from_test_infra_generated_stats( "td_heuristic_profiling.json", TD_HEURISTIC_PROFILING_FILE, @@ -182,7 +184,7 @@ def copy_additional_previous_failures() -> None: def get_from_test_infra_generated_stats( from_file: str, to_file: str, failure_explanation: str -) -> Dict[str, Any]: +) -> dict[str, Any]: url = f"https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/{from_file}" try: return fetch_and_cache( diff --git a/tools/stats/monitor.py b/tools/stats/monitor.py index 45ad385aea90..6f190aa52e70 100644 --- a/tools/stats/monitor.py +++ b/tools/stats/monitor.py @@ -1,14 +1,17 @@ #!/usr/bin/env python3 + +from __future__ import annotations + import datetime import json import signal import time -from typing import Any, Dict, List +from typing import Any import psutil # type: ignore[import] -def get_processes_running_python_tests() -> List[Any]: +def get_processes_running_python_tests() -> list[Any]: python_processes = [] for process in psutil.process_iter(): try: @@ -20,7 +23,7 @@ def get_processes_running_python_tests() -> List[Any]: return python_processes -def get_per_process_cpu_info() -> List[Dict[str, Any]]: +def get_per_process_cpu_info() -> list[dict[str, Any]]: processes = get_processes_running_python_tests() per_process_info = [] for p in processes: @@ -49,7 +52,7 @@ def get_per_process_cpu_info() -> List[Dict[str, Any]]: return per_process_info -def get_per_process_gpu_info(handle: Any) -> List[Dict[str, Any]]: +def get_per_process_gpu_info(handle: Any) -> list[dict[str, Any]]: processes = pynvml.nvmlDeviceGetComputeRunningProcesses(handle) per_process_info = [] for p in processes: @@ -58,7 +61,7 @@ def get_per_process_gpu_info(handle: Any) -> List[Dict[str, Any]]: return per_process_info -def rocm_get_per_process_gpu_info(handle: Any) -> List[Dict[str, Any]]: +def rocm_get_per_process_gpu_info(handle: Any) -> list[dict[str, Any]]: processes = amdsmi.amdsmi_get_gpu_process_list(handle) per_process_info = [] for p in processes: diff --git a/tools/stats/test_dashboard.py b/tools/stats/test_dashboard.py index 777f1e6b5637..2d3a2e0e1521 100644 --- a/tools/stats/test_dashboard.py +++ b/tools/stats/test_dashboard.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import os import re @@ -6,7 +8,7 @@ from collections import defaultdict from functools import lru_cache from pathlib import Path from tempfile import TemporaryDirectory -from typing import Any, cast, Dict, List +from typing import Any, cast import requests @@ -18,6 +20,7 @@ from tools.stats.upload_stats_lib import ( upload_workflow_stats_to_s3, ) + REGEX_JOB_INFO = r"(.*) \/ .*test \(([^,]*), .*\)" @@ -56,7 +59,7 @@ def get_test_config(job_name: str) -> str: def get_td_exclusions( workflow_run_id: int, workflow_run_attempt: int -) -> Dict[str, Any]: +) -> dict[str, Any]: with TemporaryDirectory() as temp_dir: print("Using temporary directory:", temp_dir) os.chdir(temp_dir) @@ -68,7 +71,7 @@ def get_td_exclusions( for path in s3_paths: unzip(path) - grouped_tests: Dict[str, Any] = defaultdict(lambda: defaultdict(set)) + grouped_tests: dict[str, Any] = defaultdict(lambda: defaultdict(set)) for td_exclusions in Path(".").glob("**/td_exclusions*.json"): with open(td_exclusions) as f: exclusions = json.load(f) @@ -85,9 +88,9 @@ def get_td_exclusions( return grouped_tests -def group_test_cases(test_cases: List[Dict[str, Any]]) -> Dict[str, Any]: +def group_test_cases(test_cases: list[dict[str, Any]]) -> dict[str, Any]: start = time.time() - grouped_tests: Dict[str, Any] = defaultdict( + grouped_tests: dict[str, Any] = defaultdict( lambda: defaultdict( lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))) ) @@ -112,8 +115,8 @@ def group_test_cases(test_cases: List[Dict[str, Any]]) -> Dict[str, Any]: return grouped_tests -def get_reruns(grouped_tests: Dict[str, Any]) -> Dict[str, Any]: - reruns: Dict[str, Any] = defaultdict( +def get_reruns(grouped_tests: dict[str, Any]) -> dict[str, Any]: + reruns: dict[str, Any] = defaultdict( lambda: defaultdict( lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))) ) @@ -136,8 +139,8 @@ def get_reruns(grouped_tests: Dict[str, Any]) -> Dict[str, Any]: return reruns -def get_invoking_file_summary(grouped_tests: Dict[str, Any]) -> Dict[str, Any]: - invoking_file_summary: Dict[str, Any] = defaultdict( +def get_invoking_file_summary(grouped_tests: dict[str, Any]) -> dict[str, Any]: + invoking_file_summary: dict[str, Any] = defaultdict( lambda: defaultdict(lambda: defaultdict(lambda: {"count": 0, "time": 0.0})) ) for build_name, build in grouped_tests.items(): @@ -157,7 +160,7 @@ def get_invoking_file_summary(grouped_tests: Dict[str, Any]) -> Dict[str, Any]: def upload_additional_info( - workflow_run_id: int, workflow_run_attempt: int, test_cases: List[Dict[str, Any]] + workflow_run_id: int, workflow_run_attempt: int, test_cases: list[dict[str, Any]] ) -> None: grouped_tests = group_test_cases(test_cases) reruns = get_reruns(grouped_tests) diff --git a/tools/stats/upload_artifacts.py b/tools/stats/upload_artifacts.py index eb0fde7f38ac..5036d745d963 100644 --- a/tools/stats/upload_artifacts.py +++ b/tools/stats/upload_artifacts.py @@ -5,6 +5,7 @@ from tempfile import TemporaryDirectory from tools.stats.upload_stats_lib import download_gha_artifacts, upload_file_to_s3 + ARTIFACTS = [ "sccache-stats", "test-jsons", diff --git a/tools/stats/upload_dynamo_perf_stats.py b/tools/stats/upload_dynamo_perf_stats.py index c6c507863f44..b7564091808d 100644 --- a/tools/stats/upload_dynamo_perf_stats.py +++ b/tools/stats/upload_dynamo_perf_stats.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import argparse import csv import os import re from pathlib import Path from tempfile import TemporaryDirectory -from typing import Any, Dict, List +from typing import Any from tools.stats.upload_stats_lib import download_s3_artifacts, unzip, upload_to_rockset @@ -23,7 +25,7 @@ def upload_dynamo_perf_stats_to_rockset( workflow_run_attempt: int, head_branch: str, match_filename: str, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: match_filename_regex = re.compile(match_filename) perf_stats = [] with TemporaryDirectory() as temp_dir: diff --git a/tools/stats/upload_external_contrib_stats.py b/tools/stats/upload_external_contrib_stats.py index 142019065f69..a90811592e2f 100644 --- a/tools/stats/upload_external_contrib_stats.py +++ b/tools/stats/upload_external_contrib_stats.py @@ -1,16 +1,18 @@ +from __future__ import annotations + import argparse import datetime import json import os - import time import urllib.parse -from typing import Any, Callable, cast, Dict, List, Optional, Set +from typing import Any, Callable, cast, Dict, List from urllib.error import HTTPError from urllib.request import Request, urlopen from tools.stats.upload_stats_lib import upload_to_s3 + FILTER_OUT_USERS = { "pytorchmergebot", "facebook-github-bot", @@ -23,9 +25,9 @@ FILTER_OUT_USERS = { def _fetch_url( url: str, - headers: Dict[str, str], - data: Optional[Dict[str, Any]] = None, - method: Optional[str] = None, + headers: dict[str, str], + data: dict[str, Any] | None = None, + method: str | None = None, reader: Callable[[Any], Any] = lambda x: x.read(), ) -> Any: token = os.environ.get("GITHUB_TOKEN") @@ -49,9 +51,9 @@ def _fetch_url( def fetch_json( url: str, - params: Optional[Dict[str, Any]] = None, - data: Optional[Dict[str, Any]] = None, -) -> List[Dict[str, Any]]: + params: dict[str, Any] | None = None, + data: dict[str, Any] | None = None, +) -> list[dict[str, Any]]: headers = {"Accept": "application/vnd.github.v3+json"} if params is not None and len(params) > 0: url += "?" + "&".join( @@ -65,16 +67,16 @@ def fetch_json( def get_external_pr_data( start_date: datetime.date, end_date: datetime.date, period_length: int = 1 -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: pr_info = [] period_begin_date = start_date pr_count = 0 - users: Set[str] = set() + users: set[str] = set() while period_begin_date < end_date: period_end_date = period_begin_date + datetime.timedelta(days=period_length - 1) page = 1 - responses: List[Dict[str, Any]] = [] + responses: list[dict[str, Any]] = [] while len(responses) > 0 or page == 1: response = cast( Dict[str, Any], diff --git a/tools/stats/upload_metrics.py b/tools/stats/upload_metrics.py index 16688c340ced..2a574165f19a 100644 --- a/tools/stats/upload_metrics.py +++ b/tools/stats/upload_metrics.py @@ -1,13 +1,15 @@ +from __future__ import annotations + import datetime import inspect import os import time import uuid - from decimal import Decimal -from typing import Any, Dict +from typing import Any from warnings import warn + # boto3 is an optional dependency. If it's not installed, # we'll just not emit the metrics. # Keeping this logic here so that callers don't have to @@ -65,7 +67,7 @@ class EnvVarMetric: return value -global_metrics: Dict[str, Any] = {} +global_metrics: dict[str, Any] = {} def add_global_metric(metric_name: str, metric_value: Any) -> None: @@ -79,7 +81,7 @@ def add_global_metric(metric_name: str, metric_value: Any) -> None: def emit_metric( metric_name: str, - metrics: Dict[str, Any], + metrics: dict[str, Any], ) -> None: """ Upload a metric to DynamoDB (and from there, Rockset). @@ -174,7 +176,7 @@ def emit_metric( print(f"Not emitting metrics for {metric_name}. Boto wasn't imported.") -def _convert_float_values_to_decimals(data: Dict[str, Any]) -> Dict[str, Any]: +def _convert_float_values_to_decimals(data: dict[str, Any]) -> dict[str, Any]: # Attempt to recurse def _helper(o: Any) -> Any: if isinstance(o, float): diff --git a/tools/stats/upload_sccache_stats.py b/tools/stats/upload_sccache_stats.py index 2252ac966d38..0f59d10ff698 100644 --- a/tools/stats/upload_sccache_stats.py +++ b/tools/stats/upload_sccache_stats.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import argparse import json import os from pathlib import Path from tempfile import TemporaryDirectory -from typing import Any, Dict, List +from typing import Any from tools.stats.upload_stats_lib import ( download_s3_artifacts, @@ -13,7 +15,7 @@ from tools.stats.upload_stats_lib import ( def get_sccache_stats( workflow_run_id: int, workflow_run_attempt: int -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: with TemporaryDirectory() as temp_dir: print("Using temporary directory:", temp_dir) os.chdir(temp_dir) diff --git a/tools/stats/upload_stats_lib.py b/tools/stats/upload_stats_lib.py index 770af32247d6..caffa35a3b17 100644 --- a/tools/stats/upload_stats_lib.py +++ b/tools/stats/upload_stats_lib.py @@ -1,16 +1,18 @@ +from __future__ import annotations + import gzip import io import json import os import zipfile - from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any import boto3 # type: ignore[import] import requests import rockset # type: ignore[import] + PYTORCH_REPO = "https://api.github.com/repos/pytorch/pytorch" S3_RESOURCE = boto3.resource("s3") @@ -21,14 +23,14 @@ MAX_RETRY_IN_NON_DISABLED_MODE = 3 * 3 BATCH_SIZE = 5000 -def _get_request_headers() -> Dict[str, str]: +def _get_request_headers() -> dict[str, str]: return { "Accept": "application/vnd.github.v3+json", "Authorization": "token " + os.environ["GITHUB_TOKEN"], } -def _get_artifact_urls(prefix: str, workflow_run_id: int) -> Dict[Path, str]: +def _get_artifact_urls(prefix: str, workflow_run_id: int) -> dict[Path, str]: """Get all workflow artifacts with 'test-report' in the name.""" response = requests.get( f"{PYTORCH_REPO}/actions/runs/{workflow_run_id}/artifacts?per_page=100", @@ -78,7 +80,7 @@ def _download_artifact( def download_s3_artifacts( prefix: str, workflow_run_id: int, workflow_run_attempt: int -) -> List[Path]: +) -> list[Path]: bucket = S3_RESOURCE.Bucket("gha-artifacts") objs = bucket.objects.filter( Prefix=f"pytorch/pytorch/{workflow_run_id}/{workflow_run_attempt}/artifact/{prefix}" @@ -104,7 +106,7 @@ def download_s3_artifacts( def download_gha_artifacts( prefix: str, workflow_run_id: int, workflow_run_attempt: int -) -> List[Path]: +) -> list[Path]: artifact_urls = _get_artifact_urls(prefix, workflow_run_id) paths = [] for name, url in artifact_urls.items(): @@ -114,7 +116,7 @@ def download_gha_artifacts( def upload_to_rockset( collection: str, - docs: List[Any], + docs: list[Any], workspace: str = "commons", client: Any = None, ) -> None: @@ -142,7 +144,7 @@ def upload_to_rockset( def upload_to_s3( bucket_name: str, key: str, - docs: List[Dict[str, Any]], + docs: list[dict[str, Any]], ) -> None: print(f"Writing {len(docs)} documents to S3") body = io.StringIO() @@ -164,7 +166,7 @@ def upload_to_s3( def read_from_s3( bucket_name: str, key: str, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: print(f"Reading from s3://{bucket_name}/{key}") body = ( S3_RESOURCE.Object( @@ -182,7 +184,7 @@ def upload_workflow_stats_to_s3( workflow_run_id: int, workflow_run_attempt: int, collection: str, - docs: List[Dict[str, Any]], + docs: list[dict[str, Any]], ) -> None: bucket_name = "ossci-raw-job-status" key = f"{collection}/{workflow_run_id}/{workflow_run_attempt}" @@ -220,7 +222,7 @@ def unzip(p: Path) -> None: zip.extractall(unzipped_dir) -def is_rerun_disabled_tests(tests: Dict[str, Dict[str, int]]) -> bool: +def is_rerun_disabled_tests(tests: dict[str, dict[str, int]]) -> bool: """ Check if the test report is coming from rerun_disabled_tests workflow where each test is run multiple times @@ -231,7 +233,7 @@ def is_rerun_disabled_tests(tests: Dict[str, Dict[str, int]]) -> bool: ) -def get_job_id(report: Path) -> Optional[int]: +def get_job_id(report: Path) -> int | None: # [Job id in artifacts] # Retrieve the job id from the report path. In our GHA workflows, we append # the job id to the end of the report name, so `report` looks like: diff --git a/tools/stats/upload_test_stat_aggregates.py b/tools/stats/upload_test_stat_aggregates.py index 5eb9a12d9833..e128ca4bf14f 100644 --- a/tools/stats/upload_test_stat_aggregates.py +++ b/tools/stats/upload_test_stat_aggregates.py @@ -1,17 +1,19 @@ +from __future__ import annotations + import argparse import ast import datetime import json import os import re -from typing import Any, List, Union +from typing import Any import rockset # type: ignore[import] from tools.stats.upload_stats_lib import upload_to_s3 -def get_oncall_from_testfile(testfile: str) -> Union[List[str], None]: +def get_oncall_from_testfile(testfile: str) -> list[str] | None: path = f"test/{testfile}" if not path.endswith(".py"): path += ".py" diff --git a/tools/stats/upload_test_stats.py b/tools/stats/upload_test_stats.py index 41fe6e76cd62..6984d3c73c40 100644 --- a/tools/stats/upload_test_stats.py +++ b/tools/stats/upload_test_stats.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import os import sys @@ -5,7 +7,7 @@ import xml.etree.ElementTree as ET from multiprocessing import cpu_count, Pool from pathlib import Path from tempfile import TemporaryDirectory -from typing import Any, Dict, List +from typing import Any from tools.stats.test_dashboard import upload_additional_info from tools.stats.upload_stats_lib import ( @@ -21,14 +23,14 @@ def parse_xml_report( report: Path, workflow_id: int, workflow_run_attempt: int, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """Convert a test report xml file into a JSON-serializable list of test cases.""" print(f"Parsing {tag}s for test report: {report}") job_id = get_job_id(report) print(f"Found job id: {job_id}") - test_cases: List[Dict[str, Any]] = [] + test_cases: list[dict[str, Any]] = [] root = ET.parse(report) for test_case in root.iter(tag): @@ -53,9 +55,9 @@ def parse_xml_report( return test_cases -def process_xml_element(element: ET.Element) -> Dict[str, Any]: +def process_xml_element(element: ET.Element) -> dict[str, Any]: """Convert a test suite element into a JSON-serializable dict.""" - ret: Dict[str, Any] = {} + ret: dict[str, Any] = {} # Convert attributes directly into dict elements. # e.g. @@ -110,7 +112,7 @@ def process_xml_element(element: ET.Element) -> Dict[str, Any]: return ret -def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> List[Dict[str, Any]]: +def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> list[dict[str, Any]]: with TemporaryDirectory() as temp_dir: print("Using temporary directory:", temp_dir) os.chdir(temp_dir) @@ -146,7 +148,7 @@ def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> List[Dict[str, def get_tests_for_circleci( workflow_run_id: int, workflow_run_attempt: int -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: # Parse the reports and transform them to JSON test_cases = [] for xml_report in Path(".").glob("**/test/test-reports/**/*.xml"): @@ -159,13 +161,13 @@ def get_tests_for_circleci( return test_cases -def summarize_test_cases(test_cases: List[Dict[str, Any]]) -> List[Dict[str, Any]]: +def summarize_test_cases(test_cases: list[dict[str, Any]]) -> list[dict[str, Any]]: """Group test cases by classname, file, and job_id. We perform the aggregation manually instead of using the `test-suite` XML tag because xmlrunner does not produce reliable output for it. """ - def get_key(test_case: Dict[str, Any]) -> Any: + def get_key(test_case: dict[str, Any]) -> Any: return ( test_case.get("file"), test_case.get("classname"), @@ -176,7 +178,7 @@ def summarize_test_cases(test_cases: List[Dict[str, Any]]) -> List[Dict[str, Any test_case["invoking_file"], ) - def init_value(test_case: Dict[str, Any]) -> Dict[str, Any]: + def init_value(test_case: dict[str, Any]) -> dict[str, Any]: return { "file": test_case.get("file"), "classname": test_case.get("classname"), diff --git a/tools/stats/upload_test_stats_intermediate.py b/tools/stats/upload_test_stats_intermediate.py index 77cab472367b..d0a32f0630e8 100644 --- a/tools/stats/upload_test_stats_intermediate.py +++ b/tools/stats/upload_test_stats_intermediate.py @@ -4,6 +4,7 @@ import sys from tools.stats.test_dashboard import upload_additional_info from tools.stats.upload_test_stats import get_tests + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Upload test stats to Rockset") parser.add_argument( diff --git a/tools/test/gen_operators_yaml_test.py b/tools/test/gen_operators_yaml_test.py index 956ec0f34124..ef129974febf 100644 --- a/tools/test/gen_operators_yaml_test.py +++ b/tools/test/gen_operators_yaml_test.py @@ -5,7 +5,6 @@ import argparse import json import unittest from collections import defaultdict - from unittest.mock import Mock, patch from gen_operators_yaml import ( @@ -43,10 +42,10 @@ def _mock_load_op_dep_graph(): class GenOperatorsYAMLTest(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: pass - def test_filter_creation(self): + def test_filter_creation(self) -> None: filter_func = make_filter_from_options( model_name="abc", model_versions=["100", "101"], @@ -99,7 +98,7 @@ class GenOperatorsYAMLTest(unittest.TestCase): len(filtered_configs) == 2 ), f"Expected 2 elements in filtered_configs, but got {len(filtered_configs)}" - def test_verification_success(self): + def test_verification_success(self) -> None: filter_func = make_filter_from_options( model_name="abc", model_versions=["100", "101"], @@ -142,7 +141,7 @@ class GenOperatorsYAMLTest(unittest.TestCase): "expected verify_all_specified_present to succeed instead it raised an exception" ) - def test_verification_fail(self): + def test_verification_fail(self) -> None: config = [ { "model": { @@ -229,7 +228,7 @@ class GenOperatorsYAMLTest(unittest.TestCase): ) def test_fill_output_with_arguments_not_include_all_overloads( self, mock_parse_options: Mock, mock_load_op_dep_graph: Mock - ): + ) -> None: parser = argparse.ArgumentParser(description="Generate used operators YAML") options = get_parser_options(parser) diff --git a/tools/test/gen_oplist_test.py b/tools/test/gen_oplist_test.py index 33f9fb293edc..482366260853 100644 --- a/tools/test/gen_oplist_test.py +++ b/tools/test/gen_oplist_test.py @@ -8,10 +8,10 @@ from tools.code_analyzer.gen_oplist import throw_if_any_op_includes_overloads class GenOplistTest(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: pass - def test_throw_if_any_op_includes_overloads(self): + def test_throw_if_any_op_includes_overloads(self) -> None: selective_builder = MagicMock() selective_builder.operators = MagicMock() selective_builder.operators.items.return_value = [ diff --git a/tools/test/heuristics/test_heuristics.py b/tools/test/heuristics/test_heuristics.py index a1d1534704da..f18ad5b1eaad 100644 --- a/tools/test/heuristics/test_heuristics.py +++ b/tools/test/heuristics/test_heuristics.py @@ -1,10 +1,12 @@ # For testing specific heuristics +from __future__ import annotations + import io import json import pathlib import sys import unittest -from typing import Any, Dict, List, Set +from typing import Any from unittest import mock REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent @@ -28,14 +30,14 @@ sys.path.remove(str(REPO_ROOT)) HEURISTIC_CLASS = "tools.testing.target_determination.heuristics.historical_class_failure_correlation." -def mocked_file(contents: Dict[Any, Any]) -> io.IOBase: +def mocked_file(contents: dict[Any, Any]) -> io.IOBase: file_object = io.StringIO() json.dump(contents, file_object) file_object.seek(0) return file_object -def gen_historical_class_failures() -> Dict[str, Dict[str, float]]: +def gen_historical_class_failures() -> dict[str, dict[str, float]]: return { "file1": { "test1::classA": 0.5, @@ -80,8 +82,8 @@ class TestHistoricalClassFailureCorrelation(TestTD): ) def test_get_prediction_confidence( self, - historical_class_failures: Dict[str, Dict[str, float]], - changed_files: List[str], + historical_class_failures: dict[str, dict[str, float]], + changed_files: list[str], ) -> None: tests_to_prioritize = ALL_TESTS @@ -113,7 +115,7 @@ class TestHistoricalClassFailureCorrelation(TestTD): class TestParsePrevTests(TestTD): @mock.patch("os.path.exists", return_value=False) def test_cache_does_not_exist(self, mock_exists: Any) -> None: - expected_failing_test_files: Set[str] = set() + expected_failing_test_files: set[str] = set() found_tests = get_previous_failures() @@ -122,7 +124,7 @@ class TestParsePrevTests(TestTD): @mock.patch("os.path.exists", return_value=True) @mock.patch("builtins.open", return_value=mocked_file({"": True})) def test_empty_cache(self, mock_exists: Any, mock_open: Any) -> None: - expected_failing_test_files: Set[str] = set() + expected_failing_test_files: set[str] = set() found_tests = get_previous_failures() diff --git a/tools/test/heuristics/test_interface.py b/tools/test/heuristics/test_interface.py index df122ab7d56f..78ead7d0d623 100644 --- a/tools/test/heuristics/test_interface.py +++ b/tools/test/heuristics/test_interface.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import pathlib import sys import unittest -from typing import Any, Dict, List +from typing import Any REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent sys.path.append(str(REPO_ROOT)) @@ -13,7 +15,7 @@ sys.path.remove(str(REPO_ROOT)) class TestTD(unittest.TestCase): def assert_test_scores_almost_equal( - self, d1: Dict[TestRun, float], d2: Dict[TestRun, float] + self, d1: dict[TestRun, float], d2: dict[TestRun, float] ) -> None: # Check that dictionaries are the same, except for floating point errors self.assertEqual(set(d1.keys()), set(d2.keys())) @@ -24,7 +26,7 @@ class TestTD(unittest.TestCase): # Create a dummy heuristic class class Heuristic(interface.HeuristicInterface): def get_prediction_confidence( - self, tests: List[str] + self, tests: list[str] ) -> interface.TestPrioritizations: # Return junk return interface.TestPrioritizations([], {}) @@ -259,9 +261,9 @@ class TestTestPrioritizations(TestTD): class TestAggregatedHeuristics(TestTD): def check( self, - tests: List[str], - test_prioritizations: List[Dict[TestRun, float]], - expected: Dict[TestRun, float], + tests: list[str], + test_prioritizations: list[dict[TestRun, float]], + expected: dict[TestRun, float], ) -> None: aggregated_heuristics = interface.AggregatedHeuristics(tests) for i, test_prioritization in enumerate(test_prioritizations): @@ -429,7 +431,7 @@ class TestAggregatedHeuristicsTestStats(TestTD): stats3 = aggregator.get_test_stats(TestRun("test3")) stats5 = aggregator.get_test_stats(TestRun("test5::classA")) - def assert_valid_dict(dict_contents: Dict[str, Any]) -> None: + def assert_valid_dict(dict_contents: dict[str, Any]) -> None: for key, value in dict_contents.items(): self.assertTrue(isinstance(key, str)) self.assertTrue( diff --git a/tools/test/heuristics/test_utils.py b/tools/test/heuristics/test_utils.py index 934897271ae5..ddc2a72a5ff8 100644 --- a/tools/test/heuristics/test_utils.py +++ b/tools/test/heuristics/test_utils.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import pathlib import sys import unittest -from typing import Any, Dict +from typing import Any REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent @@ -14,14 +16,14 @@ sys.path.remove(str(REPO_ROOT)) class TestHeuristicsUtils(unittest.TestCase): def assertDictAlmostEqual( - self, first: Dict[TestRun, Any], second: Dict[TestRun, Any] + self, first: dict[TestRun, Any], second: dict[TestRun, Any] ) -> None: self.assertEqual(first.keys(), second.keys()) for key in first.keys(): self.assertAlmostEqual(first[key], second[key]) def test_normalize_ratings(self) -> None: - ratings: Dict[TestRun, float] = { + ratings: dict[TestRun, float] = { TestRun("test1"): 1, TestRun("test2"): 2, TestRun("test3"): 4, diff --git a/tools/test/test_cmake.py b/tools/test/test_cmake.py index 618b951a8c54..4a87043dccbf 100644 --- a/tools/test/test_cmake.py +++ b/tools/test/test_cmake.py @@ -1,12 +1,13 @@ +from __future__ import annotations + import contextlib import os import typing import unittest import unittest.mock -from typing import Iterator, Optional, Sequence +from typing import Iterator, Sequence import tools.setup_helpers.cmake - import tools.setup_helpers.env # noqa: F401 unused but resolves circular import @@ -79,7 +80,7 @@ class TestCMake(unittest.TestCase): @contextlib.contextmanager -def env_var(key: str, value: Optional[str]) -> Iterator[None]: +def env_var(key: str, value: str | None) -> Iterator[None]: """Sets/clears an environment variable within a Python context.""" # Get the previous value and then override it. previous_value = os.environ.get(key) @@ -91,7 +92,7 @@ def env_var(key: str, value: Optional[str]) -> Iterator[None]: set_env_var(key, previous_value) -def set_env_var(key: str, value: Optional[str]) -> None: +def set_env_var(key: str, value: str | None) -> None: """Sets/clears an environment variable.""" if value is None: os.environ.pop(key, None) diff --git a/tools/test/test_codegen.py b/tools/test/test_codegen.py index 8e11ac8ea054..cefd8aeeded6 100644 --- a/tools/test/test_codegen.py +++ b/tools/test/test_codegen.py @@ -1,14 +1,13 @@ +from __future__ import annotations + import dataclasses import typing import unittest from collections import defaultdict -from typing import Dict, List import yaml - from tools.autograd import gen_autograd_functions, load_derivatives -import torchgen.model from torchgen import dest from torchgen.api.types import CppSignatureGroup, DispatcherSignature from torchgen.context import native_function_manager @@ -22,6 +21,7 @@ from torchgen.model import ( BackendIndex, BackendMetadata, DispatchKey, + FunctionSchema, Location, NativeFunction, OperatorName, @@ -32,7 +32,7 @@ from torchgen.selective_build.selector import SelectiveBuilder class TestCreateDerivative(unittest.TestCase): def test_named_grads(self) -> None: - schema = torchgen.model.FunctionSchema.parse( + schema = FunctionSchema.parse( "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)" ) native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) @@ -47,7 +47,7 @@ class TestCreateDerivative(unittest.TestCase): def test_non_differentiable_output(self) -> None: specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)" - schema = torchgen.model.FunctionSchema.parse(specification) + schema = FunctionSchema.parse(specification) native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) _, differentiability_info = load_derivatives.create_differentiability_info( @@ -69,7 +69,7 @@ class TestCreateDerivative(unittest.TestCase): ) def test_indexed_grads(self) -> None: - schema = torchgen.model.FunctionSchema.parse( + schema = FunctionSchema.parse( "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)" ) native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) @@ -84,7 +84,7 @@ class TestCreateDerivative(unittest.TestCase): def test_named_grads_and_indexed_grads(self) -> None: specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)" - schema = torchgen.model.FunctionSchema.parse(specification) + schema = FunctionSchema.parse(specification) native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) with self.assertRaisesRegex( @@ -112,7 +112,7 @@ class TestCreateDerivative(unittest.TestCase): class TestGenAutogradFunctions(unittest.TestCase): def test_non_differentiable_output_invalid_type(self) -> None: specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)" - schema = torchgen.model.FunctionSchema.parse(specification) + schema = FunctionSchema.parse(specification) native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) _, differentiability_info = load_derivatives.create_differentiability_info( @@ -141,7 +141,7 @@ class TestGenAutogradFunctions(unittest.TestCase): def test_non_differentiable_output_output_differentiability(self) -> None: specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y, Tensor z)" - schema = torchgen.model.FunctionSchema.parse(specification) + schema = FunctionSchema.parse(specification) native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) _, differentiability_info = load_derivatives.create_differentiability_info( @@ -182,7 +182,7 @@ class TestGenAutogradFunctions(unittest.TestCase): def test_register_bogus_dispatch_key(self) -> None: specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)" - schema = torchgen.model.FunctionSchema.parse(specification) + schema = FunctionSchema.parse(specification) native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) with self.assertRaisesRegex( @@ -213,17 +213,17 @@ class TestGenAutogradFunctions(unittest.TestCase): class TestGenSchemaRegistration(unittest.TestCase): def setUp(self) -> None: self.selector = SelectiveBuilder.get_nop_selector() - self.custom_native_function, _ = torchgen.model.NativeFunction.from_yaml( + self.custom_native_function, _ = NativeFunction.from_yaml( {"func": "custom::func() -> bool"}, - loc=torchgen.model.Location(__file__, 1), + loc=Location(__file__, 1), valid_tags=set(), ) ( self.fragment_custom_native_function, _, - ) = torchgen.model.NativeFunction.from_yaml( + ) = NativeFunction.from_yaml( {"func": "quantized_decomposed::func() -> bool"}, - loc=torchgen.model.Location(__file__, 1), + loc=Location(__file__, 1), valid_tags=set(), ) @@ -285,9 +285,9 @@ TORCH_LIBRARY(custom, m) { ) def test_3_namespaces_schema_registration_code_valid(self) -> None: - custom2_native_function, _ = torchgen.model.NativeFunction.from_yaml( + custom2_native_function, _ = NativeFunction.from_yaml( {"func": "custom2::func() -> bool"}, - loc=torchgen.model.Location(__file__, 1), + loc=Location(__file__, 1), valid_tags=set(), ) ( @@ -320,7 +320,7 @@ class TestGenNativeFunctionDeclaration(unittest.TestCase): def setUp(self) -> None: self.op_1_native_function, op_1_backend_index = NativeFunction.from_yaml( {"func": "op_1() -> bool", "dispatch": {"CPU": "kernel_1"}}, - loc=torchgen.model.Location(__file__, 1), + loc=Location(__file__, 1), valid_tags=set(), ) self.op_2_native_function, op_2_backend_index = NativeFunction.from_yaml( @@ -328,11 +328,11 @@ class TestGenNativeFunctionDeclaration(unittest.TestCase): "func": "op_2() -> bool", "dispatch": {"CPU": "kernel_2", "QuantizedCPU": "custom::kernel_3"}, }, - loc=torchgen.model.Location(__file__, 1), + loc=Location(__file__, 1), valid_tags=set(), ) - backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = { + backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = { DispatchKey.CPU: {}, DispatchKey.QuantizedCPU: {}, } @@ -382,9 +382,9 @@ TORCH_API bool kernel_1(); # Test for native_function_generation class TestNativeFunctionGeneratrion(unittest.TestCase): def setUp(self) -> None: - self.native_functions: List[NativeFunction] = [] - self.backend_indices: Dict[ - DispatchKey, Dict[OperatorName, BackendMetadata] + self.native_functions: list[NativeFunction] = [] + self.backend_indices: dict[ + DispatchKey, dict[OperatorName, BackendMetadata] ] = defaultdict(dict) yaml_entry = """ - func: op(Tensor self) -> Tensor @@ -405,7 +405,7 @@ class TestNativeFunctionGeneratrion(unittest.TestCase): "dispatch": {"CPU": "kernel_1"}, "autogen": "op_2.out", }, - loc=torchgen.model.Location(__file__, 1), + loc=Location(__file__, 1), valid_tags=set(), ) BackendIndex.grow_index(self.backend_indices, two_returns_backend_index) @@ -442,8 +442,8 @@ class TestNativeFunctionGeneratrion(unittest.TestCase): # Test for static_dispatch class TestStaticDispatchGeneratrion(unittest.TestCase): def setUp(self) -> None: - self.backend_indices: Dict[ - DispatchKey, Dict[OperatorName, BackendMetadata] + self.backend_indices: dict[ + DispatchKey, dict[OperatorName, BackendMetadata] ] = defaultdict(dict) yaml_entry = """ - func: op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) @@ -500,9 +500,9 @@ class TestStaticDispatchGeneratrion(unittest.TestCase): # Represents the most basic NativeFunction. Use dataclasses.replace() # to edit for use. -DEFAULT_NATIVE_FUNCTION, _ = torchgen.model.NativeFunction.from_yaml( +DEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml( {"func": "func() -> bool"}, - loc=torchgen.model.Location(__file__, 1), + loc=Location(__file__, 1), valid_tags=set(), ) diff --git a/tools/test/test_create_alerts.py b/tools/test/test_create_alerts.py index a9a4ab6deba7..11afebf85573 100644 --- a/tools/test/test_create_alerts.py +++ b/tools/test/test_create_alerts.py @@ -1,4 +1,6 @@ -from typing import Any, List +from __future__ import annotations + +from typing import Any from unittest import main, TestCase from tools.alerts.create_alerts import filter_job_names, JobStatus @@ -38,7 +40,7 @@ MOCK_TEST_DATA = [ class TestGitHubPR(TestCase): # Should fail when jobs are ? ? Fail Fail def test_alert(self) -> None: - modified_data: List[Any] = [{}] + modified_data: list[Any] = [{}] modified_data.append({}) modified_data.extend(MOCK_TEST_DATA) status = JobStatus(JOB_NAME, modified_data) diff --git a/tools/test/test_executorch_custom_ops.py b/tools/test/test_executorch_custom_ops.py index 771ecb7a7a3f..767fe0580b17 100644 --- a/tools/test/test_executorch_custom_ops.py +++ b/tools/test/test_executorch_custom_ops.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import tempfile import unittest -from typing import Any, Dict +from typing import Any from unittest.mock import ANY, Mock, patch import expecttest @@ -13,10 +15,11 @@ from torchgen.model import Location, NativeFunction from torchgen.selective_build.selector import SelectiveBuilder from torchgen.utils import FileManager + SPACES = " " -def _get_native_function_from_yaml(yaml_obj: Dict[str, object]) -> NativeFunction: +def _get_native_function_from_yaml(yaml_obj: dict[str, object]) -> NativeFunction: native_function, _ = NativeFunction.from_yaml( yaml_obj, loc=Location(__file__, 1), @@ -33,7 +36,7 @@ class TestComputeNativeFunctionStub(expecttest.TestCase): """ def _test_function_schema_generates_correct_kernel( - self, obj: Dict[str, Any], expected: str + self, obj: dict[str, Any], expected: str ) -> None: func = _get_native_function_from_yaml(obj) diff --git a/tools/test/test_executorch_gen.py b/tools/test/test_executorch_gen.py index 3123c274a522..3f74bde15952 100644 --- a/tools/test/test_executorch_gen.py +++ b/tools/test/test_executorch_gen.py @@ -1,13 +1,13 @@ +from __future__ import annotations + import os import tempfile import unittest -from typing import Dict import yaml from torchgen.executorch.model import ETKernelIndex, ETKernelKey from torchgen.gen import LineLoader - from torchgen.gen_executorch import ( ComputeCodegenUnboxedKernels, gen_functions_declarations, @@ -24,6 +24,7 @@ from torchgen.model import ( ) from torchgen.selective_build.selector import SelectiveBuilder + TEST_YAML = """ - func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -345,7 +346,7 @@ class TestGenFunctionsDeclarations(unittest.TestCase): valid_tags=set(), ) - backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = { + backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = { DispatchKey.CPU: {}, DispatchKey.QuantizedCPU: {}, } diff --git a/tools/test/test_executorch_signatures.py b/tools/test/test_executorch_signatures.py index 543926d4c31e..79f291aba3d2 100644 --- a/tools/test/test_executorch_signatures.py +++ b/tools/test/test_executorch_signatures.py @@ -4,6 +4,7 @@ from torchgen.executorch.api.types import ExecutorchCppSignature from torchgen.local import parametrize from torchgen.model import Location, NativeFunction + DEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml( {"func": "foo.out(Tensor input, *, Tensor(a!) out) -> Tensor(a!)"}, loc=Location(__file__, 1), diff --git a/tools/test/test_gen_backend_stubs.py b/tools/test/test_gen_backend_stubs.py index 34bd2fe230fb..e103c573b3ca 100644 --- a/tools/test/test_gen_backend_stubs.py +++ b/tools/test/test_gen_backend_stubs.py @@ -1,9 +1,10 @@ # Owner(s): ["module: codegen"] +from __future__ import annotations + import os import tempfile import unittest -from typing import Optional import expecttest @@ -29,7 +30,7 @@ class TestGenBackendStubs(expecttest.TestCase): run(fp.name, "", True) def get_errors_from_gen_backend_stubs( - self, yaml_str: str, *, kernels_str: Optional[str] = None + self, yaml_str: str, *, kernels_str: str | None = None ) -> str: with tempfile.NamedTemporaryFile(mode="w") as fp: fp.write(yaml_str) diff --git a/tools/test/test_selective_build.py b/tools/test/test_selective_build.py index d4fbea6c7690..59e6e617072e 100644 --- a/tools/test/test_selective_build.py +++ b/tools/test/test_selective_build.py @@ -1,7 +1,7 @@ import unittest -from torchgen.selective_build.operator import * # noqa: F403 from torchgen.model import Location, NativeFunction +from torchgen.selective_build.operator import * # noqa: F403 from torchgen.selective_build.selector import ( combine_selective_builders, SelectiveBuilder, @@ -9,7 +9,7 @@ from torchgen.selective_build.selector import ( class TestSelectiveBuild(unittest.TestCase): - def test_selective_build_operator(self): + def test_selective_build_operator(self) -> None: op = SelectiveBuildOperator( "aten::add.int", is_root_operator=True, @@ -21,7 +21,7 @@ class TestSelectiveBuild(unittest.TestCase): self.assertFalse(op.is_used_for_training) self.assertFalse(op.include_all_overloads) - def test_selector_factory(self): + def test_selector_factory(self) -> None: yaml_config_v1 = """ debug_info: - model1@v100 @@ -132,7 +132,7 @@ operators: selector_legacy_v1.is_operator_selected_for_training("aten::add.float") ) - def test_operator_combine(self): + def test_operator_combine(self) -> None: op1 = SelectiveBuildOperator( "aten::add.int", is_root_operator=True, @@ -177,7 +177,7 @@ operators: self.assertRaises(Exception, gen_new_op) - def test_training_op_fetch(self): + def test_training_op_fetch(self) -> None: yaml_config = """ operators: aten::add.int: @@ -194,7 +194,7 @@ operators: self.assertTrue(selector.is_operator_selected_for_training("aten::add.int")) self.assertTrue(selector.is_operator_selected_for_training("aten::add")) - def test_kernel_dtypes(self): + def test_kernel_dtypes(self) -> None: yaml_config = """ kernel_metadata: add_kernel: @@ -221,7 +221,7 @@ kernel_metadata: self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16")) self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32")) - def test_merge_kernel_dtypes(self): + def test_merge_kernel_dtypes(self) -> None: yaml_config1 = """ kernel_metadata: add_kernel: @@ -266,7 +266,7 @@ kernel_metadata: self.assertTrue(selector.is_kernel_dtype_selected("mul_kernel", "int8")) self.assertFalse(selector.is_kernel_dtype_selected("mul_kernel", "int32")) - def test_all_kernel_dtypes_selected(self): + def test_all_kernel_dtypes_selected(self) -> None: yaml_config = """ include_all_non_op_selectives: True """ @@ -279,7 +279,7 @@ include_all_non_op_selectives: True self.assertTrue(selector.is_kernel_dtype_selected("add1_kernel", "int32")) self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "float")) - def test_custom_namespace_selected_correctly(self): + def test_custom_namespace_selected_correctly(self) -> None: yaml_config = """ operators: aten::add.int: @@ -301,7 +301,7 @@ operators: class TestExecuTorchSelectiveBuild(unittest.TestCase): - def test_et_kernel_selected(self): + def test_et_kernel_selected(self) -> None: yaml_config = """ et_kernel_metadata: aten::add.out: diff --git a/tools/test/test_test_selections.py b/tools/test/test_test_selections.py index cc9bf5f4435d..2fd59545414f 100644 --- a/tools/test/test_test_selections.py +++ b/tools/test/test_test_selections.py @@ -1,10 +1,11 @@ +from __future__ import annotations + import functools import pathlib import random import sys import unittest from collections import defaultdict -from typing import Dict, List, Tuple REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent try: @@ -17,12 +18,12 @@ except ModuleNotFoundError: sys.exit(1) -def gen_class_times(test_times: Dict[str, float]) -> Dict[str, Dict[str, float]]: +def gen_class_times(test_times: dict[str, float]) -> dict[str, dict[str, float]]: return {k: {"class1": v} for k, v in test_times.items()} class TestCalculateShards(unittest.TestCase): - tests: List[TestRun] = [ + tests: list[TestRun] = [ TestRun("super_long_test"), TestRun("long_test1"), TestRun("long_test2"), @@ -36,7 +37,7 @@ class TestCalculateShards(unittest.TestCase): TestRun("short_test5"), ] - test_times: Dict[str, float] = { + test_times: dict[str, float] = { "super_long_test": 55, "long_test1": 22, "long_test2": 18, @@ -50,7 +51,7 @@ class TestCalculateShards(unittest.TestCase): "short_test5": 0.01, } - test_class_times: Dict[str, Dict[str, float]] = { + test_class_times: dict[str, dict[str, float]] = { "super_long_test": {"class1": 55}, "long_test1": {"class1": 1, "class2": 21}, "long_test2": {"class1": 10, "class2": 8}, @@ -66,8 +67,8 @@ class TestCalculateShards(unittest.TestCase): def assert_shards_equal( self, - expected_shards: List[Tuple[float, List[ShardedTest]]], - actual_shards: List[Tuple[float, List[ShardedTest]]], + expected_shards: list[tuple[float, list[ShardedTest]]], + actual_shards: list[tuple[float, list[ShardedTest]]], ) -> None: for expected, actual in zip(expected_shards, actual_shards): self.assertAlmostEqual(expected[0], actual[0]) @@ -363,7 +364,7 @@ class TestCalculateShards(unittest.TestCase): ) def test_split_shards(self) -> None: - test_times: Dict[str, float] = {"test1": THRESHOLD, "test2": THRESHOLD} + test_times: dict[str, float] = {"test1": THRESHOLD, "test2": THRESHOLD} expected_shards = [ (600.0, [ShardedTest(test="test1", shard=1, num_shards=1, time=THRESHOLD)]), (600.0, [ShardedTest(test="test2", shard=1, num_shards=1, time=THRESHOLD)]), @@ -438,7 +439,7 @@ class TestCalculateShards(unittest.TestCase): tests = [TestRun(x) for x in test_names] serial = [x for x in test_names if random.randint(0, 1) == 0] has_times = [x for x in test_names if random.randint(0, 1) == 0] - random_times: Dict[str, float] = { + random_times: dict[str, float] = { i: random.randint(0, THRESHOLD * 10) for i in has_times } sort_by_time = random.randint(0, 1) == 0 @@ -456,7 +457,7 @@ class TestCalculateShards(unittest.TestCase): max_diff = max(times) - min(times) self.assertTrue(max_diff <= THRESHOLD + (num_tests - len(has_times)) * 60) - all_sharded_tests: Dict[str, List[ShardedTest]] = defaultdict(list) + all_sharded_tests: dict[str, list[ShardedTest]] = defaultdict(list) for _, sharded_tests in shards: for sharded_test in sharded_tests: all_sharded_tests[sharded_test.name].append(sharded_test) diff --git a/tools/test/test_upload_stats_lib.py b/tools/test/test_upload_stats_lib.py index 0baf323966e1..6642aeaa4211 100644 --- a/tools/test/test_upload_stats_lib.py +++ b/tools/test/test_upload_stats_lib.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import decimal import inspect import pathlib import sys import unittest -from typing import Any, Dict +from typing import Any from unittest import mock REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent @@ -81,9 +83,9 @@ class TestUploadStats(unittest.TestCase): } # Preserve the metric emitted - emitted_metric: Dict[str, Any] = {} + emitted_metric: dict[str, Any] = {} - def mock_put_item(Item: Dict[str, Any]) -> None: + def mock_put_item(Item: dict[str, Any]) -> None: nonlocal emitted_metric emitted_metric = Item @@ -115,9 +117,9 @@ class TestUploadStats(unittest.TestCase): } # Preserve the metric emitted - emitted_metric: Dict[str, Any] = {} + emitted_metric: dict[str, Any] = {} - def mock_put_item(Item: Dict[str, Any]) -> None: + def mock_put_item(Item: dict[str, Any]) -> None: nonlocal emitted_metric emitted_metric = Item @@ -151,9 +153,9 @@ class TestUploadStats(unittest.TestCase): } # Preserve the metric emitted - emitted_metric: Dict[str, Any] = {} + emitted_metric: dict[str, Any] = {} - def mock_put_item(Item: Dict[str, Any]) -> None: + def mock_put_item(Item: dict[str, Any]) -> None: nonlocal emitted_metric emitted_metric = Item @@ -187,9 +189,9 @@ class TestUploadStats(unittest.TestCase): ).start() # Preserve the metric emitted - emitted_metric: Dict[str, Any] = {} + emitted_metric: dict[str, Any] = {} - def mock_put_item(Item: Dict[str, Any]) -> None: + def mock_put_item(Item: dict[str, Any]) -> None: nonlocal emitted_metric emitted_metric = Item @@ -208,7 +210,7 @@ class TestUploadStats(unittest.TestCase): ) -> None: metric = {"some_number": 123} - emit_should_include: Dict[str, Any] = metric.copy() + emit_should_include: dict[str, Any] = metric.copy() # Github Actions defaults some env vars to an empty string default_val = "" @@ -220,9 +222,9 @@ class TestUploadStats(unittest.TestCase): ).start() # Preserve the metric emitted - emitted_metric: Dict[str, Any] = {} + emitted_metric: dict[str, Any] = {} - def mock_put_item(Item: Dict[str, Any]) -> None: + def mock_put_item(Item: dict[str, Any]) -> None: nonlocal emitted_metric emitted_metric = Item @@ -264,7 +266,7 @@ class TestUploadStats(unittest.TestCase): put_item_invoked = False - def mock_put_item(Item: Dict[str, Any]) -> None: + def mock_put_item(Item: dict[str, Any]) -> None: nonlocal put_item_invoked put_item_invoked = True @@ -289,7 +291,7 @@ class TestUploadStats(unittest.TestCase): put_item_invoked = False - def mock_put_item(Item: Dict[str, Any]) -> None: + def mock_put_item(Item: dict[str, Any]) -> None: nonlocal put_item_invoked put_item_invoked = True diff --git a/tools/test/test_upload_test_stats.py b/tools/test/test_upload_test_stats.py index be27d341576c..4bf49ff936a8 100644 --- a/tools/test/test_upload_test_stats.py +++ b/tools/test/test_upload_test_stats.py @@ -3,6 +3,7 @@ import unittest from tools.stats.upload_test_stats import get_tests, summarize_test_cases + IN_CI = os.environ.get("CI") diff --git a/tools/test/test_vulkan_codegen.py b/tools/test/test_vulkan_codegen.py index 0726a3efb9d5..465612c95549 100644 --- a/tools/test/test_vulkan_codegen.py +++ b/tools/test/test_vulkan_codegen.py @@ -3,6 +3,7 @@ import unittest from tools.gen_vulkan_spv import DEFAULT_ENV, SPVGenerator + #################### # Data for testing # #################### diff --git a/tools/testing/discover_tests.py b/tools/testing/discover_tests.py index 5c37fdd4fe9d..6c505ad65e54 100644 --- a/tools/testing/discover_tests.py +++ b/tools/testing/discover_tests.py @@ -1,8 +1,9 @@ +from __future__ import annotations + import glob import os import sys from pathlib import Path -from typing import List, Optional, Union CPP_TEST_PREFIX = "cpp" CPP_TEST_PATH = "build/bin" @@ -16,11 +17,11 @@ def parse_test_module(test: str) -> str: def discover_tests( base_dir: Path = REPO_ROOT / "test", - cpp_tests_dir: Optional[Union[str, Path]] = None, - blocklisted_patterns: Optional[List[str]] = None, - blocklisted_tests: Optional[List[str]] = None, - extra_tests: Optional[List[str]] = None, -) -> List[str]: + cpp_tests_dir: str | Path | None = None, + blocklisted_patterns: list[str] | None = None, + blocklisted_tests: list[str] | None = None, + extra_tests: list[str] | None = None, +) -> list[str]: """ Searches for all python files starting with test_ excluding one specified by patterns. If cpp_tests_dir is provided, also scan for all C++ tests under that directory. They diff --git a/tools/testing/explicit_ci_jobs.py b/tools/testing/explicit_ci_jobs.py index 594e00d437f9..b81e2cb6215a 100755 --- a/tools/testing/explicit_ci_jobs.py +++ b/tools/testing/explicit_ci_jobs.py @@ -1,12 +1,13 @@ #!/usr/bin/env python3 +from __future__ import annotations + import argparse import fnmatch import pathlib import subprocess import textwrap - -from typing import Any, Dict, List +from typing import Any import yaml @@ -29,11 +30,11 @@ WORKFLOWS_TO_CHECK = [ def add_job( - workflows: Dict[str, Any], + workflows: dict[str, Any], workflow_name: str, type: str, - job: Dict[str, Any], - past_jobs: Dict[str, Any], + job: dict[str, Any], + past_jobs: dict[str, Any], ) -> None: """ Add job 'job' under 'type' and 'workflow_name' to 'workflow' in place. Also @@ -58,14 +59,14 @@ def add_job( def get_filtered_circleci_config( - workflows: Dict[str, Any], relevant_jobs: List[str] -) -> Dict[str, Any]: + workflows: dict[str, Any], relevant_jobs: list[str] +) -> dict[str, Any]: """ Given an existing CircleCI config, remove every job that's not listed in 'relevant_jobs' """ - new_workflows: Dict[str, Any] = {} - past_jobs: Dict[str, Any] = {} + new_workflows: dict[str, Any] = {} + past_jobs: dict[str, Any] = {} for workflow_name, workflow in workflows.items(): if workflow_name not in WORKFLOWS_TO_CHECK: # Don't care about this workflow, skip it entirely @@ -92,7 +93,7 @@ def get_filtered_circleci_config( return new_workflows -def commit_ci(files: List[str], message: str) -> None: +def commit_ci(files: list[str], message: str) -> None: # Check that there are no other modified files than the ones edited by this # tool stdout = subprocess.run( diff --git a/tools/testing/modulefinder_determinator.py b/tools/testing/modulefinder_determinator.py index ba58d75c57fe..eba68d78b16b 100644 --- a/tools/testing/modulefinder_determinator.py +++ b/tools/testing/modulefinder_determinator.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import modulefinder import os import pathlib import sys import warnings -from typing import Any, Dict, List, Set +from typing import Any REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent @@ -51,11 +53,11 @@ TARGET_DET_LIST = [ ] -_DEP_MODULES_CACHE: Dict[str, Set[str]] = {} +_DEP_MODULES_CACHE: dict[str, set[str]] = {} def should_run_test( - target_det_list: List[str], test: str, touched_files: List[str], options: Any + target_det_list: list[str], test: str, touched_files: list[str], options: Any ) -> bool: test = parse_test_module(test) # Some tests are faster to execute than to determine. @@ -139,7 +141,7 @@ def log_test_reason(file_type: str, filename: str, test: str, options: Any) -> N ) -def get_dep_modules(test: str) -> Set[str]: +def get_dep_modules(test: str) -> set[str]: # Cache results in case of repetition if test in _DEP_MODULES_CACHE: return _DEP_MODULES_CACHE[test] diff --git a/tools/testing/target_determination/determinator.py b/tools/testing/target_determination/determinator.py index 17320a73a194..ff65251945ed 100644 --- a/tools/testing/target_determination/determinator.py +++ b/tools/testing/target_determination/determinator.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import sys -from typing import Any, List +from typing import Any from tools.testing.target_determination.heuristics import ( AggregatedHeuristics as AggregatedHeuristics, @@ -9,7 +11,7 @@ from tools.testing.target_determination.heuristics import ( def get_test_prioritizations( - tests: List[str], file: Any = sys.stdout + tests: list[str], file: Any = sys.stdout ) -> AggregatedHeuristics: aggregated_results = AggregatedHeuristics(tests) print(f"Received {len(tests)} tests to prioritize", file=file) diff --git a/tools/testing/target_determination/gen_artifact.py b/tools/testing/target_determination/gen_artifact.py index c5165cbb8108..f69924c451ff 100644 --- a/tools/testing/target_determination/gen_artifact.py +++ b/tools/testing/target_determination/gen_artifact.py @@ -1,12 +1,14 @@ +from __future__ import annotations + import json import os import pathlib -from typing import Any, List +from typing import Any REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent -def gen_ci_artifact(included: List[Any], excluded: List[Any]) -> None: +def gen_ci_artifact(included: list[Any], excluded: list[Any]) -> None: file_name = f"td_exclusions-{os.urandom(10).hex()}.json" with open(REPO_ROOT / "test" / "test-reports" / file_name, "w") as f: json.dump({"included": included, "excluded": excluded}, f) diff --git a/tools/testing/target_determination/heuristics/__init__.py b/tools/testing/target_determination/heuristics/__init__.py index 62b92a15edfe..711f40c7dab1 100644 --- a/tools/testing/target_determination/heuristics/__init__.py +++ b/tools/testing/target_determination/heuristics/__init__.py @@ -1,4 +1,6 @@ -from typing import List, Tuple +from __future__ import annotations + +from typing import TYPE_CHECKING from tools.testing.target_determination.heuristics.correlated_with_historical_failures import ( CorrelatedWithHistoricalFailures, @@ -11,23 +13,27 @@ from tools.testing.target_determination.heuristics.historical_class_failure_corr from tools.testing.target_determination.heuristics.historical_edited_files import ( HistorialEditedFiles, ) - from tools.testing.target_determination.heuristics.interface import ( AggregatedHeuristics as AggregatedHeuristics, - HeuristicInterface as HeuristicInterface, TestPrioritizations as TestPrioritizations, ) from tools.testing.target_determination.heuristics.llm import LLM from tools.testing.target_determination.heuristics.mentioned_in_pr import MentionedInPR - from tools.testing.target_determination.heuristics.previously_failed_in_pr import ( PreviouslyFailedInPR, ) from tools.testing.target_determination.heuristics.profiling import Profiling + +if TYPE_CHECKING: + from tools.testing.target_determination.heuristics.interface import ( + HeuristicInterface as HeuristicInterface, + ) + + # All currently running heuristics. # To add a heurstic in trial mode, specify the keywork argument `trial_mode=True`. -HEURISTICS: List[HeuristicInterface] = [ +HEURISTICS: list[HeuristicInterface] = [ PreviouslyFailedInPR(), EditedByPR(), MentionedInPR(), diff --git a/tools/testing/target_determination/heuristics/correlated_with_historical_failures.py b/tools/testing/target_determination/heuristics/correlated_with_historical_failures.py index 590ca468ae95..1fe93a3ef601 100644 --- a/tools/testing/target_determination/heuristics/correlated_with_historical_failures.py +++ b/tools/testing/target_determination/heuristics/correlated_with_historical_failures.py @@ -1,15 +1,15 @@ -from typing import Any, Dict, List +from __future__ import annotations + +from typing import Any from tools.stats.import_test_stats import ( ADDITIONAL_CI_FILES_FOLDER, TEST_FILE_RATINGS_FILE, ) - from tools.testing.target_determination.heuristics.interface import ( HeuristicInterface, TestPrioritizations, ) - from tools.testing.target_determination.heuristics.utils import ( get_ratings_for_tests, normalize_ratings, @@ -18,10 +18,10 @@ from tools.testing.test_run import TestRun class CorrelatedWithHistoricalFailures(HeuristicInterface): - def __init__(self, **kwargs: Dict[str, Any]): + def __init__(self, **kwargs: dict[str, Any]) -> None: super().__init__(**kwargs) - def get_prediction_confidence(self, tests: List[str]) -> TestPrioritizations: + def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations: test_ratings = get_ratings_for_tests( ADDITIONAL_CI_FILES_FOLDER / TEST_FILE_RATINGS_FILE ) diff --git a/tools/testing/target_determination/heuristics/edited_by_pr.py b/tools/testing/target_determination/heuristics/edited_by_pr.py index d0a473db78ae..b21235365215 100644 --- a/tools/testing/target_determination/heuristics/edited_by_pr.py +++ b/tools/testing/target_determination/heuristics/edited_by_pr.py @@ -1,4 +1,6 @@ -from typing import Any, Dict, List, Set +from __future__ import annotations + +from typing import Any from warnings import warn from tools.testing.target_determination.heuristics.interface import ( @@ -13,17 +15,17 @@ from tools.testing.test_run import TestRun class EditedByPR(HeuristicInterface): - def __init__(self, **kwargs: Dict[str, Any]): + def __init__(self, **kwargs: dict[str, Any]) -> None: super().__init__(**kwargs) - def get_prediction_confidence(self, tests: List[str]) -> TestPrioritizations: + def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations: critical_tests = _get_modified_tests() return TestPrioritizations( tests, {TestRun(test): 1 for test in critical_tests if test in tests} ) -def _get_modified_tests() -> Set[str]: +def _get_modified_tests() -> set[str]: try: changed_files = query_changed_files() except Exception as e: diff --git a/tools/testing/target_determination/heuristics/filepath.py b/tools/testing/target_determination/heuristics/filepath.py index 066d13f49b1e..31f6cd71b25a 100644 --- a/tools/testing/target_determination/heuristics/filepath.py +++ b/tools/testing/target_determination/heuristics/filepath.py @@ -1,7 +1,9 @@ +from __future__ import annotations + from collections import defaultdict from functools import lru_cache from pathlib import Path -from typing import Any, Callable, Dict, List +from typing import Any, Callable from warnings import warn from tools.testing.target_determination.heuristics.interface import ( @@ -17,7 +19,7 @@ from tools.testing.test_run import TestRun REPO_ROOT = Path(__file__).parent.parent.parent.parent -keyword_synonyms: Dict[str, List[str]] = { +keyword_synonyms: dict[str, list[str]] = { "amp": ["mixed_precision"], "quant": ["quantized", "quantization", "quantize"], "decomp": ["decomposition", "decompositions"], @@ -39,14 +41,14 @@ not_keyword = [ "internal", ] -custom_matchers: Dict[str, Callable[[str], bool]] = { +custom_matchers: dict[str, Callable[[str], bool]] = { "nn": lambda x: "nn" in x.replace("onnx", "_"), "c10": lambda x: "c10" in x.replace("c10d", "_"), } @lru_cache(maxsize=1) -def get_keywords(file: str) -> List[str]: +def get_keywords(file: str) -> list[str]: keywords = [] for folder in Path(file).parts[:-1]: folder = sanitize_folder_name(folder) @@ -79,11 +81,11 @@ def file_matches_keyword(file: str, keyword: str) -> bool: class Filepath(HeuristicInterface): # Heuristic based on folders in the file path. Takes each folder of each # changed file and attempts to find matches based on those folders - def __init__(self, **kwargs: Dict[str, Any]) -> None: + def __init__(self, **kwargs: dict[str, Any]) -> None: super().__init__(**kwargs) - def get_prediction_confidence(self, tests: List[str]) -> TestPrioritizations: - keyword_frequency: Dict[str, int] = defaultdict(int) + def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations: + keyword_frequency: dict[str, int] = defaultdict(int) try: changed_files = query_changed_files() except Exception as e: @@ -95,7 +97,7 @@ class Filepath(HeuristicInterface): for keyword in keywords: keyword_frequency[keyword] += 1 - test_ratings: Dict[str, float] = defaultdict(float) + test_ratings: dict[str, float] = defaultdict(float) for test in tests: for keyword, frequency in keyword_frequency.items(): diff --git a/tools/testing/target_determination/heuristics/historical_class_failure_correlation.py b/tools/testing/target_determination/heuristics/historical_class_failure_correlation.py index 0a3097607af3..b489a65326dd 100644 --- a/tools/testing/target_determination/heuristics/historical_class_failure_correlation.py +++ b/tools/testing/target_determination/heuristics/historical_class_failure_correlation.py @@ -1,19 +1,19 @@ +from __future__ import annotations + import json import os from collections import defaultdict -from typing import Any, cast, Dict, List, Set +from typing import Any, cast, Dict from warnings import warn from tools.stats.import_test_stats import ( ADDITIONAL_CI_FILES_FOLDER, TEST_CLASS_RATINGS_FILE, ) - from tools.testing.target_determination.heuristics.interface import ( HeuristicInterface, TestPrioritizations, ) - from tools.testing.target_determination.heuristics.utils import ( normalize_ratings, query_changed_files, @@ -28,10 +28,10 @@ class HistoricalClassFailurCorrelation(HeuristicInterface): when the files edited by current PR were modified. """ - def __init__(self, **kwargs: Any): + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - def get_prediction_confidence(self, tests: List[str]) -> TestPrioritizations: + def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations: ratings = _get_ratings_for_tests(set(tests)) test_ratings = { TestRun(k): v for (k, v) in ratings.items() if TestRun(k).test_file in tests @@ -39,7 +39,7 @@ class HistoricalClassFailurCorrelation(HeuristicInterface): return TestPrioritizations(tests, normalize_ratings(test_ratings, 0.25)) -def _get_historical_test_class_correlations() -> Dict[str, Dict[str, float]]: +def _get_historical_test_class_correlations() -> dict[str, dict[str, float]]: path = REPO_ROOT / ADDITIONAL_CI_FILES_FOLDER / TEST_CLASS_RATINGS_FILE if not os.path.exists(path): print(f"could not find path {path}") @@ -50,8 +50,8 @@ def _get_historical_test_class_correlations() -> Dict[str, Dict[str, float]]: def _get_ratings_for_tests( - tests_to_run: Set[str], -) -> Dict[str, float]: + tests_to_run: set[str], +) -> dict[str, float]: # Get the files edited try: changed_files = query_changed_files() @@ -65,7 +65,7 @@ def _get_ratings_for_tests( # Find the tests failures that are correlated with the edited files. # Filter the list to only include tests we want to run. - ratings: Dict[str, float] = defaultdict(float) + ratings: dict[str, float] = defaultdict(float) for file in changed_files: for qualified_test_class, score in test_class_correlations.get( file, {} @@ -79,8 +79,8 @@ def _get_ratings_for_tests( def _rank_correlated_tests( - tests_to_run: List[str], -) -> List[str]: + tests_to_run: list[str], +) -> list[str]: # Find the tests failures that are correlated with the edited files. # Filter the list to only include tests we want to run. tests_to_run = set(tests_to_run) diff --git a/tools/testing/target_determination/heuristics/historical_edited_files.py b/tools/testing/target_determination/heuristics/historical_edited_files.py index c3b7f07719d2..be855538ca25 100644 --- a/tools/testing/target_determination/heuristics/historical_edited_files.py +++ b/tools/testing/target_determination/heuristics/historical_edited_files.py @@ -1,15 +1,15 @@ -from typing import Any, List +from __future__ import annotations + +from typing import Any from tools.stats.import_test_stats import ( ADDITIONAL_CI_FILES_FOLDER, TD_HEURISTIC_HISTORICAL_EDITED_FILES, ) - from tools.testing.target_determination.heuristics.interface import ( HeuristicInterface, TestPrioritizations, ) - from tools.testing.target_determination.heuristics.utils import ( get_ratings_for_tests, normalize_ratings, @@ -23,10 +23,10 @@ from tools.testing.test_run import TestRun # future commits that change fileA should probably run testFileA. Based on this, # a correlation dict is built based on what files were edited in commits on main. class HistorialEditedFiles(HeuristicInterface): - def __init__(self, **kwargs: Any): + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - def get_prediction_confidence(self, tests: List[str]) -> TestPrioritizations: + def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations: test_ratings = get_ratings_for_tests( ADDITIONAL_CI_FILES_FOLDER / TD_HEURISTIC_HISTORICAL_EDITED_FILES ) diff --git a/tools/testing/target_determination/heuristics/interface.py b/tools/testing/target_determination/heuristics/interface.py index 77052e6bba0d..8df3f786e92c 100644 --- a/tools/testing/target_determination/heuristics/interface.py +++ b/tools/testing/target_determination/heuristics/interface.py @@ -1,6 +1,8 @@ +from __future__ import annotations + from abc import abstractmethod from copy import copy -from typing import Any, Dict, FrozenSet, Iterable, Iterator, List, Tuple +from typing import Any, Iterable, Iterator from tools.testing.test_run import TestRun @@ -18,13 +20,13 @@ class TestPrioritizations: otherwise it breaks the test sharding logic """ - _original_tests: FrozenSet[str] - _test_scores: Dict[TestRun, float] + _original_tests: frozenset[str] + _test_scores: dict[TestRun, float] def __init__( self, tests_being_ranked: Iterable[str], # The tests that are being prioritized. - scores: Dict[TestRun, float], + scores: dict[TestRun, float], ) -> None: self._original_tests = frozenset(tests_being_ranked) self._test_scores = {TestRun(test): 0.0 for test in self._original_tests} @@ -59,7 +61,7 @@ class TestPrioritizations: files.keys() ), "The set of tests in the TestPrioritizations must be identical to the set of tests passed in" - def _traverse_scores(self) -> Iterator[Tuple[float, TestRun]]: + def _traverse_scores(self) -> Iterator[tuple[float, TestRun]]: # Sort by score, then alphabetically by test name for test, score in sorted( self._test_scores.items(), key=lambda x: (-x[1], str(x[0])) @@ -70,7 +72,7 @@ class TestPrioritizations: if test_run.test_file not in self._original_tests: return # We don't need this test - relevant_test_runs: List[TestRun] = [ + relevant_test_runs: list[TestRun] = [ tr for tr in self._test_scores.keys() if tr & test_run and tr != test_run ] @@ -90,7 +92,7 @@ class TestPrioritizations: if test_run.test_file not in self._original_tests: return - relevant_test_runs: List[TestRun] = [ + relevant_test_runs: list[TestRun] = [ tr for tr in self._test_scores.keys() if tr & test_run ] @@ -108,11 +110,11 @@ class TestPrioritizations: self.validate() - def get_all_tests(self) -> List[TestRun]: + def get_all_tests(self) -> list[TestRun]: """Returns all tests in the TestPrioritizations""" return [x[1] for x in self._traverse_scores()] - def get_top_per_tests(self, n: int) -> Tuple[List[TestRun], List[TestRun]]: + def get_top_per_tests(self, n: int) -> tuple[list[TestRun], list[TestRun]]: """Divides list of tests into two based on the top n% of scores. The first list is the top, and the second is the rest.""" tests = [x[1] for x in self._traverse_scores()] @@ -132,7 +134,7 @@ class TestPrioritizations: def print_info(self) -> None: print(self.get_info_str()) - def get_priority_info_for_test(self, test_run: TestRun) -> Dict[str, Any]: + def get_priority_info_for_test(self, test_run: TestRun) -> dict[str, Any]: """Given a failing test, returns information about it's prioritization that we want to emit in our metrics.""" for idx, (score, test) in enumerate(self._traverse_scores()): # Different heuristics may result in a given test file being split @@ -142,7 +144,7 @@ class TestPrioritizations: return {"position": idx, "score": score} raise AssertionError(f"Test run {test_run} not found") - def get_test_stats(self, test: TestRun) -> Dict[str, Any]: + def get_test_stats(self, test: TestRun) -> dict[str, Any]: return { "test_name": test.test_file, "test_filters": test.get_pytest_filter(), @@ -154,7 +156,7 @@ class TestPrioritizations: }, } - def to_json(self) -> Dict[str, Any]: + def to_json(self) -> dict[str, Any]: """ Returns a JSON dict that describes this TestPrioritizations object. """ @@ -169,7 +171,7 @@ class TestPrioritizations: return json_dict @staticmethod - def from_json(json_dict: Dict[str, Any]) -> "TestPrioritizations": + def from_json(json_dict: dict[str, Any]) -> TestPrioritizations: """ Returns a TestPrioritizations object from a JSON dict. """ @@ -182,7 +184,7 @@ class TestPrioritizations: ) return test_prioritizations - def amend_tests(self, tests: List[str]) -> None: + def amend_tests(self, tests: list[str]) -> None: """ Removes tests that are not in the given list from the TestPrioritizations. Adds tests that are in the list but not in the @@ -210,13 +212,13 @@ class AggregatedHeuristics: It saves the individual results from each heuristic and exposes an aggregated view. """ - _heuristic_results: Dict[ - "HeuristicInterface", TestPrioritizations + _heuristic_results: dict[ + HeuristicInterface, TestPrioritizations ] # Key is the Heuristic's name. Dicts will preserve the order of insertion, which is important for sharding - _all_tests: FrozenSet[str] + _all_tests: frozenset[str] - def __init__(self, all_tests: List[str]) -> None: + def __init__(self, all_tests: list[str]) -> None: self._all_tests = frozenset(all_tests) self._heuristic_results = {} self.validate() @@ -229,7 +231,7 @@ class AggregatedHeuristics: ), f"Tests in {heuristic.name} are not the same as the tests in the AggregatedHeuristics" def add_heuristic_results( - self, heuristic: "HeuristicInterface", heuristic_results: TestPrioritizations + self, heuristic: HeuristicInterface, heuristic_results: TestPrioritizations ) -> None: if heuristic in self._heuristic_results: raise ValueError(f"We already have heuristics for {heuristic.name}") @@ -257,11 +259,11 @@ class AggregatedHeuristics: new_tp.validate() return new_tp - def get_test_stats(self, test: TestRun) -> Dict[str, Any]: + def get_test_stats(self, test: TestRun) -> dict[str, Any]: """ Returns the aggregated statistics for a given test. """ - stats: Dict[str, Any] = { + stats: dict[str, Any] = { "test_name": test.test_file, "test_filters": test.get_pytest_filter(), } @@ -287,11 +289,11 @@ class AggregatedHeuristics: return stats - def to_json(self) -> Dict[str, Any]: + def to_json(self) -> dict[str, Any]: """ Returns a JSON dict that describes this AggregatedHeuristics object. """ - json_dict: Dict[str, Any] = {} + json_dict: dict[str, Any] = {} for heuristic, heuristic_results in self._heuristic_results.items(): json_dict[heuristic.name] = heuristic_results.to_json() @@ -321,7 +323,7 @@ class HeuristicInterface: return self.name @abstractmethod - def get_prediction_confidence(self, tests: List[str]) -> TestPrioritizations: + def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations: """ Returns a float ranking ranging from -1 to 1, where negative means skip, positive means run, 0 means no idea, and magnitude = how confident the diff --git a/tools/testing/target_determination/heuristics/llm.py b/tools/testing/target_determination/heuristics/llm.py index d3021d93b1f0..b046f96dafbb 100644 --- a/tools/testing/target_determination/heuristics/llm.py +++ b/tools/testing/target_determination/heuristics/llm.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import json import os import re from collections import defaultdict from pathlib import Path -from typing import Any, Dict, List +from typing import Any from tools.stats.import_test_stats import ADDITIONAL_CI_FILES_FOLDER from tools.testing.target_determination.heuristics.interface import ( @@ -18,10 +20,10 @@ REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent.parent class LLM(HeuristicInterface): - def __init__(self, **kwargs: Dict[str, Any]): + def __init__(self, **kwargs: dict[str, Any]) -> None: super().__init__(**kwargs) - def get_prediction_confidence(self, tests: List[str]) -> TestPrioritizations: + def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations: critical_tests = self.get_mappings() filter_valid_tests = { TestRun(test): score @@ -31,7 +33,7 @@ class LLM(HeuristicInterface): normalized_scores = normalize_ratings(filter_valid_tests, 0.25) return TestPrioritizations(tests, normalized_scores) - def get_mappings(self) -> Dict[str, float]: + def get_mappings(self) -> dict[str, float]: path = ( REPO_ROOT / ADDITIONAL_CI_FILES_FOLDER diff --git a/tools/testing/target_determination/heuristics/mentioned_in_pr.py b/tools/testing/target_determination/heuristics/mentioned_in_pr.py index 074a375e2dde..66da4e42d80f 100644 --- a/tools/testing/target_determination/heuristics/mentioned_in_pr.py +++ b/tools/testing/target_determination/heuristics/mentioned_in_pr.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import re -from typing import Any, List +from typing import Any from tools.testing.target_determination.heuristics.interface import ( HeuristicInterface, @@ -19,13 +21,13 @@ from tools.testing.test_run import TestRun # body, test_foo will be rated 1. If I mention #123 in the PR body, and #123 # mentions "test_foo", test_foo will be rated 1. class MentionedInPR(HeuristicInterface): - def __init__(self, **kwargs: Any): + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - def _search_for_linked_issues(self, s: str) -> List[str]: + def _search_for_linked_issues(self, s: str) -> list[str]: return re.findall(r"#(\d+)", s) + re.findall(r"/pytorch/pytorch/.*/(\d+)", s) - def get_prediction_confidence(self, tests: List[str]) -> TestPrioritizations: + def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations: try: commit_messages = get_git_commit_info() except Exception as e: @@ -42,7 +44,7 @@ class MentionedInPR(HeuristicInterface): pr_body = "" # Search for linked issues or PRs - linked_issue_bodies: List[str] = [] + linked_issue_bodies: list[str] = [] for issue in self._search_for_linked_issues( commit_messages ) + self._search_for_linked_issues(pr_body): diff --git a/tools/testing/target_determination/heuristics/previously_failed_in_pr.py b/tools/testing/target_determination/heuristics/previously_failed_in_pr.py index 26439f2cc4e7..ed8486227e1f 100644 --- a/tools/testing/target_determination/heuristics/previously_failed_in_pr.py +++ b/tools/testing/target_determination/heuristics/previously_failed_in_pr.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import json import os from pathlib import Path -from typing import Any, Dict, List, Set +from typing import Any from tools.stats.import_test_stats import ( ADDITIONAL_CI_FILES_FOLDER, @@ -22,17 +24,17 @@ REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent.parent class PreviouslyFailedInPR(HeuristicInterface): - def __init__(self, **kwargs: Dict[str, Any]): + def __init__(self, **kwargs: dict[str, Any]) -> None: super().__init__(**kwargs) - def get_prediction_confidence(self, tests: List[str]) -> TestPrioritizations: + def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations: critical_tests = get_previous_failures() | read_additional_test_failures_file() return TestPrioritizations( tests, {TestRun(test): 1 for test in critical_tests if test in tests} ) -def get_previous_failures() -> Set[str]: +def get_previous_failures() -> set[str]: path = REPO_ROOT / ADDITIONAL_CI_FILES_FOLDER / TD_HEURISTIC_PREVIOUSLY_FAILED if not os.path.exists(path): print(f"could not find path {path}") @@ -43,7 +45,7 @@ def get_previous_failures() -> Set[str]: ) -def _parse_prev_failing_test_files(last_failed_tests: Dict[str, bool]) -> Set[str]: +def _parse_prev_failing_test_files(last_failed_tests: dict[str, bool]) -> set[str]: prioritized_tests = set() # The keys are formatted as "test_file.py::test_class::test_method[params]" @@ -57,7 +59,7 @@ def _parse_prev_failing_test_files(last_failed_tests: Dict[str, bool]) -> Set[st return prioritized_tests -def gen_additional_test_failures_file(tests: List[str]) -> None: +def gen_additional_test_failures_file(tests: list[str]) -> None: # Segfaults usually result in no xml and some tests don't run through pytest # (ex doctests). In these cases, there will be no entry in the pytest # cache, so we should generate a separate file for them and upload it to s3 @@ -69,7 +71,7 @@ def gen_additional_test_failures_file(tests: List[str]) -> None: json.dump(tests, f, indent=2) -def read_additional_test_failures_file() -> Set[str]: +def read_additional_test_failures_file() -> set[str]: path = ( REPO_ROOT / ADDITIONAL_CI_FILES_FOLDER diff --git a/tools/testing/target_determination/heuristics/profiling.py b/tools/testing/target_determination/heuristics/profiling.py index a1cf8e3a40d6..8f17c51ca11e 100644 --- a/tools/testing/target_determination/heuristics/profiling.py +++ b/tools/testing/target_determination/heuristics/profiling.py @@ -1,15 +1,15 @@ -from typing import Any, List +from __future__ import annotations + +from typing import Any from tools.stats.import_test_stats import ( ADDITIONAL_CI_FILES_FOLDER, TD_HEURISTIC_PROFILING_FILE, ) - from tools.testing.target_determination.heuristics.interface import ( HeuristicInterface, TestPrioritizations, ) - from tools.testing.target_determination.heuristics.utils import ( get_ratings_for_tests, normalize_ratings, @@ -21,10 +21,10 @@ from tools.testing.test_run import TestRun # test to see files were involved in each tests and used to build a correlation # dict (where all ratings are 1). class Profiling(HeuristicInterface): - def __init__(self, **kwargs: Any): + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - def get_prediction_confidence(self, tests: List[str]) -> TestPrioritizations: + def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations: test_ratings = get_ratings_for_tests( ADDITIONAL_CI_FILES_FOLDER / TD_HEURISTIC_PROFILING_FILE ) diff --git a/tools/testing/target_determination/heuristics/utils.py b/tools/testing/target_determination/heuristics/utils.py index 7d8297d56e56..17259756533d 100644 --- a/tools/testing/target_determination/heuristics/utils.py +++ b/tools/testing/target_determination/heuristics/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import os import re @@ -5,16 +7,18 @@ import subprocess from collections import defaultdict from functools import lru_cache from pathlib import Path -from typing import cast, Dict, List, Optional, Set, Union +from typing import cast, Dict, TYPE_CHECKING from urllib.request import Request, urlopen from warnings import warn -from tools.testing.test_run import TestRun + +if TYPE_CHECKING: + from tools.testing.test_run import TestRun REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent.parent -def python_test_file_to_test_name(tests: Set[str]) -> Set[str]: +def python_test_file_to_test_name(tests: set[str]) -> set[str]: prefix = f"test{os.path.sep}" valid_tests = {f for f in tests if f.startswith(prefix) and f.endswith(".py")} valid_tests = {f[len(prefix) : -len(".py")] for f in valid_tests} @@ -23,7 +27,7 @@ def python_test_file_to_test_name(tests: Set[str]) -> Set[str]: @lru_cache(maxsize=None) -def get_pr_number() -> Optional[int]: +def get_pr_number() -> int | None: pr_number = os.environ.get("PR_NUMBER", "") if pr_number == "": re_match = re.match(r"^refs/tags/.*/(\d+)$", os.environ.get("GITHUB_REF", "")) @@ -68,7 +72,7 @@ def get_merge_base() -> str: return merge_base -def query_changed_files() -> List[str]: +def query_changed_files() -> list[str]: base_commit = get_merge_base() proc = subprocess.run( @@ -117,8 +121,8 @@ def get_issue_or_pr_body(number: int) -> str: def normalize_ratings( - ratings: Dict[TestRun, float], max_value: float, min_value: float = 0 -) -> Dict[TestRun, float]: + ratings: dict[TestRun, float], max_value: float, min_value: float = 0 +) -> dict[TestRun, float]: # Takse the ratings, makes the max value into max_value, and proportionally # distributes the rest of the ratings. # Ex [1,2,3,4] and max_value 8 gets converted to [2,4,6,8] @@ -138,7 +142,7 @@ def normalize_ratings( return normalized_ratings -def get_ratings_for_tests(file: Union[str, Path]) -> Dict[str, float]: +def get_ratings_for_tests(file: str | Path) -> dict[str, float]: path = REPO_ROOT / file if not os.path.exists(path): print(f"could not find path {path}") @@ -150,14 +154,14 @@ def get_ratings_for_tests(file: Union[str, Path]) -> Dict[str, float]: except Exception as e: warn(f"Can't query changed test files due to {e}") return {} - ratings: Dict[str, float] = defaultdict(float) + ratings: dict[str, float] = defaultdict(float) for file in changed_files: for test_file, score in test_file_ratings.get(file, {}).items(): ratings[test_file] += score return ratings -def get_correlated_tests(file: Union[str, Path]) -> List[str]: +def get_correlated_tests(file: str | Path) -> list[str]: ratings = get_ratings_for_tests(file) prioritize = sorted(ratings, key=lambda x: -ratings[x]) return prioritize diff --git a/tools/testing/test_run.py b/tools/testing/test_run.py index 70ba67421eef..9d77f969989f 100644 --- a/tools/testing/test_run.py +++ b/tools/testing/test_run.py @@ -1,6 +1,8 @@ +from __future__ import annotations + from copy import copy from functools import total_ordering -from typing import Any, Dict, FrozenSet, Iterable, List, Optional, Union +from typing import Any, Iterable class TestRun: @@ -15,16 +17,16 @@ class TestRun: """ test_file: str - _excluded: FrozenSet[str] # Tests that should be excluded from this test run - _included: FrozenSet[ + _excluded: frozenset[str] # Tests that should be excluded from this test run + _included: frozenset[ str ] # If non-empy, only these tests should be run in this test run def __init__( self, name: str, - excluded: Optional[Iterable[str]] = None, - included: Optional[Iterable[str]] = None, + excluded: Iterable[str] | None = None, + included: Iterable[str] | None = None, ) -> None: if excluded and included: raise ValueError("Can't specify both included and excluded") @@ -45,7 +47,7 @@ class TestRun: self._included = frozenset(ins) @staticmethod - def empty() -> "TestRun": + def empty() -> TestRun: return TestRun("") def is_empty(self) -> bool: @@ -56,10 +58,10 @@ class TestRun: def is_full_file(self) -> bool: return not self._included and not self._excluded - def included(self) -> FrozenSet[str]: + def included(self) -> frozenset[str]: return self._included - def excluded(self) -> FrozenSet[str]: + def excluded(self) -> frozenset[str]: return self._excluded def get_pytest_filter(self) -> str: @@ -70,7 +72,7 @@ class TestRun: else: return "" - def contains(self, test: "TestRun") -> bool: + def contains(self, test: TestRun) -> bool: if self.test_file != test.test_file: return False @@ -92,7 +94,7 @@ class TestRun: # Does self exclude anything test includes? If not, we're good return not self._excluded.intersection(test._included) - def __copy__(self) -> "TestRun": + def __copy__(self) -> TestRun: return TestRun(self.test_file, excluded=self._excluded, included=self._included) def __bool__(self) -> bool: @@ -126,7 +128,7 @@ class TestRun: def __hash__(self) -> int: return hash((self.test_file, self._included, self._excluded)) - def __or__(self, other: "TestRun") -> "TestRun": + def __or__(self, other: TestRun) -> TestRun: """ To OR/Union test runs means to run all the tests that either of the two runs specify. """ @@ -167,7 +169,7 @@ class TestRun: excluded = self._excluded | other._excluded return TestRun(self.test_file, excluded=excluded - included) - def __sub__(self, other: "TestRun") -> "TestRun": + def __sub__(self, other: TestRun) -> TestRun: """ To subtract test runs means to run all the tests in the first run except for what the second run specifies. """ @@ -186,7 +188,7 @@ class TestRun: if other.is_full_file(): return TestRun.empty() - def return_inclusions_or_empty(inclusions: FrozenSet[str]) -> TestRun: + def return_inclusions_or_empty(inclusions: frozenset[str]) -> TestRun: if inclusions: return TestRun(self.test_file, included=inclusions) return TestRun.empty() @@ -204,14 +206,14 @@ class TestRun: else: return return_inclusions_or_empty(other._excluded - self._excluded) - def __and__(self, other: "TestRun") -> "TestRun": + def __and__(self, other: TestRun) -> TestRun: if self.test_file != other.test_file: return TestRun.empty() return (self | other) - (self - other) - (other - self) - def to_json(self) -> Dict[str, Any]: - r: Dict[str, Any] = { + def to_json(self) -> dict[str, Any]: + r: dict[str, Any] = { "test_file": self.test_file, } if self._included: @@ -221,7 +223,7 @@ class TestRun: return r @staticmethod - def from_json(json: Dict[str, Any]) -> "TestRun": + def from_json(json: dict[str, Any]) -> TestRun: return TestRun( json["test_file"], included=json.get("included", []), @@ -234,14 +236,14 @@ class ShardedTest: test: TestRun shard: int num_shards: int - time: Optional[float] # In seconds + time: float | None # In seconds def __init__( self, - test: Union[TestRun, str], + test: TestRun | str, shard: int, num_shards: int, - time: Optional[float] = None, + time: float | None = None, ) -> None: if isinstance(test, str): test = TestRun(test) @@ -296,7 +298,7 @@ class ShardedTest: def get_time(self, default: float = 0) -> float: return self.time if self.time is not None else default - def get_pytest_args(self) -> List[str]: + def get_pytest_args(self) -> list[str]: filter = self.test.get_pytest_filter() if filter: return ["-k", self.test.get_pytest_filter()] diff --git a/tools/testing/test_selections.py b/tools/testing/test_selections.py index 3e43edd50247..5ac1d946d7ec 100644 --- a/tools/testing/test_selections.py +++ b/tools/testing/test_selections.py @@ -1,9 +1,10 @@ +from __future__ import annotations + import math import os import subprocess from pathlib import Path - -from typing import Callable, Dict, FrozenSet, List, Optional, Sequence, Tuple +from typing import Callable, Sequence from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests from tools.testing.test_run import ShardedTest, TestRun @@ -47,8 +48,8 @@ if IS_ROCM and not IS_MEM_LEAK_CHECK: class ShardJob: def __init__(self) -> None: - self.serial: List[ShardedTest] = [] - self.parallel: List[ShardedTest] = [] + self.serial: list[ShardedTest] = [] + self.parallel: list[ShardedTest] = [] def get_total_time(self) -> float: """Default is the value for which to substitute if a test has no time""" @@ -59,16 +60,16 @@ class ShardJob: time = max(procs) + sum(test.get_time() for test in self.serial) return time - def convert_to_tuple(self) -> Tuple[float, List[ShardedTest]]: + def convert_to_tuple(self) -> tuple[float, list[ShardedTest]]: return (self.get_total_time(), self.serial + self.parallel) def get_with_pytest_shard( tests: Sequence[TestRun], - test_file_times: Dict[str, float], - test_class_times: Optional[Dict[str, Dict[str, float]]], -) -> List[ShardedTest]: - sharded_tests: List[ShardedTest] = [] + test_file_times: dict[str, float], + test_class_times: dict[str, dict[str, float]] | None, +) -> list[ShardedTest]: + sharded_tests: list[ShardedTest] = [] for test in tests: duration = get_duration(test, test_file_times, test_class_times or {}) @@ -86,9 +87,9 @@ def get_with_pytest_shard( def get_duration( test: TestRun, - test_file_times: Dict[str, float], - test_class_times: Dict[str, Dict[str, float]], -) -> Optional[float]: + test_file_times: dict[str, float], + test_class_times: dict[str, dict[str, float]], +) -> float | None: """Calculate the time for a TestRun based on the given test_file_times and test_class_times. Returns None if the time is unknown.""" file_duration = test_file_times.get(test.test_file, None) @@ -96,8 +97,8 @@ def get_duration( return file_duration def get_duration_for_classes( - test_file: str, test_classes: FrozenSet[str] - ) -> Optional[float]: + test_file: str, test_classes: frozenset[str] + ) -> float | None: duration: float = 0 for test_class in test_classes: @@ -127,9 +128,9 @@ def get_duration( def shard( - sharded_jobs: List[ShardJob], + sharded_jobs: list[ShardJob], pytest_sharded_tests: Sequence[ShardedTest], - estimated_time_limit: Optional[float] = None, + estimated_time_limit: float | None = None, serial: bool = False, ) -> None: # Modifies sharded_jobs in place @@ -142,7 +143,7 @@ def shard( round_robin_index = 0 def _get_min_sharded_job( - sharded_jobs: List[ShardJob], test: ShardedTest + sharded_jobs: list[ShardJob], test: ShardedTest ) -> ShardJob: if test.time is None: nonlocal round_robin_index @@ -152,7 +153,7 @@ def shard( return min(sharded_jobs, key=lambda j: j.get_total_time()) def _shard_serial( - tests: Sequence[ShardedTest], sharded_jobs: List[ShardJob] + tests: Sequence[ShardedTest], sharded_jobs: list[ShardJob] ) -> None: assert estimated_time_limit is not None, "Estimated time limit must be provided" new_sharded_jobs = sharded_jobs @@ -166,7 +167,7 @@ def shard( min_sharded_job.serial.append(test) def _shard_parallel( - tests: Sequence[ShardedTest], sharded_jobs: List[ShardJob] + tests: Sequence[ShardedTest], sharded_jobs: list[ShardJob] ) -> None: for test in tests: min_sharded_job = _get_min_sharded_job(sharded_jobs, test) @@ -183,11 +184,11 @@ def shard( def calculate_shards( num_shards: int, tests: Sequence[TestRun], - test_file_times: Dict[str, float], - test_class_times: Optional[Dict[str, Dict[str, float]]], - must_serial: Optional[Callable[[str], bool]] = None, + test_file_times: dict[str, float], + test_class_times: dict[str, dict[str, float]] | None, + must_serial: Callable[[str], bool] | None = None, sort_by_time: bool = True, -) -> List[Tuple[float, List[ShardedTest]]]: +) -> list[tuple[float, list[ShardedTest]]]: must_serial = must_serial or (lambda x: True) test_class_times = test_class_times or {}