[BE][Easy] enable postponed annotations in tools (#129375)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129375
Approved by: https://github.com/malfet
This commit is contained in:
Xuehai Pan
2024-06-29 12:48:06 +08:00
committed by PyTorch MergeBot
parent 58f346c874
commit 8a67daf283
123 changed files with 1274 additions and 1053 deletions

View File

@ -1,16 +1,19 @@
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import os
import re
from collections import defaultdict
from difflib import SequenceMatcher
from typing import Any, Dict, List, Set, Tuple
from typing import Any
import requests
from setuptools import distutils # type: ignore[import]
ALL_SKIPPED_THRESHOLD = 100
SIMILARITY_THRESHOLD = 0.75
FAILURE_CHAIN_THRESHOLD = 2
@ -65,14 +68,14 @@ DISABLED_ALERTS = [
class JobStatus:
job_name: str = ""
jobs: List[Any] = []
jobs: list[Any] = []
current_status: Any = None
job_statuses: List[Any] = []
filtered_statuses: List[Any] = []
failure_chain: List[Any] = []
flaky_jobs: List[Any] = []
job_statuses: list[Any] = []
filtered_statuses: list[Any] = []
failure_chain: list[Any] = []
flaky_jobs: list[Any] = []
def __init__(self, job_name: str, job_statuses: List[Any]):
def __init__(self, job_name: str, job_statuses: list[Any]) -> None:
self.job_name = job_name
self.job_statuses = job_statuses
@ -93,7 +96,7 @@ class JobStatus:
return status
return None
def get_unique_failures(self, jobs: List[Any]) -> Dict[str, List[Any]]:
def get_unique_failures(self, jobs: list[Any]) -> dict[str, list[Any]]:
"""
Returns list of jobs grouped by failureCaptures from the input list
"""
@ -120,7 +123,7 @@ class JobStatus:
return failures
# A flaky job is if it's the only job that has that failureCapture and is not the most recent job
def get_flaky_jobs(self) -> List[Any]:
def get_flaky_jobs(self) -> list[Any]:
unique_failures = self.get_unique_failures(self.filtered_statuses)
flaky_jobs = []
for failure in unique_failures:
@ -134,7 +137,7 @@ class JobStatus:
# The most recent failure chain is an array of jobs that have the same-ish failures.
# A success in the middle of the chain will terminate the chain.
def get_most_recent_failure_chain(self) -> List[Any]:
def get_most_recent_failure_chain(self) -> list[Any]:
failures = []
found_most_recent_failure = False
@ -178,7 +181,7 @@ def fetch_hud_data(repo: str, branch: str) -> Any:
# Creates a Dict of Job Name -> [JobData]. Essentially a Column in HUD
def map_job_data(jobNames: Any, shaGrid: Any) -> Dict[str, Any]:
def map_job_data(jobNames: Any, shaGrid: Any) -> dict[str, Any]:
jobData = defaultdict(list)
for sha in shaGrid:
for ind, job in enumerate(sha["jobs"]):
@ -196,13 +199,13 @@ def is_job_skipped(job: Any) -> bool:
return conclusion in (NEUTRAL, SKIPPED) or conclusion is None
def get_failed_jobs(job_data: List[Any]) -> List[Any]:
def get_failed_jobs(job_data: list[Any]) -> list[Any]:
return [job for job in job_data if job["conclusion"] == "failure"]
def classify_jobs(
all_job_names: List[str], sha_grid: Any, filtered_jobs_names: Set[str]
) -> Tuple[List[JobStatus], List[Any]]:
all_job_names: list[str], sha_grid: Any, filtered_jobs_names: set[str]
) -> tuple[list[JobStatus], list[Any]]:
"""
Creates Job Statuses which has the logic for if need to alert or if there's flaky jobs.
Classifies jobs into jobs to alert on and flaky jobs.
@ -212,7 +215,7 @@ def classify_jobs(
:return:
"""
job_data = map_job_data(all_job_names, sha_grid)
job_statuses: List[JobStatus] = []
job_statuses: list[JobStatus] = []
for job in job_data:
job_statuses.append(JobStatus(job, job_data[job]))
@ -230,7 +233,7 @@ def classify_jobs(
# filter job names that don't match the regex
def filter_job_names(job_names: List[str], job_name_regex: str) -> List[str]:
def filter_job_names(job_names: list[str], job_name_regex: str) -> list[str]:
if job_name_regex:
return [
job_name for job_name in job_names if re.match(job_name_regex, job_name)
@ -240,7 +243,7 @@ def filter_job_names(job_names: List[str], job_name_regex: str) -> List[str]:
def get_recurrently_failing_jobs_alerts(
repo: str, branch: str, job_name_regex: str
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
job_names, sha_grid = fetch_hud_data(repo=repo, branch=branch)
filtered_job_names = set(filter_job_names(job_names, job_name_regex))

View File

@ -14,18 +14,17 @@ generated. In the full build system, OUTPUT_DIR is
torch/testing/_internal/generated
"""
from __future__ import annotations
import argparse
import os
import textwrap
from collections import defaultdict
from typing import Any, Dict, List, Sequence
from typing import Any, Sequence, TYPE_CHECKING
import torchgen.api.python as python
from torchgen.context import with_native_function
from torchgen.gen import parse_native_yaml
from torchgen.model import Argument, BaseOperatorName, NativeFunction
from torchgen.utils import FileManager
from .gen_python_functions import (
@ -39,6 +38,10 @@ from .gen_python_functions import (
)
if TYPE_CHECKING:
from torchgen.model import Argument, BaseOperatorName, NativeFunction
def gen_annotated(
native_yaml_path: str, tags_yaml_path: str, out: str, autograd_dir: str
) -> None:
@ -53,9 +56,9 @@ def gen_annotated(
(is_py_fft_function, "torch._C._fft"),
(is_py_variable_method, "torch.Tensor"),
)
annotated_args: List[str] = []
annotated_args: list[str] = []
for pred, namespace in mappings:
groups: Dict[BaseOperatorName, List[NativeFunction]] = defaultdict(list)
groups: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list)
for f in native_functions:
if not should_generate_py_binding(f) or not pred(f):
continue
@ -77,7 +80,7 @@ def gen_annotated(
@with_native_function
def gen_annotated_args(f: NativeFunction) -> str:
def _get_kwargs_func_exclusion_list() -> List[str]:
def _get_kwargs_func_exclusion_list() -> list[str]:
# functions that currently don't work with kwargs in test_overrides.py
return [
"diagonal",
@ -87,12 +90,12 @@ def gen_annotated_args(f: NativeFunction) -> str:
]
def _add_out_arg(
out_args: List[Dict[str, Any]], args: Sequence[Argument], *, is_kwarg_only: bool
out_args: list[dict[str, Any]], args: Sequence[Argument], *, is_kwarg_only: bool
) -> None:
for arg in args:
if arg.default is not None:
continue
out_arg: Dict[str, Any] = {}
out_arg: dict[str, Any] = {}
out_arg["is_kwarg_only"] = str(is_kwarg_only)
out_arg["name"] = arg.name
out_arg["simple_type"] = python.argument_type_str(
@ -103,7 +106,7 @@ def gen_annotated_args(f: NativeFunction) -> str:
out_arg["size"] = size_t
out_args.append(out_arg)
out_args: List[Dict[str, Any]] = []
out_args: list[dict[str, Any]] = []
_add_out_arg(out_args, f.func.arguments.flat_positional, is_kwarg_only=False)
if f"{f.func.name.name}" not in _get_kwargs_func_exclusion_list():
_add_out_arg(out_args, f.func.arguments.flat_kwarg_only, is_kwarg_only=True)

View File

@ -22,9 +22,10 @@ torch/csrc/autograd/generated/
# gen_python_functions.py: generates Python bindings to THPVariable
#
from __future__ import annotations
import argparse
import os
from typing import List
from torchgen.api import cpp
from torchgen.api.autograd import (
@ -69,7 +70,7 @@ def gen_autograd(
),
key=lambda f: cpp.name(f.func),
)
fns_with_diff_infos: List[
fns_with_diff_infos: list[
NativeFunctionWithDifferentiabilityInfo
] = match_differentiability_info(fns, differentiability_infos)

View File

@ -4,7 +4,10 @@
# Functions.h/cpp: subclasses of autograd::Node
# python_functions.h/cpp: Python bindings for the above classes
#
from typing import Dict, List, Sequence, Tuple
from __future__ import annotations
from typing import Sequence
from torchgen.api.autograd import (
Derivative,
@ -43,6 +46,7 @@ from torchgen.utils import FileManager
from .gen_inplace_or_view_type import VIEW_FUNCTIONS
FUNCTION_DECLARATION = CodeTemplate(
"""\
#ifdef _WIN32
@ -443,8 +447,8 @@ UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS
def get_infos_with_derivatives_list(
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]]
) -> List[DifferentiabilityInfo]:
differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]]
) -> list[DifferentiabilityInfo]:
diff_info_list = [
info
for diffinfo_dict in differentiability_infos.values()
@ -456,7 +460,7 @@ def get_infos_with_derivatives_list(
def gen_autograd_functions_lib(
out: str,
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
template_path: str,
) -> None:
"""Functions.h and Functions.cpp body
@ -490,7 +494,7 @@ def gen_autograd_functions_lib(
def gen_autograd_functions_python(
out: str,
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
template_path: str,
) -> None:
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
@ -536,17 +540,17 @@ def gen_autograd_functions_python(
def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str:
saved_variables: List[str] = []
release_variables: List[str] = []
saved_list_sizes: List[str] = []
unpack: List[str] = []
asserts: List[str] = []
compute_index_ranges: List[str] = []
getter_definitions: List[str] = []
py_getsetdef_structs: List[str] = []
compiled_args: List[str] = []
apply_with_saved_before: List[str] = []
apply_with_saved_after: List[str] = []
saved_variables: list[str] = []
release_variables: list[str] = []
saved_list_sizes: list[str] = []
unpack: list[str] = []
asserts: list[str] = []
compute_index_ranges: list[str] = []
getter_definitions: list[str] = []
py_getsetdef_structs: list[str] = []
compiled_args: list[str] = []
apply_with_saved_before: list[str] = []
apply_with_saved_after: list[str] = []
for arg in info.args_with_derivatives:
if arg.type in TENSOR_LIST_LIKE_CTYPES:
@ -807,7 +811,7 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
else:
will_release_variables = ""
body: List[str] = []
body: list[str] = []
if uses_single_grad(info):
body.append("const auto& grad = grads[0];")
@ -821,7 +825,7 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
def emit_derivative(
derivative: Derivative,
args_with_derivatives: Sequence[Binding],
) -> Tuple[bool, str]:
) -> tuple[bool, str]:
formula = derivative.formula
var_names = derivative.var_names
if len(var_names) == 1:
@ -857,7 +861,7 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
else:
grad_input_mask = ""
idx_ranges = ", ".join(f"{n}_ix" for n in var_names)
copy_ranges: List[str] = []
copy_ranges: list[str] = []
for i, n in enumerate(var_names):
copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i))
return False, DERIVATIVE_MULTI.substitute(

View File

@ -4,7 +4,7 @@
# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp
# The fallback is expected to mimick this codegen, so we should keep the two in sync.
from typing import Dict, List, Optional, Tuple
from __future__ import annotations
from torchgen.api import cpp
from torchgen.api.autograd import (
@ -24,8 +24,7 @@ from torchgen.api.types import (
OptionalCType,
symIntArrayRefT,
SymIntT,
# See Note [Nested Arg Types]
tensorT,
tensorT, # See Note [Nested Arg Types]
)
from torchgen.code_template import CodeTemplate
from torchgen.context import with_native_function
@ -46,6 +45,7 @@ from .gen_trace_type import (
type_wrapper_name,
)
# See NOTE [ Autograd View Variables ] in variable.h for details.
# If you update list VIEW_FUNCTIONS or RETURNS_VIEWS_OF_INPUT,
# you **MUST** also update the public list of view ops accordingly in
@ -281,7 +281,7 @@ def inverse_view_name(f: NativeFunction) -> str:
return f"{copy_variant}{overload}_inverse"
def extract_bindings(f: NativeFunction) -> List[Binding]:
def extract_bindings(f: NativeFunction) -> list[Binding]:
return [
r
for a in f.func.schema_order_arguments()
@ -297,9 +297,9 @@ def extract_bindings(f: NativeFunction) -> List[Binding]:
@with_native_function
def unpack_args(f: NativeFunction) -> Tuple[List[str], List[Binding]]:
body: List[str] = []
unpacked_bindings: List[Binding] = []
def unpack_args(f: NativeFunction) -> tuple[list[str], list[Binding]]:
body: list[str] = []
unpacked_bindings: list[Binding] = []
for i, binding in enumerate(extract_bindings(f)):
assert not isinstance(binding.argument, SelfArgument)
@ -338,7 +338,7 @@ def get_base_name(f: NativeFunction) -> str:
return f.func.name.name.base # TODO: should be str(f.func.name.name)?
def get_view_info(f: NativeFunction) -> Optional[str]:
def get_view_info(f: NativeFunction) -> str | None:
base_name = get_base_name(f)
view_info = VIEW_FUNCTIONS.get(base_name, None)
if view_info is None and base_name in RETURNS_VIEWS_OF_INPUT:
@ -347,7 +347,7 @@ def get_view_info(f: NativeFunction) -> Optional[str]:
def emit_view_func(
f: NativeFunction, bindings: List[Binding], view_idx: Optional[str] = None
f: NativeFunction, bindings: list[Binding], view_idx: str | None = None
) -> str:
"""Generate an additional lambda function to recover views in backward when as_strided is not supported.
See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details.
@ -355,8 +355,8 @@ def emit_view_func(
# TODO: Clean this logic up if we get rid of reverse view funcs or reify them.
input_base = "input_base"
replay_view_func = ""
updated_args: List[str] = []
known_view_arg_simple_types: List[CType] = [
updated_args: list[str] = []
known_view_arg_simple_types: list[CType] = [
BaseCType(longT),
OptionalCType(BaseCType(longT)),
BaseCType(SymIntT),
@ -448,7 +448,7 @@ def emit_view_func(
def emit_view_body(
fn: NativeFunctionWithDifferentiabilityInfo, var: str
) -> Tuple[str, str]:
) -> tuple[str, str]:
# See NOTE [ Autograd View Variables ] in variable.h for details.
f = fn.func
base_name = get_base_name(f)
@ -523,9 +523,9 @@ def modifies_arguments(f: NativeFunction) -> bool:
@with_native_function_with_differentiability_info
def emit_inplace_or_view_body(fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]:
def emit_inplace_or_view_body(fn: NativeFunctionWithDifferentiabilityInfo) -> list[str]:
f = fn.func
inplace_view_body: List[str] = []
inplace_view_body: list[str] = []
dispatcher_sig = DispatcherSignature.from_schema(f.func)
dispatcher_exprs = dispatcher_sig.exprs()
@ -584,7 +584,7 @@ def gen_formals(f: NativeFunction) -> str:
@with_native_function_with_differentiability_info
def inplace_or_view_method_definition(
fn: NativeFunctionWithDifferentiabilityInfo,
) -> Optional[str]:
) -> str | None:
f = fn.func
if get_view_info(f) is None and (
# For functions that modify their inputs but don't return them,
@ -605,7 +605,7 @@ def inplace_or_view_method_definition(
@with_native_function_with_differentiability_info
def inplace_or_view_method_registration(
fn: NativeFunctionWithDifferentiabilityInfo,
) -> Optional[str]:
) -> str | None:
f = fn.func
if get_view_info(f) is None and (
not modifies_arguments(f) or len(f.func.returns) == 0
@ -626,7 +626,7 @@ def use_derived(fn: NativeFunctionWithDifferentiabilityInfo) -> bool:
def gen_inplace_or_view_type_env(
fn: NativeFunctionWithDifferentiabilityInfo,
) -> Dict[str, List[str]]:
) -> dict[str, list[str]]:
definition = inplace_or_view_method_definition(fn)
registration = inplace_or_view_method_registration(fn)
@ -649,7 +649,7 @@ def gen_inplace_or_view_type(
out: str,
native_yaml_path: str,
tags_yaml_path: str,
fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo],
fns_with_infos: list[NativeFunctionWithDifferentiabilityInfo],
template_path: str,
) -> None:
# NOTE: see Note [Sharded File] at the top of the VariableType.cpp

View File

@ -31,11 +31,12 @@
# message, but use what's there
#
from __future__ import annotations
import itertools
import re
from collections import defaultdict
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple
from typing import Callable, Iterable, Sequence
import yaml
@ -56,7 +57,6 @@ from torchgen.api.python import (
signature_from_schema,
structseq_fieldnames,
)
from torchgen.code_template import CodeTemplate
from torchgen.context import with_native_function
from torchgen.gen import cpp_string, parse_native_yaml, parse_tags_yaml
@ -75,6 +75,7 @@ from torchgen.yaml_utils import YamlLoader
from .gen_inplace_or_view_type import is_tensor_list_type
from .gen_trace_type import should_trace
#
# declarations blocklist
# We skip codegen for these functions, for various reasons.
@ -369,7 +370,7 @@ def gen(
valid_tags = parse_tags_yaml(tags_yaml_path)
def gen_tags_enum() -> Dict[str, str]:
def gen_tags_enum() -> dict[str, str]:
return {
"enum_of_valid_tags": (
"".join(
@ -384,9 +385,9 @@ def gen(
def group_filter_overloads(
pairs: Sequence[PythonSignatureNativeFunctionPair],
pred: Callable[[NativeFunction], bool],
) -> Dict[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]:
grouped: Dict[
BaseOperatorName, List[PythonSignatureNativeFunctionPair]
) -> dict[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]:
grouped: dict[
BaseOperatorName, list[PythonSignatureNativeFunctionPair]
] = defaultdict(list)
for pair in pairs:
if pred(pair.function):
@ -398,17 +399,17 @@ def create_python_bindings(
fm: FileManager,
pairs: Sequence[PythonSignatureNativeFunctionPair],
pred: Callable[[NativeFunction], bool],
module: Optional[str],
module: str | None,
filename: str,
*,
method: bool,
symint: bool = True,
) -> None:
"""Generates Python bindings to ATen functions"""
py_methods: List[str] = []
ops_headers: List[str] = []
py_method_defs: List[str] = []
py_forwards: List[str] = []
py_methods: list[str] = []
ops_headers: list[str] = []
py_method_defs: list[str] = []
py_forwards: list[str] = []
grouped = group_filter_overloads(pairs, pred)
@ -445,8 +446,8 @@ def create_python_return_type_bindings(
Generate function to initialize and return named tuple for native functions
which returns named tuple and registration invocations in `python_return_types.cpp`.
"""
py_return_types_definition: List[str] = []
py_return_types_registrations: List[str] = []
py_return_types_definition: list[str] = []
py_return_types_registrations: list[str] = []
grouped = group_filter_overloads(pairs, pred)
@ -484,7 +485,7 @@ def create_python_return_type_bindings_header(
Generate function to initialize and return named tuple for native functions
which returns named tuple and relevant entry for the map in `python_return_types.cpp`.
"""
py_return_types_declarations: List[str] = []
py_return_types_declarations: list[str] = []
grouped = group_filter_overloads(pairs, pred)
@ -510,7 +511,7 @@ def create_python_bindings_sharded(
fm: FileManager,
pairs: Sequence[PythonSignatureNativeFunctionPair],
pred: Callable[[NativeFunction], bool],
module: Optional[str],
module: str | None,
filename: str,
*,
method: bool,
@ -521,13 +522,13 @@ def create_python_bindings_sharded(
grouped = group_filter_overloads(pairs, pred)
def key_func(
kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]
) -> str:
return kv[0].base
def env_func(
kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
) -> Dict[str, List[str]]:
kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]
) -> dict[str, list[str]]:
name, fn_pairs = kv
return {
"ops_headers": [f"#include <ATen/ops/{name.base}.h>"],
@ -553,7 +554,7 @@ def create_python_bindings_sharded(
def load_signatures(
native_functions: List[NativeFunction],
native_functions: list[NativeFunction],
deprecated_yaml_path: str,
*,
method: bool,
@ -580,19 +581,19 @@ def load_deprecated_signatures(
*,
method: bool,
pyi: bool,
) -> List[PythonSignatureNativeFunctionPair]:
) -> list[PythonSignatureNativeFunctionPair]:
# The deprecated.yaml doesn't have complete type information, we need
# find and leverage the original ATen signature (to which it delegates
# the call) to generate the full python signature.
# We join the deprecated and the original signatures using type-only form.
# group the original ATen signatures by name
grouped: Dict[str, List[PythonSignatureNativeFunctionPair]] = defaultdict(list)
grouped: dict[str, list[PythonSignatureNativeFunctionPair]] = defaultdict(list)
for pair in pairs:
grouped[pair.signature.name].append(pair)
# find matching original signatures for each deprecated signature
results: List[PythonSignatureNativeFunctionPair] = []
results: list[PythonSignatureNativeFunctionPair] = []
with open(deprecated_yaml_path) as f:
deprecated_defs = yaml.load(f, Loader=YamlLoader)
@ -701,15 +702,15 @@ def gen_structseq_typename_key(f: NativeFunction) -> str:
def emit_structseq_call(
overloads: Sequence[PythonSignatureNativeFunctionPair],
) -> Tuple[List[str], Dict[str, str]]:
) -> tuple[list[str], dict[str, str]]:
"""
Generate block of named tuple type def inits, and add typeref snippets
to declarations that use them
"""
typenames: Dict[
typenames: dict[
str, str
] = {} # map from unique name + field name lists to typedef name
typedefs: List[str] = [] # typedef declarations and init code
typedefs: list[str] = [] # typedef declarations and init code
for overload in overloads:
fieldnames = structseq_fieldnames(overload.function.func.returns)
@ -732,17 +733,17 @@ static PyTypeObject* {typename} = generated::get_{name}_structseq();"""
def generate_return_type_definition_and_registrations(
overloads: Sequence[PythonSignatureNativeFunctionPair],
) -> Tuple[List[str], List[str]]:
) -> tuple[list[str], list[str]]:
"""
Generate block of function in `python_return_types.cpp` to initialize
and return named tuple for a native function which returns named tuple
and registration invocations in same file.
"""
typenames: Dict[
typenames: dict[
str, str
] = {} # map from unique name + field name lists to typedef name
definitions: List[str] = [] # function definition to register the typedef
registrations: List[str] = [] # register call for the typedef
definitions: list[str] = [] # function definition to register the typedef
registrations: list[str] = [] # register call for the typedef
for overload in overloads:
fieldnames = structseq_fieldnames(overload.function.func.returns)
@ -783,15 +784,15 @@ PyTypeObject* get_{name}_structseq() {{
def generate_return_type_declarations(
overloads: Sequence[PythonSignatureNativeFunctionPair],
) -> List[str]:
) -> list[str]:
"""
Generate block of function declarations in `python_return_types.h` to initialize
and return named tuple for a native function.
"""
typenames: Dict[
typenames: dict[
str, str
] = {} # map from unique name + field name lists to typedef name
declarations: List[str] = [] # function declaration to register the typedef
declarations: list[str] = [] # function declaration to register the typedef
for overload in overloads:
fieldnames = structseq_fieldnames(overload.function.func.returns)
@ -891,7 +892,7 @@ static PyObject * ${pycname}(PyObject* self_, PyObject* args)
def method_impl(
name: BaseOperatorName,
module: Optional[str],
module: str | None,
overloads: Sequence[PythonSignatureNativeFunctionPair],
*,
method: bool,
@ -918,8 +919,8 @@ def method_impl(
overloads, symint=symint
)
is_singleton = len(grouped_overloads) == 1
signatures: List[str] = []
dispatch: List[str] = []
signatures: list[str] = []
dispatch: list[str] = []
for overload_index, overload in enumerate(grouped_overloads):
signature = overload.signature.signature_str(symint=symint)
signatures.append(f"{cpp_string(str(signature))},")
@ -959,7 +960,7 @@ def method_impl(
def gen_has_torch_function_check(
name: BaseOperatorName, module: Optional[str], *, noarg: bool, method: bool
name: BaseOperatorName, module: str | None, *, noarg: bool, method: bool
) -> str:
if noarg:
if method:
@ -1007,7 +1008,7 @@ if (_r.isNone(${out_idx})) {
def emit_dispatch_case(
overload: PythonSignatureGroup,
structseq_typenames: Dict[str, str],
structseq_typenames: dict[str, str],
*,
symint: bool = True,
) -> str:
@ -1050,7 +1051,7 @@ def forward_decls(
overloads: Sequence[PythonSignatureNativeFunctionPair],
*,
method: bool,
) -> Tuple[str, ...]:
) -> tuple[str, ...]:
if method:
return ()
@ -1078,7 +1079,7 @@ static PyObject * {pycname}(PyObject* self_, PyObject* args, PyObject* kwargs);
def method_def(
name: BaseOperatorName,
module: Optional[str],
module: str | None,
overloads: Sequence[PythonSignatureNativeFunctionPair],
*,
method: bool,
@ -1114,8 +1115,8 @@ def method_def(
def group_overloads(
overloads: Sequence[PythonSignatureNativeFunctionPair], *, symint: bool = True
) -> Sequence[PythonSignatureGroup]:
bases: Dict[str, PythonSignatureNativeFunctionPair] = {}
outplaces: Dict[str, PythonSignatureNativeFunctionPair] = {}
bases: dict[str, PythonSignatureNativeFunctionPair] = {}
outplaces: dict[str, PythonSignatureNativeFunctionPair] = {}
# first group by signature ignoring out arguments
for overload in overloads:
@ -1137,7 +1138,7 @@ def group_overloads(
for sig, out in outplaces.items():
if sig not in bases:
candidates: List[str] = []
candidates: list[str] = []
for overload in overloads:
if (
str(overload.function.func.name.name)
@ -1268,7 +1269,7 @@ def sort_overloads(
)
# Construct the relation graph
larger_than: Dict[int, Set[int]] = defaultdict(set)
larger_than: dict[int, set[int]] = defaultdict(set)
for i1, overload1 in enumerate(grouped_overloads):
for i2, overload2 in enumerate(grouped_overloads):
if is_smaller(overload1.signature, overload2.signature):
@ -1279,7 +1280,7 @@ def sort_overloads(
# Use a topological sort to sort overloads according to the partial order.
N = len(grouped_overloads)
sorted_ids: List[int] = list(filter(lambda x: x not in larger_than, range(N)))
sorted_ids: list[int] = list(filter(lambda x: x not in larger_than, range(N)))
for idx in range(N):
# The size of sorted_ids will grow to N eventually.
@ -1304,7 +1305,7 @@ def sort_overloads(
def emit_single_dispatch(
ps: PythonSignature,
f: NativeFunction,
structseq_typenames: Dict[str, str],
structseq_typenames: dict[str, str],
*,
symint: bool = True,
) -> str:

View File

@ -1,5 +1,7 @@
from __future__ import annotations
import itertools
from typing import Dict, List, Sequence, Union
from typing import Sequence
from torchgen.api import cpp
from torchgen.api.types import DispatcherSignature
@ -8,6 +10,7 @@ from torchgen.context import with_native_function
from torchgen.model import Argument, NativeFunction, SchemaKind, TensorOptionsArguments
from torchgen.utils import FileManager
# Note [Manual Backend kernels]
# For these ops, we want to manually register to dispatch key Backend and
# skip codegen-ed registeration to all keys before Backend.
@ -136,9 +139,7 @@ ADD_TRACE_INPUT = CodeTemplate("""jit::tracer::addInputs(node, "${name}", ${inpu
def format_trace_inputs(f: NativeFunction) -> str:
def dispatch_trace_input(
arg: Union[Argument, TensorOptionsArguments]
) -> Sequence[str]:
def dispatch_trace_input(arg: Argument | TensorOptionsArguments) -> Sequence[str]:
if isinstance(arg, TensorOptionsArguments):
name = "options"
return [
@ -156,7 +157,7 @@ def format_trace_inputs(f: NativeFunction) -> str:
else:
return [ADD_TRACE_INPUT.substitute(name=name, input=name)]
args: List[Union[Argument, TensorOptionsArguments]] = list(
args: list[Argument | TensorOptionsArguments] = list(
f.func.schema_order_arguments()
)
@ -399,8 +400,8 @@ ${assign_return_values}at::_ops::${unambiguous_name}::redispatch(${unpacked_args
)
def emit_trace_body(f: NativeFunction) -> List[str]:
trace_body: List[str] = []
def emit_trace_body(f: NativeFunction) -> list[str]:
trace_body: list[str] = []
trace_body.append(format_prerecord_trace(f))
@ -503,7 +504,7 @@ def method_registration(f: NativeFunction) -> str:
)
def gen_trace_type_func(fn: NativeFunction) -> Dict[str, List[str]]:
def gen_trace_type_func(fn: NativeFunction) -> dict[str, list[str]]:
return {
"ops_headers": [f"#include <ATen/ops/{fn.root_name}_ops.h>"],
"trace_method_definitions": [method_definition(fn)],
@ -512,7 +513,7 @@ def gen_trace_type_func(fn: NativeFunction) -> Dict[str, List[str]]:
def gen_trace_type(
out: str, native_functions: List[NativeFunction], template_path: str
out: str, native_functions: list[NativeFunction], template_path: str
) -> None:
# NOTE: see Note [Sharded File] at the top of the VariableType.cpp
# template regarding sharding of the generated files.

View File

@ -2,18 +2,19 @@
#
# This writes one file: variable_factories.h
from __future__ import annotations
import re
from typing import List, Optional
import torchgen.api.python as python
from torchgen.api import cpp
from torchgen.api.types import CppSignatureGroup
from torchgen.context import with_native_function
from torchgen.gen import parse_native_yaml
from torchgen.model import NativeFunction, TensorOptionsArguments, Variant
from torchgen.utils import FileManager, mapMaybe
OPTIONAL_TYPE_PATTERN = re.compile(r"c10::optional<(.+)>")
TYPE_PATTERN = re.compile(r"(?:const\s+)?([A-Z]\w+)")
@ -69,7 +70,7 @@ def is_factory_function(f: NativeFunction) -> bool:
@with_native_function
def process_function(f: NativeFunction) -> Optional[str]:
def process_function(f: NativeFunction) -> str | None:
name = cpp.name(f.func)
has_tensor_options = python.has_tensor_options(f)
is_factory = has_tensor_options or name.endswith("_like")
@ -83,8 +84,8 @@ def process_function(f: NativeFunction) -> Optional[str]:
sigs.append(cpp_sigs.symint_signature)
r = ""
for sig in sigs:
formals: List[str] = []
exprs: List[str] = []
formals: list[str] = []
exprs: list[str] = []
requires_grad = "false"
for arg in sig.arguments():
qualified_type = fully_qualified_type(arg.type)

View File

@ -25,8 +25,11 @@
# which will in turn dispatch back to VariableType for its
# differentiable subcomponents.
#
from __future__ import annotations
import re
from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
from typing import Callable, Sequence
from torchgen.api import cpp
from torchgen.api.autograd import (
@ -38,7 +41,6 @@ from torchgen.api.autograd import (
NativeFunctionWithDifferentiabilityInfo,
SavedAttribute,
)
from torchgen.api.types import (
ArrayRefCType,
BaseCppType,
@ -103,6 +105,7 @@ from .gen_trace_type import (
type_wrapper_name,
)
# We don't set or modify grad_fn on these methods. Generally, they return
# tensors that have requires_grad=False. In-place functions listed here will
# not examine or modify requires_grad or grad_fn.
@ -837,9 +840,9 @@ def gen_variable_type(
out: str,
native_yaml_path: str,
tags_yaml_path: str,
fns_with_diff_infos: List[NativeFunctionWithDifferentiabilityInfo],
fns_with_diff_infos: list[NativeFunctionWithDifferentiabilityInfo],
template_path: str,
used_keys: Set[str],
used_keys: set[str],
) -> None:
"""VariableType.h and VariableType.cpp body
@ -858,8 +861,8 @@ def gen_variable_type(
# helper that generates a TORCH_LIBRARY_IMPL macro for each
# dispatch key that appears in derivatives.yaml
def wrapper_registrations(used_keys: Set[str]) -> str:
library_impl_macro_list: List[str] = []
def wrapper_registrations(used_keys: set[str]) -> str:
library_impl_macro_list: list[str] = []
for key in sorted(used_keys):
dispatch_key = key
if key == "Default":
@ -926,7 +929,7 @@ def gen_wrapper_registration(f: NativeFunction, key: str = "Default") -> str:
def gen_variable_type_func(
fn: NativeFunctionWithDifferentiabilityInfo,
) -> Dict[str, List[str]]:
) -> dict[str, list[str]]:
f = fn.func
result = {}
with native_function_manager(f):
@ -1034,7 +1037,7 @@ _foreach_ops_with_different_arity = {
@with_native_function_with_differentiability_info_and_key
def emit_body(
fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
) -> List[str]:
) -> list[str]:
assert dispatch_strategy(fn) == "use_derived"
f = fn.func
info = fn.info[key] if fn.info else None
@ -1050,8 +1053,8 @@ def emit_body(
is_foreach = name.startswith("_foreach")
is_inplace_foreach = is_foreach and inplace
if is_inplace_foreach:
inplace_foreacharg2refarg: Dict[Argument, Argument] = {}
refargname2inplace_foreacharg: Dict[str, Argument] = {}
inplace_foreacharg2refarg: dict[Argument, Argument] = {}
refargname2inplace_foreacharg: dict[str, Argument] = {}
base_name_and_overload_name = (f.func.name.name.base, f.func.name.overload_name)
if info is None:
assert (
@ -1077,8 +1080,8 @@ def emit_body(
refargname2inplace_foreacharg[ref_arg.name] = foreach_arg
def gen_differentiable_input(
arg: Union[Argument, SelfArgument, TensorOptionsArguments]
) -> Optional[DifferentiableInput]:
arg: Argument | SelfArgument | TensorOptionsArguments,
) -> DifferentiableInput | None:
if isinstance(arg, TensorOptionsArguments):
return None
a: Argument = arg.argument if isinstance(arg, SelfArgument) else arg
@ -1097,7 +1100,7 @@ def emit_body(
)
@with_native_function
def gen_differentiable_inputs(f: NativeFunction) -> List[DifferentiableInput]:
def gen_differentiable_inputs(f: NativeFunction) -> list[DifferentiableInput]:
arguments = list(f.func.arguments.non_out)
if is_inplace_foreach and info is not None:
for i, arg in enumerate(f.func.arguments.flat_non_out):
@ -1115,8 +1118,8 @@ def emit_body(
return list(mapMaybe(gen_differentiable_input, arguments))
def find_args_with_derivatives(
differentiable_inputs: List[DifferentiableInput],
) -> List[DifferentiableInput]:
differentiable_inputs: list[DifferentiableInput],
) -> list[DifferentiableInput]:
"""Find arguments that have derivative definitions"""
if info is None or not info.has_derivatives:
return differentiable_inputs
@ -1178,8 +1181,8 @@ def emit_body(
and (not returns_void)
)
def emit_save_inputs() -> List[str]:
setup: List[str] = []
def emit_save_inputs() -> list[str]:
setup: list[str] = []
if info is None or not info.has_derivatives:
return setup
@ -1189,7 +1192,7 @@ def emit_body(
# We don't want to save tensors if we know that they will never be used
# when computing the derivative, so we add guards to those statements
def guard_for(arg: SavedAttribute) -> Optional[str]:
def guard_for(arg: SavedAttribute) -> str | None:
assert info is not None
# It's hard to determine the edge offset if we have TensorLists
@ -1276,8 +1279,8 @@ def emit_body(
setup.append(f"grad_fn->{arg.name}_size_ = {arg.name}.size();")
return setup
def setup_derivative(differentiable_inputs: List[DifferentiableInput]) -> List[str]:
body: List[str] = []
def setup_derivative(differentiable_inputs: list[DifferentiableInput]) -> list[str]:
body: list[str] = []
if is_out_fn:
# For out functions, ensure that no input or output requires grad
body.append(DECLARE_GRAD_FN.substitute(op="Node"))
@ -1343,8 +1346,8 @@ def emit_body(
body.append(SETUP_DERIVATIVE.substitute(setup=setup))
return body
def emit_check_if_in_complex_autograd_allowlist() -> List[str]:
body: List[str] = []
def emit_check_if_in_complex_autograd_allowlist() -> list[str]:
body: list[str] = []
if base_name in GRADIENT_IMPLEMENTED_FOR_COMPLEX:
return body
for arg in differentiable_outputs:
@ -1355,11 +1358,11 @@ def emit_body(
return body
def emit_check_no_requires_grad(
tensor_args: List[DifferentiableInput],
args_with_derivatives: List[DifferentiableInput],
) -> List[str]:
tensor_args: list[DifferentiableInput],
args_with_derivatives: list[DifferentiableInput],
) -> list[str]:
"""Checks that arguments without derivatives don't require grad"""
body: List[str] = []
body: list[str] = []
for arg in tensor_args:
if arg in args_with_derivatives:
continue
@ -1373,8 +1376,8 @@ def emit_body(
body.append(f'check_no_requires_grad({arg_name}, "{arg_name}", "{name}");')
return body
def emit_original_self_definition() -> List[str]:
body: List[str] = []
def emit_original_self_definition() -> list[str]:
body: list[str] = []
if inplace:
if is_inplace_foreach:
body.append(
@ -1412,17 +1415,17 @@ def emit_body(
def save_variables(
saved_variables: Sequence[SavedAttribute],
is_output: bool,
guard_for: Callable[[SavedAttribute], Optional[str]] = lambda name: None,
guard_for: Callable[[SavedAttribute], str | None] = lambda name: None,
) -> Sequence[str]:
# assign the saved variables to the generated grad_fn
stmts: List[str] = []
stmts: list[str] = []
for arg in sorted(saved_variables, key=lambda sa: str(sa.nctype.name)):
name = (
arg.nctype.name.name
if isinstance(arg.nctype.name, SpecialArgName)
else arg.nctype.name
)
foreacharg: Optional[Argument] = None
foreacharg: Argument | None = None
is_foreacharg_list_type: bool = False
type = arg.nctype.type
expr = arg.expr
@ -1539,10 +1542,10 @@ def emit_body(
return call
def wrap_output(
f: NativeFunction, unpacked_bindings: List[Binding], var: str
f: NativeFunction, unpacked_bindings: list[Binding], var: str
) -> str:
call = ""
rhs_value: Optional[str] = None
rhs_value: str | None = None
if not any(r.type.is_tensor_like() for r in f.func.returns):
rhs_value = var
else:
@ -1554,11 +1557,11 @@ def emit_body(
return call
def check_tensorimpl_and_storage(
call: str, unpacked_bindings: List[Binding]
call: str, unpacked_bindings: list[Binding]
) -> str:
# See NOTE [ TensorImpl and Storage Pointer Sanity Checks ]
stmts_before_call: List[str] = []
stmts_after_call: List[str] = []
stmts_before_call: list[str] = []
stmts_after_call: list[str] = []
if cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE:
return call
@ -1665,7 +1668,7 @@ def emit_body(
return call
def emit_call(
f: NativeFunction, unpacked_bindings: List[Binding], try_jit_decomposition: bool
f: NativeFunction, unpacked_bindings: list[Binding], try_jit_decomposition: bool
) -> str:
# We only care about adding `at::AutoDispatchBelowAutograd` guard for non-variable dispatch
# (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure
@ -1764,7 +1767,7 @@ def emit_body(
)
return ""
def emit_any_requires_grad() -> List[str]:
def emit_any_requires_grad() -> list[str]:
extra_condition = ""
if info and info.output_differentiability_conditions:
assert len(info.output_differentiability_conditions) == 1
@ -1782,14 +1785,14 @@ def emit_body(
)
]
def get_any_has_forward_grad_name(var_names: Tuple[str, ...]) -> str:
def get_any_has_forward_grad_name(var_names: tuple[str, ...]) -> str:
if len(var_names) == 1:
return f"_any_has_forward_grad_{var_names[0]}"
else:
return f'_any_has_forward_grad_{"_".join(var_names)}'
def emit_any_has_forward_grad() -> List[str]:
content: List[str] = []
def emit_any_has_forward_grad() -> list[str]:
content: list[str] = []
if not is_foreach:
for derivative in fw_derivatives:
requires_fw_grad = get_any_has_fw_grad_cond(derivative=derivative)
@ -1844,7 +1847,7 @@ def emit_body(
content.append("}")
return content
def emit_check_inplace() -> List[str]:
def emit_check_inplace() -> list[str]:
if not inplace:
return []
return [
@ -1852,9 +1855,9 @@ def emit_body(
for arg in differentiable_outputs
]
def emit_fw_derivatives() -> List[str]:
content: List[str] = []
fw_grad_setters: List[str] = []
def emit_fw_derivatives() -> list[str]:
content: list[str] = []
fw_grad_setters: list[str] = []
for derivative in fw_derivatives:
res = derivative.var_names
if f.func.name.name.inplace:
@ -2002,7 +2005,7 @@ def emit_body(
"(self.size(), c10::nullopt);"
)
foreach_forward_grad_formula = derivative.formula
_foreach_arg: Union[Argument, DifferentiableInput]
_foreach_arg: Argument | DifferentiableInput
if inplace:
for _foreach_arg, _ref_arg in inplace_foreacharg2refarg.items():
# note(crcrpar): Massage only Scalar and ArrayRef<Scalar> here.
@ -2044,7 +2047,7 @@ def emit_body(
content.append("\n".join(fw_grad_setters))
return content
def get_any_has_fw_grad_cond(derivative: Optional[ForwardDerivative]) -> str:
def get_any_has_fw_grad_cond(derivative: ForwardDerivative | None) -> str:
#
# Produces a condition string (e.g, "isFwGradDefined(grad_output) || isFwGradDefined(output)")
#
@ -2053,7 +2056,7 @@ def emit_body(
# - Used in the out_fn case when we want to forbid fw derivatives
# - Used in the case where the fw_derivative is not defined, but we want
# To check if there is a decomposition registered for jvp
to_check: List[str] = []
to_check: list[str] = []
for inp in list(
mapMaybe(
gen_differentiable_input,
@ -2126,7 +2129,7 @@ def emit_body(
else ""
)
body: List[str] = []
body: list[str] = []
unpack_args_stats, unpacked_bindings = unpack_args(f)
body.extend(unpack_args_stats)

View File

@ -4,10 +4,11 @@
# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp
# The fallback is expected to mimic this codegen, so we should keep the two in sync.
from typing import List, Tuple
from __future__ import annotations
from typing import TYPE_CHECKING
import torchgen.api.dispatcher as dispatcher
from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo
from torchgen.api.translate import translate
from torchgen.api.types import (
BaseCType,
@ -29,6 +30,11 @@ from .gen_inplace_or_view_type import (
use_derived,
)
if TYPE_CHECKING:
from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo
FUNCTION_DECLARATION = CodeTemplate(
"""\
#define ${uppercase_op}_AVAILABLE
@ -155,9 +161,9 @@ def returns_multi_tensor(fn: NativeFunction) -> bool:
# tuple: (list of getter logic strings, list of setter logic strings, string
# with num items expression)
def generate_state_getter_setter(
bindings: List[Binding],
bindings: list[Binding],
state_vec_type: NamedCType,
) -> Tuple[List[str], List[str], str]:
) -> tuple[list[str], list[str], str]:
getter_logic = []
setter_logic = []
@ -302,7 +308,7 @@ def process_function(fn: NativeFunction, template: CodeTemplate) -> str:
def gen_view_funcs(
out: str,
fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo],
fns_with_infos: list[NativeFunctionWithDifferentiabilityInfo],
template_path: str,
) -> None:
# don't need the info parts, just the function

View File

@ -2,14 +2,16 @@
#
# Each autograd function is represented by `DifferentiabilityInfo` containing
# a list of `Derivative`. See `torchgen.api.autograd` for the data models.
from __future__ import annotations
import re
from collections import defaultdict
from typing import Any, Counter, Dict, List, Match, Optional, Sequence, Set, Tuple
from typing import Any, Counter, Dict, Sequence, Set, Tuple
import yaml
from torchgen.api import cpp
from torchgen.api.autograd import (
Derivative,
DifferentiabilityInfo,
@ -50,9 +52,10 @@ from torchgen.model import (
from torchgen.utils import concatMap, IDENT_REGEX, split_name_params
from torchgen.yaml_utils import YamlLoader
DerivativeRet = Tuple[Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], Set[str]]
_GLOBAL_LOAD_DERIVATIVE_CACHE: Dict[Tuple[str, str], DerivativeRet] = {}
_GLOBAL_LOAD_DERIVATIVE_CACHE: dict[tuple[str, str], DerivativeRet] = {}
_VALID_AUTOGRAD_KEYS = set(AUTOGRAD_KEYS)
@ -62,11 +65,11 @@ _VALID_AUTOGRAD_KEYS = set(AUTOGRAD_KEYS)
# we generate them here instead of duplicating them in the yaml.
# See Note [Codegen'd {view}_copy Operators]
def add_view_copy_derivatives(
infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
view_groups: List[NativeFunctionsViewGroup],
infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
view_groups: list[NativeFunctionsViewGroup],
) -> None:
# Get the map from each view op's name to its corresponding view group
view_name_to_group: Dict[OperatorName, NativeFunctionsViewGroup] = {
view_name_to_group: dict[OperatorName, NativeFunctionsViewGroup] = {
g.view.func.name: g for g in view_groups
}
@ -125,10 +128,10 @@ def load_derivatives(
# function schema is the complete declaration including mutability annotation / default value and etc.
# signature is the canonical schema for a group of functions (in-place/out/functional variants)
# that are semantically related.
functions_by_signature: Dict[
FunctionSchema, List[NativeFunction]
functions_by_signature: dict[
FunctionSchema, list[NativeFunction]
] = defaultdict(list)
functions_by_schema: Dict[str, NativeFunction] = {}
functions_by_schema: dict[str, NativeFunction] = {}
for function in native_functions:
functions_by_signature[function.func.signature()].append(function)
assert str(function.func) not in functions_by_schema
@ -141,8 +144,8 @@ def load_derivatives(
# infos is a dict that maps FunctionSchema -> a dict of per dispatch key DifferentiabilityInfos
# this is useful because in tools/autograd/gen_autograd.py:match_differentiability_info
# we ultimately need to categorize the DifferentiabilityInfos by FunctionSchema
infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]] = {}
used_dispatch_keys: Set[str] = set()
infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]] = {}
used_dispatch_keys: set[str] = set()
for defn_dict in definitions:
# Ensure that the old derivatives.yaml schema with no dispatch key can be loaded.
if "dispatch" not in defn_dict:
@ -185,11 +188,11 @@ def cpp_arguments(f: NativeFunction) -> Sequence[Binding]:
def create_derivative(
f: NativeFunction,
formula: str,
var_names: Tuple[str, ...],
var_names: tuple[str, ...],
available_named_gradients: Sequence[str],
) -> Derivative:
original_formula = formula
arguments: List[NamedCType] = [
arguments: list[NamedCType] = [
a.nctype.remove_const_ref() for a in cpp_arguments(f)
]
@ -230,10 +233,10 @@ def create_derivative(
def create_forward_derivative(
f: NativeFunction, formula: str, names: Tuple[str, ...]
f: NativeFunction, formula: str, names: tuple[str, ...]
) -> ForwardDerivative:
var_names = names
var_types: Optional[Tuple[Type, ...]] = None
var_types: tuple[Type, ...] | None = None
for r in f.func.returns:
if r.name in var_names:
if var_types is None:
@ -269,12 +272,12 @@ def create_forward_derivative(
def postprocess_forward_derivatives(
f: NativeFunction,
defn_name: str,
all_arg_names: List[str],
derivatives: List[Derivative],
forward_derivatives: List[ForwardDerivative],
all_arg_names: list[str],
derivatives: list[Derivative],
forward_derivatives: list[ForwardDerivative],
args_with_derivatives: Sequence[Binding],
) -> List[ForwardDerivative]:
def find_required_inputs(formula: str, postfix: str) -> Tuple[str, ...]:
) -> list[ForwardDerivative]:
def find_required_inputs(formula: str, postfix: str) -> tuple[str, ...]:
is_foreach = f.func.name.name.base.startswith("_foreach_")
required_inputs = set()
for arg in args_with_derivatives:
@ -300,7 +303,7 @@ def postprocess_forward_derivatives(
return tuple(required_inputs)
updated_derivatives: List[ForwardDerivative] = []
updated_derivatives: list[ForwardDerivative] = []
for defn in forward_derivatives:
formula = defn.formula
@ -430,7 +433,7 @@ def postprocess_forward_derivatives(
def is_forward_derivative_definition(
all_arg_names: List[str], names: Tuple[str, ...]
all_arg_names: list[str], names: tuple[str, ...]
) -> bool:
for name in names:
if name not in all_arg_names:
@ -441,12 +444,12 @@ def is_forward_derivative_definition(
def create_differentiability_info(
defn_dict: Dict[Any, Any],
functions_by_signature: Dict[FunctionSchema, List[NativeFunction]],
functions_by_schema: Dict[str, NativeFunction],
defn_dict: dict[Any, Any],
functions_by_signature: dict[FunctionSchema, list[NativeFunction]],
functions_by_schema: dict[str, NativeFunction],
op_counter: Counter[str],
used_dispatch_keys: Set[str],
) -> Tuple[FunctionSchema, Dict[str, DifferentiabilityInfo]]:
used_dispatch_keys: set[str],
) -> tuple[FunctionSchema, dict[str, DifferentiabilityInfo]]:
"""Processes a single entry `defn` in derivatives.yaml"""
def canonical_function(
@ -463,7 +466,7 @@ def create_differentiability_info(
assert name + "_" == cpp.name(functions[0].func)
return functions[0]
def split_names(raw_names: str) -> Tuple[str, ...]:
def split_names(raw_names: str) -> tuple[str, ...]:
"""Given "foo, bar", return ["foo", "bar"]."""
return tuple(x.strip() for x in raw_names.split(","))
@ -477,7 +480,7 @@ def create_differentiability_info(
uses_grad = False # true if any derivative uses "grad"
num_grads_uses = 0 # count of uses of "grads" or "grads[INDEX]"
uses_named_grads = False # true if any derivative uses "grad_{name}"
used_grads_indices: List[int] = [] # which indices of grads are used
used_grads_indices: list[int] = [] # which indices of grads are used
for d in derivatives:
formula = d.formula
uses_grad = uses_grad or bool(
@ -521,7 +524,7 @@ def create_differentiability_info(
@with_native_function
def set_up_derivatives(
f: NativeFunction,
) -> Tuple[
) -> tuple[
Sequence[Derivative],
Sequence[ForwardDerivative],
Sequence[Binding],
@ -529,10 +532,10 @@ def create_differentiability_info(
Sequence[str],
]:
# Set up the derivative information
derivatives: List[Derivative] = []
forward_derivatives: List[ForwardDerivative] = []
non_differentiable_arg_names: List[str] = []
args_with_derivatives_set: Set[str] = set()
derivatives: list[Derivative] = []
forward_derivatives: list[ForwardDerivative] = []
non_differentiable_arg_names: list[str] = []
args_with_derivatives_set: set[str] = set()
all_arg_names = [a.name for a in cpp_arguments(f)]
all_ret_names = [
@ -699,7 +702,7 @@ def create_differentiability_info(
available_named_gradients,
) = set_up_derivatives(canonical)
used_named_gradients: Set[str] = set()
used_named_gradients: set[str] = set()
for d in derivatives:
used_named_gradients |= d.named_gradients
@ -738,7 +741,7 @@ def create_differentiability_info(
GRAD_INDEX_REGEX = r"(?:^|\W)grads\[(\d+)\]"
def used_gradient_indices(formula: str) -> List[int]:
def used_gradient_indices(formula: str) -> list[int]:
"""Determine a list of gradient indices (the i in grads[i]) that
are used by the formula.
@ -750,9 +753,9 @@ def used_gradient_indices(formula: str) -> List[int]:
def saved_variables(
formula: str,
nctypes: List[NamedCType],
var_names: Tuple[str, ...],
) -> Tuple[str, Tuple[SavedAttribute, ...]]:
nctypes: list[NamedCType],
var_names: tuple[str, ...],
) -> tuple[str, tuple[SavedAttribute, ...]]:
def stride_expr(name: str) -> str:
assert var_names == (name,), (
'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor '
@ -760,7 +763,7 @@ def saved_variables(
)
return f'strides_or_error({name}, "{name}")'
REPLACEMENTS: List[Tuple[str, Dict[str, Any]]] = [
REPLACEMENTS: list[tuple[str, dict[str, Any]]] = [
# replace self.sym_sizes() with self_sym_sizes
(
r"{}.sym_sizes\(\)",
@ -914,7 +917,7 @@ def saved_variables(
]
# find which arguments need to be saved
saved: List[SavedAttribute] = []
saved: list[SavedAttribute] = []
if ".sizes()" in formula or "->sizes()" in formula:
raise RuntimeError(
@ -941,7 +944,7 @@ def saved_variables(
# when the autograd Function is created to avoid saving variables
for regex, info in REPLACEMENTS:
def repl(m: Match[str]) -> str:
def repl(m: re.Match[str]) -> str:
suffix: str = (
info["suffix"](m) if callable(info["suffix"]) else info["suffix"]
)
@ -999,8 +1002,8 @@ def _create_op_prefix(name: str) -> str:
def dedup_vars(vars: Sequence[SavedAttribute]) -> Sequence[SavedAttribute]:
seen: Set[str] = set()
saved: List[SavedAttribute] = []
seen: set[str] = set()
saved: list[SavedAttribute] = []
for var in vars:
name = (
var.nctype.name.name

View File

@ -1,17 +1,17 @@
from __future__ import annotations
import os
import platform
import shutil
from glob import glob
from typing import Dict, Optional
from setuptools import distutils # type: ignore[import]
from .setup_helpers.cmake import CMake, USE_NINJA
from .setup_helpers.env import check_negative_env_flag, IS_64BIT, IS_WINDOWS
def _overlay_windows_vcvars(env: Dict[str, str]) -> Dict[str, str]:
def _overlay_windows_vcvars(env: dict[str, str]) -> dict[str, str]:
vc_arch = "x64" if IS_64BIT else "x86"
if platform.machine() == "ARM64":
@ -34,7 +34,7 @@ def _overlay_windows_vcvars(env: Dict[str, str]) -> Dict[str, str]:
"emulation is enabled!"
)
vc_env: Dict[str, str] = distutils._msvccompiler._get_vc_env(vc_arch)
vc_env: dict[str, str] = distutils._msvccompiler._get_vc_env(vc_arch)
# Keys in `_get_vc_env` are always lowercase.
# We turn them into uppercase before overlaying vcvars
# because OS environ keys are always uppercase on Windows.
@ -47,7 +47,7 @@ def _overlay_windows_vcvars(env: Dict[str, str]) -> Dict[str, str]:
return vc_env
def _create_build_env() -> Dict[str, str]:
def _create_build_env() -> dict[str, str]:
# XXX - our cmake file sometimes looks at the system environment
# and not cmake flags!
# you should NEVER add something to this list. It is bad practice to
@ -72,8 +72,8 @@ def _create_build_env() -> Dict[str, str]:
def build_caffe2(
version: Optional[str],
cmake_python_library: Optional[str],
version: str | None,
cmake_python_library: str | None,
build_python: bool,
rerun_cmake: bool,
cmake_only: bool,

View File

@ -5,10 +5,13 @@
# - ninja -j1 -v -n torch_python | sed -e 's/-O[23]/-g/g' -e 's#\[[0-9]\+\/[0-9]\+\] \+##' |sh
# - Copy libs from build/lib to torch/lib folder
from __future__ import annotations
import subprocess
import sys
from pathlib import Path
from typing import Any, List, Optional, Tuple
from typing import Any
PYTORCH_ROOTDIR = Path(__file__).resolve().parent.parent
TORCH_DIR = PYTORCH_ROOTDIR / "torch"
@ -17,7 +20,7 @@ BUILD_DIR = PYTORCH_ROOTDIR / "build"
BUILD_LIB_DIR = BUILD_DIR / "lib"
def check_output(args: List[str], cwd: Optional[str] = None) -> str:
def check_output(args: list[str], cwd: str | None = None) -> str:
return subprocess.check_output(args, cwd=cwd).decode("utf-8")
@ -63,7 +66,7 @@ def is_devel_setup() -> bool:
return output.strip() == str(TORCH_DIR / "__init__.py")
def create_build_plan() -> List[Tuple[str, str]]:
def create_build_plan() -> list[tuple[str, str]]:
output = check_output(
["ninja", "-j1", "-v", "-n", "torch_python"], cwd=str(BUILD_DIR)
)

View File

@ -8,13 +8,15 @@ For custom build with static dispatch, the op dependency graph will be omitted,
and it will directly output root ops as the allowlist.
"""
import argparse
from __future__ import annotations
import argparse
from collections import defaultdict
from typing import Dict, List, Set
from typing import Dict, Set
import yaml
DepGraph = Dict[str, Set[str]]
@ -34,7 +36,7 @@ def load_op_dep_graph(fname: str) -> DepGraph:
return dict(result)
def load_root_ops(fname: str) -> List[str]:
def load_root_ops(fname: str) -> list[str]:
result = []
with open(fname) as stream:
for op in yaml.safe_load(stream):
@ -44,9 +46,9 @@ def load_root_ops(fname: str) -> List[str]:
def gen_transitive_closure(
dep_graph: DepGraph,
root_ops: List[str],
root_ops: list[str],
train: bool = False,
) -> List[str]:
) -> list[str]:
result = set(root_ops)
queue = root_ops.copy()
@ -73,7 +75,7 @@ def gen_transitive_closure(
return sorted(result)
def gen_transitive_closure_str(dep_graph: DepGraph, root_ops: List[str]) -> str:
def gen_transitive_closure_str(dep_graph: DepGraph, root_ops: list[str]) -> str:
return " ".join(gen_transitive_closure(dep_graph, root_ops))

View File

@ -1,8 +1,11 @@
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import sys
from typing import Any, Dict, List, Optional
from typing import Any
import yaml
from gen_op_registration_allowlist import (
@ -17,6 +20,7 @@ from torchgen.selective_build.operator import (
)
from torchgen.selective_build.selector import merge_kernel_metadata
# Generate YAML file containing the operators used for a specific PyTorch model.
# ------------------------------------------------------------------------------
#
@ -84,17 +88,17 @@ from torchgen.selective_build.selector import merge_kernel_metadata
#
def canonical_opnames(opnames: List[str]) -> List[str]:
def canonical_opnames(opnames: list[str]) -> list[str]:
return [canonical_name(opname) for opname in opnames]
def make_filter_from_options(
model_name: str,
model_versions: List[str],
model_assets: Optional[List[str]],
model_backends: Optional[List[str]],
model_versions: list[str],
model_assets: list[str] | None,
model_backends: list[str] | None,
):
def is_model_included(model_info):
def is_model_included(model_info) -> bool:
model = model_info["model"]
if model["name"] != model_name:
return False
@ -109,7 +113,7 @@ def make_filter_from_options(
# Returns if a the specified rule is a new or old style pt_operator_library
def is_new_style_rule(model_name: str, model_versions: Optional[List[str]]):
def is_new_style_rule(model_name: str, model_versions: list[str] | None):
return model_name is not None and model_versions is not None
@ -117,13 +121,13 @@ def is_new_style_rule(model_name: str, model_versions: Optional[List[str]]):
# appear in at least one model yaml. Throws if verification is failed,
# returns None on success
def verify_all_specified_present(
model_assets: Optional[List[str]],
model_versions: List[str],
selected_models_yaml: List[Dict[str, Any]],
model_assets: list[str] | None,
model_versions: list[str],
selected_models_yaml: list[dict[str, Any]],
rule_name: str,
model_name: str,
new_style_rule: bool,
):
) -> None:
def find_missing_items(model_items, key, selected_models_yaml):
missing_items = []
if not new_style_rule or not model_items:
@ -179,10 +183,10 @@ def verify_all_specified_present(
# Uses the selected models configs and then combines them into one dictionary,
# formats them as a string, and places the string into output as a top level debug_info
def create_debug_info_from_selected_models(
output: Dict[str, object],
selected_models: List[dict],
output: dict[str, object],
selected_models: list[dict],
new_style_rule: bool,
):
) -> None:
model_dict = {
"asset_info": {}, # maps asset name -> dict of asset metadata like hashes
"is_new_style_rule": new_style_rule,
@ -201,7 +205,7 @@ def create_debug_info_from_selected_models(
output["debug_info"] = [json.dumps(model_dict)]
def fill_output(output: Dict[str, object], options: object):
def fill_output(output: dict[str, object], options: object) -> None:
"""Populate the output dict with the information required to serialize
the YAML file used for selective build.
"""
@ -458,7 +462,7 @@ def fill_output(output: Dict[str, object], options: object):
# END TRACING BASED BUILD OPS
# Merge dictionaries together to remove op duplication
operators: Dict[str, SelectiveBuildOperator] = {}
operators: dict[str, SelectiveBuildOperator] = {}
for ops_dict in bucketed_ops:
operators = merge_operator_dicts(operators, ops_dict)

View File

@ -1,10 +1,13 @@
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import os
import sys
from functools import reduce
from typing import Any, List, Set
from typing import Any
import yaml
from tools.lite_interpreter.gen_selected_mobile_ops_header import (
@ -17,11 +20,11 @@ from torchgen.selective_build.selector import (
)
def extract_all_operators(selective_builder: SelectiveBuilder) -> Set[str]:
def extract_all_operators(selective_builder: SelectiveBuilder) -> set[str]:
return set(selective_builder.operators.keys())
def extract_training_operators(selective_builder: SelectiveBuilder) -> Set[str]:
def extract_training_operators(selective_builder: SelectiveBuilder) -> set[str]:
ops = []
for op_name, op in selective_builder.operators.items():
if op.is_used_for_training:
@ -44,7 +47,7 @@ def throw_if_any_op_includes_overloads(selective_builder: SelectiveBuilder) -> N
)
def gen_supported_mobile_models(model_dicts: List[Any], output_dir: str) -> None:
def gen_supported_mobile_models(model_dicts: list[Any], output_dir: str) -> None:
supported_mobile_models_source = """/*
* Generated by gen_oplist.py
*/
@ -87,7 +90,7 @@ SupportedMobileModelCheckerRegistry register_model_versions;
out_file.write(source.encode("utf-8"))
def main(argv: List[Any]) -> None:
def main(argv: list[Any]) -> None:
"""This binary generates 3 files:
1. selected_mobile_ops.h: Primary operators used by templated selective build and Kernel Function

View File

@ -1,6 +1,8 @@
from __future__ import annotations
import argparse
import os
from typing import cast, List, Optional, Tuple
from typing import cast
from ..util.setting import (
CompilerType,
@ -38,7 +40,7 @@ BLOCKED_PYTHON_TESTS = {
}
def initialization() -> Tuple[Option, TestList, List[str]]:
def initialization() -> tuple[Option, TestList, list[str]]:
# create folder if not exists
create_folders()
# add arguments
@ -77,7 +79,7 @@ def add_arguments_oss(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
def parse_arguments(
parser: argparse.ArgumentParser,
) -> Tuple[Option, Optional[List[str]], Optional[List[str]], Optional[bool]]:
) -> tuple[Option, list[str] | None, list[str] | None, bool | None]:
# parse args
args = parser.parse_args()
# get option
@ -85,9 +87,7 @@ def parse_arguments(
return (options, args.interest_only, args.run_only, args.clean)
def get_test_list_by_type(
run_only: Optional[List[str]], test_type: TestType
) -> TestList:
def get_test_list_by_type(run_only: list[str] | None, test_type: TestType) -> TestList:
test_list: TestList = []
binary_folder = get_oss_binary_folder(test_type)
g = os.walk(binary_folder)
@ -106,7 +106,7 @@ def get_test_list_by_type(
return test_list
def get_test_list(run_only: Optional[List[str]]) -> TestList:
def get_test_list(run_only: list[str] | None) -> TestList:
test_list: TestList = []
# add c++ test list
test_list.extend(get_test_list_by_type(run_only, TestType.CPP))
@ -122,7 +122,7 @@ def get_test_list(run_only: Optional[List[str]]) -> TestList:
return test_list
def empty_list_if_none(arg_interested_folder: Optional[List[str]]) -> List[str]:
def empty_list_if_none(arg_interested_folder: list[str] | None) -> list[str]:
if arg_interested_folder is None:
return []
# if this argument is specified, just return itself
@ -134,7 +134,7 @@ def gcc_export_init() -> None:
create_folder(JSON_FOLDER_BASE_DIR)
def get_python_run_only(args_run_only: Optional[List[str]]) -> List[str]:
def get_python_run_only(args_run_only: list[str] | None) -> list[str]:
# if user specifies run-only option
if args_run_only:
return args_run_only
@ -144,7 +144,7 @@ def get_python_run_only(args_run_only: Optional[List[str]]) -> List[str]:
return ["run_test.py"]
else:
# for clang, some tests will result in too large intermediate files that can't be merged by llvm, we need to skip them
run_only: List[str] = []
run_only: list[str] = []
binary_folder = get_oss_binary_folder(TestType.PY)
g = os.walk(binary_folder)
for _, _, file_list in g:

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import os
import subprocess
from typing import List, Optional
from ..util.setting import CompilerType, TestType, TOOLS_FOLDER
from ..util.utils import print_error, remove_file
@ -14,7 +15,7 @@ def get_oss_binary_folder(test_type: TestType) -> str:
)
def get_oss_shared_library() -> List[str]:
def get_oss_shared_library() -> list[str]:
lib_dir = os.path.join(get_pytorch_folder(), "build", "lib")
return [
os.path.join(lib_dir, lib)
@ -48,7 +49,7 @@ def get_pytorch_folder() -> str:
)
def detect_compiler_type() -> Optional[CompilerType]:
def detect_compiler_type() -> CompilerType | None:
# check if user specifies the compiler type
user_specify = os.environ.get("CXX", None)
if user_specify:
@ -76,7 +77,7 @@ def clean_up_gcda() -> None:
remove_file(item)
def get_gcda_files() -> List[str]:
def get_gcda_files() -> list[str]:
folder_has_gcda = os.path.join(get_pytorch_folder(), "build")
if os.path.isdir(folder_has_gcda):
# TODO use glob

View File

@ -1,7 +1,8 @@
from __future__ import annotations
import os
import subprocess
import time
from typing import List
from ..util.setting import (
JSON_FOLDER_BASE_DIR,
@ -25,7 +26,7 @@ from .utils import get_tool_path_by_platform, run_cpp_test
def create_corresponding_folder(
cur_path: str, prefix_cur_path: str, dir_list: List[str], new_base_folder: str
cur_path: str, prefix_cur_path: str, dir_list: list[str], new_base_folder: str
) -> None:
for dir_name in dir_list:
relative_path = convert_to_relative_path(
@ -70,7 +71,7 @@ def export_target(
merged_file: str,
json_file: str,
binary_file: str,
shared_library_list: List[str],
shared_library_list: list[str],
platform_type: TestPlatform,
) -> None:
if binary_file is None:

View File

@ -1,7 +1,8 @@
from __future__ import annotations
import os
import subprocess
import time
from typing import Dict
# gcc is only used in oss
from ..oss.utils import get_gcda_files, run_oss_python_test
@ -10,7 +11,7 @@ from ..util.utils import print_log, print_time
from .utils import run_cpp_test
def update_gzip_dict(gzip_dict: Dict[str, int], file_name: str) -> str:
def update_gzip_dict(gzip_dict: dict[str, int], file_name: str) -> str:
file_name = file_name.lower()
gzip_dict[file_name] = gzip_dict.get(file_name, 0) + 1
num = gzip_dict[file_name]
@ -34,7 +35,7 @@ def export() -> None:
# collect .gcda files
gcda_files = get_gcda_files()
# file name like utils.cpp may have same name in different folder
gzip_dict: Dict[str, int] = {}
gzip_dict: dict[str, int] = {}
for gcda_item in gcda_files:
# generate json.gz
subprocess.check_call(["gcov", "-i", gcda_item])

View File

@ -1,12 +1,14 @@
import typing as t
from __future__ import annotations
from typing import Any, NamedTuple
class CoverageRecord(t.NamedTuple):
class CoverageRecord(NamedTuple):
filepath: str
covered_lines: t.List[int]
uncovered_lines: t.Optional[t.List[int]] = None
covered_lines: list[int]
uncovered_lines: list[int] | None = None
def to_dict(self) -> t.Dict[str, t.Any]:
def to_dict(self) -> dict[str, Any]:
return {
"filepath": self.filepath,
"covered_lines": self.covered_lines,

View File

@ -1,4 +1,6 @@
from typing import Any, Dict, List, Set
from __future__ import annotations
from typing import Any
from .coverage_record import CoverageRecord
@ -10,7 +12,7 @@ class GcovCoverageParser:
of CoverageRecord(s).
"""
def __init__(self, llvm_coverage: Dict[str, Any]) -> None:
def __init__(self, llvm_coverage: dict[str, Any]) -> None:
self._llvm_coverage = llvm_coverage
@staticmethod
@ -24,17 +26,17 @@ class GcovCoverageParser:
return True
return False
def parse(self) -> List[CoverageRecord]:
def parse(self) -> list[CoverageRecord]:
# The JSON format is described in the gcov source code
# https://gcc.gnu.org/onlinedocs/gcc/Invoking-Gcov.html
records: List[CoverageRecord] = []
records: list[CoverageRecord] = []
for file_info in self._llvm_coverage["files"]:
filepath = file_info["file"]
if self._skip_coverage(filepath):
continue
# parse json file
covered_lines: Set[int] = set()
uncovered_lines: Set[int] = set()
covered_lines: set[int] = set()
uncovered_lines: set[int] = set()
for line in file_info["lines"]:
line_number = line["line_number"]
count = line["count"]

View File

@ -1,4 +1,6 @@
from typing import Any, Dict, List, Set, Tuple
from __future__ import annotations
from typing import Any
from .coverage_record import CoverageRecord
from .llvm_coverage_segment import LlvmCoverageSegment, parse_segments
@ -12,7 +14,7 @@ class LlvmCoverageParser:
"""
def __init__(self, llvm_coverage: Dict[str, Any]) -> None:
def __init__(self, llvm_coverage: dict[str, Any]) -> None:
self._llvm_coverage = llvm_coverage
@staticmethod
@ -28,13 +30,13 @@ class LlvmCoverageParser:
@staticmethod
def _collect_coverage(
segments: List[LlvmCoverageSegment],
) -> Tuple[List[int], List[int]]:
segments: list[LlvmCoverageSegment],
) -> tuple[list[int], list[int]]:
"""
Stateful parsing of coverage segments.
"""
covered_lines: Set[int] = set()
uncovered_lines: Set[int] = set()
covered_lines: set[int] = set()
uncovered_lines: set[int] = set()
prev_segment = LlvmCoverageSegment(1, 0, 0, 0, 0, None)
for segment in segments:
covered_range, uncovered_range = segment.get_coverage(prev_segment)
@ -45,10 +47,10 @@ class LlvmCoverageParser:
uncovered_lines.difference_update(covered_lines)
return sorted(covered_lines), sorted(uncovered_lines)
def parse(self, repo_name: str) -> List[CoverageRecord]:
def parse(self, repo_name: str) -> list[CoverageRecord]:
# The JSON format is described in the LLVM source code
# https://github.com/llvm-mirror/llvm/blob/master/tools/llvm-cov/CoverageExporterJson.cpp
records: List[CoverageRecord] = []
records: list[CoverageRecord] = []
for export_unit in self._llvm_coverage["data"]:
for file_info in export_unit["files"]:
filepath = file_info["filename"]

View File

@ -1,4 +1,6 @@
from typing import List, NamedTuple, Optional, Tuple
from __future__ import annotations
from typing import NamedTuple
class LlvmCoverageSegment(NamedTuple):
@ -7,7 +9,7 @@ class LlvmCoverageSegment(NamedTuple):
segment_count: int
has_count: int
is_region_entry: int
is_gap_entry: Optional[int]
is_gap_entry: int | None
@property
def has_coverage(self) -> bool:
@ -18,8 +20,8 @@ class LlvmCoverageSegment(NamedTuple):
return self.has_count > 0
def get_coverage(
self, prev_segment: "LlvmCoverageSegment"
) -> Tuple[List[int], List[int]]:
self, prev_segment: LlvmCoverageSegment
) -> tuple[list[int], list[int]]:
# Code adapted from testpilot.testinfra.runners.gtestcoveragerunner.py
if not prev_segment.is_executable:
return [], []
@ -32,12 +34,12 @@ class LlvmCoverageSegment(NamedTuple):
return (lines_range, []) if prev_segment.has_coverage else ([], lines_range)
def parse_segments(raw_segments: List[List[int]]) -> List[LlvmCoverageSegment]:
def parse_segments(raw_segments: list[list[int]]) -> list[LlvmCoverageSegment]:
"""
Creates LlvmCoverageSegment from a list of lists in llvm export json.
each segment is represented by 5-element array.
"""
ret: List[LlvmCoverageSegment] = []
ret: list[LlvmCoverageSegment] = []
for raw_segment in raw_segments:
assert (
len(raw_segment) == 5 or len(raw_segment) == 6

View File

@ -1,10 +1,13 @@
from __future__ import annotations
import os
import subprocess
from typing import Dict, IO, List, Set, Tuple
from typing import IO, Tuple
from ..oss.utils import get_pytorch_folder
from ..util.setting import SUMMARY_FOLDER_DIR, TestList, TestStatusType
CoverageItem = Tuple[str, float, int, int]
@ -16,7 +19,7 @@ def key_by_name(x: CoverageItem) -> str:
return x[0]
def is_intrested_file(file_path: str, interested_folders: List[str]) -> bool:
def is_intrested_file(file_path: str, interested_folders: list[str]) -> bool:
if "cuda" in file_path:
return False
if "aten/gen_aten" in file_path or "aten/aten_" in file_path:
@ -27,7 +30,7 @@ def is_intrested_file(file_path: str, interested_folders: List[str]) -> bool:
return False
def is_this_type_of_tests(target_name: str, test_set_by_type: Set[str]) -> bool:
def is_this_type_of_tests(target_name: str, test_set_by_type: set[str]) -> bool:
# tests are divided into three types: success / partial success / fail to collect coverage
for test in test_set_by_type:
if target_name in test:
@ -36,7 +39,7 @@ def is_this_type_of_tests(target_name: str, test_set_by_type: Set[str]) -> bool:
def print_test_by_type(
tests: TestList, test_set_by_type: Set[str], type_name: str, summary_file: IO[str]
tests: TestList, test_set_by_type: set[str], type_name: str, summary_file: IO[str]
) -> None:
print("Tests " + type_name + " to collect coverage:", file=summary_file)
for test in tests:
@ -48,8 +51,8 @@ def print_test_by_type(
def print_test_condition(
tests: TestList,
tests_type: TestStatusType,
interested_folders: List[str],
coverage_only: List[str],
interested_folders: list[str],
coverage_only: list[str],
summary_file: IO[str],
summary_type: str,
) -> None:
@ -77,10 +80,10 @@ def print_test_condition(
def line_oriented_report(
tests: TestList,
tests_type: TestStatusType,
interested_folders: List[str],
coverage_only: List[str],
covered_lines: Dict[str, Set[int]],
uncovered_lines: Dict[str, Set[int]],
interested_folders: list[str],
coverage_only: list[str],
covered_lines: dict[str, set[int]],
uncovered_lines: dict[str, set[int]],
) -> None:
with open(os.path.join(SUMMARY_FOLDER_DIR, "line_summary"), "w+") as report_file:
print_test_condition(
@ -119,13 +122,13 @@ def print_file_summary(
def print_file_oriented_report(
tests_type: TestStatusType,
coverage: List[CoverageItem],
coverage: list[CoverageItem],
covered_summary: int,
total_summary: int,
summary_file: IO[str],
tests: TestList,
interested_folders: List[str],
coverage_only: List[str],
interested_folders: list[str],
coverage_only: list[str],
) -> None:
coverage_percentage = print_file_summary(
covered_summary, total_summary, summary_file
@ -155,10 +158,10 @@ def print_file_oriented_report(
def file_oriented_report(
tests: TestList,
tests_type: TestStatusType,
interested_folders: List[str],
coverage_only: List[str],
covered_lines: Dict[str, Set[int]],
uncovered_lines: Dict[str, Set[int]],
interested_folders: list[str],
coverage_only: list[str],
covered_lines: dict[str, set[int]],
uncovered_lines: dict[str, set[int]],
) -> None:
with open(os.path.join(SUMMARY_FOLDER_DIR, "file_summary"), "w+") as summary_file:
covered_summary = 0
@ -193,7 +196,7 @@ def file_oriented_report(
)
def get_html_ignored_pattern() -> List[str]:
def get_html_ignored_pattern() -> list[str]:
return ["/usr/*", "*anaconda3/*", "*third_party/*"]

View File

@ -1,7 +1,9 @@
from __future__ import annotations
import json
import os
import time
from typing import Any, Dict, List, Set, Tuple
from typing import Any, TYPE_CHECKING
from ..util.setting import (
CompilerType,
@ -16,7 +18,6 @@ from ..util.utils import (
print_time,
related_to_test_list,
)
from .parser.coverage_record import CoverageRecord
from .parser.gcov_coverage_parser import GcovCoverageParser
from .parser.llvm_coverage_parser import LlvmCoverageParser
from .print_report import (
@ -26,16 +27,20 @@ from .print_report import (
)
if TYPE_CHECKING:
from .parser.coverage_record import CoverageRecord
# coverage_records: Dict[str, LineInfo] = {}
covered_lines: Dict[str, Set[int]] = {}
uncovered_lines: Dict[str, Set[int]] = {}
covered_lines: dict[str, set[int]] = {}
uncovered_lines: dict[str, set[int]] = {}
tests_type: TestStatusType = {"success": set(), "partial": set(), "fail": set()}
def transform_file_name(
file_path: str, interested_folders: List[str], platform: TestPlatform
file_path: str, interested_folders: list[str], platform: TestPlatform
) -> str:
remove_patterns: Set[str] = {".DEFAULT.cpp", ".AVX.cpp", ".AVX2.cpp"}
remove_patterns: set[str] = {".DEFAULT.cpp", ".AVX.cpp", ".AVX2.cpp"}
for pattern in remove_patterns:
file_path = file_path.replace(pattern, "")
# if user has specified interested folder
@ -54,7 +59,7 @@ def transform_file_name(
def is_intrested_file(
file_path: str, interested_folders: List[str], platform: TestPlatform
file_path: str, interested_folders: list[str], platform: TestPlatform
) -> bool:
ignored_patterns = ["cuda", "aten/gen_aten", "aten/aten_", "build/"]
if any(pattern in file_path for pattern in ignored_patterns):
@ -77,7 +82,7 @@ def is_intrested_file(
return True
def get_json_obj(json_file: str) -> Tuple[Any, int]:
def get_json_obj(json_file: str) -> tuple[Any, int]:
"""
Sometimes at the start of file llvm/gcov will complains "fail to find coverage data",
then we need to skip these lines
@ -102,7 +107,7 @@ def get_json_obj(json_file: str) -> Tuple[Any, int]:
return None, 2
def parse_json(json_file: str, platform: TestPlatform) -> List[CoverageRecord]:
def parse_json(json_file: str, platform: TestPlatform) -> list[CoverageRecord]:
print("start parse:", json_file)
json_obj, read_status = get_json_obj(json_file)
if read_status == 0:
@ -117,7 +122,7 @@ def parse_json(json_file: str, platform: TestPlatform) -> List[CoverageRecord]:
cov_type = detect_compiler_type(platform)
coverage_records: List[CoverageRecord] = []
coverage_records: list[CoverageRecord] = []
if cov_type == CompilerType.CLANG:
coverage_records = LlvmCoverageParser(json_obj).parse("fbcode")
# print(coverage_records)
@ -128,7 +133,7 @@ def parse_json(json_file: str, platform: TestPlatform) -> List[CoverageRecord]:
def parse_jsons(
test_list: TestList, interested_folders: List[str], platform: TestPlatform
test_list: TestList, interested_folders: list[str], platform: TestPlatform
) -> None:
g = os.walk(JSON_FOLDER_BASE_DIR)
@ -152,8 +157,8 @@ def parse_jsons(
def update_coverage(
coverage_records: List[CoverageRecord],
interested_folders: List[str],
coverage_records: list[CoverageRecord],
interested_folders: list[str],
platform: TestPlatform,
) -> None:
for item in coverage_records:
@ -187,8 +192,8 @@ def update_set() -> None:
def summarize_jsons(
test_list: TestList,
interested_folders: List[str],
coverage_only: List[str],
interested_folders: list[str],
coverage_only: list[str],
platform: TestPlatform,
) -> None:
start_time = time.time()

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import os
from enum import Enum
from typing import Dict, List, Set

View File

@ -1,8 +1,10 @@
from __future__ import annotations
import os
import shutil
import sys
import time
from typing import Any, NoReturn, Optional
from typing import Any, NoReturn
from .setting import (
CompilerType,
@ -113,7 +115,7 @@ def get_test_name_from_whole_path(path: str) -> str:
return path[start + 1 : end]
def check_compiler_type(cov_type: Optional[CompilerType]) -> None:
def check_compiler_type(cov_type: CompilerType | None) -> None:
if cov_type is not None and cov_type in [CompilerType.GCC, CompilerType.CLANG]:
return
raise Exception( # noqa: TRY002

View File

@ -1,5 +1,6 @@
import setuptools # type: ignore[import]
with open("README.md", encoding="utf-8") as fh:
long_description = fh.read()

View File

@ -22,6 +22,7 @@ from typing import Any
from coverage import CoverageData, CoveragePlugin # type: ignore[import]
# All coverage stats resulting from this plug-in will be in a separate .coverage file that should be merged later with
# `coverage combine`. The convention seems to be .coverage.dotted.suffix based on the following link:
# https://coverage.readthedocs.io/en/coverage-5.5/cmd.html#combining-data-files-coverage-combine

View File

@ -5,6 +5,7 @@ import sys
from urllib.error import URLError
from urllib.request import urlretrieve
MIRRORS = [
"http://yann.lecun.com/exdb/mnist/",
"https://ossci-datasets.s3.amazonaws.com/mnist/",

View File

@ -5,6 +5,7 @@ import sys
import traceback
import warnings
MIN_CUDA_VERSION = "11.6"
MIN_ROCM_VERSION = "5.4"
MIN_PYTHON_VERSION = (3, 8)
@ -141,7 +142,7 @@ def check_rocm():
return rocm_ver if torch.version.hip else "None"
def check_dynamo(backend, device, err_msg):
def check_dynamo(backend, device, err_msg) -> None:
import torch
if device == "cuda" and not torch.cuda.is_available():
@ -203,7 +204,7 @@ _SANITY_CHECK_ARGS = (
)
def main():
def main() -> None:
python_ver = check_python()
torch_ver = check_torch()
cuda_ver = check_cuda()

View File

@ -1,14 +1,17 @@
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import re
import sys
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Any, Dict
from typing_extensions import TypedDict # Python 3.11+
import yaml
Step = Dict[str, Any]
@ -17,7 +20,7 @@ class Script(TypedDict):
script: str
def extract(step: Step) -> Optional[Script]:
def extract(step: Step) -> Script | None:
run = step.get("run")
# https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#using-a-specific-shell

View File

@ -1,5 +1,7 @@
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import array
import codecs
@ -15,7 +17,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import subprocess
import textwrap
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Any
import yaml
from yaml.constructor import ConstructorError
@ -29,14 +31,14 @@ except ImportError:
CPP_H_NAME = "spv.h"
CPP_SRC_NAME = "spv.cpp"
DEFAULT_ENV: Dict[str, Any] = {
DEFAULT_ENV: dict[str, Any] = {
"PRECISION": "highp",
"FLOAT_IMAGE_FORMAT": "rgba16f",
"INT_IMAGE_FORMAT": "rgba32i",
"UINT_IMAGE_FORMAT": "rgba32ui",
}
TYPES_ENV: Dict[str, Any] = {
TYPES_ENV: dict[str, Any] = {
"IMAGE_FORMAT": {
"float": "rgba32f",
"half": "rgba16f",
@ -91,7 +93,7 @@ TYPES_ENV: Dict[str, Any] = {
},
}
FUNCS_ENV: Dict[str, Any] = {
FUNCS_ENV: dict[str, Any] = {
"GET_POS": {
3: lambda pos: pos,
2: lambda pos: f"{pos}.xy",
@ -169,7 +171,7 @@ def escape(line: str) -> str:
# https://github.com/google/XNNPACK/blob/master/tools/xngen.py
def preprocess(
input_text: str, variables: Dict[str, Any], input_path: str = "codegen"
input_text: str, variables: dict[str, Any], input_path: str = "codegen"
) -> str:
input_lines = input_text.splitlines()
python_lines = []
@ -243,9 +245,9 @@ def preprocess(
class SPVGenerator:
def __init__(
self,
src_dir_paths: Union[str, List[str]],
env: Dict[Any, Any],
glslc_path: Optional[str],
src_dir_paths: str | list[str],
env: dict[Any, Any],
glslc_path: str | None,
) -> None:
if isinstance(src_dir_paths, str):
self.src_dir_paths = [src_dir_paths]
@ -255,18 +257,18 @@ class SPVGenerator:
self.env = env
self.glslc_path = glslc_path
self.glsl_src_files: Dict[str, str] = {}
self.template_yaml_files: List[str] = []
self.glsl_src_files: dict[str, str] = {}
self.template_yaml_files: list[str] = []
self.addSrcAndYamlFiles(self.src_dir_paths)
self.shader_template_params: Dict[Any, Any] = {}
self.shader_template_params: dict[Any, Any] = {}
for yaml_file in self.template_yaml_files:
self.parseTemplateYaml(yaml_file)
self.output_shader_map: Dict[str, Tuple[str, Dict[str, str]]] = {}
self.output_shader_map: dict[str, tuple[str, dict[str, str]]] = {}
self.constructOutputMap()
def addSrcAndYamlFiles(self, src_dir_paths: List[str]) -> None:
def addSrcAndYamlFiles(self, src_dir_paths: list[str]) -> None:
for src_path in src_dir_paths:
# Collect glsl source files
glsl_files = glob.glob(
@ -285,9 +287,9 @@ class SPVGenerator:
def generateVariantCombinations(
self,
iterated_params: Dict[str, Any],
exclude_params: Optional[Set[str]] = None,
) -> List[Any]:
iterated_params: dict[str, Any],
exclude_params: set[str] | None = None,
) -> list[Any]:
if exclude_params is None:
exclude_params = set()
all_iterated_params = []
@ -362,8 +364,8 @@ class SPVGenerator:
)
def create_shader_params(
self, variant_params: Optional[Dict[str, Any]] = None
) -> Dict[str, str]:
self, variant_params: dict[str, Any] | None = None
) -> dict[str, str]:
if variant_params is None:
variant_params = {}
shader_params = copy.deepcopy(self.env)
@ -409,7 +411,7 @@ class SPVGenerator:
self.create_shader_params(),
)
def generateSPV(self, output_dir: str) -> Dict[str, str]:
def generateSPV(self, output_dir: str) -> dict[str, str]:
output_file_map = {}
for shader_name in self.output_shader_map:
source_glsl = self.output_shader_map[shader_name][0]
@ -457,11 +459,11 @@ class SPVGenerator:
@dataclass
class ShaderInfo:
tile_size: List[int]
layouts: List[str]
tile_size: list[int]
layouts: list[str]
weight_storage_type: str = ""
bias_storage_type: str = ""
register_for: Optional[Tuple[str, List[str]]] = None
register_for: tuple[str, list[str]] | None = None
def getName(filePath: str) -> str:
@ -478,7 +480,7 @@ def isTileSizeLine(lineStr: str) -> bool:
return re.search(tile_size_id, lineStr) is not None
def findTileSizes(lineStr: str) -> List[int]:
def findTileSizes(lineStr: str) -> list[int]:
tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)"
matches = re.search(tile_size_id, lineStr)
if matches is None:
@ -520,7 +522,7 @@ def isRegisterForLine(lineStr: str) -> bool:
return re.search(register_for_id, lineStr) is not None
def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]:
def findRegisterFor(lineStr: str) -> tuple[str, list[str]]:
register_for_pattern = r"'([A-Za-z0-9_]+)'"
matches = re.findall(register_for_pattern, lineStr)
if matches is None:
@ -609,7 +611,7 @@ static const api::ShaderRegisterInit register_shaders(&register_fn);
"""
def generateSpvBinStr(spvPath: str, name: str) -> Tuple[int, str]:
def generateSpvBinStr(spvPath: str, name: str) -> tuple[int, str]:
with open(spvPath, "rb") as fr:
next_bin = array.array("I", fr.read())
sizeBytes = 4 * len(next_bin)
@ -665,7 +667,7 @@ def generateShaderDispatchStr(shader_info: ShaderInfo, name: str) -> str:
def genCppFiles(
spv_files: Dict[str, str], cpp_header_path: str, cpp_src_file_path: str
spv_files: dict[str, str], cpp_header_path: str, cpp_src_file_path: str
) -> None:
spv_bin_strs = []
register_shader_info_strs = []
@ -705,7 +707,7 @@ def genCppFiles(
##########
def parse_arg_env(items: Dict[Any, Any]) -> Dict[Any, Any]:
def parse_arg_env(items: dict[Any, Any]) -> dict[Any, Any]:
d = {}
if items:
for item in items:
@ -716,7 +718,7 @@ def parse_arg_env(items: Dict[Any, Any]) -> Dict[Any, Any]:
return d
def main(argv: List[str]) -> int:
def main(argv: list[str]) -> int:
parser = argparse.ArgumentParser(description="")
parser.add_argument(
"-i",

View File

@ -1,9 +1,10 @@
from __future__ import annotations
import argparse
import os
import re
import subprocess
from pathlib import Path
from typing import Optional, Union
from setuptools import distutils # type: ignore[import]
@ -12,7 +13,7 @@ UNKNOWN = "Unknown"
RELEASE_PATTERN = re.compile(r"/v[0-9]+(\.[0-9]+)*(-rc[0-9]+)?/")
def get_sha(pytorch_root: Union[str, Path]) -> str:
def get_sha(pytorch_root: str | Path) -> str:
try:
rev = None
if os.path.exists(os.path.join(pytorch_root, ".git")):
@ -30,7 +31,7 @@ def get_sha(pytorch_root: Union[str, Path]) -> str:
return UNKNOWN
def get_tag(pytorch_root: Union[str, Path]) -> str:
def get_tag(pytorch_root: str | Path) -> str:
try:
tag = subprocess.run(
["git", "describe", "--tags", "--exact"],
@ -46,8 +47,8 @@ def get_tag(pytorch_root: Union[str, Path]) -> str:
return UNKNOWN
def get_torch_version(sha: Optional[str] = None) -> str:
pytorch_root = Path(__file__).parent.parent
def get_torch_version(sha: str | None = None) -> str:
pytorch_root = Path(__file__).absolute().parent.parent
version = open(pytorch_root / "version.txt").read().strip()
if os.getenv("PYTORCH_BUILD_VERSION"):

View File

@ -1,10 +1,10 @@
"""GitHub Utilities"""
from __future__ import annotations
import json
import os
from typing import Any, Callable, cast, Dict, Optional, Tuple
from typing import Any, Callable, cast, Dict
from urllib.error import HTTPError
from urllib.parse import quote
from urllib.request import Request, urlopen
@ -13,11 +13,11 @@ from urllib.request import Request, urlopen
def gh_fetch_url_and_headers(
url: str,
*,
headers: Optional[Dict[str, str]] = None,
data: Optional[Dict[str, Any]] = None,
method: Optional[str] = None,
headers: dict[str, str] | None = None,
data: dict[str, Any] | None = None,
method: str | None = None,
reader: Callable[[Any], Any] = lambda x: x.read(),
) -> Tuple[Any, Any]:
) -> tuple[Any, Any]:
if headers is None:
headers = {}
token = os.environ.get("GITHUB_TOKEN")
@ -44,9 +44,9 @@ def gh_fetch_url_and_headers(
def gh_fetch_url(
url: str,
*,
headers: Optional[Dict[str, str]] = None,
data: Optional[Dict[str, Any]] = None,
method: Optional[str] = None,
headers: dict[str, str] | None = None,
data: dict[str, Any] | None = None,
method: str | None = None,
reader: Callable[[Any], Any] = lambda x: x.read(),
) -> Any:
return gh_fetch_url_and_headers(
@ -56,8 +56,8 @@ def gh_fetch_url(
def _gh_fetch_json_any(
url: str,
params: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None,
params: dict[str, Any] | None = None,
data: dict[str, Any] | None = None,
) -> Any:
headers = {"Accept": "application/vnd.github.v3+json"}
if params is not None and len(params) > 0:
@ -69,13 +69,13 @@ def _gh_fetch_json_any(
def gh_fetch_json_dict(
url: str,
params: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
params: dict[str, Any] | None = None,
data: dict[str, Any] | None = None,
) -> dict[str, Any]:
return cast(Dict[str, Any], _gh_fetch_json_any(url, params, data))
def gh_fetch_commit(org: str, repo: str, sha: str) -> Dict[str, Any]:
def gh_fetch_commit(org: str, repo: str, sha: str) -> dict[str, Any]:
return gh_fetch_json_dict(
f"https://api.github.com/repos/{org}/{repo}/commits/{sha}"
)

View File

@ -1,6 +1,7 @@
import re
import sys
QUOTE_INCLUDE_RE = re.compile(r'^#include "(.*)"')
ANGLE_INCLUDE_RE = re.compile(r"^#include <(.*)>")

View File

@ -1,10 +1,13 @@
# Generates RegisterCodegenUnboxedKernels.cpp, UnboxingFunctions.h and UnboxingFunctions.cpp.
from __future__ import annotations
import argparse
import os
import pathlib
import sys
from dataclasses import dataclass
from typing import List, Literal, Sequence, Union
from pathlib import Path
from typing import Literal, Sequence, TYPE_CHECKING
import yaml
@ -15,10 +18,13 @@ from torchgen.api.unboxing import convert_arguments
from torchgen.context import method_with_native_function
from torchgen.gen import cpp_string, get_custom_build_selector, parse_native_yaml
from torchgen.model import Argument, NativeFunction, NativeFunctionsGroup, Variant
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import FileManager, make_file_manager, mapMaybe, Target
if TYPE_CHECKING:
from torchgen.selective_build.selector import SelectiveBuilder
# Generates UnboxingFunctions.h & UnboxingFunctions.cpp.
@dataclass(frozen=True)
class ComputeUnboxingFunctions:
@ -156,7 +162,7 @@ def gen_unboxing(
cpu_fm: FileManager,
selector: SelectiveBuilder,
) -> None:
def key_func(fn: Union[NativeFunction, NativeFunctionsGroup]) -> str:
def key_func(fn: NativeFunction | NativeFunctionsGroup) -> str:
return fn.root_name
selected_op_num: int = len(selector.operators)
@ -195,7 +201,7 @@ def gen_unboxing(
)
def main(args: List[str]) -> None:
def main(args: list[str]) -> None:
parser = argparse.ArgumentParser(description="Generate unboxing source files")
parser.add_argument(
"-s",
@ -272,7 +278,7 @@ def main(args: List[str]) -> None:
gen_unboxing(native_functions=native_functions, cpu_fm=cpu_fm, selector=selector)
if options.output_dependencies:
depfile_path = pathlib.Path(options.output_dependencies).resolve()
depfile_path = Path(options.output_dependencies).resolve()
depfile_name = depfile_path.name
depfile_stem = depfile_path.stem

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import argparse
import concurrent.futures
import json
@ -8,7 +10,7 @@ import subprocess
import sys
import time
from enum import Enum
from typing import List, NamedTuple, Optional, Pattern
from typing import NamedTuple
LINTER_CODE = "ACTIONLINT"
@ -22,18 +24,18 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple):
path: Optional[str]
line: Optional[int]
char: Optional[int]
path: str | None
line: int | None
char: int | None
code: str
severity: LintSeverity
name: str
original: Optional[str]
replacement: Optional[str]
description: Optional[str]
original: str | None
replacement: str | None
description: str | None
RESULTS_RE: Pattern[str] = re.compile(
RESULTS_RE: re.Pattern[str] = re.compile(
r"""(?mx)
^
(?P<file>.*?):
@ -47,8 +49,8 @@ RESULTS_RE: Pattern[str] = re.compile(
def run_command(
args: List[str],
) -> "subprocess.CompletedProcess[bytes]":
args: list[str],
) -> subprocess.CompletedProcess[bytes]:
logging.debug("$ %s", " ".join(args))
start_time = time.monotonic()
try:
@ -64,7 +66,7 @@ def run_command(
def check_file(
binary: str,
file: str,
) -> List[LintMessage]:
) -> list[LintMessage]:
try:
proc = run_command(
[

View File

@ -5,6 +5,9 @@ archive is downloaded from some sites like GitHub because it can change. Specifi
GitHub gives no guarantee to keep the same value forever. Check for more details at
https://github.com/community/community/discussions/46034.
"""
from __future__ import annotations
import argparse
import json
import re
@ -13,7 +16,7 @@ import subprocess
import sys
import xml.etree.ElementTree as ET
from enum import Enum
from typing import List, NamedTuple, Optional, Set
from typing import NamedTuple
from urllib.parse import urlparse
@ -30,18 +33,18 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple):
path: Optional[str]
line: Optional[int]
char: Optional[int]
path: str | None
line: int | None
char: int | None
code: str
severity: LintSeverity
name: str
original: Optional[str]
replacement: Optional[str]
description: Optional[str]
original: str | None
replacement: str | None
description: str | None
def is_required_checksum(urls: List[Optional[str]]) -> bool:
def is_required_checksum(urls: list[str | None]) -> bool:
if not urls:
return False
@ -58,7 +61,7 @@ def is_required_checksum(urls: List[Optional[str]]) -> bool:
def get_disallowed_checksums(
binary: str,
) -> Set[str]:
) -> set[str]:
"""
Return the set of disallowed checksums from all http_archive rules
"""
@ -96,8 +99,8 @@ def get_disallowed_checksums(
def check_bazel(
filename: str,
disallowed_checksums: Set[str],
) -> List[LintMessage]:
disallowed_checksums: set[str],
) -> list[LintMessage]:
original = ""
replacement = ""

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import argparse
import concurrent.futures
import json
@ -7,7 +9,7 @@ import subprocess
import sys
import time
from enum import Enum
from typing import Any, BinaryIO, List, NamedTuple, Optional
from typing import Any, BinaryIO, NamedTuple
IS_WINDOWS: bool = os.name == "nt"
@ -25,15 +27,15 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple):
path: Optional[str]
line: Optional[int]
char: Optional[int]
path: str | None
line: int | None
char: int | None
code: str
severity: LintSeverity
name: str
original: Optional[str]
replacement: Optional[str]
description: Optional[str]
original: str | None
replacement: str | None
description: str | None
def as_posix(name: str) -> str:
@ -41,11 +43,11 @@ def as_posix(name: str) -> str:
def _run_command(
args: List[str],
args: list[str],
*,
stdin: BinaryIO,
timeout: int,
) -> "subprocess.CompletedProcess[bytes]":
) -> subprocess.CompletedProcess[bytes]:
logging.debug("$ %s", " ".join(args))
start_time = time.monotonic()
try:
@ -63,12 +65,12 @@ def _run_command(
def run_command(
args: List[str],
args: list[str],
*,
stdin: BinaryIO,
retries: int,
timeout: int,
) -> "subprocess.CompletedProcess[bytes]":
) -> subprocess.CompletedProcess[bytes]:
remaining_retries = retries
while True:
try:
@ -90,7 +92,7 @@ def check_file(
filename: str,
retries: int,
timeout: int,
) -> List[LintMessage]:
) -> list[LintMessage]:
try:
with open(filename, "rb") as f:
original = f.read()

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import argparse
import concurrent.futures
import json
@ -8,7 +10,7 @@ import sys
import time
from enum import Enum
from pathlib import Path
from typing import Any, List, NamedTuple, Optional
from typing import Any, NamedTuple
IS_WINDOWS: bool = os.name == "nt"
@ -26,15 +28,15 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple):
path: Optional[str]
line: Optional[int]
char: Optional[int]
path: str | None
line: int | None
char: int | None
code: str
severity: LintSeverity
name: str
original: Optional[str]
replacement: Optional[str]
description: Optional[str]
original: str | None
replacement: str | None
description: str | None
def as_posix(name: str) -> str:
@ -42,10 +44,10 @@ def as_posix(name: str) -> str:
def _run_command(
args: List[str],
args: list[str],
*,
timeout: int,
) -> "subprocess.CompletedProcess[bytes]":
) -> subprocess.CompletedProcess[bytes]:
logging.debug("$ %s", " ".join(args))
start_time = time.monotonic()
try:
@ -62,11 +64,11 @@ def _run_command(
def run_command(
args: List[str],
args: list[str],
*,
retries: int,
timeout: int,
) -> "subprocess.CompletedProcess[bytes]":
) -> subprocess.CompletedProcess[bytes]:
remaining_retries = retries
while True:
try:
@ -89,7 +91,7 @@ def check_file(
binary: str,
retries: int,
timeout: int,
) -> List[LintMessage]:
) -> list[LintMessage]:
try:
with open(filename, "rb") as f:
original = f.read()

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import argparse
import concurrent.futures
import json
@ -11,7 +13,7 @@ import time
from enum import Enum
from pathlib import Path
from sysconfig import get_paths as gp
from typing import Any, List, NamedTuple, Optional, Pattern
from typing import Any, NamedTuple
# PyTorch directory root
@ -49,15 +51,15 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple):
path: Optional[str]
line: Optional[int]
char: Optional[int]
path: str | None
line: int | None
char: int | None
code: str
severity: LintSeverity
name: str
original: Optional[str]
replacement: Optional[str]
description: Optional[str]
original: str | None
replacement: str | None
description: str | None
def as_posix(name: str) -> str:
@ -65,7 +67,7 @@ def as_posix(name: str) -> str:
# c10/core/DispatchKey.cpp:281:26: error: 'k' used after it was moved [bugprone-use-after-move]
RESULTS_RE: Pattern[str] = re.compile(
RESULTS_RE: re.Pattern[str] = re.compile(
r"""(?mx)
^
(?P<file>.*?):
@ -80,8 +82,8 @@ RESULTS_RE: Pattern[str] = re.compile(
def run_command(
args: List[str],
) -> "subprocess.CompletedProcess[bytes]":
args: list[str],
) -> subprocess.CompletedProcess[bytes]:
logging.debug("$ %s", " ".join(args))
start_time = time.monotonic()
try:
@ -103,7 +105,7 @@ severities = {
}
def clang_search_dirs() -> List[str]:
def clang_search_dirs() -> list[str]:
# Compilers are ordered based on fallback preference
# We pick the first one that is available on the system
compilers = ["clang", "gcc", "cpp", "cc"]
@ -152,7 +154,7 @@ def check_file(
filename: str,
binary: str,
build_dir: Path,
) -> List[LintMessage]:
) -> list[LintMessage]:
try:
proc = run_command(
[binary, f"-p={build_dir}", *include_args, filename],

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import argparse
import concurrent.futures
import json
@ -7,7 +9,7 @@ import re
import subprocess
import time
from enum import Enum
from typing import List, NamedTuple, Optional, Pattern
from typing import NamedTuple
LINTER_CODE = "CMAKE"
@ -21,19 +23,19 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple):
path: Optional[str]
line: Optional[int]
char: Optional[int]
path: str | None
line: int | None
char: int | None
code: str
severity: LintSeverity
name: str
original: Optional[str]
replacement: Optional[str]
description: Optional[str]
original: str | None
replacement: str | None
description: str | None
# CMakeLists.txt:901: Lines should be <= 80 characters long [linelength]
RESULTS_RE: Pattern[str] = re.compile(
RESULTS_RE: re.Pattern[str] = re.compile(
r"""(?mx)
^
(?P<file>.*?):
@ -46,8 +48,8 @@ RESULTS_RE: Pattern[str] = re.compile(
def run_command(
args: List[str],
) -> "subprocess.CompletedProcess[bytes]":
args: list[str],
) -> subprocess.CompletedProcess[bytes]:
logging.debug("$ %s", " ".join(args))
start_time = time.monotonic()
try:
@ -63,7 +65,7 @@ def run_command(
def check_file(
filename: str,
config: str,
) -> List[LintMessage]:
) -> list[LintMessage]:
try:
proc = run_command(
["cmakelint", f"--config={config}", filename],

View File

@ -2,13 +2,15 @@
CONSTEXPR: Ensures users don't use vanilla constexpr since it causes issues
"""
from __future__ import annotations
import argparse
import json
import logging
import sys
from enum import Enum
from typing import NamedTuple, Optional
from typing import NamedTuple
CONSTEXPR = "constexpr char"
CONSTEXPR_MACRO = "CONSTEXPR_EXCEPT_WIN_CUDA char"
@ -21,18 +23,18 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple):
path: Optional[str]
line: Optional[int]
char: Optional[int]
path: str | None
line: int | None
char: int | None
code: str
severity: LintSeverity
name: str
original: Optional[str]
replacement: Optional[str]
description: Optional[str]
original: str | None
replacement: str | None
description: str | None
def check_file(filename: str) -> Optional[LintMessage]:
def check_file(filename: str) -> LintMessage | None:
logging.debug("Checking file %s", filename)
with open(filename) as f:

View File

@ -1,14 +1,17 @@
"""
EXEC: Ensure that source files are not executable.
"""
from __future__ import annotations
import argparse
import json
import logging
import os
import sys
from enum import Enum
from typing import NamedTuple, Optional
from typing import NamedTuple
LINTER_CODE = "EXEC"
@ -21,18 +24,18 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple):
path: Optional[str]
line: Optional[int]
char: Optional[int]
path: str | None
line: int | None
char: int | None
code: str
severity: LintSeverity
name: str
original: Optional[str]
replacement: Optional[str]
description: Optional[str]
original: str | None
replacement: str | None
description: str | None
def check_file(filename: str) -> Optional[LintMessage]:
def check_file(filename: str) -> LintMessage | None:
is_executable = os.access(filename, os.X_OK)
if is_executable:
return LintMessage(

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import argparse
import json
import logging
@ -7,7 +9,7 @@ import subprocess
import sys
import time
from enum import Enum
from typing import Any, Dict, List, NamedTuple, Optional, Pattern, Set
from typing import Any, NamedTuple
IS_WINDOWS: bool = os.name == "nt"
@ -25,15 +27,15 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple):
path: Optional[str]
line: Optional[int]
char: Optional[int]
path: str | None
line: int | None
char: int | None
code: str
severity: LintSeverity
name: str
original: Optional[str]
replacement: Optional[str]
description: Optional[str]
original: str | None
replacement: str | None
description: str | None
def as_posix(name: str) -> str:
@ -42,7 +44,7 @@ def as_posix(name: str) -> str:
# fmt: off
# https://www.flake8rules.com/
DOCUMENTED_IN_FLAKE8RULES: Set[str] = {
DOCUMENTED_IN_FLAKE8RULES: set[str] = {
"E101", "E111", "E112", "E113", "E114", "E115", "E116", "E117",
"E121", "E122", "E123", "E124", "E125", "E126", "E127", "E128", "E129",
"E131", "E133",
@ -78,14 +80,14 @@ DOCUMENTED_IN_FLAKE8RULES: Set[str] = {
}
# https://pypi.org/project/flake8-comprehensions/#rules
DOCUMENTED_IN_FLAKE8COMPREHENSIONS: Set[str] = {
DOCUMENTED_IN_FLAKE8COMPREHENSIONS: set[str] = {
"C400", "C401", "C402", "C403", "C404", "C405", "C406", "C407", "C408", "C409",
"C410",
"C411", "C412", "C413", "C414", "C415", "C416",
}
# https://github.com/PyCQA/flake8-bugbear#list-of-warnings
DOCUMENTED_IN_BUGBEAR: Set[str] = {
DOCUMENTED_IN_BUGBEAR: set[str] = {
"B001", "B002", "B003", "B004", "B005", "B006", "B007", "B008", "B009", "B010",
"B011", "B012", "B013", "B014", "B015",
"B301", "B302", "B303", "B304", "B305", "B306",
@ -98,7 +100,7 @@ DOCUMENTED_IN_BUGBEAR: Set[str] = {
# stdin:3:6: T484 Name 'foo' is not defined
# stdin:3:-100: W605 invalid escape sequence '\/'
# stdin:3:1: E302 expected 2 blank lines, found 1
RESULTS_RE: Pattern[str] = re.compile(
RESULTS_RE: re.Pattern[str] = re.compile(
r"""(?mx)
^
(?P<file>.*?):
@ -134,10 +136,10 @@ def _test_results_re() -> None:
def _run_command(
args: List[str],
args: list[str],
*,
extra_env: Optional[Dict[str, str]],
) -> "subprocess.CompletedProcess[str]":
extra_env: dict[str, str] | None,
) -> subprocess.CompletedProcess[str]:
logging.debug(
"$ %s",
" ".join(
@ -158,11 +160,11 @@ def _run_command(
def run_command(
args: List[str],
args: list[str],
*,
extra_env: Optional[Dict[str, str]],
extra_env: dict[str, str] | None,
retries: int,
) -> "subprocess.CompletedProcess[str]":
) -> subprocess.CompletedProcess[str]:
remaining_retries = retries
while True:
try:
@ -243,11 +245,11 @@ def get_issue_documentation_url(code: str) -> str:
def check_files(
filenames: List[str],
flake8_plugins_path: Optional[str],
severities: Dict[str, LintSeverity],
filenames: list[str],
flake8_plugins_path: str | None,
severities: dict[str, LintSeverity],
retries: int,
) -> List[LintMessage]:
) -> list[LintMessage]:
try:
proc = run_command(
[sys.executable, "-mflake8", "--exit-zero"] + filenames,
@ -351,7 +353,7 @@ def main() -> None:
else os.path.realpath(args.flake8_plugins_path)
)
severities: Dict[str, LintSeverity] = {}
severities: dict[str, LintSeverity] = {}
if args.severity:
for severity in args.severity:
parts = severity.split(":", 1)

View File

@ -2,6 +2,8 @@
Generic linter that greps for a pattern and optionally suggests replacements.
"""
from __future__ import annotations
import argparse
import json
import logging
@ -10,7 +12,7 @@ import subprocess
import sys
import time
from enum import Enum
from typing import Any, List, NamedTuple, Optional
from typing import Any, NamedTuple
IS_WINDOWS: bool = os.name == "nt"
@ -28,15 +30,15 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple):
path: Optional[str]
line: Optional[int]
char: Optional[int]
path: str | None
line: int | None
char: int | None
code: str
severity: LintSeverity
name: str
original: Optional[str]
replacement: Optional[str]
description: Optional[str]
original: str | None
replacement: str | None
description: str | None
def as_posix(name: str) -> str:
@ -44,8 +46,8 @@ def as_posix(name: str) -> str:
def run_command(
args: List[str],
) -> "subprocess.CompletedProcess[bytes]":
args: list[str],
) -> subprocess.CompletedProcess[bytes]:
logging.debug("$ %s", " ".join(args))
start_time = time.monotonic()
try:
@ -65,7 +67,7 @@ def lint_file(
linter_name: str,
error_name: str,
error_description: str,
) -> Optional[LintMessage]:
) -> LintMessage | None:
# matching_line looks like:
# tools/linter/clangtidy_linter.py:13:import foo.bar.baz
split = matching_line.split(":")

View File

@ -1,8 +1,10 @@
from __future__ import annotations
import json
import subprocess
import sys
from enum import Enum
from typing import NamedTuple, Optional, Tuple
from typing import NamedTuple
LINTER_CODE = "LINTRUNNER_VERSION"
@ -16,18 +18,18 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple):
path: Optional[str]
line: Optional[int]
char: Optional[int]
path: str | None
line: int | None
char: int | None
code: str
severity: LintSeverity
name: str
original: Optional[str]
replacement: Optional[str]
description: Optional[str]
original: str | None
replacement: str | None
description: str | None
def toVersionString(version_tuple: Tuple[int, int, int]) -> str:
def toVersionString(version_tuple: tuple[int, int, int]) -> str:
return ".".join(str(x) for x in version_tuple)

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import argparse
import json
import logging
@ -8,7 +10,7 @@ import sys
import time
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, NamedTuple, Optional, Pattern
from typing import Any, NamedTuple
IS_WINDOWS: bool = os.name == "nt"
@ -26,15 +28,15 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple):
path: Optional[str]
line: Optional[int]
char: Optional[int]
path: str | None
line: int | None
char: int | None
code: str
severity: LintSeverity
name: str
original: Optional[str]
replacement: Optional[str]
description: Optional[str]
original: str | None
replacement: str | None
description: str | None
def as_posix(name: str) -> str:
@ -42,7 +44,7 @@ def as_posix(name: str) -> str:
# tools/linter/flake8_linter.py:15:13: error: Incompatibl...int") [assignment]
RESULTS_RE: Pattern[str] = re.compile(
RESULTS_RE: re.Pattern[str] = re.compile(
r"""(?mx)
^
(?P<file>.*?):
@ -56,7 +58,7 @@ RESULTS_RE: Pattern[str] = re.compile(
)
# torch/_dynamo/variables/tensor.py:363: error: INTERNAL ERROR
INTERNAL_ERROR_RE: Pattern[str] = re.compile(
INTERNAL_ERROR_RE: re.Pattern[str] = re.compile(
r"""(?mx)
^
(?P<file>.*?):
@ -69,11 +71,11 @@ INTERNAL_ERROR_RE: Pattern[str] = re.compile(
def run_command(
args: List[str],
args: list[str],
*,
extra_env: Optional[Dict[str, str]],
extra_env: dict[str, str] | None,
retries: int,
) -> "subprocess.CompletedProcess[bytes]":
) -> subprocess.CompletedProcess[bytes]:
logging.debug("$ %s", " ".join(args))
start_time = time.monotonic()
try:
@ -94,7 +96,7 @@ severities = {
}
def check_mypy_installed(code: str) -> List[LintMessage]:
def check_mypy_installed(code: str) -> list[LintMessage]:
cmd = [sys.executable, "-mmypy", "-V"]
try:
subprocess.run(cmd, check=True, capture_output=True)
@ -117,11 +119,11 @@ def check_mypy_installed(code: str) -> List[LintMessage]:
def check_files(
filenames: List[str],
filenames: list[str],
config: str,
retries: int,
code: str,
) -> List[LintMessage]:
) -> list[LintMessage]:
# dmypy has a bug where it won't pick up changes if you pass it absolute
# file names, see https://github.com/python/mypy/issues/16768
filenames = [os.path.relpath(f) for f in filenames]
@ -224,7 +226,7 @@ def main() -> None:
# Use a dictionary here to preserve order. mypy cares about order,
# tragically, e.g. https://github.com/python/mypy/issues/2015
filenames: Dict[str, bool] = {}
filenames: dict[str, bool] = {}
# If a stub file exists, have mypy check it instead of the original file, in
# accordance with PEP-484 (see https://www.python.org/dev/peps/pep-0484/#stub-files)

View File

@ -14,12 +14,14 @@ is simply to make sure that there is *some* configuration of ruamel that can rou
the YAML, not to be prescriptive about it.
"""
from __future__ import annotations
import argparse
import json
import sys
from enum import Enum
from io import StringIO
from typing import NamedTuple, Optional
from typing import NamedTuple
import ruamel.yaml # type: ignore[import]
@ -32,15 +34,15 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple):
path: Optional[str]
line: Optional[int]
char: Optional[int]
path: str | None
line: int | None
char: int | None
code: str
severity: LintSeverity
name: str
original: Optional[str]
replacement: Optional[str]
description: Optional[str]
original: str | None
replacement: str | None
description: str | None
if __name__ == "__main__":

View File

@ -1,13 +1,16 @@
"""
NEWLINE: Checks files to make sure there are no trailing newlines.
"""
from __future__ import annotations
import argparse
import json
import logging
import sys
from enum import Enum
from typing import List, NamedTuple, Optional
from typing import NamedTuple
NEWLINE = 10 # ASCII "\n"
CARRIAGE_RETURN = 13 # ASCII "\r"
@ -22,18 +25,18 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple):
path: Optional[str]
line: Optional[int]
char: Optional[int]
path: str | None
line: int | None
char: int | None
code: str
severity: LintSeverity
name: str
original: Optional[str]
replacement: Optional[str]
description: Optional[str]
original: str | None
replacement: str | None
description: str | None
def check_file(filename: str) -> Optional[LintMessage]:
def check_file(filename: str) -> LintMessage | None:
logging.debug("Checking file %s", filename)
with open(filename, "rb") as f:
@ -85,7 +88,7 @@ def check_file(filename: str) -> Optional[LintMessage]:
description="Trailing newline found. Run `lintrunner --take NEWLINE -a` to apply changes.",
)
has_changes = False
original_lines: Optional[List[bytes]] = None
original_lines: list[bytes] | None = None
for idx, line in enumerate(lines):
if len(line) >= 2 and line[-1] == NEWLINE and line[-2] == CARRIAGE_RETURN:
if not has_changes:

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import argparse
import concurrent.futures
import json
@ -5,7 +7,7 @@ import logging
import os
import sys
from enum import Enum
from typing import Any, List, NamedTuple, Optional
from typing import Any, NamedTuple
IS_WINDOWS: bool = os.name == "nt"
@ -23,18 +25,18 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple):
path: Optional[str]
line: Optional[int]
char: Optional[int]
path: str | None
line: int | None
char: int | None
code: str
severity: LintSeverity
name: str
original: Optional[str]
replacement: Optional[str]
description: Optional[str]
original: str | None
replacement: str | None
description: str | None
def check_file(filename: str) -> List[LintMessage]:
def check_file(filename: str) -> list[LintMessage]:
with open(filename, "rb") as f:
original = f.read().decode("utf-8")
replacement = ""

View File

@ -1,6 +1,9 @@
"""
Initializer script that installs stuff to pip.
"""
from __future__ import annotations
import argparse
import logging
import os
@ -9,10 +12,8 @@ import subprocess
import sys
import time
from typing import List
def run_command(args: List[str]) -> "subprocess.CompletedProcess[bytes]":
def run_command(args: list[str]) -> subprocess.CompletedProcess[bytes]:
logging.debug("$ %s", " ".join(args))
start_time = time.monotonic()
try:

View File

@ -14,6 +14,7 @@ import sys
import time
from typing import Any, BinaryIO
LINTER_CODE = "RUFF"
IS_WINDOWS: bool = os.name == "nt"

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import argparse
import json
import logging
@ -6,7 +8,7 @@ import subprocess
import sys
import time
from enum import Enum
from typing import List, NamedTuple, Optional
from typing import NamedTuple
LINTER_CODE = "SHELLCHECK"
@ -20,20 +22,20 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple):
path: Optional[str]
line: Optional[int]
char: Optional[int]
path: str | None
line: int | None
char: int | None
code: str
severity: LintSeverity
name: str
original: Optional[str]
replacement: Optional[str]
description: Optional[str]
original: str | None
replacement: str | None
description: str | None
def run_command(
args: List[str],
) -> "subprocess.CompletedProcess[bytes]":
args: list[str],
) -> subprocess.CompletedProcess[bytes]:
logging.debug("$ %s", " ".join(args))
start_time = time.monotonic()
try:
@ -47,8 +49,8 @@ def run_command(
def check_files(
files: List[str],
) -> List[LintMessage]:
files: list[str],
) -> list[LintMessage]:
try:
proc = run_command(
["shellcheck", "--external-sources", "--format=json1"] + files

View File

@ -6,15 +6,19 @@ calls run_tests to ensure that the test will be run in OSS CI.
Takes ~2 minuters to run without the multiprocessing, probably overkill.
"""
from __future__ import annotations
import argparse
import json
import multiprocessing as mp
from enum import Enum
from typing import List, NamedTuple, Optional
from typing import NamedTuple
import libcst as cst
import libcst.matchers as m
LINTER_CODE = "TEST_HAS_MAIN"
@ -62,18 +66,18 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple):
path: Optional[str]
line: Optional[int]
char: Optional[int]
path: str | None
line: int | None
char: int | None
code: str
severity: LintSeverity
name: str
original: Optional[str]
replacement: Optional[str]
description: Optional[str]
original: str | None
replacement: str | None
description: str | None
def check_file(filename: str) -> List[LintMessage]:
def check_file(filename: str) -> list[LintMessage]:
lint_messages = []
with open(filename) as f:

View File

@ -8,10 +8,13 @@ has valid ownership information in a comment header. Valid means:
- Each owner label actually exists in PyTorch
- Each owner label starts with "module: " or "oncall: " or is in ACCEPTABLE_OWNER_LABELS
"""
from __future__ import annotations
import argparse
import json
from enum import Enum
from typing import Any, List, NamedTuple, Optional
from typing import Any, NamedTuple
from urllib.request import urlopen
@ -26,15 +29,15 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple):
path: Optional[str]
line: Optional[int]
char: Optional[int]
path: str | None
line: int | None
char: int | None
code: str
severity: LintSeverity
name: str
original: Optional[str]
replacement: Optional[str]
description: Optional[str]
original: str | None
replacement: str | None
description: str | None
# Team/owner labels usually start with "module: " or "oncall: ", but the following are acceptable exceptions
@ -58,8 +61,8 @@ GLOB_EXCEPTIONS = ["**/test/run_test.py"]
def check_labels(
labels: List[str], filename: str, line_number: int
) -> List[LintMessage]:
labels: list[str], filename: str, line_number: int
) -> list[LintMessage]:
lint_messages = []
for label in labels:
if label not in PYTORCH_LABELS:
@ -104,7 +107,7 @@ def check_labels(
return lint_messages
def check_file(filename: str) -> List[LintMessage]:
def check_file(filename: str) -> list[LintMessage]:
lint_messages = []
has_ownership_info = False

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import argparse
import concurrent.futures
import json
@ -6,7 +8,7 @@ import os
import sys
from enum import Enum
from pathlib import Path
from typing import Any, List, NamedTuple, Optional
from typing import Any, NamedTuple
from ufmt.core import ufmt_string
from ufmt.util import make_black_config
@ -28,15 +30,15 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple):
path: Optional[str]
line: Optional[int]
char: Optional[int]
path: str | None
line: int | None
char: int | None
code: str
severity: LintSeverity
name: str
original: Optional[str]
replacement: Optional[str]
description: Optional[str]
original: str | None
replacement: str | None
description: str | None
def as_posix(name: str) -> str:
@ -59,7 +61,7 @@ def format_error_message(filename: str, err: Exception) -> LintMessage:
def check_file(
filename: str,
) -> List[LintMessage]:
) -> list[LintMessage]:
with open(filename, "rb") as f:
original = f.read().decode("utf-8")

View File

@ -2,16 +2,20 @@
Any job with a specific `sync-tag` must match all other jobs with the same `sync-tag`.
"""
from __future__ import annotations
import argparse
import itertools
import json
from collections import defaultdict
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Iterable, NamedTuple, Optional
from typing import Any, Iterable, NamedTuple
from yaml import dump, load
# Safely load fast C Yaml loader/dumper if they are available
try:
from yaml import CSafeLoader as Loader
@ -27,15 +31,15 @@ class LintSeverity(str, Enum):
class LintMessage(NamedTuple):
path: Optional[str]
line: Optional[int]
char: Optional[int]
path: str | None
line: int | None
char: int | None
code: str
severity: LintSeverity
name: str
original: Optional[str]
replacement: Optional[str]
description: Optional[str]
original: str | None
replacement: str | None
description: str | None
def glob_yamls(path: Path) -> Iterable[Path]:
@ -51,7 +55,7 @@ def is_workflow(yaml: Any) -> bool:
return yaml.get("jobs") is not None
def print_lint_message(path: Path, job: Dict[str, Any], sync_tag: str) -> None:
def print_lint_message(path: Path, job: dict[str, Any], sync_tag: str) -> None:
job_id = next(iter(job.keys()))
with open(path) as f:
lines = f.readlines()

View File

@ -1,10 +1,11 @@
from __future__ import annotations
import os
import subprocess
import sys
from typing import List
def run_cmd(cmd: List[str]) -> None:
def run_cmd(cmd: list[str]) -> None:
print(f"Running: {cmd}")
result = subprocess.run(
cmd,

View File

@ -1,13 +1,16 @@
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import os
from typing import Set
import yaml
from torchgen.code_template import CodeTemplate
from torchgen.selective_build.selector import SelectiveBuilder
# Safely load fast C Yaml loader/dumper if they are available
try:
from yaml import CSafeLoader as Loader
@ -46,7 +49,7 @@ selected_mobile_ops_preamble = """#pragma once
"""
def extract_root_operators(selective_builder: SelectiveBuilder) -> Set[str]:
def extract_root_operators(selective_builder: SelectiveBuilder) -> set[str]:
ops = []
for op_name, op in selective_builder.operators.items():
if op.is_root_operator:
@ -125,7 +128,7 @@ def write_selected_mobile_ops(
# 2. All kernel dtypes
def write_selected_mobile_ops_with_all_dtypes(
output_file_path: str,
root_ops: Set[str],
root_ops: set[str],
) -> None:
with open(output_file_path, "wb") as out_file:
body_parts = [selected_mobile_ops_preamble]

View File

@ -1,5 +1,6 @@
import lldb # type: ignore[import]
# load into lldb instance with:
# command script import tools/lldb/deploy_debugger.py

View File

@ -24,6 +24,9 @@ well. This can be done with
Pulling will reinstalle the conda dependencies as well as the nightly binaries into
the repo directory.
"""
from __future__ import annotations
import contextlib
import datetime
import functools
@ -40,23 +43,10 @@ import time
import uuid
from argparse import ArgumentParser
from ast import literal_eval
from typing import (
Any,
Callable,
cast,
Dict,
Generator,
Iterable,
Iterator,
List,
Optional,
Sequence,
Set,
Tuple,
TypeVar,
)
from typing import Any, Callable, cast, Generator, Iterable, Iterator, Sequence, TypeVar
LOGGER: Optional[logging.Logger] = None
LOGGER: logging.Logger | None = None
URL_FORMAT = "{base_url}/{platform}/{dist_name}.tar.bz2"
DATETIME_FORMAT = "%Y-%m-%d_%Hh%Mm%Ss"
SHA1_RE = re.compile("([0-9a-fA-F]{40})")
@ -68,9 +58,9 @@ SPECS_TO_INSTALL = ("pytorch", "mypy", "pytest", "hypothesis", "ipython", "sphin
class Formatter(logging.Formatter):
redactions: Dict[str, str]
redactions: dict[str, str]
def __init__(self, fmt: Optional[str] = None, datefmt: Optional[str] = None):
def __init__(self, fmt: str | None = None, datefmt: str | None = None) -> None:
super().__init__(fmt, datefmt)
self.redactions = {}
@ -192,7 +182,7 @@ def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, N
sys.exit(1)
def check_in_repo() -> Optional[str]:
def check_in_repo() -> str | None:
"""Ensures that we are in the PyTorch repo."""
if not os.path.isfile("setup.py"):
return "Not in root-level PyTorch repo, no setup.py found"
@ -203,7 +193,7 @@ def check_in_repo() -> Optional[str]:
return None
def check_branch(subcommand: str, branch: Optional[str]) -> Optional[str]:
def check_branch(subcommand: str, branch: str | None) -> str | None:
"""Checks that the branch name can be checked out."""
if subcommand != "checkout":
return None
@ -259,7 +249,7 @@ def timed(prefix: str) -> Callable[[F], F]:
def _make_channel_args(
channels: Iterable[str] = ("pytorch-nightly",),
override_channels: bool = False,
) -> List[str]:
) -> list[str]:
args = []
for channel in channels:
args.append("--channel")
@ -271,11 +261,11 @@ def _make_channel_args(
@timed("Solving conda environment")
def conda_solve(
name: Optional[str] = None,
prefix: Optional[str] = None,
name: str | None = None,
prefix: str | None = None,
channels: Iterable[str] = ("pytorch-nightly",),
override_channels: bool = False,
) -> Tuple[List[str], str, str, bool, List[str]]:
) -> tuple[list[str], str, str, bool, list[str]]:
"""Performs the conda solve and splits the deps from the package."""
# compute what environment to use
if prefix is not None:
@ -329,7 +319,7 @@ def conda_solve(
@timed("Installing dependencies")
def deps_install(deps: List[str], existing_env: bool, env_opts: List[str]) -> None:
def deps_install(deps: list[str], existing_env: bool, env_opts: list[str]) -> None:
"""Install dependencies to deps environment"""
if not existing_env:
# first remove previous pytorch-deps env
@ -342,7 +332,7 @@ def deps_install(deps: List[str], existing_env: bool, env_opts: List[str]) -> No
@timed("Installing pytorch nightly binaries")
def pytorch_install(url: str) -> "tempfile.TemporaryDirectory[str]":
def pytorch_install(url: str) -> tempfile.TemporaryDirectory[str]:
"""Install pytorch into a temporary directory"""
pytdir = tempfile.TemporaryDirectory()
cmd = ["conda", "create", "--yes", "--no-deps", "--prefix", pytdir.name, url]
@ -421,33 +411,33 @@ def pull_nightly_version(spdir: str) -> None:
p = subprocess.run(cmd, check=True)
def _get_listing_linux(source_dir: str) -> List[str]:
def _get_listing_linux(source_dir: str) -> list[str]:
listing = glob.glob(os.path.join(source_dir, "*.so"))
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.so")))
return listing
def _get_listing_osx(source_dir: str) -> List[str]:
def _get_listing_osx(source_dir: str) -> list[str]:
# oddly, these are .so files even on Mac
listing = glob.glob(os.path.join(source_dir, "*.so"))
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dylib")))
return listing
def _get_listing_win(source_dir: str) -> List[str]:
def _get_listing_win(source_dir: str) -> list[str]:
listing = glob.glob(os.path.join(source_dir, "*.pyd"))
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.lib")))
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dll")))
return listing
def _glob_pyis(d: str) -> Set[str]:
def _glob_pyis(d: str) -> set[str]:
search = os.path.join(d, "**", "*.pyi")
pyis = {os.path.relpath(p, d) for p in glob.iglob(search)}
return pyis
def _find_missing_pyi(source_dir: str, target_dir: str) -> List[str]:
def _find_missing_pyi(source_dir: str, target_dir: str) -> list[str]:
source_pyis = _glob_pyis(source_dir)
target_pyis = _glob_pyis(target_dir)
missing_pyis = [os.path.join(source_dir, p) for p in (source_pyis - target_pyis)]
@ -455,7 +445,7 @@ def _find_missing_pyi(source_dir: str, target_dir: str) -> List[str]:
return missing_pyis
def _get_listing(source_dir: str, target_dir: str, platform: str) -> List[str]:
def _get_listing(source_dir: str, target_dir: str, platform: str) -> list[str]:
if platform.startswith("linux"):
listing = _get_listing_linux(source_dir)
elif platform.startswith("osx"):
@ -510,12 +500,12 @@ def _move_single(
mover(src, trg)
def _copy_files(listing: List[str], source_dir: str, target_dir: str) -> None:
def _copy_files(listing: list[str], source_dir: str, target_dir: str) -> None:
for src in listing:
_move_single(src, source_dir, target_dir, shutil.copy2, "Copying")
def _link_files(listing: List[str], source_dir: str, target_dir: str) -> None:
def _link_files(listing: list[str], source_dir: str, target_dir: str) -> None:
for src in listing:
_move_single(src, source_dir, target_dir, os.link, "Linking")
@ -537,7 +527,7 @@ def move_nightly_files(spdir: str, platform: str) -> None:
_copy_files(listing, source_dir, target_dir)
def _available_envs() -> Dict[str, str]:
def _available_envs() -> dict[str, str]:
cmd = ["conda", "env", "list"]
p = subprocess.run(
cmd,
@ -559,7 +549,7 @@ def _available_envs() -> Dict[str, str]:
@timed("Writing pytorch-nightly.pth")
def write_pth(env_opts: List[str], platform: str) -> None:
def write_pth(env_opts: list[str], platform: str) -> None:
"""Writes Python path file for this dir."""
env_type, env_dir = env_opts
if env_type == "--name":
@ -582,9 +572,9 @@ def install(
*,
logger: logging.Logger,
subcommand: str = "checkout",
branch: Optional[str] = None,
name: Optional[str] = None,
prefix: Optional[str] = None,
branch: str | None = None,
name: str | None = None,
prefix: str | None = None,
channels: Iterable[str] = ("pytorch-nightly",),
override_channels: bool = False,
) -> None:
@ -673,7 +663,7 @@ def make_parser() -> ArgumentParser:
return p
def main(args: Optional[Sequence[str]] = None) -> None:
def main(args: Sequence[str] | None = None) -> None:
"""Main entry point"""
global LOGGER
p = make_parser()

View File

@ -13,13 +13,15 @@ CMAKE_CUDA_COMPILER_LAUNCHER="python;tools/nvcc_fix_deps.py;ccache"
"""
from __future__ import annotations
import subprocess
import sys
from pathlib import Path
from typing import List, Optional, TextIO
from typing import TextIO
def resolve_include(path: Path, include_dirs: List[Path]) -> Path:
def resolve_include(path: Path, include_dirs: list[Path]) -> Path:
for include_path in include_dirs:
abs_path = include_path / path
if abs_path.exists():
@ -36,7 +38,7 @@ Tried the following paths, but none existed:
)
def repair_depfile(depfile: TextIO, include_dirs: List[Path]) -> None:
def repair_depfile(depfile: TextIO, include_dirs: list[Path]) -> None:
changes_made = False
out = ""
for line in depfile:
@ -70,8 +72,8 @@ PRE_INCLUDE_ARGS = ["-include", "--pre-include"]
POST_INCLUDE_ARGS = ["-I", "--include-path", "-isystem", "--system-include"]
def extract_include_arg(include_dirs: List[Path], i: int, args: List[str]) -> None:
def extract_one(name: str, i: int, args: List[str]) -> Optional[str]:
def extract_include_arg(include_dirs: list[Path], i: int, args: list[str]) -> None:
def extract_one(name: str, i: int, args: list[str]) -> str | None:
arg = args[i]
if arg == name:
return args[i + 1]

View File

@ -24,6 +24,7 @@ import yaml
from torchgen import utils as torchgen_utils
from torchgen.yaml_utils import YamlLoader
_RULES_GENERATED_COMMENT = """\
GENERATED CODE - DO NOT EDIT DIRECTLY
This file is generated by gen_diagnostics.py.

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import argparse
import collections
import importlib
import sys
from pprint import pformat
from typing import Dict, List, Sequence
from typing import Sequence
from unittest.mock import Mock, patch
from warnings import warn
@ -220,7 +222,7 @@ to_py_type_ops = ("bool", "float", "complex", "long", "index", "int", "nonzero")
all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops
def sig_for_ops(opname: str) -> List[str]:
def sig_for_ops(opname: str) -> list[str]:
"""sig_for_ops(opname : str) -> List[str]
Returns signatures for operator special functions (__add__ etc.)"""
@ -254,8 +256,8 @@ def sig_for_ops(opname: str) -> List[str]:
raise Exception("unknown op", opname) # noqa: TRY002
def generate_type_hints(sig_group: PythonSignatureGroup) -> List[str]:
type_hints: List[str] = []
def generate_type_hints(sig_group: PythonSignatureGroup) -> list[str]:
type_hints: list[str] = []
# Some deprecated ops that are on the blocklist are still included in pyi
if sig_group.signature.name in blocklist and not sig_group.signature.deprecated:
@ -285,7 +287,7 @@ def generate_type_hints(sig_group: PythonSignatureGroup) -> List[str]:
return type_hints
def get_max_pool_dispatch(name: str, arg_list: List[str]) -> Dict[str, List[str]]:
def get_max_pool_dispatch(name: str, arg_list: list[str]) -> dict[str, list[str]]:
flag_pos = arg_list.index("{return_indices}")
# If return_indices is positional arg, everything before should have no default
arg_list_positional = (
@ -329,7 +331,7 @@ def gen_nn_functional(fm: FileManager) -> None:
)
# TODO the list for `torch._C._nn` is nonexhaustive
unsorted_c_nn_function_hints: Dict[str, List[str]] = {}
unsorted_c_nn_function_hints: dict[str, list[str]] = {}
for d in (2, 3):
unsorted_c_nn_function_hints.update(
@ -471,7 +473,7 @@ def gen_nn_functional(fm: FileManager) -> None:
}
)
c_nn_function_hints: List[str] = []
c_nn_function_hints: list[str] = []
for _, hints in sorted(unsorted_c_nn_function_hints.items()):
if len(hints) > 1:
hints = ["@overload\n" + h for h in hints]
@ -528,7 +530,7 @@ def gen_nn_functional(fm: FileManager) -> None:
)
# Functions generated by `torch._jit_internal.boolean_dispatch` in `nn.functional`
unsorted_dispatched_hints: Dict[str, List[str]] = {}
unsorted_dispatched_hints: dict[str, list[str]] = {}
for d in (1, 2, 3):
unsorted_dispatched_hints.update(
@ -563,7 +565,7 @@ def gen_nn_functional(fm: FileManager) -> None:
# There's no fractional_max_pool1d
del unsorted_dispatched_hints["fractional_max_pool1d"]
dispatched_hints: List[str] = []
dispatched_hints: list[str] = []
for _, hints in sorted(unsorted_dispatched_hints.items()):
if len(hints) > 1:
hints = ["@overload\n" + h for h in hints]
@ -594,7 +596,7 @@ We gather the docstrings for torch with the following steps:
"""
def gather_docstrs() -> Dict[str, str]:
def gather_docstrs() -> dict[str, str]:
docstrs = {}
def mock_add_docstr(func: Mock, docstr: str) -> None:
@ -648,12 +650,12 @@ def gen_pyi(
# also needs to update the other file.
# Dictionary for NamedTuple definitions
structseqs: Dict[str, str] = {}
structseqs: dict[str, str] = {}
# Generate type signatures for top-level functions
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
unsorted_function_hints: Dict[str, List[str]] = collections.defaultdict(list)
unsorted_function_hints: dict[str, list[str]] = collections.defaultdict(list)
for n, n1, n2 in [
("csr", "crow", "col"),
@ -1054,7 +1056,7 @@ def gen_pyi(
# Generate type signatures for Tensor methods
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
unsorted_tensor_method_hints: Dict[str, List[str]] = collections.defaultdict(list)
unsorted_tensor_method_hints: dict[str, list[str]] = collections.defaultdict(list)
unsorted_tensor_method_hints.update(
{
"size": [

View File

@ -1,8 +1,11 @@
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import os
from typing import Any, List, Union
from typing import Any
try:
from junitparser import ( # type: ignore[import]
@ -23,8 +26,8 @@ except ImportError:
print("rich not found, for color output use 'pip install rich'")
def parse_junit_reports(path_to_reports: str) -> List[TestCase]: # type: ignore[no-any-unimported]
def parse_file(path: str) -> List[TestCase]: # type: ignore[no-any-unimported]
def parse_junit_reports(path_to_reports: str) -> list[TestCase]: # type: ignore[no-any-unimported]
def parse_file(path: str) -> list[TestCase]: # type: ignore[no-any-unimported]
try:
return convert_junit_to_testcases(JUnitXml.fromfile(path))
except Exception as err:
@ -46,7 +49,7 @@ def parse_junit_reports(path_to_reports: str) -> List[TestCase]: # type: ignore
return ret_xml
def convert_junit_to_testcases(xml: Union[JUnitXml, TestSuite]) -> List[TestCase]: # type: ignore[no-any-unimported]
def convert_junit_to_testcases(xml: JUnitXml | TestSuite) -> list[TestCase]: # type: ignore[no-any-unimported]
testcases = []
for item in xml:
if isinstance(item, TestSuite):
@ -56,7 +59,7 @@ def convert_junit_to_testcases(xml: Union[JUnitXml, TestSuite]) -> List[TestCase
return testcases
def render_tests(testcases: List[TestCase]) -> None: # type: ignore[no-any-unimported]
def render_tests(testcases: list[TestCase]) -> None: # type: ignore[no-any-unimported]
num_passed = 0
num_skipped = 0
num_failed = 0

View File

@ -1,9 +1,10 @@
from __future__ import annotations
import os
import sys
from typing import Optional
def which(thefile: str) -> Optional[str]:
def which(thefile: str) -> str | None:
path = os.environ.get("PATH", os.defpath).split(os.pathsep)
for d in path:
fname = os.path.join(d, thefile)

View File

@ -1,5 +1,6 @@
"Manages CMake."
from __future__ import annotations
import multiprocessing
import os
@ -8,7 +9,7 @@ import sys
import sysconfig
from distutils.version import LooseVersion
from subprocess import CalledProcessError, check_call, check_output
from typing import Any, cast, Dict, List, Optional
from typing import Any, cast
from . import which
from .cmake_utils import CMakeValue, get_cmake_cache_variables_from_file
@ -77,7 +78,7 @@ class CMake:
return cmake_command
@staticmethod
def _get_version(cmd: Optional[str]) -> Any:
def _get_version(cmd: str | None) -> Any:
"Returns cmake version."
if cmd is None:
@ -87,7 +88,7 @@ class CMake:
return LooseVersion(line.strip().split(" ")[2])
raise RuntimeError("no version found")
def run(self, args: List[str], env: Dict[str, str]) -> None:
def run(self, args: list[str], env: dict[str, str]) -> None:
"Executes cmake with arguments and an environment."
command = [self._cmake_command] + args
@ -101,13 +102,13 @@ class CMake:
sys.exit(1)
@staticmethod
def defines(args: List[str], **kwargs: CMakeValue) -> None:
def defines(args: list[str], **kwargs: CMakeValue) -> None:
"Adds definitions to a cmake argument list."
for key, value in sorted(kwargs.items()):
if value is not None:
args.append(f"-D{key}={value}")
def get_cmake_cache_variables(self) -> Dict[str, CMakeValue]:
def get_cmake_cache_variables(self) -> dict[str, CMakeValue]:
r"""Gets values in CMakeCache.txt into a dictionary.
Returns:
dict: A ``dict`` containing the value of cached CMake variables.
@ -117,11 +118,11 @@ class CMake:
def generate(
self,
version: Optional[str],
cmake_python_library: Optional[str],
version: str | None,
cmake_python_library: str | None,
build_python: bool,
build_test: bool,
my_env: Dict[str, str],
my_env: dict[str, str],
rerun: bool,
) -> None:
"Runs cmake to generate native build files."
@ -181,7 +182,7 @@ class CMake:
_mkdir_p(self.build_dir)
# Store build options that are directly stored in environment variables
build_options: Dict[str, CMakeValue] = {}
build_options: dict[str, CMakeValue] = {}
# Build options that do not start with "BUILD_", "USE_", or "CMAKE_" and are directly controlled by env vars.
# This is a dict that maps environment variables to the corresponding variable name in CMake.
@ -340,7 +341,7 @@ class CMake:
args.append(base_dir)
self.run(args, env=my_env)
def build(self, my_env: Dict[str, str]) -> None:
def build(self, my_env: dict[str, str]) -> None:
"Runs cmake to build binaries."
from .env import build_type

View File

@ -3,8 +3,10 @@ This is refactored from cmake.py to avoid circular imports issue with env.py,
which calls get_cmake_cache_variables_from_file
"""
from __future__ import annotations
import re
from typing import Dict, IO, Optional, Union
from typing import IO, Optional, Union
CMakeValue = Optional[Union[bool, str]]
@ -42,7 +44,7 @@ def convert_cmake_value_to_python_value(
def get_cmake_cache_variables_from_file(
cmake_cache_file: IO[str],
) -> Dict[str, CMakeValue]:
) -> dict[str, CMakeValue]:
r"""Gets values in CMakeCache.txt into a dictionary.
Args:

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import os
import platform
import struct
import sys
from itertools import chain
from typing import cast, Iterable, List, Optional
from typing import cast, Iterable
IS_WINDOWS = platform.system() == "Windows"
@ -30,11 +32,11 @@ def check_negative_env_flag(name: str, default: str = "") -> bool:
return os.getenv(name, default).upper() in ["OFF", "0", "NO", "FALSE", "N"]
def gather_paths(env_vars: Iterable[str]) -> List[str]:
def gather_paths(env_vars: Iterable[str]) -> list[str]:
return list(chain(*(os.getenv(v, "").split(os.pathsep) for v in env_vars)))
def lib_paths_from_base(base_path: str) -> List[str]:
def lib_paths_from_base(base_path: str) -> list[str]:
return [os.path.join(base_path, s) for s in ["lib/x64", "lib", "lib64"]]
@ -54,7 +56,7 @@ class BuildType:
"""
def __init__(self, cmake_build_type_env: Optional[str] = None) -> None:
def __init__(self, cmake_build_type_env: str | None = None) -> None:
if cmake_build_type_env is not None:
self.build_type_string = cmake_build_type_env
return

View File

@ -2,9 +2,12 @@
# and use the version numbers from there as substitutions for
# an expand_template action. Since there isn't, this silly script exists.
from __future__ import annotations
import argparse
import os
from typing import cast, Dict, Tuple
from typing import cast, Tuple
Version = Tuple[int, int, int]
@ -30,7 +33,7 @@ def parse_version(version: str) -> Version:
return cast(Version, tuple([int(n) for n in version_number_str.split(".")]))
def apply_replacements(replacements: Dict[str, str], text: str) -> str:
def apply_replacements(replacements: dict[str, str], text: str) -> str:
"""
Applies the given replacements within the text.

View File

@ -1,8 +1,10 @@
from __future__ import annotations
import argparse
import os
import pathlib
import sys
from typing import Any, cast, Optional
from typing import Any, cast
import yaml
@ -18,10 +20,10 @@ TAGS_PATH = "aten/src/ATen/native/tags.yaml"
def generate_code(
gen_dir: pathlib.Path,
native_functions_path: Optional[str] = None,
tags_path: Optional[str] = None,
install_dir: Optional[str] = None,
subset: Optional[str] = None,
native_functions_path: str | None = None,
tags_path: str | None = None,
install_dir: str | None = None,
subset: str | None = None,
disable_autograd: bool = False,
force_schema_registration: bool = False,
operator_selector: Any = None,
@ -102,8 +104,8 @@ def get_selector_from_legacy_operator_selection_list(
def get_selector(
selected_op_list_path: Optional[str],
operators_yaml_path: Optional[str],
selected_op_list_path: str | None,
operators_yaml_path: str | None,
) -> Any:
# cwrap depends on pyyaml, so we can't import it earlier
root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

View File

@ -1,10 +1,12 @@
from __future__ import annotations
import argparse
import json
import os
import xml.etree.ElementTree as ET
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Dict, Generator, Tuple
from typing import Any, Generator
from tools.stats.upload_stats_lib import (
download_s3_artifacts,
@ -14,13 +16,14 @@ from tools.stats.upload_stats_lib import (
)
from tools.stats.upload_test_stats import process_xml_element
TESTCASE_TAG = "testcase"
SEPARATOR = ";"
def process_report(
report: Path,
) -> Dict[str, Dict[str, int]]:
) -> dict[str, dict[str, int]]:
"""
Return a list of disabled tests that should be re-enabled and those that are still
flaky (failed or skipped)
@ -36,7 +39,7 @@ def process_report(
# * Skipped tests from unittest
#
# We want to keep track of how many times the test fails (num_red) or passes (num_green)
all_tests: Dict[str, Dict[str, int]] = {}
all_tests: dict[str, dict[str, int]] = {}
for test_case in root.iter(TESTCASE_TAG):
parsed_test_case = process_xml_element(test_case)
@ -116,7 +119,7 @@ def get_test_reports(
yield from Path(".").glob("**/*.xml")
def get_disabled_test_name(test_id: str) -> Tuple[str, str, str, str]:
def get_disabled_test_name(test_id: str) -> tuple[str, str, str, str]:
"""
Follow flaky bot convention here, if that changes, this will also need to be updated
"""
@ -133,7 +136,7 @@ def prepare_record(
flaky: bool,
num_red: int = 0,
num_green: int = 0,
) -> Tuple[Any, Dict[str, Any]]:
) -> tuple[Any, dict[str, Any]]:
"""
Prepare the record to save onto S3
"""
@ -162,7 +165,7 @@ def prepare_record(
def save_results(
workflow_id: int,
workflow_run_attempt: int,
all_tests: Dict[str, Dict[str, int]],
all_tests: dict[str, dict[str, int]],
) -> None:
"""
Save the result to S3, so it can go to Rockset
@ -228,7 +231,7 @@ def main(repo: str, workflow_run_id: int, workflow_run_attempt: int) -> None:
Find the list of all disabled tests that should be re-enabled
"""
# Aggregated across all jobs
all_tests: Dict[str, Dict[str, int]] = {}
all_tests: dict[str, dict[str, int]] = {}
for report in get_test_reports(
args.repo, args.workflow_run_id, args.workflow_run_attempt

View File

@ -1,17 +1,19 @@
#!/usr/bin/env python3
from __future__ import annotations
import datetime
import json
import os
import pathlib
import shutil
from typing import Any, Callable, cast, Dict, List, Optional, Union
from typing import Any, Callable, cast, Dict
from urllib.request import urlopen
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
def get_disabled_issues() -> List[str]:
def get_disabled_issues() -> list[str]:
reenabled_issues = os.getenv("REENABLED_ISSUES", "")
issue_numbers = reenabled_issues.split(",")
print("Ignoring disabled issues: ", issue_numbers)
@ -34,11 +36,11 @@ FILE_CACHE_LIFESPAN_SECONDS = datetime.timedelta(hours=3).seconds
def fetch_and_cache(
dirpath: Union[str, pathlib.Path],
dirpath: str | pathlib.Path,
name: str,
url: str,
process_fn: Callable[[Dict[str, Any]], Dict[str, Any]],
) -> Dict[str, Any]:
process_fn: Callable[[dict[str, Any]], dict[str, Any]],
) -> dict[str, Any]:
"""
This fetch and cache utils allows sharing between different process.
"""
@ -76,7 +78,7 @@ def fetch_and_cache(
def get_slow_tests(
dirpath: str, filename: str = SLOW_TESTS_FILE
) -> Optional[Dict[str, float]]:
) -> dict[str, float] | None:
url = "https://ossci-metrics.s3.amazonaws.com/slow-tests.json"
try:
return fetch_and_cache(dirpath, filename, url, lambda x: x)
@ -85,7 +87,7 @@ def get_slow_tests(
return {}
def get_test_times() -> Dict[str, Dict[str, float]]:
def get_test_times() -> dict[str, dict[str, float]]:
return get_from_test_infra_generated_stats(
"test-times.json",
TEST_TIMES_FILE,
@ -93,7 +95,7 @@ def get_test_times() -> Dict[str, Dict[str, float]]:
)
def get_test_class_times() -> Dict[str, Dict[str, float]]:
def get_test_class_times() -> dict[str, dict[str, float]]:
return get_from_test_infra_generated_stats(
"test-class-times.json",
TEST_CLASS_TIMES_FILE,
@ -103,8 +105,8 @@ def get_test_class_times() -> Dict[str, Dict[str, float]]:
def get_disabled_tests(
dirpath: str, filename: str = DISABLED_TESTS_FILE
) -> Optional[Dict[str, Any]]:
def process_disabled_test(the_response: Dict[str, Any]) -> Dict[str, Any]:
) -> dict[str, Any] | None:
def process_disabled_test(the_response: dict[str, Any]) -> dict[str, Any]:
# remove re-enabled tests and condense even further by getting rid of pr_num
disabled_issues = get_disabled_issues()
disabled_test_from_issues = dict()
@ -124,7 +126,7 @@ def get_disabled_tests(
return {}
def get_test_file_ratings() -> Dict[str, Any]:
def get_test_file_ratings() -> dict[str, Any]:
return get_from_test_infra_generated_stats(
"file_test_rating.json",
TEST_FILE_RATINGS_FILE,
@ -132,7 +134,7 @@ def get_test_file_ratings() -> Dict[str, Any]:
)
def get_test_class_ratings() -> Dict[str, Any]:
def get_test_class_ratings() -> dict[str, Any]:
return get_from_test_infra_generated_stats(
"file_test_class_rating.json",
TEST_CLASS_RATINGS_FILE,
@ -140,7 +142,7 @@ def get_test_class_ratings() -> Dict[str, Any]:
)
def get_td_heuristic_historial_edited_files_json() -> Dict[str, Any]:
def get_td_heuristic_historial_edited_files_json() -> dict[str, Any]:
return get_from_test_infra_generated_stats(
"td_heuristic_historical_edited_files.json",
TD_HEURISTIC_HISTORICAL_EDITED_FILES,
@ -148,7 +150,7 @@ def get_td_heuristic_historial_edited_files_json() -> Dict[str, Any]:
)
def get_td_heuristic_profiling_json() -> Dict[str, Any]:
def get_td_heuristic_profiling_json() -> dict[str, Any]:
return get_from_test_infra_generated_stats(
"td_heuristic_profiling.json",
TD_HEURISTIC_PROFILING_FILE,
@ -182,7 +184,7 @@ def copy_additional_previous_failures() -> None:
def get_from_test_infra_generated_stats(
from_file: str, to_file: str, failure_explanation: str
) -> Dict[str, Any]:
) -> dict[str, Any]:
url = f"https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/{from_file}"
try:
return fetch_and_cache(

View File

@ -1,14 +1,17 @@
#!/usr/bin/env python3
from __future__ import annotations
import datetime
import json
import signal
import time
from typing import Any, Dict, List
from typing import Any
import psutil # type: ignore[import]
def get_processes_running_python_tests() -> List[Any]:
def get_processes_running_python_tests() -> list[Any]:
python_processes = []
for process in psutil.process_iter():
try:
@ -20,7 +23,7 @@ def get_processes_running_python_tests() -> List[Any]:
return python_processes
def get_per_process_cpu_info() -> List[Dict[str, Any]]:
def get_per_process_cpu_info() -> list[dict[str, Any]]:
processes = get_processes_running_python_tests()
per_process_info = []
for p in processes:
@ -49,7 +52,7 @@ def get_per_process_cpu_info() -> List[Dict[str, Any]]:
return per_process_info
def get_per_process_gpu_info(handle: Any) -> List[Dict[str, Any]]:
def get_per_process_gpu_info(handle: Any) -> list[dict[str, Any]]:
processes = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
per_process_info = []
for p in processes:
@ -58,7 +61,7 @@ def get_per_process_gpu_info(handle: Any) -> List[Dict[str, Any]]:
return per_process_info
def rocm_get_per_process_gpu_info(handle: Any) -> List[Dict[str, Any]]:
def rocm_get_per_process_gpu_info(handle: Any) -> list[dict[str, Any]]:
processes = amdsmi.amdsmi_get_gpu_process_list(handle)
per_process_info = []
for p in processes:

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import json
import os
import re
@ -6,7 +8,7 @@ from collections import defaultdict
from functools import lru_cache
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, cast, Dict, List
from typing import Any, cast
import requests
@ -18,6 +20,7 @@ from tools.stats.upload_stats_lib import (
upload_workflow_stats_to_s3,
)
REGEX_JOB_INFO = r"(.*) \/ .*test \(([^,]*), .*\)"
@ -56,7 +59,7 @@ def get_test_config(job_name: str) -> str:
def get_td_exclusions(
workflow_run_id: int, workflow_run_attempt: int
) -> Dict[str, Any]:
) -> dict[str, Any]:
with TemporaryDirectory() as temp_dir:
print("Using temporary directory:", temp_dir)
os.chdir(temp_dir)
@ -68,7 +71,7 @@ def get_td_exclusions(
for path in s3_paths:
unzip(path)
grouped_tests: Dict[str, Any] = defaultdict(lambda: defaultdict(set))
grouped_tests: dict[str, Any] = defaultdict(lambda: defaultdict(set))
for td_exclusions in Path(".").glob("**/td_exclusions*.json"):
with open(td_exclusions) as f:
exclusions = json.load(f)
@ -85,9 +88,9 @@ def get_td_exclusions(
return grouped_tests
def group_test_cases(test_cases: List[Dict[str, Any]]) -> Dict[str, Any]:
def group_test_cases(test_cases: list[dict[str, Any]]) -> dict[str, Any]:
start = time.time()
grouped_tests: Dict[str, Any] = defaultdict(
grouped_tests: dict[str, Any] = defaultdict(
lambda: defaultdict(
lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
)
@ -112,8 +115,8 @@ def group_test_cases(test_cases: List[Dict[str, Any]]) -> Dict[str, Any]:
return grouped_tests
def get_reruns(grouped_tests: Dict[str, Any]) -> Dict[str, Any]:
reruns: Dict[str, Any] = defaultdict(
def get_reruns(grouped_tests: dict[str, Any]) -> dict[str, Any]:
reruns: dict[str, Any] = defaultdict(
lambda: defaultdict(
lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
)
@ -136,8 +139,8 @@ def get_reruns(grouped_tests: Dict[str, Any]) -> Dict[str, Any]:
return reruns
def get_invoking_file_summary(grouped_tests: Dict[str, Any]) -> Dict[str, Any]:
invoking_file_summary: Dict[str, Any] = defaultdict(
def get_invoking_file_summary(grouped_tests: dict[str, Any]) -> dict[str, Any]:
invoking_file_summary: dict[str, Any] = defaultdict(
lambda: defaultdict(lambda: defaultdict(lambda: {"count": 0, "time": 0.0}))
)
for build_name, build in grouped_tests.items():
@ -157,7 +160,7 @@ def get_invoking_file_summary(grouped_tests: Dict[str, Any]) -> Dict[str, Any]:
def upload_additional_info(
workflow_run_id: int, workflow_run_attempt: int, test_cases: List[Dict[str, Any]]
workflow_run_id: int, workflow_run_attempt: int, test_cases: list[dict[str, Any]]
) -> None:
grouped_tests = group_test_cases(test_cases)
reruns = get_reruns(grouped_tests)

View File

@ -5,6 +5,7 @@ from tempfile import TemporaryDirectory
from tools.stats.upload_stats_lib import download_gha_artifacts, upload_file_to_s3
ARTIFACTS = [
"sccache-stats",
"test-jsons",

View File

@ -1,10 +1,12 @@
from __future__ import annotations
import argparse
import csv
import os
import re
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Dict, List
from typing import Any
from tools.stats.upload_stats_lib import download_s3_artifacts, unzip, upload_to_rockset
@ -23,7 +25,7 @@ def upload_dynamo_perf_stats_to_rockset(
workflow_run_attempt: int,
head_branch: str,
match_filename: str,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
match_filename_regex = re.compile(match_filename)
perf_stats = []
with TemporaryDirectory() as temp_dir:

View File

@ -1,16 +1,18 @@
from __future__ import annotations
import argparse
import datetime
import json
import os
import time
import urllib.parse
from typing import Any, Callable, cast, Dict, List, Optional, Set
from typing import Any, Callable, cast, Dict, List
from urllib.error import HTTPError
from urllib.request import Request, urlopen
from tools.stats.upload_stats_lib import upload_to_s3
FILTER_OUT_USERS = {
"pytorchmergebot",
"facebook-github-bot",
@ -23,9 +25,9 @@ FILTER_OUT_USERS = {
def _fetch_url(
url: str,
headers: Dict[str, str],
data: Optional[Dict[str, Any]] = None,
method: Optional[str] = None,
headers: dict[str, str],
data: dict[str, Any] | None = None,
method: str | None = None,
reader: Callable[[Any], Any] = lambda x: x.read(),
) -> Any:
token = os.environ.get("GITHUB_TOKEN")
@ -49,9 +51,9 @@ def _fetch_url(
def fetch_json(
url: str,
params: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
params: dict[str, Any] | None = None,
data: dict[str, Any] | None = None,
) -> list[dict[str, Any]]:
headers = {"Accept": "application/vnd.github.v3+json"}
if params is not None and len(params) > 0:
url += "?" + "&".join(
@ -65,16 +67,16 @@ def fetch_json(
def get_external_pr_data(
start_date: datetime.date, end_date: datetime.date, period_length: int = 1
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
pr_info = []
period_begin_date = start_date
pr_count = 0
users: Set[str] = set()
users: set[str] = set()
while period_begin_date < end_date:
period_end_date = period_begin_date + datetime.timedelta(days=period_length - 1)
page = 1
responses: List[Dict[str, Any]] = []
responses: list[dict[str, Any]] = []
while len(responses) > 0 or page == 1:
response = cast(
Dict[str, Any],

View File

@ -1,13 +1,15 @@
from __future__ import annotations
import datetime
import inspect
import os
import time
import uuid
from decimal import Decimal
from typing import Any, Dict
from typing import Any
from warnings import warn
# boto3 is an optional dependency. If it's not installed,
# we'll just not emit the metrics.
# Keeping this logic here so that callers don't have to
@ -65,7 +67,7 @@ class EnvVarMetric:
return value
global_metrics: Dict[str, Any] = {}
global_metrics: dict[str, Any] = {}
def add_global_metric(metric_name: str, metric_value: Any) -> None:
@ -79,7 +81,7 @@ def add_global_metric(metric_name: str, metric_value: Any) -> None:
def emit_metric(
metric_name: str,
metrics: Dict[str, Any],
metrics: dict[str, Any],
) -> None:
"""
Upload a metric to DynamoDB (and from there, Rockset).
@ -174,7 +176,7 @@ def emit_metric(
print(f"Not emitting metrics for {metric_name}. Boto wasn't imported.")
def _convert_float_values_to_decimals(data: Dict[str, Any]) -> Dict[str, Any]:
def _convert_float_values_to_decimals(data: dict[str, Any]) -> dict[str, Any]:
# Attempt to recurse
def _helper(o: Any) -> Any:
if isinstance(o, float):

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import argparse
import json
import os
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Dict, List
from typing import Any
from tools.stats.upload_stats_lib import (
download_s3_artifacts,
@ -13,7 +15,7 @@ from tools.stats.upload_stats_lib import (
def get_sccache_stats(
workflow_run_id: int, workflow_run_attempt: int
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
with TemporaryDirectory() as temp_dir:
print("Using temporary directory:", temp_dir)
os.chdir(temp_dir)

View File

@ -1,16 +1,18 @@
from __future__ import annotations
import gzip
import io
import json
import os
import zipfile
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any
import boto3 # type: ignore[import]
import requests
import rockset # type: ignore[import]
PYTORCH_REPO = "https://api.github.com/repos/pytorch/pytorch"
S3_RESOURCE = boto3.resource("s3")
@ -21,14 +23,14 @@ MAX_RETRY_IN_NON_DISABLED_MODE = 3 * 3
BATCH_SIZE = 5000
def _get_request_headers() -> Dict[str, str]:
def _get_request_headers() -> dict[str, str]:
return {
"Accept": "application/vnd.github.v3+json",
"Authorization": "token " + os.environ["GITHUB_TOKEN"],
}
def _get_artifact_urls(prefix: str, workflow_run_id: int) -> Dict[Path, str]:
def _get_artifact_urls(prefix: str, workflow_run_id: int) -> dict[Path, str]:
"""Get all workflow artifacts with 'test-report' in the name."""
response = requests.get(
f"{PYTORCH_REPO}/actions/runs/{workflow_run_id}/artifacts?per_page=100",
@ -78,7 +80,7 @@ def _download_artifact(
def download_s3_artifacts(
prefix: str, workflow_run_id: int, workflow_run_attempt: int
) -> List[Path]:
) -> list[Path]:
bucket = S3_RESOURCE.Bucket("gha-artifacts")
objs = bucket.objects.filter(
Prefix=f"pytorch/pytorch/{workflow_run_id}/{workflow_run_attempt}/artifact/{prefix}"
@ -104,7 +106,7 @@ def download_s3_artifacts(
def download_gha_artifacts(
prefix: str, workflow_run_id: int, workflow_run_attempt: int
) -> List[Path]:
) -> list[Path]:
artifact_urls = _get_artifact_urls(prefix, workflow_run_id)
paths = []
for name, url in artifact_urls.items():
@ -114,7 +116,7 @@ def download_gha_artifacts(
def upload_to_rockset(
collection: str,
docs: List[Any],
docs: list[Any],
workspace: str = "commons",
client: Any = None,
) -> None:
@ -142,7 +144,7 @@ def upload_to_rockset(
def upload_to_s3(
bucket_name: str,
key: str,
docs: List[Dict[str, Any]],
docs: list[dict[str, Any]],
) -> None:
print(f"Writing {len(docs)} documents to S3")
body = io.StringIO()
@ -164,7 +166,7 @@ def upload_to_s3(
def read_from_s3(
bucket_name: str,
key: str,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
print(f"Reading from s3://{bucket_name}/{key}")
body = (
S3_RESOURCE.Object(
@ -182,7 +184,7 @@ def upload_workflow_stats_to_s3(
workflow_run_id: int,
workflow_run_attempt: int,
collection: str,
docs: List[Dict[str, Any]],
docs: list[dict[str, Any]],
) -> None:
bucket_name = "ossci-raw-job-status"
key = f"{collection}/{workflow_run_id}/{workflow_run_attempt}"
@ -220,7 +222,7 @@ def unzip(p: Path) -> None:
zip.extractall(unzipped_dir)
def is_rerun_disabled_tests(tests: Dict[str, Dict[str, int]]) -> bool:
def is_rerun_disabled_tests(tests: dict[str, dict[str, int]]) -> bool:
"""
Check if the test report is coming from rerun_disabled_tests workflow where
each test is run multiple times
@ -231,7 +233,7 @@ def is_rerun_disabled_tests(tests: Dict[str, Dict[str, int]]) -> bool:
)
def get_job_id(report: Path) -> Optional[int]:
def get_job_id(report: Path) -> int | None:
# [Job id in artifacts]
# Retrieve the job id from the report path. In our GHA workflows, we append
# the job id to the end of the report name, so `report` looks like:

View File

@ -1,17 +1,19 @@
from __future__ import annotations
import argparse
import ast
import datetime
import json
import os
import re
from typing import Any, List, Union
from typing import Any
import rockset # type: ignore[import]
from tools.stats.upload_stats_lib import upload_to_s3
def get_oncall_from_testfile(testfile: str) -> Union[List[str], None]:
def get_oncall_from_testfile(testfile: str) -> list[str] | None:
path = f"test/{testfile}"
if not path.endswith(".py"):
path += ".py"

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import argparse
import os
import sys
@ -5,7 +7,7 @@ import xml.etree.ElementTree as ET
from multiprocessing import cpu_count, Pool
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Dict, List
from typing import Any
from tools.stats.test_dashboard import upload_additional_info
from tools.stats.upload_stats_lib import (
@ -21,14 +23,14 @@ def parse_xml_report(
report: Path,
workflow_id: int,
workflow_run_attempt: int,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""Convert a test report xml file into a JSON-serializable list of test cases."""
print(f"Parsing {tag}s for test report: {report}")
job_id = get_job_id(report)
print(f"Found job id: {job_id}")
test_cases: List[Dict[str, Any]] = []
test_cases: list[dict[str, Any]] = []
root = ET.parse(report)
for test_case in root.iter(tag):
@ -53,9 +55,9 @@ def parse_xml_report(
return test_cases
def process_xml_element(element: ET.Element) -> Dict[str, Any]:
def process_xml_element(element: ET.Element) -> dict[str, Any]:
"""Convert a test suite element into a JSON-serializable dict."""
ret: Dict[str, Any] = {}
ret: dict[str, Any] = {}
# Convert attributes directly into dict elements.
# e.g.
@ -110,7 +112,7 @@ def process_xml_element(element: ET.Element) -> Dict[str, Any]:
return ret
def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> List[Dict[str, Any]]:
def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> list[dict[str, Any]]:
with TemporaryDirectory() as temp_dir:
print("Using temporary directory:", temp_dir)
os.chdir(temp_dir)
@ -146,7 +148,7 @@ def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> List[Dict[str,
def get_tests_for_circleci(
workflow_run_id: int, workflow_run_attempt: int
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
# Parse the reports and transform them to JSON
test_cases = []
for xml_report in Path(".").glob("**/test/test-reports/**/*.xml"):
@ -159,13 +161,13 @@ def get_tests_for_circleci(
return test_cases
def summarize_test_cases(test_cases: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
def summarize_test_cases(test_cases: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Group test cases by classname, file, and job_id. We perform the aggregation
manually instead of using the `test-suite` XML tag because xmlrunner does
not produce reliable output for it.
"""
def get_key(test_case: Dict[str, Any]) -> Any:
def get_key(test_case: dict[str, Any]) -> Any:
return (
test_case.get("file"),
test_case.get("classname"),
@ -176,7 +178,7 @@ def summarize_test_cases(test_cases: List[Dict[str, Any]]) -> List[Dict[str, Any
test_case["invoking_file"],
)
def init_value(test_case: Dict[str, Any]) -> Dict[str, Any]:
def init_value(test_case: dict[str, Any]) -> dict[str, Any]:
return {
"file": test_case.get("file"),
"classname": test_case.get("classname"),

View File

@ -4,6 +4,7 @@ import sys
from tools.stats.test_dashboard import upload_additional_info
from tools.stats.upload_test_stats import get_tests
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Upload test stats to Rockset")
parser.add_argument(

View File

@ -5,7 +5,6 @@ import argparse
import json
import unittest
from collections import defaultdict
from unittest.mock import Mock, patch
from gen_operators_yaml import (
@ -43,10 +42,10 @@ def _mock_load_op_dep_graph():
class GenOperatorsYAMLTest(unittest.TestCase):
def setUp(self):
def setUp(self) -> None:
pass
def test_filter_creation(self):
def test_filter_creation(self) -> None:
filter_func = make_filter_from_options(
model_name="abc",
model_versions=["100", "101"],
@ -99,7 +98,7 @@ class GenOperatorsYAMLTest(unittest.TestCase):
len(filtered_configs) == 2
), f"Expected 2 elements in filtered_configs, but got {len(filtered_configs)}"
def test_verification_success(self):
def test_verification_success(self) -> None:
filter_func = make_filter_from_options(
model_name="abc",
model_versions=["100", "101"],
@ -142,7 +141,7 @@ class GenOperatorsYAMLTest(unittest.TestCase):
"expected verify_all_specified_present to succeed instead it raised an exception"
)
def test_verification_fail(self):
def test_verification_fail(self) -> None:
config = [
{
"model": {
@ -229,7 +228,7 @@ class GenOperatorsYAMLTest(unittest.TestCase):
)
def test_fill_output_with_arguments_not_include_all_overloads(
self, mock_parse_options: Mock, mock_load_op_dep_graph: Mock
):
) -> None:
parser = argparse.ArgumentParser(description="Generate used operators YAML")
options = get_parser_options(parser)

View File

@ -8,10 +8,10 @@ from tools.code_analyzer.gen_oplist import throw_if_any_op_includes_overloads
class GenOplistTest(unittest.TestCase):
def setUp(self):
def setUp(self) -> None:
pass
def test_throw_if_any_op_includes_overloads(self):
def test_throw_if_any_op_includes_overloads(self) -> None:
selective_builder = MagicMock()
selective_builder.operators = MagicMock()
selective_builder.operators.items.return_value = [

View File

@ -1,10 +1,12 @@
# For testing specific heuristics
from __future__ import annotations
import io
import json
import pathlib
import sys
import unittest
from typing import Any, Dict, List, Set
from typing import Any
from unittest import mock
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
@ -28,14 +30,14 @@ sys.path.remove(str(REPO_ROOT))
HEURISTIC_CLASS = "tools.testing.target_determination.heuristics.historical_class_failure_correlation."
def mocked_file(contents: Dict[Any, Any]) -> io.IOBase:
def mocked_file(contents: dict[Any, Any]) -> io.IOBase:
file_object = io.StringIO()
json.dump(contents, file_object)
file_object.seek(0)
return file_object
def gen_historical_class_failures() -> Dict[str, Dict[str, float]]:
def gen_historical_class_failures() -> dict[str, dict[str, float]]:
return {
"file1": {
"test1::classA": 0.5,
@ -80,8 +82,8 @@ class TestHistoricalClassFailureCorrelation(TestTD):
)
def test_get_prediction_confidence(
self,
historical_class_failures: Dict[str, Dict[str, float]],
changed_files: List[str],
historical_class_failures: dict[str, dict[str, float]],
changed_files: list[str],
) -> None:
tests_to_prioritize = ALL_TESTS
@ -113,7 +115,7 @@ class TestHistoricalClassFailureCorrelation(TestTD):
class TestParsePrevTests(TestTD):
@mock.patch("os.path.exists", return_value=False)
def test_cache_does_not_exist(self, mock_exists: Any) -> None:
expected_failing_test_files: Set[str] = set()
expected_failing_test_files: set[str] = set()
found_tests = get_previous_failures()
@ -122,7 +124,7 @@ class TestParsePrevTests(TestTD):
@mock.patch("os.path.exists", return_value=True)
@mock.patch("builtins.open", return_value=mocked_file({"": True}))
def test_empty_cache(self, mock_exists: Any, mock_open: Any) -> None:
expected_failing_test_files: Set[str] = set()
expected_failing_test_files: set[str] = set()
found_tests = get_previous_failures()

View File

@ -1,7 +1,9 @@
from __future__ import annotations
import pathlib
import sys
import unittest
from typing import Any, Dict, List
from typing import Any
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
sys.path.append(str(REPO_ROOT))
@ -13,7 +15,7 @@ sys.path.remove(str(REPO_ROOT))
class TestTD(unittest.TestCase):
def assert_test_scores_almost_equal(
self, d1: Dict[TestRun, float], d2: Dict[TestRun, float]
self, d1: dict[TestRun, float], d2: dict[TestRun, float]
) -> None:
# Check that dictionaries are the same, except for floating point errors
self.assertEqual(set(d1.keys()), set(d2.keys()))
@ -24,7 +26,7 @@ class TestTD(unittest.TestCase):
# Create a dummy heuristic class
class Heuristic(interface.HeuristicInterface):
def get_prediction_confidence(
self, tests: List[str]
self, tests: list[str]
) -> interface.TestPrioritizations:
# Return junk
return interface.TestPrioritizations([], {})
@ -259,9 +261,9 @@ class TestTestPrioritizations(TestTD):
class TestAggregatedHeuristics(TestTD):
def check(
self,
tests: List[str],
test_prioritizations: List[Dict[TestRun, float]],
expected: Dict[TestRun, float],
tests: list[str],
test_prioritizations: list[dict[TestRun, float]],
expected: dict[TestRun, float],
) -> None:
aggregated_heuristics = interface.AggregatedHeuristics(tests)
for i, test_prioritization in enumerate(test_prioritizations):
@ -429,7 +431,7 @@ class TestAggregatedHeuristicsTestStats(TestTD):
stats3 = aggregator.get_test_stats(TestRun("test3"))
stats5 = aggregator.get_test_stats(TestRun("test5::classA"))
def assert_valid_dict(dict_contents: Dict[str, Any]) -> None:
def assert_valid_dict(dict_contents: dict[str, Any]) -> None:
for key, value in dict_contents.items():
self.assertTrue(isinstance(key, str))
self.assertTrue(

View File

@ -1,7 +1,9 @@
from __future__ import annotations
import pathlib
import sys
import unittest
from typing import Any, Dict
from typing import Any
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
@ -14,14 +16,14 @@ sys.path.remove(str(REPO_ROOT))
class TestHeuristicsUtils(unittest.TestCase):
def assertDictAlmostEqual(
self, first: Dict[TestRun, Any], second: Dict[TestRun, Any]
self, first: dict[TestRun, Any], second: dict[TestRun, Any]
) -> None:
self.assertEqual(first.keys(), second.keys())
for key in first.keys():
self.assertAlmostEqual(first[key], second[key])
def test_normalize_ratings(self) -> None:
ratings: Dict[TestRun, float] = {
ratings: dict[TestRun, float] = {
TestRun("test1"): 1,
TestRun("test2"): 2,
TestRun("test3"): 4,

View File

@ -1,12 +1,13 @@
from __future__ import annotations
import contextlib
import os
import typing
import unittest
import unittest.mock
from typing import Iterator, Optional, Sequence
from typing import Iterator, Sequence
import tools.setup_helpers.cmake
import tools.setup_helpers.env # noqa: F401 unused but resolves circular import
@ -79,7 +80,7 @@ class TestCMake(unittest.TestCase):
@contextlib.contextmanager
def env_var(key: str, value: Optional[str]) -> Iterator[None]:
def env_var(key: str, value: str | None) -> Iterator[None]:
"""Sets/clears an environment variable within a Python context."""
# Get the previous value and then override it.
previous_value = os.environ.get(key)
@ -91,7 +92,7 @@ def env_var(key: str, value: Optional[str]) -> Iterator[None]:
set_env_var(key, previous_value)
def set_env_var(key: str, value: Optional[str]) -> None:
def set_env_var(key: str, value: str | None) -> None:
"""Sets/clears an environment variable."""
if value is None:
os.environ.pop(key, None)

View File

@ -1,14 +1,13 @@
from __future__ import annotations
import dataclasses
import typing
import unittest
from collections import defaultdict
from typing import Dict, List
import yaml
from tools.autograd import gen_autograd_functions, load_derivatives
import torchgen.model
from torchgen import dest
from torchgen.api.types import CppSignatureGroup, DispatcherSignature
from torchgen.context import native_function_manager
@ -22,6 +21,7 @@ from torchgen.model import (
BackendIndex,
BackendMetadata,
DispatchKey,
FunctionSchema,
Location,
NativeFunction,
OperatorName,
@ -32,7 +32,7 @@ from torchgen.selective_build.selector import SelectiveBuilder
class TestCreateDerivative(unittest.TestCase):
def test_named_grads(self) -> None:
schema = torchgen.model.FunctionSchema.parse(
schema = FunctionSchema.parse(
"func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
@ -47,7 +47,7 @@ class TestCreateDerivative(unittest.TestCase):
def test_non_differentiable_output(self) -> None:
specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
schema = torchgen.model.FunctionSchema.parse(specification)
schema = FunctionSchema.parse(specification)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
_, differentiability_info = load_derivatives.create_differentiability_info(
@ -69,7 +69,7 @@ class TestCreateDerivative(unittest.TestCase):
)
def test_indexed_grads(self) -> None:
schema = torchgen.model.FunctionSchema.parse(
schema = FunctionSchema.parse(
"func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
@ -84,7 +84,7 @@ class TestCreateDerivative(unittest.TestCase):
def test_named_grads_and_indexed_grads(self) -> None:
specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
schema = torchgen.model.FunctionSchema.parse(specification)
schema = FunctionSchema.parse(specification)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
with self.assertRaisesRegex(
@ -112,7 +112,7 @@ class TestCreateDerivative(unittest.TestCase):
class TestGenAutogradFunctions(unittest.TestCase):
def test_non_differentiable_output_invalid_type(self) -> None:
specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
schema = torchgen.model.FunctionSchema.parse(specification)
schema = FunctionSchema.parse(specification)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
_, differentiability_info = load_derivatives.create_differentiability_info(
@ -141,7 +141,7 @@ class TestGenAutogradFunctions(unittest.TestCase):
def test_non_differentiable_output_output_differentiability(self) -> None:
specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y, Tensor z)"
schema = torchgen.model.FunctionSchema.parse(specification)
schema = FunctionSchema.parse(specification)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
_, differentiability_info = load_derivatives.create_differentiability_info(
@ -182,7 +182,7 @@ class TestGenAutogradFunctions(unittest.TestCase):
def test_register_bogus_dispatch_key(self) -> None:
specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
schema = torchgen.model.FunctionSchema.parse(specification)
schema = FunctionSchema.parse(specification)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
with self.assertRaisesRegex(
@ -213,17 +213,17 @@ class TestGenAutogradFunctions(unittest.TestCase):
class TestGenSchemaRegistration(unittest.TestCase):
def setUp(self) -> None:
self.selector = SelectiveBuilder.get_nop_selector()
self.custom_native_function, _ = torchgen.model.NativeFunction.from_yaml(
self.custom_native_function, _ = NativeFunction.from_yaml(
{"func": "custom::func() -> bool"},
loc=torchgen.model.Location(__file__, 1),
loc=Location(__file__, 1),
valid_tags=set(),
)
(
self.fragment_custom_native_function,
_,
) = torchgen.model.NativeFunction.from_yaml(
) = NativeFunction.from_yaml(
{"func": "quantized_decomposed::func() -> bool"},
loc=torchgen.model.Location(__file__, 1),
loc=Location(__file__, 1),
valid_tags=set(),
)
@ -285,9 +285,9 @@ TORCH_LIBRARY(custom, m) {
)
def test_3_namespaces_schema_registration_code_valid(self) -> None:
custom2_native_function, _ = torchgen.model.NativeFunction.from_yaml(
custom2_native_function, _ = NativeFunction.from_yaml(
{"func": "custom2::func() -> bool"},
loc=torchgen.model.Location(__file__, 1),
loc=Location(__file__, 1),
valid_tags=set(),
)
(
@ -320,7 +320,7 @@ class TestGenNativeFunctionDeclaration(unittest.TestCase):
def setUp(self) -> None:
self.op_1_native_function, op_1_backend_index = NativeFunction.from_yaml(
{"func": "op_1() -> bool", "dispatch": {"CPU": "kernel_1"}},
loc=torchgen.model.Location(__file__, 1),
loc=Location(__file__, 1),
valid_tags=set(),
)
self.op_2_native_function, op_2_backend_index = NativeFunction.from_yaml(
@ -328,11 +328,11 @@ class TestGenNativeFunctionDeclaration(unittest.TestCase):
"func": "op_2() -> bool",
"dispatch": {"CPU": "kernel_2", "QuantizedCPU": "custom::kernel_3"},
},
loc=torchgen.model.Location(__file__, 1),
loc=Location(__file__, 1),
valid_tags=set(),
)
backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = {
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = {
DispatchKey.CPU: {},
DispatchKey.QuantizedCPU: {},
}
@ -382,9 +382,9 @@ TORCH_API bool kernel_1();
# Test for native_function_generation
class TestNativeFunctionGeneratrion(unittest.TestCase):
def setUp(self) -> None:
self.native_functions: List[NativeFunction] = []
self.backend_indices: Dict[
DispatchKey, Dict[OperatorName, BackendMetadata]
self.native_functions: list[NativeFunction] = []
self.backend_indices: dict[
DispatchKey, dict[OperatorName, BackendMetadata]
] = defaultdict(dict)
yaml_entry = """
- func: op(Tensor self) -> Tensor
@ -405,7 +405,7 @@ class TestNativeFunctionGeneratrion(unittest.TestCase):
"dispatch": {"CPU": "kernel_1"},
"autogen": "op_2.out",
},
loc=torchgen.model.Location(__file__, 1),
loc=Location(__file__, 1),
valid_tags=set(),
)
BackendIndex.grow_index(self.backend_indices, two_returns_backend_index)
@ -442,8 +442,8 @@ class TestNativeFunctionGeneratrion(unittest.TestCase):
# Test for static_dispatch
class TestStaticDispatchGeneratrion(unittest.TestCase):
def setUp(self) -> None:
self.backend_indices: Dict[
DispatchKey, Dict[OperatorName, BackendMetadata]
self.backend_indices: dict[
DispatchKey, dict[OperatorName, BackendMetadata]
] = defaultdict(dict)
yaml_entry = """
- func: op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
@ -500,9 +500,9 @@ class TestStaticDispatchGeneratrion(unittest.TestCase):
# Represents the most basic NativeFunction. Use dataclasses.replace()
# to edit for use.
DEFAULT_NATIVE_FUNCTION, _ = torchgen.model.NativeFunction.from_yaml(
DEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml(
{"func": "func() -> bool"},
loc=torchgen.model.Location(__file__, 1),
loc=Location(__file__, 1),
valid_tags=set(),
)

View File

@ -1,4 +1,6 @@
from typing import Any, List
from __future__ import annotations
from typing import Any
from unittest import main, TestCase
from tools.alerts.create_alerts import filter_job_names, JobStatus
@ -38,7 +40,7 @@ MOCK_TEST_DATA = [
class TestGitHubPR(TestCase):
# Should fail when jobs are ? ? Fail Fail
def test_alert(self) -> None:
modified_data: List[Any] = [{}]
modified_data: list[Any] = [{}]
modified_data.append({})
modified_data.extend(MOCK_TEST_DATA)
status = JobStatus(JOB_NAME, modified_data)

View File

@ -1,6 +1,8 @@
from __future__ import annotations
import tempfile
import unittest
from typing import Any, Dict
from typing import Any
from unittest.mock import ANY, Mock, patch
import expecttest
@ -13,10 +15,11 @@ from torchgen.model import Location, NativeFunction
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import FileManager
SPACES = " "
def _get_native_function_from_yaml(yaml_obj: Dict[str, object]) -> NativeFunction:
def _get_native_function_from_yaml(yaml_obj: dict[str, object]) -> NativeFunction:
native_function, _ = NativeFunction.from_yaml(
yaml_obj,
loc=Location(__file__, 1),
@ -33,7 +36,7 @@ class TestComputeNativeFunctionStub(expecttest.TestCase):
"""
def _test_function_schema_generates_correct_kernel(
self, obj: Dict[str, Any], expected: str
self, obj: dict[str, Any], expected: str
) -> None:
func = _get_native_function_from_yaml(obj)

View File

@ -1,13 +1,13 @@
from __future__ import annotations
import os
import tempfile
import unittest
from typing import Dict
import yaml
from torchgen.executorch.model import ETKernelIndex, ETKernelKey
from torchgen.gen import LineLoader
from torchgen.gen_executorch import (
ComputeCodegenUnboxedKernels,
gen_functions_declarations,
@ -24,6 +24,7 @@ from torchgen.model import (
)
from torchgen.selective_build.selector import SelectiveBuilder
TEST_YAML = """
- func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -345,7 +346,7 @@ class TestGenFunctionsDeclarations(unittest.TestCase):
valid_tags=set(),
)
backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = {
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = {
DispatchKey.CPU: {},
DispatchKey.QuantizedCPU: {},
}

View File

@ -4,6 +4,7 @@ from torchgen.executorch.api.types import ExecutorchCppSignature
from torchgen.local import parametrize
from torchgen.model import Location, NativeFunction
DEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml(
{"func": "foo.out(Tensor input, *, Tensor(a!) out) -> Tensor(a!)"},
loc=Location(__file__, 1),

View File

@ -1,9 +1,10 @@
# Owner(s): ["module: codegen"]
from __future__ import annotations
import os
import tempfile
import unittest
from typing import Optional
import expecttest
@ -29,7 +30,7 @@ class TestGenBackendStubs(expecttest.TestCase):
run(fp.name, "", True)
def get_errors_from_gen_backend_stubs(
self, yaml_str: str, *, kernels_str: Optional[str] = None
self, yaml_str: str, *, kernels_str: str | None = None
) -> str:
with tempfile.NamedTemporaryFile(mode="w") as fp:
fp.write(yaml_str)

View File

@ -1,7 +1,7 @@
import unittest
from torchgen.selective_build.operator import * # noqa: F403
from torchgen.model import Location, NativeFunction
from torchgen.selective_build.operator import * # noqa: F403
from torchgen.selective_build.selector import (
combine_selective_builders,
SelectiveBuilder,
@ -9,7 +9,7 @@ from torchgen.selective_build.selector import (
class TestSelectiveBuild(unittest.TestCase):
def test_selective_build_operator(self):
def test_selective_build_operator(self) -> None:
op = SelectiveBuildOperator(
"aten::add.int",
is_root_operator=True,
@ -21,7 +21,7 @@ class TestSelectiveBuild(unittest.TestCase):
self.assertFalse(op.is_used_for_training)
self.assertFalse(op.include_all_overloads)
def test_selector_factory(self):
def test_selector_factory(self) -> None:
yaml_config_v1 = """
debug_info:
- model1@v100
@ -132,7 +132,7 @@ operators:
selector_legacy_v1.is_operator_selected_for_training("aten::add.float")
)
def test_operator_combine(self):
def test_operator_combine(self) -> None:
op1 = SelectiveBuildOperator(
"aten::add.int",
is_root_operator=True,
@ -177,7 +177,7 @@ operators:
self.assertRaises(Exception, gen_new_op)
def test_training_op_fetch(self):
def test_training_op_fetch(self) -> None:
yaml_config = """
operators:
aten::add.int:
@ -194,7 +194,7 @@ operators:
self.assertTrue(selector.is_operator_selected_for_training("aten::add.int"))
self.assertTrue(selector.is_operator_selected_for_training("aten::add"))
def test_kernel_dtypes(self):
def test_kernel_dtypes(self) -> None:
yaml_config = """
kernel_metadata:
add_kernel:
@ -221,7 +221,7 @@ kernel_metadata:
self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16"))
self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32"))
def test_merge_kernel_dtypes(self):
def test_merge_kernel_dtypes(self) -> None:
yaml_config1 = """
kernel_metadata:
add_kernel:
@ -266,7 +266,7 @@ kernel_metadata:
self.assertTrue(selector.is_kernel_dtype_selected("mul_kernel", "int8"))
self.assertFalse(selector.is_kernel_dtype_selected("mul_kernel", "int32"))
def test_all_kernel_dtypes_selected(self):
def test_all_kernel_dtypes_selected(self) -> None:
yaml_config = """
include_all_non_op_selectives: True
"""
@ -279,7 +279,7 @@ include_all_non_op_selectives: True
self.assertTrue(selector.is_kernel_dtype_selected("add1_kernel", "int32"))
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "float"))
def test_custom_namespace_selected_correctly(self):
def test_custom_namespace_selected_correctly(self) -> None:
yaml_config = """
operators:
aten::add.int:
@ -301,7 +301,7 @@ operators:
class TestExecuTorchSelectiveBuild(unittest.TestCase):
def test_et_kernel_selected(self):
def test_et_kernel_selected(self) -> None:
yaml_config = """
et_kernel_metadata:
aten::add.out:

Some files were not shown because too many files have changed in this diff Show More