mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[BE][Easy] enable postponed annotations in tools
(#129375)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129375 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
58f346c874
commit
8a67daf283
@ -1,16 +1,19 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from difflib import SequenceMatcher
|
||||
from typing import Any, Dict, List, Set, Tuple
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from setuptools import distutils # type: ignore[import]
|
||||
|
||||
|
||||
ALL_SKIPPED_THRESHOLD = 100
|
||||
SIMILARITY_THRESHOLD = 0.75
|
||||
FAILURE_CHAIN_THRESHOLD = 2
|
||||
@ -65,14 +68,14 @@ DISABLED_ALERTS = [
|
||||
|
||||
class JobStatus:
|
||||
job_name: str = ""
|
||||
jobs: List[Any] = []
|
||||
jobs: list[Any] = []
|
||||
current_status: Any = None
|
||||
job_statuses: List[Any] = []
|
||||
filtered_statuses: List[Any] = []
|
||||
failure_chain: List[Any] = []
|
||||
flaky_jobs: List[Any] = []
|
||||
job_statuses: list[Any] = []
|
||||
filtered_statuses: list[Any] = []
|
||||
failure_chain: list[Any] = []
|
||||
flaky_jobs: list[Any] = []
|
||||
|
||||
def __init__(self, job_name: str, job_statuses: List[Any]):
|
||||
def __init__(self, job_name: str, job_statuses: list[Any]) -> None:
|
||||
self.job_name = job_name
|
||||
self.job_statuses = job_statuses
|
||||
|
||||
@ -93,7 +96,7 @@ class JobStatus:
|
||||
return status
|
||||
return None
|
||||
|
||||
def get_unique_failures(self, jobs: List[Any]) -> Dict[str, List[Any]]:
|
||||
def get_unique_failures(self, jobs: list[Any]) -> dict[str, list[Any]]:
|
||||
"""
|
||||
Returns list of jobs grouped by failureCaptures from the input list
|
||||
"""
|
||||
@ -120,7 +123,7 @@ class JobStatus:
|
||||
return failures
|
||||
|
||||
# A flaky job is if it's the only job that has that failureCapture and is not the most recent job
|
||||
def get_flaky_jobs(self) -> List[Any]:
|
||||
def get_flaky_jobs(self) -> list[Any]:
|
||||
unique_failures = self.get_unique_failures(self.filtered_statuses)
|
||||
flaky_jobs = []
|
||||
for failure in unique_failures:
|
||||
@ -134,7 +137,7 @@ class JobStatus:
|
||||
|
||||
# The most recent failure chain is an array of jobs that have the same-ish failures.
|
||||
# A success in the middle of the chain will terminate the chain.
|
||||
def get_most_recent_failure_chain(self) -> List[Any]:
|
||||
def get_most_recent_failure_chain(self) -> list[Any]:
|
||||
failures = []
|
||||
found_most_recent_failure = False
|
||||
|
||||
@ -178,7 +181,7 @@ def fetch_hud_data(repo: str, branch: str) -> Any:
|
||||
|
||||
|
||||
# Creates a Dict of Job Name -> [JobData]. Essentially a Column in HUD
|
||||
def map_job_data(jobNames: Any, shaGrid: Any) -> Dict[str, Any]:
|
||||
def map_job_data(jobNames: Any, shaGrid: Any) -> dict[str, Any]:
|
||||
jobData = defaultdict(list)
|
||||
for sha in shaGrid:
|
||||
for ind, job in enumerate(sha["jobs"]):
|
||||
@ -196,13 +199,13 @@ def is_job_skipped(job: Any) -> bool:
|
||||
return conclusion in (NEUTRAL, SKIPPED) or conclusion is None
|
||||
|
||||
|
||||
def get_failed_jobs(job_data: List[Any]) -> List[Any]:
|
||||
def get_failed_jobs(job_data: list[Any]) -> list[Any]:
|
||||
return [job for job in job_data if job["conclusion"] == "failure"]
|
||||
|
||||
|
||||
def classify_jobs(
|
||||
all_job_names: List[str], sha_grid: Any, filtered_jobs_names: Set[str]
|
||||
) -> Tuple[List[JobStatus], List[Any]]:
|
||||
all_job_names: list[str], sha_grid: Any, filtered_jobs_names: set[str]
|
||||
) -> tuple[list[JobStatus], list[Any]]:
|
||||
"""
|
||||
Creates Job Statuses which has the logic for if need to alert or if there's flaky jobs.
|
||||
Classifies jobs into jobs to alert on and flaky jobs.
|
||||
@ -212,7 +215,7 @@ def classify_jobs(
|
||||
:return:
|
||||
"""
|
||||
job_data = map_job_data(all_job_names, sha_grid)
|
||||
job_statuses: List[JobStatus] = []
|
||||
job_statuses: list[JobStatus] = []
|
||||
for job in job_data:
|
||||
job_statuses.append(JobStatus(job, job_data[job]))
|
||||
|
||||
@ -230,7 +233,7 @@ def classify_jobs(
|
||||
|
||||
|
||||
# filter job names that don't match the regex
|
||||
def filter_job_names(job_names: List[str], job_name_regex: str) -> List[str]:
|
||||
def filter_job_names(job_names: list[str], job_name_regex: str) -> list[str]:
|
||||
if job_name_regex:
|
||||
return [
|
||||
job_name for job_name in job_names if re.match(job_name_regex, job_name)
|
||||
@ -240,7 +243,7 @@ def filter_job_names(job_names: List[str], job_name_regex: str) -> List[str]:
|
||||
|
||||
def get_recurrently_failing_jobs_alerts(
|
||||
repo: str, branch: str, job_name_regex: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
job_names, sha_grid = fetch_hud_data(repo=repo, branch=branch)
|
||||
|
||||
filtered_job_names = set(filter_job_names(job_names, job_name_regex))
|
||||
|
@ -14,18 +14,17 @@ generated. In the full build system, OUTPUT_DIR is
|
||||
torch/testing/_internal/generated
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import textwrap
|
||||
from collections import defaultdict
|
||||
|
||||
from typing import Any, Dict, List, Sequence
|
||||
from typing import Any, Sequence, TYPE_CHECKING
|
||||
|
||||
import torchgen.api.python as python
|
||||
from torchgen.context import with_native_function
|
||||
|
||||
from torchgen.gen import parse_native_yaml
|
||||
from torchgen.model import Argument, BaseOperatorName, NativeFunction
|
||||
from torchgen.utils import FileManager
|
||||
|
||||
from .gen_python_functions import (
|
||||
@ -39,6 +38,10 @@ from .gen_python_functions import (
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torchgen.model import Argument, BaseOperatorName, NativeFunction
|
||||
|
||||
|
||||
def gen_annotated(
|
||||
native_yaml_path: str, tags_yaml_path: str, out: str, autograd_dir: str
|
||||
) -> None:
|
||||
@ -53,9 +56,9 @@ def gen_annotated(
|
||||
(is_py_fft_function, "torch._C._fft"),
|
||||
(is_py_variable_method, "torch.Tensor"),
|
||||
)
|
||||
annotated_args: List[str] = []
|
||||
annotated_args: list[str] = []
|
||||
for pred, namespace in mappings:
|
||||
groups: Dict[BaseOperatorName, List[NativeFunction]] = defaultdict(list)
|
||||
groups: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list)
|
||||
for f in native_functions:
|
||||
if not should_generate_py_binding(f) or not pred(f):
|
||||
continue
|
||||
@ -77,7 +80,7 @@ def gen_annotated(
|
||||
|
||||
@with_native_function
|
||||
def gen_annotated_args(f: NativeFunction) -> str:
|
||||
def _get_kwargs_func_exclusion_list() -> List[str]:
|
||||
def _get_kwargs_func_exclusion_list() -> list[str]:
|
||||
# functions that currently don't work with kwargs in test_overrides.py
|
||||
return [
|
||||
"diagonal",
|
||||
@ -87,12 +90,12 @@ def gen_annotated_args(f: NativeFunction) -> str:
|
||||
]
|
||||
|
||||
def _add_out_arg(
|
||||
out_args: List[Dict[str, Any]], args: Sequence[Argument], *, is_kwarg_only: bool
|
||||
out_args: list[dict[str, Any]], args: Sequence[Argument], *, is_kwarg_only: bool
|
||||
) -> None:
|
||||
for arg in args:
|
||||
if arg.default is not None:
|
||||
continue
|
||||
out_arg: Dict[str, Any] = {}
|
||||
out_arg: dict[str, Any] = {}
|
||||
out_arg["is_kwarg_only"] = str(is_kwarg_only)
|
||||
out_arg["name"] = arg.name
|
||||
out_arg["simple_type"] = python.argument_type_str(
|
||||
@ -103,7 +106,7 @@ def gen_annotated_args(f: NativeFunction) -> str:
|
||||
out_arg["size"] = size_t
|
||||
out_args.append(out_arg)
|
||||
|
||||
out_args: List[Dict[str, Any]] = []
|
||||
out_args: list[dict[str, Any]] = []
|
||||
_add_out_arg(out_args, f.func.arguments.flat_positional, is_kwarg_only=False)
|
||||
if f"{f.func.name.name}" not in _get_kwargs_func_exclusion_list():
|
||||
_add_out_arg(out_args, f.func.arguments.flat_kwarg_only, is_kwarg_only=True)
|
||||
|
@ -22,9 +22,10 @@ torch/csrc/autograd/generated/
|
||||
# gen_python_functions.py: generates Python bindings to THPVariable
|
||||
#
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from torchgen.api import cpp
|
||||
from torchgen.api.autograd import (
|
||||
@ -69,7 +70,7 @@ def gen_autograd(
|
||||
),
|
||||
key=lambda f: cpp.name(f.func),
|
||||
)
|
||||
fns_with_diff_infos: List[
|
||||
fns_with_diff_infos: list[
|
||||
NativeFunctionWithDifferentiabilityInfo
|
||||
] = match_differentiability_info(fns, differentiability_infos)
|
||||
|
||||
|
@ -4,7 +4,10 @@
|
||||
# Functions.h/cpp: subclasses of autograd::Node
|
||||
# python_functions.h/cpp: Python bindings for the above classes
|
||||
#
|
||||
from typing import Dict, List, Sequence, Tuple
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence
|
||||
|
||||
from torchgen.api.autograd import (
|
||||
Derivative,
|
||||
@ -43,6 +46,7 @@ from torchgen.utils import FileManager
|
||||
|
||||
from .gen_inplace_or_view_type import VIEW_FUNCTIONS
|
||||
|
||||
|
||||
FUNCTION_DECLARATION = CodeTemplate(
|
||||
"""\
|
||||
#ifdef _WIN32
|
||||
@ -443,8 +447,8 @@ UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS
|
||||
|
||||
|
||||
def get_infos_with_derivatives_list(
|
||||
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]]
|
||||
) -> List[DifferentiabilityInfo]:
|
||||
differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]]
|
||||
) -> list[DifferentiabilityInfo]:
|
||||
diff_info_list = [
|
||||
info
|
||||
for diffinfo_dict in differentiability_infos.values()
|
||||
@ -456,7 +460,7 @@ def get_infos_with_derivatives_list(
|
||||
|
||||
def gen_autograd_functions_lib(
|
||||
out: str,
|
||||
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
|
||||
differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
|
||||
template_path: str,
|
||||
) -> None:
|
||||
"""Functions.h and Functions.cpp body
|
||||
@ -490,7 +494,7 @@ def gen_autograd_functions_lib(
|
||||
|
||||
def gen_autograd_functions_python(
|
||||
out: str,
|
||||
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
|
||||
differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
|
||||
template_path: str,
|
||||
) -> None:
|
||||
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
|
||||
@ -536,17 +540,17 @@ def gen_autograd_functions_python(
|
||||
|
||||
|
||||
def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str:
|
||||
saved_variables: List[str] = []
|
||||
release_variables: List[str] = []
|
||||
saved_list_sizes: List[str] = []
|
||||
unpack: List[str] = []
|
||||
asserts: List[str] = []
|
||||
compute_index_ranges: List[str] = []
|
||||
getter_definitions: List[str] = []
|
||||
py_getsetdef_structs: List[str] = []
|
||||
compiled_args: List[str] = []
|
||||
apply_with_saved_before: List[str] = []
|
||||
apply_with_saved_after: List[str] = []
|
||||
saved_variables: list[str] = []
|
||||
release_variables: list[str] = []
|
||||
saved_list_sizes: list[str] = []
|
||||
unpack: list[str] = []
|
||||
asserts: list[str] = []
|
||||
compute_index_ranges: list[str] = []
|
||||
getter_definitions: list[str] = []
|
||||
py_getsetdef_structs: list[str] = []
|
||||
compiled_args: list[str] = []
|
||||
apply_with_saved_before: list[str] = []
|
||||
apply_with_saved_after: list[str] = []
|
||||
|
||||
for arg in info.args_with_derivatives:
|
||||
if arg.type in TENSOR_LIST_LIKE_CTYPES:
|
||||
@ -807,7 +811,7 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
else:
|
||||
will_release_variables = ""
|
||||
|
||||
body: List[str] = []
|
||||
body: list[str] = []
|
||||
|
||||
if uses_single_grad(info):
|
||||
body.append("const auto& grad = grads[0];")
|
||||
@ -821,7 +825,7 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
def emit_derivative(
|
||||
derivative: Derivative,
|
||||
args_with_derivatives: Sequence[Binding],
|
||||
) -> Tuple[bool, str]:
|
||||
) -> tuple[bool, str]:
|
||||
formula = derivative.formula
|
||||
var_names = derivative.var_names
|
||||
if len(var_names) == 1:
|
||||
@ -857,7 +861,7 @@ PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
|
||||
else:
|
||||
grad_input_mask = ""
|
||||
idx_ranges = ", ".join(f"{n}_ix" for n in var_names)
|
||||
copy_ranges: List[str] = []
|
||||
copy_ranges: list[str] = []
|
||||
for i, n in enumerate(var_names):
|
||||
copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i))
|
||||
return False, DERIVATIVE_MULTI.substitute(
|
||||
|
@ -4,7 +4,7 @@
|
||||
# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp
|
||||
# The fallback is expected to mimick this codegen, so we should keep the two in sync.
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from __future__ import annotations
|
||||
|
||||
from torchgen.api import cpp
|
||||
from torchgen.api.autograd import (
|
||||
@ -24,8 +24,7 @@ from torchgen.api.types import (
|
||||
OptionalCType,
|
||||
symIntArrayRefT,
|
||||
SymIntT,
|
||||
# See Note [Nested Arg Types]
|
||||
tensorT,
|
||||
tensorT, # See Note [Nested Arg Types]
|
||||
)
|
||||
from torchgen.code_template import CodeTemplate
|
||||
from torchgen.context import with_native_function
|
||||
@ -46,6 +45,7 @@ from .gen_trace_type import (
|
||||
type_wrapper_name,
|
||||
)
|
||||
|
||||
|
||||
# See NOTE [ Autograd View Variables ] in variable.h for details.
|
||||
# If you update list VIEW_FUNCTIONS or RETURNS_VIEWS_OF_INPUT,
|
||||
# you **MUST** also update the public list of view ops accordingly in
|
||||
@ -281,7 +281,7 @@ def inverse_view_name(f: NativeFunction) -> str:
|
||||
return f"{copy_variant}{overload}_inverse"
|
||||
|
||||
|
||||
def extract_bindings(f: NativeFunction) -> List[Binding]:
|
||||
def extract_bindings(f: NativeFunction) -> list[Binding]:
|
||||
return [
|
||||
r
|
||||
for a in f.func.schema_order_arguments()
|
||||
@ -297,9 +297,9 @@ def extract_bindings(f: NativeFunction) -> List[Binding]:
|
||||
|
||||
|
||||
@with_native_function
|
||||
def unpack_args(f: NativeFunction) -> Tuple[List[str], List[Binding]]:
|
||||
body: List[str] = []
|
||||
unpacked_bindings: List[Binding] = []
|
||||
def unpack_args(f: NativeFunction) -> tuple[list[str], list[Binding]]:
|
||||
body: list[str] = []
|
||||
unpacked_bindings: list[Binding] = []
|
||||
|
||||
for i, binding in enumerate(extract_bindings(f)):
|
||||
assert not isinstance(binding.argument, SelfArgument)
|
||||
@ -338,7 +338,7 @@ def get_base_name(f: NativeFunction) -> str:
|
||||
return f.func.name.name.base # TODO: should be str(f.func.name.name)?
|
||||
|
||||
|
||||
def get_view_info(f: NativeFunction) -> Optional[str]:
|
||||
def get_view_info(f: NativeFunction) -> str | None:
|
||||
base_name = get_base_name(f)
|
||||
view_info = VIEW_FUNCTIONS.get(base_name, None)
|
||||
if view_info is None and base_name in RETURNS_VIEWS_OF_INPUT:
|
||||
@ -347,7 +347,7 @@ def get_view_info(f: NativeFunction) -> Optional[str]:
|
||||
|
||||
|
||||
def emit_view_func(
|
||||
f: NativeFunction, bindings: List[Binding], view_idx: Optional[str] = None
|
||||
f: NativeFunction, bindings: list[Binding], view_idx: str | None = None
|
||||
) -> str:
|
||||
"""Generate an additional lambda function to recover views in backward when as_strided is not supported.
|
||||
See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details.
|
||||
@ -355,8 +355,8 @@ def emit_view_func(
|
||||
# TODO: Clean this logic up if we get rid of reverse view funcs or reify them.
|
||||
input_base = "input_base"
|
||||
replay_view_func = ""
|
||||
updated_args: List[str] = []
|
||||
known_view_arg_simple_types: List[CType] = [
|
||||
updated_args: list[str] = []
|
||||
known_view_arg_simple_types: list[CType] = [
|
||||
BaseCType(longT),
|
||||
OptionalCType(BaseCType(longT)),
|
||||
BaseCType(SymIntT),
|
||||
@ -448,7 +448,7 @@ def emit_view_func(
|
||||
|
||||
def emit_view_body(
|
||||
fn: NativeFunctionWithDifferentiabilityInfo, var: str
|
||||
) -> Tuple[str, str]:
|
||||
) -> tuple[str, str]:
|
||||
# See NOTE [ Autograd View Variables ] in variable.h for details.
|
||||
f = fn.func
|
||||
base_name = get_base_name(f)
|
||||
@ -523,9 +523,9 @@ def modifies_arguments(f: NativeFunction) -> bool:
|
||||
|
||||
|
||||
@with_native_function_with_differentiability_info
|
||||
def emit_inplace_or_view_body(fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]:
|
||||
def emit_inplace_or_view_body(fn: NativeFunctionWithDifferentiabilityInfo) -> list[str]:
|
||||
f = fn.func
|
||||
inplace_view_body: List[str] = []
|
||||
inplace_view_body: list[str] = []
|
||||
|
||||
dispatcher_sig = DispatcherSignature.from_schema(f.func)
|
||||
dispatcher_exprs = dispatcher_sig.exprs()
|
||||
@ -584,7 +584,7 @@ def gen_formals(f: NativeFunction) -> str:
|
||||
@with_native_function_with_differentiability_info
|
||||
def inplace_or_view_method_definition(
|
||||
fn: NativeFunctionWithDifferentiabilityInfo,
|
||||
) -> Optional[str]:
|
||||
) -> str | None:
|
||||
f = fn.func
|
||||
if get_view_info(f) is None and (
|
||||
# For functions that modify their inputs but don't return them,
|
||||
@ -605,7 +605,7 @@ def inplace_or_view_method_definition(
|
||||
@with_native_function_with_differentiability_info
|
||||
def inplace_or_view_method_registration(
|
||||
fn: NativeFunctionWithDifferentiabilityInfo,
|
||||
) -> Optional[str]:
|
||||
) -> str | None:
|
||||
f = fn.func
|
||||
if get_view_info(f) is None and (
|
||||
not modifies_arguments(f) or len(f.func.returns) == 0
|
||||
@ -626,7 +626,7 @@ def use_derived(fn: NativeFunctionWithDifferentiabilityInfo) -> bool:
|
||||
|
||||
def gen_inplace_or_view_type_env(
|
||||
fn: NativeFunctionWithDifferentiabilityInfo,
|
||||
) -> Dict[str, List[str]]:
|
||||
) -> dict[str, list[str]]:
|
||||
definition = inplace_or_view_method_definition(fn)
|
||||
registration = inplace_or_view_method_registration(fn)
|
||||
|
||||
@ -649,7 +649,7 @@ def gen_inplace_or_view_type(
|
||||
out: str,
|
||||
native_yaml_path: str,
|
||||
tags_yaml_path: str,
|
||||
fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo],
|
||||
fns_with_infos: list[NativeFunctionWithDifferentiabilityInfo],
|
||||
template_path: str,
|
||||
) -> None:
|
||||
# NOTE: see Note [Sharded File] at the top of the VariableType.cpp
|
||||
|
@ -31,11 +31,12 @@
|
||||
# message, but use what's there
|
||||
#
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple
|
||||
from typing import Callable, Iterable, Sequence
|
||||
|
||||
import yaml
|
||||
|
||||
@ -56,7 +57,6 @@ from torchgen.api.python import (
|
||||
signature_from_schema,
|
||||
structseq_fieldnames,
|
||||
)
|
||||
|
||||
from torchgen.code_template import CodeTemplate
|
||||
from torchgen.context import with_native_function
|
||||
from torchgen.gen import cpp_string, parse_native_yaml, parse_tags_yaml
|
||||
@ -75,6 +75,7 @@ from torchgen.yaml_utils import YamlLoader
|
||||
from .gen_inplace_or_view_type import is_tensor_list_type
|
||||
from .gen_trace_type import should_trace
|
||||
|
||||
|
||||
#
|
||||
# declarations blocklist
|
||||
# We skip codegen for these functions, for various reasons.
|
||||
@ -369,7 +370,7 @@ def gen(
|
||||
|
||||
valid_tags = parse_tags_yaml(tags_yaml_path)
|
||||
|
||||
def gen_tags_enum() -> Dict[str, str]:
|
||||
def gen_tags_enum() -> dict[str, str]:
|
||||
return {
|
||||
"enum_of_valid_tags": (
|
||||
"".join(
|
||||
@ -384,9 +385,9 @@ def gen(
|
||||
def group_filter_overloads(
|
||||
pairs: Sequence[PythonSignatureNativeFunctionPair],
|
||||
pred: Callable[[NativeFunction], bool],
|
||||
) -> Dict[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]:
|
||||
grouped: Dict[
|
||||
BaseOperatorName, List[PythonSignatureNativeFunctionPair]
|
||||
) -> dict[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]:
|
||||
grouped: dict[
|
||||
BaseOperatorName, list[PythonSignatureNativeFunctionPair]
|
||||
] = defaultdict(list)
|
||||
for pair in pairs:
|
||||
if pred(pair.function):
|
||||
@ -398,17 +399,17 @@ def create_python_bindings(
|
||||
fm: FileManager,
|
||||
pairs: Sequence[PythonSignatureNativeFunctionPair],
|
||||
pred: Callable[[NativeFunction], bool],
|
||||
module: Optional[str],
|
||||
module: str | None,
|
||||
filename: str,
|
||||
*,
|
||||
method: bool,
|
||||
symint: bool = True,
|
||||
) -> None:
|
||||
"""Generates Python bindings to ATen functions"""
|
||||
py_methods: List[str] = []
|
||||
ops_headers: List[str] = []
|
||||
py_method_defs: List[str] = []
|
||||
py_forwards: List[str] = []
|
||||
py_methods: list[str] = []
|
||||
ops_headers: list[str] = []
|
||||
py_method_defs: list[str] = []
|
||||
py_forwards: list[str] = []
|
||||
|
||||
grouped = group_filter_overloads(pairs, pred)
|
||||
|
||||
@ -445,8 +446,8 @@ def create_python_return_type_bindings(
|
||||
Generate function to initialize and return named tuple for native functions
|
||||
which returns named tuple and registration invocations in `python_return_types.cpp`.
|
||||
"""
|
||||
py_return_types_definition: List[str] = []
|
||||
py_return_types_registrations: List[str] = []
|
||||
py_return_types_definition: list[str] = []
|
||||
py_return_types_registrations: list[str] = []
|
||||
|
||||
grouped = group_filter_overloads(pairs, pred)
|
||||
|
||||
@ -484,7 +485,7 @@ def create_python_return_type_bindings_header(
|
||||
Generate function to initialize and return named tuple for native functions
|
||||
which returns named tuple and relevant entry for the map in `python_return_types.cpp`.
|
||||
"""
|
||||
py_return_types_declarations: List[str] = []
|
||||
py_return_types_declarations: list[str] = []
|
||||
|
||||
grouped = group_filter_overloads(pairs, pred)
|
||||
|
||||
@ -510,7 +511,7 @@ def create_python_bindings_sharded(
|
||||
fm: FileManager,
|
||||
pairs: Sequence[PythonSignatureNativeFunctionPair],
|
||||
pred: Callable[[NativeFunction], bool],
|
||||
module: Optional[str],
|
||||
module: str | None,
|
||||
filename: str,
|
||||
*,
|
||||
method: bool,
|
||||
@ -521,13 +522,13 @@ def create_python_bindings_sharded(
|
||||
grouped = group_filter_overloads(pairs, pred)
|
||||
|
||||
def key_func(
|
||||
kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
|
||||
kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]
|
||||
) -> str:
|
||||
return kv[0].base
|
||||
|
||||
def env_func(
|
||||
kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
|
||||
) -> Dict[str, List[str]]:
|
||||
kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]
|
||||
) -> dict[str, list[str]]:
|
||||
name, fn_pairs = kv
|
||||
return {
|
||||
"ops_headers": [f"#include <ATen/ops/{name.base}.h>"],
|
||||
@ -553,7 +554,7 @@ def create_python_bindings_sharded(
|
||||
|
||||
|
||||
def load_signatures(
|
||||
native_functions: List[NativeFunction],
|
||||
native_functions: list[NativeFunction],
|
||||
deprecated_yaml_path: str,
|
||||
*,
|
||||
method: bool,
|
||||
@ -580,19 +581,19 @@ def load_deprecated_signatures(
|
||||
*,
|
||||
method: bool,
|
||||
pyi: bool,
|
||||
) -> List[PythonSignatureNativeFunctionPair]:
|
||||
) -> list[PythonSignatureNativeFunctionPair]:
|
||||
# The deprecated.yaml doesn't have complete type information, we need
|
||||
# find and leverage the original ATen signature (to which it delegates
|
||||
# the call) to generate the full python signature.
|
||||
# We join the deprecated and the original signatures using type-only form.
|
||||
|
||||
# group the original ATen signatures by name
|
||||
grouped: Dict[str, List[PythonSignatureNativeFunctionPair]] = defaultdict(list)
|
||||
grouped: dict[str, list[PythonSignatureNativeFunctionPair]] = defaultdict(list)
|
||||
for pair in pairs:
|
||||
grouped[pair.signature.name].append(pair)
|
||||
|
||||
# find matching original signatures for each deprecated signature
|
||||
results: List[PythonSignatureNativeFunctionPair] = []
|
||||
results: list[PythonSignatureNativeFunctionPair] = []
|
||||
|
||||
with open(deprecated_yaml_path) as f:
|
||||
deprecated_defs = yaml.load(f, Loader=YamlLoader)
|
||||
@ -701,15 +702,15 @@ def gen_structseq_typename_key(f: NativeFunction) -> str:
|
||||
|
||||
def emit_structseq_call(
|
||||
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
||||
) -> Tuple[List[str], Dict[str, str]]:
|
||||
) -> tuple[list[str], dict[str, str]]:
|
||||
"""
|
||||
Generate block of named tuple type def inits, and add typeref snippets
|
||||
to declarations that use them
|
||||
"""
|
||||
typenames: Dict[
|
||||
typenames: dict[
|
||||
str, str
|
||||
] = {} # map from unique name + field name lists to typedef name
|
||||
typedefs: List[str] = [] # typedef declarations and init code
|
||||
typedefs: list[str] = [] # typedef declarations and init code
|
||||
|
||||
for overload in overloads:
|
||||
fieldnames = structseq_fieldnames(overload.function.func.returns)
|
||||
@ -732,17 +733,17 @@ static PyTypeObject* {typename} = generated::get_{name}_structseq();"""
|
||||
|
||||
def generate_return_type_definition_and_registrations(
|
||||
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
Generate block of function in `python_return_types.cpp` to initialize
|
||||
and return named tuple for a native function which returns named tuple
|
||||
and registration invocations in same file.
|
||||
"""
|
||||
typenames: Dict[
|
||||
typenames: dict[
|
||||
str, str
|
||||
] = {} # map from unique name + field name lists to typedef name
|
||||
definitions: List[str] = [] # function definition to register the typedef
|
||||
registrations: List[str] = [] # register call for the typedef
|
||||
definitions: list[str] = [] # function definition to register the typedef
|
||||
registrations: list[str] = [] # register call for the typedef
|
||||
|
||||
for overload in overloads:
|
||||
fieldnames = structseq_fieldnames(overload.function.func.returns)
|
||||
@ -783,15 +784,15 @@ PyTypeObject* get_{name}_structseq() {{
|
||||
|
||||
def generate_return_type_declarations(
|
||||
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""
|
||||
Generate block of function declarations in `python_return_types.h` to initialize
|
||||
and return named tuple for a native function.
|
||||
"""
|
||||
typenames: Dict[
|
||||
typenames: dict[
|
||||
str, str
|
||||
] = {} # map from unique name + field name lists to typedef name
|
||||
declarations: List[str] = [] # function declaration to register the typedef
|
||||
declarations: list[str] = [] # function declaration to register the typedef
|
||||
|
||||
for overload in overloads:
|
||||
fieldnames = structseq_fieldnames(overload.function.func.returns)
|
||||
@ -891,7 +892,7 @@ static PyObject * ${pycname}(PyObject* self_, PyObject* args)
|
||||
|
||||
def method_impl(
|
||||
name: BaseOperatorName,
|
||||
module: Optional[str],
|
||||
module: str | None,
|
||||
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
||||
*,
|
||||
method: bool,
|
||||
@ -918,8 +919,8 @@ def method_impl(
|
||||
overloads, symint=symint
|
||||
)
|
||||
is_singleton = len(grouped_overloads) == 1
|
||||
signatures: List[str] = []
|
||||
dispatch: List[str] = []
|
||||
signatures: list[str] = []
|
||||
dispatch: list[str] = []
|
||||
for overload_index, overload in enumerate(grouped_overloads):
|
||||
signature = overload.signature.signature_str(symint=symint)
|
||||
signatures.append(f"{cpp_string(str(signature))},")
|
||||
@ -959,7 +960,7 @@ def method_impl(
|
||||
|
||||
|
||||
def gen_has_torch_function_check(
|
||||
name: BaseOperatorName, module: Optional[str], *, noarg: bool, method: bool
|
||||
name: BaseOperatorName, module: str | None, *, noarg: bool, method: bool
|
||||
) -> str:
|
||||
if noarg:
|
||||
if method:
|
||||
@ -1007,7 +1008,7 @@ if (_r.isNone(${out_idx})) {
|
||||
|
||||
def emit_dispatch_case(
|
||||
overload: PythonSignatureGroup,
|
||||
structseq_typenames: Dict[str, str],
|
||||
structseq_typenames: dict[str, str],
|
||||
*,
|
||||
symint: bool = True,
|
||||
) -> str:
|
||||
@ -1050,7 +1051,7 @@ def forward_decls(
|
||||
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
||||
*,
|
||||
method: bool,
|
||||
) -> Tuple[str, ...]:
|
||||
) -> tuple[str, ...]:
|
||||
if method:
|
||||
return ()
|
||||
|
||||
@ -1078,7 +1079,7 @@ static PyObject * {pycname}(PyObject* self_, PyObject* args, PyObject* kwargs);
|
||||
|
||||
def method_def(
|
||||
name: BaseOperatorName,
|
||||
module: Optional[str],
|
||||
module: str | None,
|
||||
overloads: Sequence[PythonSignatureNativeFunctionPair],
|
||||
*,
|
||||
method: bool,
|
||||
@ -1114,8 +1115,8 @@ def method_def(
|
||||
def group_overloads(
|
||||
overloads: Sequence[PythonSignatureNativeFunctionPair], *, symint: bool = True
|
||||
) -> Sequence[PythonSignatureGroup]:
|
||||
bases: Dict[str, PythonSignatureNativeFunctionPair] = {}
|
||||
outplaces: Dict[str, PythonSignatureNativeFunctionPair] = {}
|
||||
bases: dict[str, PythonSignatureNativeFunctionPair] = {}
|
||||
outplaces: dict[str, PythonSignatureNativeFunctionPair] = {}
|
||||
|
||||
# first group by signature ignoring out arguments
|
||||
for overload in overloads:
|
||||
@ -1137,7 +1138,7 @@ def group_overloads(
|
||||
|
||||
for sig, out in outplaces.items():
|
||||
if sig not in bases:
|
||||
candidates: List[str] = []
|
||||
candidates: list[str] = []
|
||||
for overload in overloads:
|
||||
if (
|
||||
str(overload.function.func.name.name)
|
||||
@ -1268,7 +1269,7 @@ def sort_overloads(
|
||||
)
|
||||
|
||||
# Construct the relation graph
|
||||
larger_than: Dict[int, Set[int]] = defaultdict(set)
|
||||
larger_than: dict[int, set[int]] = defaultdict(set)
|
||||
for i1, overload1 in enumerate(grouped_overloads):
|
||||
for i2, overload2 in enumerate(grouped_overloads):
|
||||
if is_smaller(overload1.signature, overload2.signature):
|
||||
@ -1279,7 +1280,7 @@ def sort_overloads(
|
||||
|
||||
# Use a topological sort to sort overloads according to the partial order.
|
||||
N = len(grouped_overloads)
|
||||
sorted_ids: List[int] = list(filter(lambda x: x not in larger_than, range(N)))
|
||||
sorted_ids: list[int] = list(filter(lambda x: x not in larger_than, range(N)))
|
||||
|
||||
for idx in range(N):
|
||||
# The size of sorted_ids will grow to N eventually.
|
||||
@ -1304,7 +1305,7 @@ def sort_overloads(
|
||||
def emit_single_dispatch(
|
||||
ps: PythonSignature,
|
||||
f: NativeFunction,
|
||||
structseq_typenames: Dict[str, str],
|
||||
structseq_typenames: dict[str, str],
|
||||
*,
|
||||
symint: bool = True,
|
||||
) -> str:
|
||||
|
@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
from typing import Dict, List, Sequence, Union
|
||||
from typing import Sequence
|
||||
|
||||
from torchgen.api import cpp
|
||||
from torchgen.api.types import DispatcherSignature
|
||||
@ -8,6 +10,7 @@ from torchgen.context import with_native_function
|
||||
from torchgen.model import Argument, NativeFunction, SchemaKind, TensorOptionsArguments
|
||||
from torchgen.utils import FileManager
|
||||
|
||||
|
||||
# Note [Manual Backend kernels]
|
||||
# For these ops, we want to manually register to dispatch key Backend and
|
||||
# skip codegen-ed registeration to all keys before Backend.
|
||||
@ -136,9 +139,7 @@ ADD_TRACE_INPUT = CodeTemplate("""jit::tracer::addInputs(node, "${name}", ${inpu
|
||||
|
||||
|
||||
def format_trace_inputs(f: NativeFunction) -> str:
|
||||
def dispatch_trace_input(
|
||||
arg: Union[Argument, TensorOptionsArguments]
|
||||
) -> Sequence[str]:
|
||||
def dispatch_trace_input(arg: Argument | TensorOptionsArguments) -> Sequence[str]:
|
||||
if isinstance(arg, TensorOptionsArguments):
|
||||
name = "options"
|
||||
return [
|
||||
@ -156,7 +157,7 @@ def format_trace_inputs(f: NativeFunction) -> str:
|
||||
else:
|
||||
return [ADD_TRACE_INPUT.substitute(name=name, input=name)]
|
||||
|
||||
args: List[Union[Argument, TensorOptionsArguments]] = list(
|
||||
args: list[Argument | TensorOptionsArguments] = list(
|
||||
f.func.schema_order_arguments()
|
||||
)
|
||||
|
||||
@ -399,8 +400,8 @@ ${assign_return_values}at::_ops::${unambiguous_name}::redispatch(${unpacked_args
|
||||
)
|
||||
|
||||
|
||||
def emit_trace_body(f: NativeFunction) -> List[str]:
|
||||
trace_body: List[str] = []
|
||||
def emit_trace_body(f: NativeFunction) -> list[str]:
|
||||
trace_body: list[str] = []
|
||||
|
||||
trace_body.append(format_prerecord_trace(f))
|
||||
|
||||
@ -503,7 +504,7 @@ def method_registration(f: NativeFunction) -> str:
|
||||
)
|
||||
|
||||
|
||||
def gen_trace_type_func(fn: NativeFunction) -> Dict[str, List[str]]:
|
||||
def gen_trace_type_func(fn: NativeFunction) -> dict[str, list[str]]:
|
||||
return {
|
||||
"ops_headers": [f"#include <ATen/ops/{fn.root_name}_ops.h>"],
|
||||
"trace_method_definitions": [method_definition(fn)],
|
||||
@ -512,7 +513,7 @@ def gen_trace_type_func(fn: NativeFunction) -> Dict[str, List[str]]:
|
||||
|
||||
|
||||
def gen_trace_type(
|
||||
out: str, native_functions: List[NativeFunction], template_path: str
|
||||
out: str, native_functions: list[NativeFunction], template_path: str
|
||||
) -> None:
|
||||
# NOTE: see Note [Sharded File] at the top of the VariableType.cpp
|
||||
# template regarding sharding of the generated files.
|
||||
|
@ -2,18 +2,19 @@
|
||||
#
|
||||
# This writes one file: variable_factories.h
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
import torchgen.api.python as python
|
||||
from torchgen.api import cpp
|
||||
|
||||
from torchgen.api.types import CppSignatureGroup
|
||||
from torchgen.context import with_native_function
|
||||
from torchgen.gen import parse_native_yaml
|
||||
from torchgen.model import NativeFunction, TensorOptionsArguments, Variant
|
||||
from torchgen.utils import FileManager, mapMaybe
|
||||
|
||||
|
||||
OPTIONAL_TYPE_PATTERN = re.compile(r"c10::optional<(.+)>")
|
||||
TYPE_PATTERN = re.compile(r"(?:const\s+)?([A-Z]\w+)")
|
||||
|
||||
@ -69,7 +70,7 @@ def is_factory_function(f: NativeFunction) -> bool:
|
||||
|
||||
|
||||
@with_native_function
|
||||
def process_function(f: NativeFunction) -> Optional[str]:
|
||||
def process_function(f: NativeFunction) -> str | None:
|
||||
name = cpp.name(f.func)
|
||||
has_tensor_options = python.has_tensor_options(f)
|
||||
is_factory = has_tensor_options or name.endswith("_like")
|
||||
@ -83,8 +84,8 @@ def process_function(f: NativeFunction) -> Optional[str]:
|
||||
sigs.append(cpp_sigs.symint_signature)
|
||||
r = ""
|
||||
for sig in sigs:
|
||||
formals: List[str] = []
|
||||
exprs: List[str] = []
|
||||
formals: list[str] = []
|
||||
exprs: list[str] = []
|
||||
requires_grad = "false"
|
||||
for arg in sig.arguments():
|
||||
qualified_type = fully_qualified_type(arg.type)
|
||||
|
@ -25,8 +25,11 @@
|
||||
# which will in turn dispatch back to VariableType for its
|
||||
# differentiable subcomponents.
|
||||
#
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
|
||||
from typing import Callable, Sequence
|
||||
|
||||
from torchgen.api import cpp
|
||||
from torchgen.api.autograd import (
|
||||
@ -38,7 +41,6 @@ from torchgen.api.autograd import (
|
||||
NativeFunctionWithDifferentiabilityInfo,
|
||||
SavedAttribute,
|
||||
)
|
||||
|
||||
from torchgen.api.types import (
|
||||
ArrayRefCType,
|
||||
BaseCppType,
|
||||
@ -103,6 +105,7 @@ from .gen_trace_type import (
|
||||
type_wrapper_name,
|
||||
)
|
||||
|
||||
|
||||
# We don't set or modify grad_fn on these methods. Generally, they return
|
||||
# tensors that have requires_grad=False. In-place functions listed here will
|
||||
# not examine or modify requires_grad or grad_fn.
|
||||
@ -837,9 +840,9 @@ def gen_variable_type(
|
||||
out: str,
|
||||
native_yaml_path: str,
|
||||
tags_yaml_path: str,
|
||||
fns_with_diff_infos: List[NativeFunctionWithDifferentiabilityInfo],
|
||||
fns_with_diff_infos: list[NativeFunctionWithDifferentiabilityInfo],
|
||||
template_path: str,
|
||||
used_keys: Set[str],
|
||||
used_keys: set[str],
|
||||
) -> None:
|
||||
"""VariableType.h and VariableType.cpp body
|
||||
|
||||
@ -858,8 +861,8 @@ def gen_variable_type(
|
||||
|
||||
# helper that generates a TORCH_LIBRARY_IMPL macro for each
|
||||
# dispatch key that appears in derivatives.yaml
|
||||
def wrapper_registrations(used_keys: Set[str]) -> str:
|
||||
library_impl_macro_list: List[str] = []
|
||||
def wrapper_registrations(used_keys: set[str]) -> str:
|
||||
library_impl_macro_list: list[str] = []
|
||||
for key in sorted(used_keys):
|
||||
dispatch_key = key
|
||||
if key == "Default":
|
||||
@ -926,7 +929,7 @@ def gen_wrapper_registration(f: NativeFunction, key: str = "Default") -> str:
|
||||
|
||||
def gen_variable_type_func(
|
||||
fn: NativeFunctionWithDifferentiabilityInfo,
|
||||
) -> Dict[str, List[str]]:
|
||||
) -> dict[str, list[str]]:
|
||||
f = fn.func
|
||||
result = {}
|
||||
with native_function_manager(f):
|
||||
@ -1034,7 +1037,7 @@ _foreach_ops_with_different_arity = {
|
||||
@with_native_function_with_differentiability_info_and_key
|
||||
def emit_body(
|
||||
fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
assert dispatch_strategy(fn) == "use_derived"
|
||||
f = fn.func
|
||||
info = fn.info[key] if fn.info else None
|
||||
@ -1050,8 +1053,8 @@ def emit_body(
|
||||
is_foreach = name.startswith("_foreach")
|
||||
is_inplace_foreach = is_foreach and inplace
|
||||
if is_inplace_foreach:
|
||||
inplace_foreacharg2refarg: Dict[Argument, Argument] = {}
|
||||
refargname2inplace_foreacharg: Dict[str, Argument] = {}
|
||||
inplace_foreacharg2refarg: dict[Argument, Argument] = {}
|
||||
refargname2inplace_foreacharg: dict[str, Argument] = {}
|
||||
base_name_and_overload_name = (f.func.name.name.base, f.func.name.overload_name)
|
||||
if info is None:
|
||||
assert (
|
||||
@ -1077,8 +1080,8 @@ def emit_body(
|
||||
refargname2inplace_foreacharg[ref_arg.name] = foreach_arg
|
||||
|
||||
def gen_differentiable_input(
|
||||
arg: Union[Argument, SelfArgument, TensorOptionsArguments]
|
||||
) -> Optional[DifferentiableInput]:
|
||||
arg: Argument | SelfArgument | TensorOptionsArguments,
|
||||
) -> DifferentiableInput | None:
|
||||
if isinstance(arg, TensorOptionsArguments):
|
||||
return None
|
||||
a: Argument = arg.argument if isinstance(arg, SelfArgument) else arg
|
||||
@ -1097,7 +1100,7 @@ def emit_body(
|
||||
)
|
||||
|
||||
@with_native_function
|
||||
def gen_differentiable_inputs(f: NativeFunction) -> List[DifferentiableInput]:
|
||||
def gen_differentiable_inputs(f: NativeFunction) -> list[DifferentiableInput]:
|
||||
arguments = list(f.func.arguments.non_out)
|
||||
if is_inplace_foreach and info is not None:
|
||||
for i, arg in enumerate(f.func.arguments.flat_non_out):
|
||||
@ -1115,8 +1118,8 @@ def emit_body(
|
||||
return list(mapMaybe(gen_differentiable_input, arguments))
|
||||
|
||||
def find_args_with_derivatives(
|
||||
differentiable_inputs: List[DifferentiableInput],
|
||||
) -> List[DifferentiableInput]:
|
||||
differentiable_inputs: list[DifferentiableInput],
|
||||
) -> list[DifferentiableInput]:
|
||||
"""Find arguments that have derivative definitions"""
|
||||
if info is None or not info.has_derivatives:
|
||||
return differentiable_inputs
|
||||
@ -1178,8 +1181,8 @@ def emit_body(
|
||||
and (not returns_void)
|
||||
)
|
||||
|
||||
def emit_save_inputs() -> List[str]:
|
||||
setup: List[str] = []
|
||||
def emit_save_inputs() -> list[str]:
|
||||
setup: list[str] = []
|
||||
if info is None or not info.has_derivatives:
|
||||
return setup
|
||||
|
||||
@ -1189,7 +1192,7 @@ def emit_body(
|
||||
|
||||
# We don't want to save tensors if we know that they will never be used
|
||||
# when computing the derivative, so we add guards to those statements
|
||||
def guard_for(arg: SavedAttribute) -> Optional[str]:
|
||||
def guard_for(arg: SavedAttribute) -> str | None:
|
||||
assert info is not None
|
||||
|
||||
# It's hard to determine the edge offset if we have TensorLists
|
||||
@ -1276,8 +1279,8 @@ def emit_body(
|
||||
setup.append(f"grad_fn->{arg.name}_size_ = {arg.name}.size();")
|
||||
return setup
|
||||
|
||||
def setup_derivative(differentiable_inputs: List[DifferentiableInput]) -> List[str]:
|
||||
body: List[str] = []
|
||||
def setup_derivative(differentiable_inputs: list[DifferentiableInput]) -> list[str]:
|
||||
body: list[str] = []
|
||||
if is_out_fn:
|
||||
# For out functions, ensure that no input or output requires grad
|
||||
body.append(DECLARE_GRAD_FN.substitute(op="Node"))
|
||||
@ -1343,8 +1346,8 @@ def emit_body(
|
||||
body.append(SETUP_DERIVATIVE.substitute(setup=setup))
|
||||
return body
|
||||
|
||||
def emit_check_if_in_complex_autograd_allowlist() -> List[str]:
|
||||
body: List[str] = []
|
||||
def emit_check_if_in_complex_autograd_allowlist() -> list[str]:
|
||||
body: list[str] = []
|
||||
if base_name in GRADIENT_IMPLEMENTED_FOR_COMPLEX:
|
||||
return body
|
||||
for arg in differentiable_outputs:
|
||||
@ -1355,11 +1358,11 @@ def emit_body(
|
||||
return body
|
||||
|
||||
def emit_check_no_requires_grad(
|
||||
tensor_args: List[DifferentiableInput],
|
||||
args_with_derivatives: List[DifferentiableInput],
|
||||
) -> List[str]:
|
||||
tensor_args: list[DifferentiableInput],
|
||||
args_with_derivatives: list[DifferentiableInput],
|
||||
) -> list[str]:
|
||||
"""Checks that arguments without derivatives don't require grad"""
|
||||
body: List[str] = []
|
||||
body: list[str] = []
|
||||
for arg in tensor_args:
|
||||
if arg in args_with_derivatives:
|
||||
continue
|
||||
@ -1373,8 +1376,8 @@ def emit_body(
|
||||
body.append(f'check_no_requires_grad({arg_name}, "{arg_name}", "{name}");')
|
||||
return body
|
||||
|
||||
def emit_original_self_definition() -> List[str]:
|
||||
body: List[str] = []
|
||||
def emit_original_self_definition() -> list[str]:
|
||||
body: list[str] = []
|
||||
if inplace:
|
||||
if is_inplace_foreach:
|
||||
body.append(
|
||||
@ -1412,17 +1415,17 @@ def emit_body(
|
||||
def save_variables(
|
||||
saved_variables: Sequence[SavedAttribute],
|
||||
is_output: bool,
|
||||
guard_for: Callable[[SavedAttribute], Optional[str]] = lambda name: None,
|
||||
guard_for: Callable[[SavedAttribute], str | None] = lambda name: None,
|
||||
) -> Sequence[str]:
|
||||
# assign the saved variables to the generated grad_fn
|
||||
stmts: List[str] = []
|
||||
stmts: list[str] = []
|
||||
for arg in sorted(saved_variables, key=lambda sa: str(sa.nctype.name)):
|
||||
name = (
|
||||
arg.nctype.name.name
|
||||
if isinstance(arg.nctype.name, SpecialArgName)
|
||||
else arg.nctype.name
|
||||
)
|
||||
foreacharg: Optional[Argument] = None
|
||||
foreacharg: Argument | None = None
|
||||
is_foreacharg_list_type: bool = False
|
||||
type = arg.nctype.type
|
||||
expr = arg.expr
|
||||
@ -1539,10 +1542,10 @@ def emit_body(
|
||||
return call
|
||||
|
||||
def wrap_output(
|
||||
f: NativeFunction, unpacked_bindings: List[Binding], var: str
|
||||
f: NativeFunction, unpacked_bindings: list[Binding], var: str
|
||||
) -> str:
|
||||
call = ""
|
||||
rhs_value: Optional[str] = None
|
||||
rhs_value: str | None = None
|
||||
if not any(r.type.is_tensor_like() for r in f.func.returns):
|
||||
rhs_value = var
|
||||
else:
|
||||
@ -1554,11 +1557,11 @@ def emit_body(
|
||||
return call
|
||||
|
||||
def check_tensorimpl_and_storage(
|
||||
call: str, unpacked_bindings: List[Binding]
|
||||
call: str, unpacked_bindings: list[Binding]
|
||||
) -> str:
|
||||
# See NOTE [ TensorImpl and Storage Pointer Sanity Checks ]
|
||||
stmts_before_call: List[str] = []
|
||||
stmts_after_call: List[str] = []
|
||||
stmts_before_call: list[str] = []
|
||||
stmts_after_call: list[str] = []
|
||||
|
||||
if cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE:
|
||||
return call
|
||||
@ -1665,7 +1668,7 @@ def emit_body(
|
||||
return call
|
||||
|
||||
def emit_call(
|
||||
f: NativeFunction, unpacked_bindings: List[Binding], try_jit_decomposition: bool
|
||||
f: NativeFunction, unpacked_bindings: list[Binding], try_jit_decomposition: bool
|
||||
) -> str:
|
||||
# We only care about adding `at::AutoDispatchBelowAutograd` guard for non-variable dispatch
|
||||
# (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure
|
||||
@ -1764,7 +1767,7 @@ def emit_body(
|
||||
)
|
||||
return ""
|
||||
|
||||
def emit_any_requires_grad() -> List[str]:
|
||||
def emit_any_requires_grad() -> list[str]:
|
||||
extra_condition = ""
|
||||
if info and info.output_differentiability_conditions:
|
||||
assert len(info.output_differentiability_conditions) == 1
|
||||
@ -1782,14 +1785,14 @@ def emit_body(
|
||||
)
|
||||
]
|
||||
|
||||
def get_any_has_forward_grad_name(var_names: Tuple[str, ...]) -> str:
|
||||
def get_any_has_forward_grad_name(var_names: tuple[str, ...]) -> str:
|
||||
if len(var_names) == 1:
|
||||
return f"_any_has_forward_grad_{var_names[0]}"
|
||||
else:
|
||||
return f'_any_has_forward_grad_{"_".join(var_names)}'
|
||||
|
||||
def emit_any_has_forward_grad() -> List[str]:
|
||||
content: List[str] = []
|
||||
def emit_any_has_forward_grad() -> list[str]:
|
||||
content: list[str] = []
|
||||
if not is_foreach:
|
||||
for derivative in fw_derivatives:
|
||||
requires_fw_grad = get_any_has_fw_grad_cond(derivative=derivative)
|
||||
@ -1844,7 +1847,7 @@ def emit_body(
|
||||
content.append("}")
|
||||
return content
|
||||
|
||||
def emit_check_inplace() -> List[str]:
|
||||
def emit_check_inplace() -> list[str]:
|
||||
if not inplace:
|
||||
return []
|
||||
return [
|
||||
@ -1852,9 +1855,9 @@ def emit_body(
|
||||
for arg in differentiable_outputs
|
||||
]
|
||||
|
||||
def emit_fw_derivatives() -> List[str]:
|
||||
content: List[str] = []
|
||||
fw_grad_setters: List[str] = []
|
||||
def emit_fw_derivatives() -> list[str]:
|
||||
content: list[str] = []
|
||||
fw_grad_setters: list[str] = []
|
||||
for derivative in fw_derivatives:
|
||||
res = derivative.var_names
|
||||
if f.func.name.name.inplace:
|
||||
@ -2002,7 +2005,7 @@ def emit_body(
|
||||
"(self.size(), c10::nullopt);"
|
||||
)
|
||||
foreach_forward_grad_formula = derivative.formula
|
||||
_foreach_arg: Union[Argument, DifferentiableInput]
|
||||
_foreach_arg: Argument | DifferentiableInput
|
||||
if inplace:
|
||||
for _foreach_arg, _ref_arg in inplace_foreacharg2refarg.items():
|
||||
# note(crcrpar): Massage only Scalar and ArrayRef<Scalar> here.
|
||||
@ -2044,7 +2047,7 @@ def emit_body(
|
||||
content.append("\n".join(fw_grad_setters))
|
||||
return content
|
||||
|
||||
def get_any_has_fw_grad_cond(derivative: Optional[ForwardDerivative]) -> str:
|
||||
def get_any_has_fw_grad_cond(derivative: ForwardDerivative | None) -> str:
|
||||
#
|
||||
# Produces a condition string (e.g, "isFwGradDefined(grad_output) || isFwGradDefined(output)")
|
||||
#
|
||||
@ -2053,7 +2056,7 @@ def emit_body(
|
||||
# - Used in the out_fn case when we want to forbid fw derivatives
|
||||
# - Used in the case where the fw_derivative is not defined, but we want
|
||||
# To check if there is a decomposition registered for jvp
|
||||
to_check: List[str] = []
|
||||
to_check: list[str] = []
|
||||
for inp in list(
|
||||
mapMaybe(
|
||||
gen_differentiable_input,
|
||||
@ -2126,7 +2129,7 @@ def emit_body(
|
||||
else ""
|
||||
)
|
||||
|
||||
body: List[str] = []
|
||||
body: list[str] = []
|
||||
unpack_args_stats, unpacked_bindings = unpack_args(f)
|
||||
|
||||
body.extend(unpack_args_stats)
|
||||
|
@ -4,10 +4,11 @@
|
||||
# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp
|
||||
# The fallback is expected to mimic this codegen, so we should keep the two in sync.
|
||||
|
||||
from typing import List, Tuple
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torchgen.api.dispatcher as dispatcher
|
||||
from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo
|
||||
from torchgen.api.translate import translate
|
||||
from torchgen.api.types import (
|
||||
BaseCType,
|
||||
@ -29,6 +30,11 @@ from .gen_inplace_or_view_type import (
|
||||
use_derived,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo
|
||||
|
||||
|
||||
FUNCTION_DECLARATION = CodeTemplate(
|
||||
"""\
|
||||
#define ${uppercase_op}_AVAILABLE
|
||||
@ -155,9 +161,9 @@ def returns_multi_tensor(fn: NativeFunction) -> bool:
|
||||
# tuple: (list of getter logic strings, list of setter logic strings, string
|
||||
# with num items expression)
|
||||
def generate_state_getter_setter(
|
||||
bindings: List[Binding],
|
||||
bindings: list[Binding],
|
||||
state_vec_type: NamedCType,
|
||||
) -> Tuple[List[str], List[str], str]:
|
||||
) -> tuple[list[str], list[str], str]:
|
||||
getter_logic = []
|
||||
setter_logic = []
|
||||
|
||||
@ -302,7 +308,7 @@ def process_function(fn: NativeFunction, template: CodeTemplate) -> str:
|
||||
|
||||
def gen_view_funcs(
|
||||
out: str,
|
||||
fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo],
|
||||
fns_with_infos: list[NativeFunctionWithDifferentiabilityInfo],
|
||||
template_path: str,
|
||||
) -> None:
|
||||
# don't need the info parts, just the function
|
||||
|
@ -2,14 +2,16 @@
|
||||
#
|
||||
# Each autograd function is represented by `DifferentiabilityInfo` containing
|
||||
# a list of `Derivative`. See `torchgen.api.autograd` for the data models.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from typing import Any, Counter, Dict, List, Match, Optional, Sequence, Set, Tuple
|
||||
from typing import Any, Counter, Dict, Sequence, Set, Tuple
|
||||
|
||||
import yaml
|
||||
|
||||
from torchgen.api import cpp
|
||||
|
||||
from torchgen.api.autograd import (
|
||||
Derivative,
|
||||
DifferentiabilityInfo,
|
||||
@ -50,9 +52,10 @@ from torchgen.model import (
|
||||
from torchgen.utils import concatMap, IDENT_REGEX, split_name_params
|
||||
from torchgen.yaml_utils import YamlLoader
|
||||
|
||||
|
||||
DerivativeRet = Tuple[Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], Set[str]]
|
||||
|
||||
_GLOBAL_LOAD_DERIVATIVE_CACHE: Dict[Tuple[str, str], DerivativeRet] = {}
|
||||
_GLOBAL_LOAD_DERIVATIVE_CACHE: dict[tuple[str, str], DerivativeRet] = {}
|
||||
|
||||
_VALID_AUTOGRAD_KEYS = set(AUTOGRAD_KEYS)
|
||||
|
||||
@ -62,11 +65,11 @@ _VALID_AUTOGRAD_KEYS = set(AUTOGRAD_KEYS)
|
||||
# we generate them here instead of duplicating them in the yaml.
|
||||
# See Note [Codegen'd {view}_copy Operators]
|
||||
def add_view_copy_derivatives(
|
||||
infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
|
||||
view_groups: List[NativeFunctionsViewGroup],
|
||||
infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
|
||||
view_groups: list[NativeFunctionsViewGroup],
|
||||
) -> None:
|
||||
# Get the map from each view op's name to its corresponding view group
|
||||
view_name_to_group: Dict[OperatorName, NativeFunctionsViewGroup] = {
|
||||
view_name_to_group: dict[OperatorName, NativeFunctionsViewGroup] = {
|
||||
g.view.func.name: g for g in view_groups
|
||||
}
|
||||
|
||||
@ -125,10 +128,10 @@ def load_derivatives(
|
||||
# function schema is the complete declaration including mutability annotation / default value and etc.
|
||||
# signature is the canonical schema for a group of functions (in-place/out/functional variants)
|
||||
# that are semantically related.
|
||||
functions_by_signature: Dict[
|
||||
FunctionSchema, List[NativeFunction]
|
||||
functions_by_signature: dict[
|
||||
FunctionSchema, list[NativeFunction]
|
||||
] = defaultdict(list)
|
||||
functions_by_schema: Dict[str, NativeFunction] = {}
|
||||
functions_by_schema: dict[str, NativeFunction] = {}
|
||||
for function in native_functions:
|
||||
functions_by_signature[function.func.signature()].append(function)
|
||||
assert str(function.func) not in functions_by_schema
|
||||
@ -141,8 +144,8 @@ def load_derivatives(
|
||||
# infos is a dict that maps FunctionSchema -> a dict of per dispatch key DifferentiabilityInfos
|
||||
# this is useful because in tools/autograd/gen_autograd.py:match_differentiability_info
|
||||
# we ultimately need to categorize the DifferentiabilityInfos by FunctionSchema
|
||||
infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]] = {}
|
||||
used_dispatch_keys: Set[str] = set()
|
||||
infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]] = {}
|
||||
used_dispatch_keys: set[str] = set()
|
||||
for defn_dict in definitions:
|
||||
# Ensure that the old derivatives.yaml schema with no dispatch key can be loaded.
|
||||
if "dispatch" not in defn_dict:
|
||||
@ -185,11 +188,11 @@ def cpp_arguments(f: NativeFunction) -> Sequence[Binding]:
|
||||
def create_derivative(
|
||||
f: NativeFunction,
|
||||
formula: str,
|
||||
var_names: Tuple[str, ...],
|
||||
var_names: tuple[str, ...],
|
||||
available_named_gradients: Sequence[str],
|
||||
) -> Derivative:
|
||||
original_formula = formula
|
||||
arguments: List[NamedCType] = [
|
||||
arguments: list[NamedCType] = [
|
||||
a.nctype.remove_const_ref() for a in cpp_arguments(f)
|
||||
]
|
||||
|
||||
@ -230,10 +233,10 @@ def create_derivative(
|
||||
|
||||
|
||||
def create_forward_derivative(
|
||||
f: NativeFunction, formula: str, names: Tuple[str, ...]
|
||||
f: NativeFunction, formula: str, names: tuple[str, ...]
|
||||
) -> ForwardDerivative:
|
||||
var_names = names
|
||||
var_types: Optional[Tuple[Type, ...]] = None
|
||||
var_types: tuple[Type, ...] | None = None
|
||||
for r in f.func.returns:
|
||||
if r.name in var_names:
|
||||
if var_types is None:
|
||||
@ -269,12 +272,12 @@ def create_forward_derivative(
|
||||
def postprocess_forward_derivatives(
|
||||
f: NativeFunction,
|
||||
defn_name: str,
|
||||
all_arg_names: List[str],
|
||||
derivatives: List[Derivative],
|
||||
forward_derivatives: List[ForwardDerivative],
|
||||
all_arg_names: list[str],
|
||||
derivatives: list[Derivative],
|
||||
forward_derivatives: list[ForwardDerivative],
|
||||
args_with_derivatives: Sequence[Binding],
|
||||
) -> List[ForwardDerivative]:
|
||||
def find_required_inputs(formula: str, postfix: str) -> Tuple[str, ...]:
|
||||
) -> list[ForwardDerivative]:
|
||||
def find_required_inputs(formula: str, postfix: str) -> tuple[str, ...]:
|
||||
is_foreach = f.func.name.name.base.startswith("_foreach_")
|
||||
required_inputs = set()
|
||||
for arg in args_with_derivatives:
|
||||
@ -300,7 +303,7 @@ def postprocess_forward_derivatives(
|
||||
|
||||
return tuple(required_inputs)
|
||||
|
||||
updated_derivatives: List[ForwardDerivative] = []
|
||||
updated_derivatives: list[ForwardDerivative] = []
|
||||
|
||||
for defn in forward_derivatives:
|
||||
formula = defn.formula
|
||||
@ -430,7 +433,7 @@ def postprocess_forward_derivatives(
|
||||
|
||||
|
||||
def is_forward_derivative_definition(
|
||||
all_arg_names: List[str], names: Tuple[str, ...]
|
||||
all_arg_names: list[str], names: tuple[str, ...]
|
||||
) -> bool:
|
||||
for name in names:
|
||||
if name not in all_arg_names:
|
||||
@ -441,12 +444,12 @@ def is_forward_derivative_definition(
|
||||
|
||||
|
||||
def create_differentiability_info(
|
||||
defn_dict: Dict[Any, Any],
|
||||
functions_by_signature: Dict[FunctionSchema, List[NativeFunction]],
|
||||
functions_by_schema: Dict[str, NativeFunction],
|
||||
defn_dict: dict[Any, Any],
|
||||
functions_by_signature: dict[FunctionSchema, list[NativeFunction]],
|
||||
functions_by_schema: dict[str, NativeFunction],
|
||||
op_counter: Counter[str],
|
||||
used_dispatch_keys: Set[str],
|
||||
) -> Tuple[FunctionSchema, Dict[str, DifferentiabilityInfo]]:
|
||||
used_dispatch_keys: set[str],
|
||||
) -> tuple[FunctionSchema, dict[str, DifferentiabilityInfo]]:
|
||||
"""Processes a single entry `defn` in derivatives.yaml"""
|
||||
|
||||
def canonical_function(
|
||||
@ -463,7 +466,7 @@ def create_differentiability_info(
|
||||
assert name + "_" == cpp.name(functions[0].func)
|
||||
return functions[0]
|
||||
|
||||
def split_names(raw_names: str) -> Tuple[str, ...]:
|
||||
def split_names(raw_names: str) -> tuple[str, ...]:
|
||||
"""Given "foo, bar", return ["foo", "bar"]."""
|
||||
return tuple(x.strip() for x in raw_names.split(","))
|
||||
|
||||
@ -477,7 +480,7 @@ def create_differentiability_info(
|
||||
uses_grad = False # true if any derivative uses "grad"
|
||||
num_grads_uses = 0 # count of uses of "grads" or "grads[INDEX]"
|
||||
uses_named_grads = False # true if any derivative uses "grad_{name}"
|
||||
used_grads_indices: List[int] = [] # which indices of grads are used
|
||||
used_grads_indices: list[int] = [] # which indices of grads are used
|
||||
for d in derivatives:
|
||||
formula = d.formula
|
||||
uses_grad = uses_grad or bool(
|
||||
@ -521,7 +524,7 @@ def create_differentiability_info(
|
||||
@with_native_function
|
||||
def set_up_derivatives(
|
||||
f: NativeFunction,
|
||||
) -> Tuple[
|
||||
) -> tuple[
|
||||
Sequence[Derivative],
|
||||
Sequence[ForwardDerivative],
|
||||
Sequence[Binding],
|
||||
@ -529,10 +532,10 @@ def create_differentiability_info(
|
||||
Sequence[str],
|
||||
]:
|
||||
# Set up the derivative information
|
||||
derivatives: List[Derivative] = []
|
||||
forward_derivatives: List[ForwardDerivative] = []
|
||||
non_differentiable_arg_names: List[str] = []
|
||||
args_with_derivatives_set: Set[str] = set()
|
||||
derivatives: list[Derivative] = []
|
||||
forward_derivatives: list[ForwardDerivative] = []
|
||||
non_differentiable_arg_names: list[str] = []
|
||||
args_with_derivatives_set: set[str] = set()
|
||||
|
||||
all_arg_names = [a.name for a in cpp_arguments(f)]
|
||||
all_ret_names = [
|
||||
@ -699,7 +702,7 @@ def create_differentiability_info(
|
||||
available_named_gradients,
|
||||
) = set_up_derivatives(canonical)
|
||||
|
||||
used_named_gradients: Set[str] = set()
|
||||
used_named_gradients: set[str] = set()
|
||||
for d in derivatives:
|
||||
used_named_gradients |= d.named_gradients
|
||||
|
||||
@ -738,7 +741,7 @@ def create_differentiability_info(
|
||||
GRAD_INDEX_REGEX = r"(?:^|\W)grads\[(\d+)\]"
|
||||
|
||||
|
||||
def used_gradient_indices(formula: str) -> List[int]:
|
||||
def used_gradient_indices(formula: str) -> list[int]:
|
||||
"""Determine a list of gradient indices (the i in grads[i]) that
|
||||
are used by the formula.
|
||||
|
||||
@ -750,9 +753,9 @@ def used_gradient_indices(formula: str) -> List[int]:
|
||||
|
||||
def saved_variables(
|
||||
formula: str,
|
||||
nctypes: List[NamedCType],
|
||||
var_names: Tuple[str, ...],
|
||||
) -> Tuple[str, Tuple[SavedAttribute, ...]]:
|
||||
nctypes: list[NamedCType],
|
||||
var_names: tuple[str, ...],
|
||||
) -> tuple[str, tuple[SavedAttribute, ...]]:
|
||||
def stride_expr(name: str) -> str:
|
||||
assert var_names == (name,), (
|
||||
'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor '
|
||||
@ -760,7 +763,7 @@ def saved_variables(
|
||||
)
|
||||
return f'strides_or_error({name}, "{name}")'
|
||||
|
||||
REPLACEMENTS: List[Tuple[str, Dict[str, Any]]] = [
|
||||
REPLACEMENTS: list[tuple[str, dict[str, Any]]] = [
|
||||
# replace self.sym_sizes() with self_sym_sizes
|
||||
(
|
||||
r"{}.sym_sizes\(\)",
|
||||
@ -914,7 +917,7 @@ def saved_variables(
|
||||
]
|
||||
|
||||
# find which arguments need to be saved
|
||||
saved: List[SavedAttribute] = []
|
||||
saved: list[SavedAttribute] = []
|
||||
|
||||
if ".sizes()" in formula or "->sizes()" in formula:
|
||||
raise RuntimeError(
|
||||
@ -941,7 +944,7 @@ def saved_variables(
|
||||
# when the autograd Function is created to avoid saving variables
|
||||
for regex, info in REPLACEMENTS:
|
||||
|
||||
def repl(m: Match[str]) -> str:
|
||||
def repl(m: re.Match[str]) -> str:
|
||||
suffix: str = (
|
||||
info["suffix"](m) if callable(info["suffix"]) else info["suffix"]
|
||||
)
|
||||
@ -999,8 +1002,8 @@ def _create_op_prefix(name: str) -> str:
|
||||
|
||||
|
||||
def dedup_vars(vars: Sequence[SavedAttribute]) -> Sequence[SavedAttribute]:
|
||||
seen: Set[str] = set()
|
||||
saved: List[SavedAttribute] = []
|
||||
seen: set[str] = set()
|
||||
saved: list[SavedAttribute] = []
|
||||
for var in vars:
|
||||
name = (
|
||||
var.nctype.name.name
|
||||
|
@ -1,17 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
from glob import glob
|
||||
from typing import Dict, Optional
|
||||
|
||||
from setuptools import distutils # type: ignore[import]
|
||||
|
||||
from .setup_helpers.cmake import CMake, USE_NINJA
|
||||
|
||||
from .setup_helpers.env import check_negative_env_flag, IS_64BIT, IS_WINDOWS
|
||||
|
||||
|
||||
def _overlay_windows_vcvars(env: Dict[str, str]) -> Dict[str, str]:
|
||||
def _overlay_windows_vcvars(env: dict[str, str]) -> dict[str, str]:
|
||||
vc_arch = "x64" if IS_64BIT else "x86"
|
||||
|
||||
if platform.machine() == "ARM64":
|
||||
@ -34,7 +34,7 @@ def _overlay_windows_vcvars(env: Dict[str, str]) -> Dict[str, str]:
|
||||
"emulation is enabled!"
|
||||
)
|
||||
|
||||
vc_env: Dict[str, str] = distutils._msvccompiler._get_vc_env(vc_arch)
|
||||
vc_env: dict[str, str] = distutils._msvccompiler._get_vc_env(vc_arch)
|
||||
# Keys in `_get_vc_env` are always lowercase.
|
||||
# We turn them into uppercase before overlaying vcvars
|
||||
# because OS environ keys are always uppercase on Windows.
|
||||
@ -47,7 +47,7 @@ def _overlay_windows_vcvars(env: Dict[str, str]) -> Dict[str, str]:
|
||||
return vc_env
|
||||
|
||||
|
||||
def _create_build_env() -> Dict[str, str]:
|
||||
def _create_build_env() -> dict[str, str]:
|
||||
# XXX - our cmake file sometimes looks at the system environment
|
||||
# and not cmake flags!
|
||||
# you should NEVER add something to this list. It is bad practice to
|
||||
@ -72,8 +72,8 @@ def _create_build_env() -> Dict[str, str]:
|
||||
|
||||
|
||||
def build_caffe2(
|
||||
version: Optional[str],
|
||||
cmake_python_library: Optional[str],
|
||||
version: str | None,
|
||||
cmake_python_library: str | None,
|
||||
build_python: bool,
|
||||
rerun_cmake: bool,
|
||||
cmake_only: bool,
|
||||
|
@ -5,10 +5,13 @@
|
||||
# - ninja -j1 -v -n torch_python | sed -e 's/-O[23]/-g/g' -e 's#\[[0-9]\+\/[0-9]\+\] \+##' |sh
|
||||
# - Copy libs from build/lib to torch/lib folder
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
|
||||
PYTORCH_ROOTDIR = Path(__file__).resolve().parent.parent
|
||||
TORCH_DIR = PYTORCH_ROOTDIR / "torch"
|
||||
@ -17,7 +20,7 @@ BUILD_DIR = PYTORCH_ROOTDIR / "build"
|
||||
BUILD_LIB_DIR = BUILD_DIR / "lib"
|
||||
|
||||
|
||||
def check_output(args: List[str], cwd: Optional[str] = None) -> str:
|
||||
def check_output(args: list[str], cwd: str | None = None) -> str:
|
||||
return subprocess.check_output(args, cwd=cwd).decode("utf-8")
|
||||
|
||||
|
||||
@ -63,7 +66,7 @@ def is_devel_setup() -> bool:
|
||||
return output.strip() == str(TORCH_DIR / "__init__.py")
|
||||
|
||||
|
||||
def create_build_plan() -> List[Tuple[str, str]]:
|
||||
def create_build_plan() -> list[tuple[str, str]]:
|
||||
output = check_output(
|
||||
["ninja", "-j1", "-v", "-n", "torch_python"], cwd=str(BUILD_DIR)
|
||||
)
|
||||
|
@ -8,13 +8,15 @@ For custom build with static dispatch, the op dependency graph will be omitted,
|
||||
and it will directly output root ops as the allowlist.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Set
|
||||
from typing import Dict, Set
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
DepGraph = Dict[str, Set[str]]
|
||||
|
||||
|
||||
@ -34,7 +36,7 @@ def load_op_dep_graph(fname: str) -> DepGraph:
|
||||
return dict(result)
|
||||
|
||||
|
||||
def load_root_ops(fname: str) -> List[str]:
|
||||
def load_root_ops(fname: str) -> list[str]:
|
||||
result = []
|
||||
with open(fname) as stream:
|
||||
for op in yaml.safe_load(stream):
|
||||
@ -44,9 +46,9 @@ def load_root_ops(fname: str) -> List[str]:
|
||||
|
||||
def gen_transitive_closure(
|
||||
dep_graph: DepGraph,
|
||||
root_ops: List[str],
|
||||
root_ops: list[str],
|
||||
train: bool = False,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
result = set(root_ops)
|
||||
queue = root_ops.copy()
|
||||
|
||||
@ -73,7 +75,7 @@ def gen_transitive_closure(
|
||||
return sorted(result)
|
||||
|
||||
|
||||
def gen_transitive_closure_str(dep_graph: DepGraph, root_ops: List[str]) -> str:
|
||||
def gen_transitive_closure_str(dep_graph: DepGraph, root_ops: list[str]) -> str:
|
||||
return " ".join(gen_transitive_closure(dep_graph, root_ops))
|
||||
|
||||
|
||||
|
@ -1,8 +1,11 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from gen_op_registration_allowlist import (
|
||||
@ -17,6 +20,7 @@ from torchgen.selective_build.operator import (
|
||||
)
|
||||
from torchgen.selective_build.selector import merge_kernel_metadata
|
||||
|
||||
|
||||
# Generate YAML file containing the operators used for a specific PyTorch model.
|
||||
# ------------------------------------------------------------------------------
|
||||
#
|
||||
@ -84,17 +88,17 @@ from torchgen.selective_build.selector import merge_kernel_metadata
|
||||
#
|
||||
|
||||
|
||||
def canonical_opnames(opnames: List[str]) -> List[str]:
|
||||
def canonical_opnames(opnames: list[str]) -> list[str]:
|
||||
return [canonical_name(opname) for opname in opnames]
|
||||
|
||||
|
||||
def make_filter_from_options(
|
||||
model_name: str,
|
||||
model_versions: List[str],
|
||||
model_assets: Optional[List[str]],
|
||||
model_backends: Optional[List[str]],
|
||||
model_versions: list[str],
|
||||
model_assets: list[str] | None,
|
||||
model_backends: list[str] | None,
|
||||
):
|
||||
def is_model_included(model_info):
|
||||
def is_model_included(model_info) -> bool:
|
||||
model = model_info["model"]
|
||||
if model["name"] != model_name:
|
||||
return False
|
||||
@ -109,7 +113,7 @@ def make_filter_from_options(
|
||||
|
||||
|
||||
# Returns if a the specified rule is a new or old style pt_operator_library
|
||||
def is_new_style_rule(model_name: str, model_versions: Optional[List[str]]):
|
||||
def is_new_style_rule(model_name: str, model_versions: list[str] | None):
|
||||
return model_name is not None and model_versions is not None
|
||||
|
||||
|
||||
@ -117,13 +121,13 @@ def is_new_style_rule(model_name: str, model_versions: Optional[List[str]]):
|
||||
# appear in at least one model yaml. Throws if verification is failed,
|
||||
# returns None on success
|
||||
def verify_all_specified_present(
|
||||
model_assets: Optional[List[str]],
|
||||
model_versions: List[str],
|
||||
selected_models_yaml: List[Dict[str, Any]],
|
||||
model_assets: list[str] | None,
|
||||
model_versions: list[str],
|
||||
selected_models_yaml: list[dict[str, Any]],
|
||||
rule_name: str,
|
||||
model_name: str,
|
||||
new_style_rule: bool,
|
||||
):
|
||||
) -> None:
|
||||
def find_missing_items(model_items, key, selected_models_yaml):
|
||||
missing_items = []
|
||||
if not new_style_rule or not model_items:
|
||||
@ -179,10 +183,10 @@ def verify_all_specified_present(
|
||||
# Uses the selected models configs and then combines them into one dictionary,
|
||||
# formats them as a string, and places the string into output as a top level debug_info
|
||||
def create_debug_info_from_selected_models(
|
||||
output: Dict[str, object],
|
||||
selected_models: List[dict],
|
||||
output: dict[str, object],
|
||||
selected_models: list[dict],
|
||||
new_style_rule: bool,
|
||||
):
|
||||
) -> None:
|
||||
model_dict = {
|
||||
"asset_info": {}, # maps asset name -> dict of asset metadata like hashes
|
||||
"is_new_style_rule": new_style_rule,
|
||||
@ -201,7 +205,7 @@ def create_debug_info_from_selected_models(
|
||||
output["debug_info"] = [json.dumps(model_dict)]
|
||||
|
||||
|
||||
def fill_output(output: Dict[str, object], options: object):
|
||||
def fill_output(output: dict[str, object], options: object) -> None:
|
||||
"""Populate the output dict with the information required to serialize
|
||||
the YAML file used for selective build.
|
||||
"""
|
||||
@ -458,7 +462,7 @@ def fill_output(output: Dict[str, object], options: object):
|
||||
# END TRACING BASED BUILD OPS
|
||||
|
||||
# Merge dictionaries together to remove op duplication
|
||||
operators: Dict[str, SelectiveBuildOperator] = {}
|
||||
operators: dict[str, SelectiveBuildOperator] = {}
|
||||
for ops_dict in bucketed_ops:
|
||||
operators = merge_operator_dicts(operators, ops_dict)
|
||||
|
||||
|
@ -1,10 +1,13 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from functools import reduce
|
||||
from typing import Any, List, Set
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from tools.lite_interpreter.gen_selected_mobile_ops_header import (
|
||||
@ -17,11 +20,11 @@ from torchgen.selective_build.selector import (
|
||||
)
|
||||
|
||||
|
||||
def extract_all_operators(selective_builder: SelectiveBuilder) -> Set[str]:
|
||||
def extract_all_operators(selective_builder: SelectiveBuilder) -> set[str]:
|
||||
return set(selective_builder.operators.keys())
|
||||
|
||||
|
||||
def extract_training_operators(selective_builder: SelectiveBuilder) -> Set[str]:
|
||||
def extract_training_operators(selective_builder: SelectiveBuilder) -> set[str]:
|
||||
ops = []
|
||||
for op_name, op in selective_builder.operators.items():
|
||||
if op.is_used_for_training:
|
||||
@ -44,7 +47,7 @@ def throw_if_any_op_includes_overloads(selective_builder: SelectiveBuilder) -> N
|
||||
)
|
||||
|
||||
|
||||
def gen_supported_mobile_models(model_dicts: List[Any], output_dir: str) -> None:
|
||||
def gen_supported_mobile_models(model_dicts: list[Any], output_dir: str) -> None:
|
||||
supported_mobile_models_source = """/*
|
||||
* Generated by gen_oplist.py
|
||||
*/
|
||||
@ -87,7 +90,7 @@ SupportedMobileModelCheckerRegistry register_model_versions;
|
||||
out_file.write(source.encode("utf-8"))
|
||||
|
||||
|
||||
def main(argv: List[Any]) -> None:
|
||||
def main(argv: list[Any]) -> None:
|
||||
"""This binary generates 3 files:
|
||||
|
||||
1. selected_mobile_ops.h: Primary operators used by templated selective build and Kernel Function
|
||||
|
@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from typing import cast, List, Optional, Tuple
|
||||
from typing import cast
|
||||
|
||||
from ..util.setting import (
|
||||
CompilerType,
|
||||
@ -38,7 +40,7 @@ BLOCKED_PYTHON_TESTS = {
|
||||
}
|
||||
|
||||
|
||||
def initialization() -> Tuple[Option, TestList, List[str]]:
|
||||
def initialization() -> tuple[Option, TestList, list[str]]:
|
||||
# create folder if not exists
|
||||
create_folders()
|
||||
# add arguments
|
||||
@ -77,7 +79,7 @@ def add_arguments_oss(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
|
||||
|
||||
def parse_arguments(
|
||||
parser: argparse.ArgumentParser,
|
||||
) -> Tuple[Option, Optional[List[str]], Optional[List[str]], Optional[bool]]:
|
||||
) -> tuple[Option, list[str] | None, list[str] | None, bool | None]:
|
||||
# parse args
|
||||
args = parser.parse_args()
|
||||
# get option
|
||||
@ -85,9 +87,7 @@ def parse_arguments(
|
||||
return (options, args.interest_only, args.run_only, args.clean)
|
||||
|
||||
|
||||
def get_test_list_by_type(
|
||||
run_only: Optional[List[str]], test_type: TestType
|
||||
) -> TestList:
|
||||
def get_test_list_by_type(run_only: list[str] | None, test_type: TestType) -> TestList:
|
||||
test_list: TestList = []
|
||||
binary_folder = get_oss_binary_folder(test_type)
|
||||
g = os.walk(binary_folder)
|
||||
@ -106,7 +106,7 @@ def get_test_list_by_type(
|
||||
return test_list
|
||||
|
||||
|
||||
def get_test_list(run_only: Optional[List[str]]) -> TestList:
|
||||
def get_test_list(run_only: list[str] | None) -> TestList:
|
||||
test_list: TestList = []
|
||||
# add c++ test list
|
||||
test_list.extend(get_test_list_by_type(run_only, TestType.CPP))
|
||||
@ -122,7 +122,7 @@ def get_test_list(run_only: Optional[List[str]]) -> TestList:
|
||||
return test_list
|
||||
|
||||
|
||||
def empty_list_if_none(arg_interested_folder: Optional[List[str]]) -> List[str]:
|
||||
def empty_list_if_none(arg_interested_folder: list[str] | None) -> list[str]:
|
||||
if arg_interested_folder is None:
|
||||
return []
|
||||
# if this argument is specified, just return itself
|
||||
@ -134,7 +134,7 @@ def gcc_export_init() -> None:
|
||||
create_folder(JSON_FOLDER_BASE_DIR)
|
||||
|
||||
|
||||
def get_python_run_only(args_run_only: Optional[List[str]]) -> List[str]:
|
||||
def get_python_run_only(args_run_only: list[str] | None) -> list[str]:
|
||||
# if user specifies run-only option
|
||||
if args_run_only:
|
||||
return args_run_only
|
||||
@ -144,7 +144,7 @@ def get_python_run_only(args_run_only: Optional[List[str]]) -> List[str]:
|
||||
return ["run_test.py"]
|
||||
else:
|
||||
# for clang, some tests will result in too large intermediate files that can't be merged by llvm, we need to skip them
|
||||
run_only: List[str] = []
|
||||
run_only: list[str] = []
|
||||
binary_folder = get_oss_binary_folder(TestType.PY)
|
||||
g = os.walk(binary_folder)
|
||||
for _, _, file_list in g:
|
||||
|
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from typing import List, Optional
|
||||
|
||||
from ..util.setting import CompilerType, TestType, TOOLS_FOLDER
|
||||
from ..util.utils import print_error, remove_file
|
||||
@ -14,7 +15,7 @@ def get_oss_binary_folder(test_type: TestType) -> str:
|
||||
)
|
||||
|
||||
|
||||
def get_oss_shared_library() -> List[str]:
|
||||
def get_oss_shared_library() -> list[str]:
|
||||
lib_dir = os.path.join(get_pytorch_folder(), "build", "lib")
|
||||
return [
|
||||
os.path.join(lib_dir, lib)
|
||||
@ -48,7 +49,7 @@ def get_pytorch_folder() -> str:
|
||||
)
|
||||
|
||||
|
||||
def detect_compiler_type() -> Optional[CompilerType]:
|
||||
def detect_compiler_type() -> CompilerType | None:
|
||||
# check if user specifies the compiler type
|
||||
user_specify = os.environ.get("CXX", None)
|
||||
if user_specify:
|
||||
@ -76,7 +77,7 @@ def clean_up_gcda() -> None:
|
||||
remove_file(item)
|
||||
|
||||
|
||||
def get_gcda_files() -> List[str]:
|
||||
def get_gcda_files() -> list[str]:
|
||||
folder_has_gcda = os.path.join(get_pytorch_folder(), "build")
|
||||
if os.path.isdir(folder_has_gcda):
|
||||
# TODO use glob
|
||||
|
@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from ..util.setting import (
|
||||
JSON_FOLDER_BASE_DIR,
|
||||
@ -25,7 +26,7 @@ from .utils import get_tool_path_by_platform, run_cpp_test
|
||||
|
||||
|
||||
def create_corresponding_folder(
|
||||
cur_path: str, prefix_cur_path: str, dir_list: List[str], new_base_folder: str
|
||||
cur_path: str, prefix_cur_path: str, dir_list: list[str], new_base_folder: str
|
||||
) -> None:
|
||||
for dir_name in dir_list:
|
||||
relative_path = convert_to_relative_path(
|
||||
@ -70,7 +71,7 @@ def export_target(
|
||||
merged_file: str,
|
||||
json_file: str,
|
||||
binary_file: str,
|
||||
shared_library_list: List[str],
|
||||
shared_library_list: list[str],
|
||||
platform_type: TestPlatform,
|
||||
) -> None:
|
||||
if binary_file is None:
|
||||
|
@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
# gcc is only used in oss
|
||||
from ..oss.utils import get_gcda_files, run_oss_python_test
|
||||
@ -10,7 +11,7 @@ from ..util.utils import print_log, print_time
|
||||
from .utils import run_cpp_test
|
||||
|
||||
|
||||
def update_gzip_dict(gzip_dict: Dict[str, int], file_name: str) -> str:
|
||||
def update_gzip_dict(gzip_dict: dict[str, int], file_name: str) -> str:
|
||||
file_name = file_name.lower()
|
||||
gzip_dict[file_name] = gzip_dict.get(file_name, 0) + 1
|
||||
num = gzip_dict[file_name]
|
||||
@ -34,7 +35,7 @@ def export() -> None:
|
||||
# collect .gcda files
|
||||
gcda_files = get_gcda_files()
|
||||
# file name like utils.cpp may have same name in different folder
|
||||
gzip_dict: Dict[str, int] = {}
|
||||
gzip_dict: dict[str, int] = {}
|
||||
for gcda_item in gcda_files:
|
||||
# generate json.gz
|
||||
subprocess.check_call(["gcov", "-i", gcda_item])
|
||||
|
@ -1,12 +1,14 @@
|
||||
import typing as t
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
|
||||
class CoverageRecord(t.NamedTuple):
|
||||
class CoverageRecord(NamedTuple):
|
||||
filepath: str
|
||||
covered_lines: t.List[int]
|
||||
uncovered_lines: t.Optional[t.List[int]] = None
|
||||
covered_lines: list[int]
|
||||
uncovered_lines: list[int] | None = None
|
||||
|
||||
def to_dict(self) -> t.Dict[str, t.Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"filepath": self.filepath,
|
||||
"covered_lines": self.covered_lines,
|
||||
|
@ -1,4 +1,6 @@
|
||||
from typing import Any, Dict, List, Set
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .coverage_record import CoverageRecord
|
||||
|
||||
@ -10,7 +12,7 @@ class GcovCoverageParser:
|
||||
of CoverageRecord(s).
|
||||
"""
|
||||
|
||||
def __init__(self, llvm_coverage: Dict[str, Any]) -> None:
|
||||
def __init__(self, llvm_coverage: dict[str, Any]) -> None:
|
||||
self._llvm_coverage = llvm_coverage
|
||||
|
||||
@staticmethod
|
||||
@ -24,17 +26,17 @@ class GcovCoverageParser:
|
||||
return True
|
||||
return False
|
||||
|
||||
def parse(self) -> List[CoverageRecord]:
|
||||
def parse(self) -> list[CoverageRecord]:
|
||||
# The JSON format is described in the gcov source code
|
||||
# https://gcc.gnu.org/onlinedocs/gcc/Invoking-Gcov.html
|
||||
records: List[CoverageRecord] = []
|
||||
records: list[CoverageRecord] = []
|
||||
for file_info in self._llvm_coverage["files"]:
|
||||
filepath = file_info["file"]
|
||||
if self._skip_coverage(filepath):
|
||||
continue
|
||||
# parse json file
|
||||
covered_lines: Set[int] = set()
|
||||
uncovered_lines: Set[int] = set()
|
||||
covered_lines: set[int] = set()
|
||||
uncovered_lines: set[int] = set()
|
||||
for line in file_info["lines"]:
|
||||
line_number = line["line_number"]
|
||||
count = line["count"]
|
||||
|
@ -1,4 +1,6 @@
|
||||
from typing import Any, Dict, List, Set, Tuple
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .coverage_record import CoverageRecord
|
||||
from .llvm_coverage_segment import LlvmCoverageSegment, parse_segments
|
||||
@ -12,7 +14,7 @@ class LlvmCoverageParser:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, llvm_coverage: Dict[str, Any]) -> None:
|
||||
def __init__(self, llvm_coverage: dict[str, Any]) -> None:
|
||||
self._llvm_coverage = llvm_coverage
|
||||
|
||||
@staticmethod
|
||||
@ -28,13 +30,13 @@ class LlvmCoverageParser:
|
||||
|
||||
@staticmethod
|
||||
def _collect_coverage(
|
||||
segments: List[LlvmCoverageSegment],
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
segments: list[LlvmCoverageSegment],
|
||||
) -> tuple[list[int], list[int]]:
|
||||
"""
|
||||
Stateful parsing of coverage segments.
|
||||
"""
|
||||
covered_lines: Set[int] = set()
|
||||
uncovered_lines: Set[int] = set()
|
||||
covered_lines: set[int] = set()
|
||||
uncovered_lines: set[int] = set()
|
||||
prev_segment = LlvmCoverageSegment(1, 0, 0, 0, 0, None)
|
||||
for segment in segments:
|
||||
covered_range, uncovered_range = segment.get_coverage(prev_segment)
|
||||
@ -45,10 +47,10 @@ class LlvmCoverageParser:
|
||||
uncovered_lines.difference_update(covered_lines)
|
||||
return sorted(covered_lines), sorted(uncovered_lines)
|
||||
|
||||
def parse(self, repo_name: str) -> List[CoverageRecord]:
|
||||
def parse(self, repo_name: str) -> list[CoverageRecord]:
|
||||
# The JSON format is described in the LLVM source code
|
||||
# https://github.com/llvm-mirror/llvm/blob/master/tools/llvm-cov/CoverageExporterJson.cpp
|
||||
records: List[CoverageRecord] = []
|
||||
records: list[CoverageRecord] = []
|
||||
for export_unit in self._llvm_coverage["data"]:
|
||||
for file_info in export_unit["files"]:
|
||||
filepath = file_info["filename"]
|
||||
|
@ -1,4 +1,6 @@
|
||||
from typing import List, NamedTuple, Optional, Tuple
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
class LlvmCoverageSegment(NamedTuple):
|
||||
@ -7,7 +9,7 @@ class LlvmCoverageSegment(NamedTuple):
|
||||
segment_count: int
|
||||
has_count: int
|
||||
is_region_entry: int
|
||||
is_gap_entry: Optional[int]
|
||||
is_gap_entry: int | None
|
||||
|
||||
@property
|
||||
def has_coverage(self) -> bool:
|
||||
@ -18,8 +20,8 @@ class LlvmCoverageSegment(NamedTuple):
|
||||
return self.has_count > 0
|
||||
|
||||
def get_coverage(
|
||||
self, prev_segment: "LlvmCoverageSegment"
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
self, prev_segment: LlvmCoverageSegment
|
||||
) -> tuple[list[int], list[int]]:
|
||||
# Code adapted from testpilot.testinfra.runners.gtestcoveragerunner.py
|
||||
if not prev_segment.is_executable:
|
||||
return [], []
|
||||
@ -32,12 +34,12 @@ class LlvmCoverageSegment(NamedTuple):
|
||||
return (lines_range, []) if prev_segment.has_coverage else ([], lines_range)
|
||||
|
||||
|
||||
def parse_segments(raw_segments: List[List[int]]) -> List[LlvmCoverageSegment]:
|
||||
def parse_segments(raw_segments: list[list[int]]) -> list[LlvmCoverageSegment]:
|
||||
"""
|
||||
Creates LlvmCoverageSegment from a list of lists in llvm export json.
|
||||
each segment is represented by 5-element array.
|
||||
"""
|
||||
ret: List[LlvmCoverageSegment] = []
|
||||
ret: list[LlvmCoverageSegment] = []
|
||||
for raw_segment in raw_segments:
|
||||
assert (
|
||||
len(raw_segment) == 5 or len(raw_segment) == 6
|
||||
|
@ -1,10 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Dict, IO, List, Set, Tuple
|
||||
from typing import IO, Tuple
|
||||
|
||||
from ..oss.utils import get_pytorch_folder
|
||||
from ..util.setting import SUMMARY_FOLDER_DIR, TestList, TestStatusType
|
||||
|
||||
|
||||
CoverageItem = Tuple[str, float, int, int]
|
||||
|
||||
|
||||
@ -16,7 +19,7 @@ def key_by_name(x: CoverageItem) -> str:
|
||||
return x[0]
|
||||
|
||||
|
||||
def is_intrested_file(file_path: str, interested_folders: List[str]) -> bool:
|
||||
def is_intrested_file(file_path: str, interested_folders: list[str]) -> bool:
|
||||
if "cuda" in file_path:
|
||||
return False
|
||||
if "aten/gen_aten" in file_path or "aten/aten_" in file_path:
|
||||
@ -27,7 +30,7 @@ def is_intrested_file(file_path: str, interested_folders: List[str]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def is_this_type_of_tests(target_name: str, test_set_by_type: Set[str]) -> bool:
|
||||
def is_this_type_of_tests(target_name: str, test_set_by_type: set[str]) -> bool:
|
||||
# tests are divided into three types: success / partial success / fail to collect coverage
|
||||
for test in test_set_by_type:
|
||||
if target_name in test:
|
||||
@ -36,7 +39,7 @@ def is_this_type_of_tests(target_name: str, test_set_by_type: Set[str]) -> bool:
|
||||
|
||||
|
||||
def print_test_by_type(
|
||||
tests: TestList, test_set_by_type: Set[str], type_name: str, summary_file: IO[str]
|
||||
tests: TestList, test_set_by_type: set[str], type_name: str, summary_file: IO[str]
|
||||
) -> None:
|
||||
print("Tests " + type_name + " to collect coverage:", file=summary_file)
|
||||
for test in tests:
|
||||
@ -48,8 +51,8 @@ def print_test_by_type(
|
||||
def print_test_condition(
|
||||
tests: TestList,
|
||||
tests_type: TestStatusType,
|
||||
interested_folders: List[str],
|
||||
coverage_only: List[str],
|
||||
interested_folders: list[str],
|
||||
coverage_only: list[str],
|
||||
summary_file: IO[str],
|
||||
summary_type: str,
|
||||
) -> None:
|
||||
@ -77,10 +80,10 @@ def print_test_condition(
|
||||
def line_oriented_report(
|
||||
tests: TestList,
|
||||
tests_type: TestStatusType,
|
||||
interested_folders: List[str],
|
||||
coverage_only: List[str],
|
||||
covered_lines: Dict[str, Set[int]],
|
||||
uncovered_lines: Dict[str, Set[int]],
|
||||
interested_folders: list[str],
|
||||
coverage_only: list[str],
|
||||
covered_lines: dict[str, set[int]],
|
||||
uncovered_lines: dict[str, set[int]],
|
||||
) -> None:
|
||||
with open(os.path.join(SUMMARY_FOLDER_DIR, "line_summary"), "w+") as report_file:
|
||||
print_test_condition(
|
||||
@ -119,13 +122,13 @@ def print_file_summary(
|
||||
|
||||
def print_file_oriented_report(
|
||||
tests_type: TestStatusType,
|
||||
coverage: List[CoverageItem],
|
||||
coverage: list[CoverageItem],
|
||||
covered_summary: int,
|
||||
total_summary: int,
|
||||
summary_file: IO[str],
|
||||
tests: TestList,
|
||||
interested_folders: List[str],
|
||||
coverage_only: List[str],
|
||||
interested_folders: list[str],
|
||||
coverage_only: list[str],
|
||||
) -> None:
|
||||
coverage_percentage = print_file_summary(
|
||||
covered_summary, total_summary, summary_file
|
||||
@ -155,10 +158,10 @@ def print_file_oriented_report(
|
||||
def file_oriented_report(
|
||||
tests: TestList,
|
||||
tests_type: TestStatusType,
|
||||
interested_folders: List[str],
|
||||
coverage_only: List[str],
|
||||
covered_lines: Dict[str, Set[int]],
|
||||
uncovered_lines: Dict[str, Set[int]],
|
||||
interested_folders: list[str],
|
||||
coverage_only: list[str],
|
||||
covered_lines: dict[str, set[int]],
|
||||
uncovered_lines: dict[str, set[int]],
|
||||
) -> None:
|
||||
with open(os.path.join(SUMMARY_FOLDER_DIR, "file_summary"), "w+") as summary_file:
|
||||
covered_summary = 0
|
||||
@ -193,7 +196,7 @@ def file_oriented_report(
|
||||
)
|
||||
|
||||
|
||||
def get_html_ignored_pattern() -> List[str]:
|
||||
def get_html_ignored_pattern() -> list[str]:
|
||||
return ["/usr/*", "*anaconda3/*", "*third_party/*"]
|
||||
|
||||
|
||||
|
@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, List, Set, Tuple
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from ..util.setting import (
|
||||
CompilerType,
|
||||
@ -16,7 +18,6 @@ from ..util.utils import (
|
||||
print_time,
|
||||
related_to_test_list,
|
||||
)
|
||||
from .parser.coverage_record import CoverageRecord
|
||||
from .parser.gcov_coverage_parser import GcovCoverageParser
|
||||
from .parser.llvm_coverage_parser import LlvmCoverageParser
|
||||
from .print_report import (
|
||||
@ -26,16 +27,20 @@ from .print_report import (
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .parser.coverage_record import CoverageRecord
|
||||
|
||||
|
||||
# coverage_records: Dict[str, LineInfo] = {}
|
||||
covered_lines: Dict[str, Set[int]] = {}
|
||||
uncovered_lines: Dict[str, Set[int]] = {}
|
||||
covered_lines: dict[str, set[int]] = {}
|
||||
uncovered_lines: dict[str, set[int]] = {}
|
||||
tests_type: TestStatusType = {"success": set(), "partial": set(), "fail": set()}
|
||||
|
||||
|
||||
def transform_file_name(
|
||||
file_path: str, interested_folders: List[str], platform: TestPlatform
|
||||
file_path: str, interested_folders: list[str], platform: TestPlatform
|
||||
) -> str:
|
||||
remove_patterns: Set[str] = {".DEFAULT.cpp", ".AVX.cpp", ".AVX2.cpp"}
|
||||
remove_patterns: set[str] = {".DEFAULT.cpp", ".AVX.cpp", ".AVX2.cpp"}
|
||||
for pattern in remove_patterns:
|
||||
file_path = file_path.replace(pattern, "")
|
||||
# if user has specified interested folder
|
||||
@ -54,7 +59,7 @@ def transform_file_name(
|
||||
|
||||
|
||||
def is_intrested_file(
|
||||
file_path: str, interested_folders: List[str], platform: TestPlatform
|
||||
file_path: str, interested_folders: list[str], platform: TestPlatform
|
||||
) -> bool:
|
||||
ignored_patterns = ["cuda", "aten/gen_aten", "aten/aten_", "build/"]
|
||||
if any(pattern in file_path for pattern in ignored_patterns):
|
||||
@ -77,7 +82,7 @@ def is_intrested_file(
|
||||
return True
|
||||
|
||||
|
||||
def get_json_obj(json_file: str) -> Tuple[Any, int]:
|
||||
def get_json_obj(json_file: str) -> tuple[Any, int]:
|
||||
"""
|
||||
Sometimes at the start of file llvm/gcov will complains "fail to find coverage data",
|
||||
then we need to skip these lines
|
||||
@ -102,7 +107,7 @@ def get_json_obj(json_file: str) -> Tuple[Any, int]:
|
||||
return None, 2
|
||||
|
||||
|
||||
def parse_json(json_file: str, platform: TestPlatform) -> List[CoverageRecord]:
|
||||
def parse_json(json_file: str, platform: TestPlatform) -> list[CoverageRecord]:
|
||||
print("start parse:", json_file)
|
||||
json_obj, read_status = get_json_obj(json_file)
|
||||
if read_status == 0:
|
||||
@ -117,7 +122,7 @@ def parse_json(json_file: str, platform: TestPlatform) -> List[CoverageRecord]:
|
||||
|
||||
cov_type = detect_compiler_type(platform)
|
||||
|
||||
coverage_records: List[CoverageRecord] = []
|
||||
coverage_records: list[CoverageRecord] = []
|
||||
if cov_type == CompilerType.CLANG:
|
||||
coverage_records = LlvmCoverageParser(json_obj).parse("fbcode")
|
||||
# print(coverage_records)
|
||||
@ -128,7 +133,7 @@ def parse_json(json_file: str, platform: TestPlatform) -> List[CoverageRecord]:
|
||||
|
||||
|
||||
def parse_jsons(
|
||||
test_list: TestList, interested_folders: List[str], platform: TestPlatform
|
||||
test_list: TestList, interested_folders: list[str], platform: TestPlatform
|
||||
) -> None:
|
||||
g = os.walk(JSON_FOLDER_BASE_DIR)
|
||||
|
||||
@ -152,8 +157,8 @@ def parse_jsons(
|
||||
|
||||
|
||||
def update_coverage(
|
||||
coverage_records: List[CoverageRecord],
|
||||
interested_folders: List[str],
|
||||
coverage_records: list[CoverageRecord],
|
||||
interested_folders: list[str],
|
||||
platform: TestPlatform,
|
||||
) -> None:
|
||||
for item in coverage_records:
|
||||
@ -187,8 +192,8 @@ def update_set() -> None:
|
||||
|
||||
def summarize_jsons(
|
||||
test_list: TestList,
|
||||
interested_folders: List[str],
|
||||
coverage_only: List[str],
|
||||
interested_folders: list[str],
|
||||
coverage_only: list[str],
|
||||
platform: TestPlatform,
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Set
|
||||
|
@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, NoReturn, Optional
|
||||
from typing import Any, NoReturn
|
||||
|
||||
from .setting import (
|
||||
CompilerType,
|
||||
@ -113,7 +115,7 @@ def get_test_name_from_whole_path(path: str) -> str:
|
||||
return path[start + 1 : end]
|
||||
|
||||
|
||||
def check_compiler_type(cov_type: Optional[CompilerType]) -> None:
|
||||
def check_compiler_type(cov_type: CompilerType | None) -> None:
|
||||
if cov_type is not None and cov_type in [CompilerType.GCC, CompilerType.CLANG]:
|
||||
return
|
||||
raise Exception( # noqa: TRY002
|
||||
|
@ -1,5 +1,6 @@
|
||||
import setuptools # type: ignore[import]
|
||||
|
||||
|
||||
with open("README.md", encoding="utf-8") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
|
@ -22,6 +22,7 @@ from typing import Any
|
||||
|
||||
from coverage import CoverageData, CoveragePlugin # type: ignore[import]
|
||||
|
||||
|
||||
# All coverage stats resulting from this plug-in will be in a separate .coverage file that should be merged later with
|
||||
# `coverage combine`. The convention seems to be .coverage.dotted.suffix based on the following link:
|
||||
# https://coverage.readthedocs.io/en/coverage-5.5/cmd.html#combining-data-files-coverage-combine
|
||||
|
@ -5,6 +5,7 @@ import sys
|
||||
from urllib.error import URLError
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
|
||||
MIRRORS = [
|
||||
"http://yann.lecun.com/exdb/mnist/",
|
||||
"https://ossci-datasets.s3.amazonaws.com/mnist/",
|
||||
|
@ -5,6 +5,7 @@ import sys
|
||||
import traceback
|
||||
import warnings
|
||||
|
||||
|
||||
MIN_CUDA_VERSION = "11.6"
|
||||
MIN_ROCM_VERSION = "5.4"
|
||||
MIN_PYTHON_VERSION = (3, 8)
|
||||
@ -141,7 +142,7 @@ def check_rocm():
|
||||
return rocm_ver if torch.version.hip else "None"
|
||||
|
||||
|
||||
def check_dynamo(backend, device, err_msg):
|
||||
def check_dynamo(backend, device, err_msg) -> None:
|
||||
import torch
|
||||
|
||||
if device == "cuda" and not torch.cuda.is_available():
|
||||
@ -203,7 +204,7 @@ _SANITY_CHECK_ARGS = (
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
python_ver = check_python()
|
||||
torch_ver = check_torch()
|
||||
cuda_ver = check_cuda()
|
||||
|
@ -1,14 +1,17 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict
|
||||
from typing_extensions import TypedDict # Python 3.11+
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
Step = Dict[str, Any]
|
||||
|
||||
|
||||
@ -17,7 +20,7 @@ class Script(TypedDict):
|
||||
script: str
|
||||
|
||||
|
||||
def extract(step: Step) -> Optional[Script]:
|
||||
def extract(step: Step) -> Script | None:
|
||||
run = step.get("run")
|
||||
|
||||
# https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#using-a-specific-shell
|
||||
|
@ -1,5 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import array
|
||||
import codecs
|
||||
@ -15,7 +17,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
import subprocess
|
||||
import textwrap
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from yaml.constructor import ConstructorError
|
||||
@ -29,14 +31,14 @@ except ImportError:
|
||||
CPP_H_NAME = "spv.h"
|
||||
CPP_SRC_NAME = "spv.cpp"
|
||||
|
||||
DEFAULT_ENV: Dict[str, Any] = {
|
||||
DEFAULT_ENV: dict[str, Any] = {
|
||||
"PRECISION": "highp",
|
||||
"FLOAT_IMAGE_FORMAT": "rgba16f",
|
||||
"INT_IMAGE_FORMAT": "rgba32i",
|
||||
"UINT_IMAGE_FORMAT": "rgba32ui",
|
||||
}
|
||||
|
||||
TYPES_ENV: Dict[str, Any] = {
|
||||
TYPES_ENV: dict[str, Any] = {
|
||||
"IMAGE_FORMAT": {
|
||||
"float": "rgba32f",
|
||||
"half": "rgba16f",
|
||||
@ -91,7 +93,7 @@ TYPES_ENV: Dict[str, Any] = {
|
||||
},
|
||||
}
|
||||
|
||||
FUNCS_ENV: Dict[str, Any] = {
|
||||
FUNCS_ENV: dict[str, Any] = {
|
||||
"GET_POS": {
|
||||
3: lambda pos: pos,
|
||||
2: lambda pos: f"{pos}.xy",
|
||||
@ -169,7 +171,7 @@ def escape(line: str) -> str:
|
||||
|
||||
# https://github.com/google/XNNPACK/blob/master/tools/xngen.py
|
||||
def preprocess(
|
||||
input_text: str, variables: Dict[str, Any], input_path: str = "codegen"
|
||||
input_text: str, variables: dict[str, Any], input_path: str = "codegen"
|
||||
) -> str:
|
||||
input_lines = input_text.splitlines()
|
||||
python_lines = []
|
||||
@ -243,9 +245,9 @@ def preprocess(
|
||||
class SPVGenerator:
|
||||
def __init__(
|
||||
self,
|
||||
src_dir_paths: Union[str, List[str]],
|
||||
env: Dict[Any, Any],
|
||||
glslc_path: Optional[str],
|
||||
src_dir_paths: str | list[str],
|
||||
env: dict[Any, Any],
|
||||
glslc_path: str | None,
|
||||
) -> None:
|
||||
if isinstance(src_dir_paths, str):
|
||||
self.src_dir_paths = [src_dir_paths]
|
||||
@ -255,18 +257,18 @@ class SPVGenerator:
|
||||
self.env = env
|
||||
self.glslc_path = glslc_path
|
||||
|
||||
self.glsl_src_files: Dict[str, str] = {}
|
||||
self.template_yaml_files: List[str] = []
|
||||
self.glsl_src_files: dict[str, str] = {}
|
||||
self.template_yaml_files: list[str] = []
|
||||
|
||||
self.addSrcAndYamlFiles(self.src_dir_paths)
|
||||
self.shader_template_params: Dict[Any, Any] = {}
|
||||
self.shader_template_params: dict[Any, Any] = {}
|
||||
for yaml_file in self.template_yaml_files:
|
||||
self.parseTemplateYaml(yaml_file)
|
||||
|
||||
self.output_shader_map: Dict[str, Tuple[str, Dict[str, str]]] = {}
|
||||
self.output_shader_map: dict[str, tuple[str, dict[str, str]]] = {}
|
||||
self.constructOutputMap()
|
||||
|
||||
def addSrcAndYamlFiles(self, src_dir_paths: List[str]) -> None:
|
||||
def addSrcAndYamlFiles(self, src_dir_paths: list[str]) -> None:
|
||||
for src_path in src_dir_paths:
|
||||
# Collect glsl source files
|
||||
glsl_files = glob.glob(
|
||||
@ -285,9 +287,9 @@ class SPVGenerator:
|
||||
|
||||
def generateVariantCombinations(
|
||||
self,
|
||||
iterated_params: Dict[str, Any],
|
||||
exclude_params: Optional[Set[str]] = None,
|
||||
) -> List[Any]:
|
||||
iterated_params: dict[str, Any],
|
||||
exclude_params: set[str] | None = None,
|
||||
) -> list[Any]:
|
||||
if exclude_params is None:
|
||||
exclude_params = set()
|
||||
all_iterated_params = []
|
||||
@ -362,8 +364,8 @@ class SPVGenerator:
|
||||
)
|
||||
|
||||
def create_shader_params(
|
||||
self, variant_params: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, str]:
|
||||
self, variant_params: dict[str, Any] | None = None
|
||||
) -> dict[str, str]:
|
||||
if variant_params is None:
|
||||
variant_params = {}
|
||||
shader_params = copy.deepcopy(self.env)
|
||||
@ -409,7 +411,7 @@ class SPVGenerator:
|
||||
self.create_shader_params(),
|
||||
)
|
||||
|
||||
def generateSPV(self, output_dir: str) -> Dict[str, str]:
|
||||
def generateSPV(self, output_dir: str) -> dict[str, str]:
|
||||
output_file_map = {}
|
||||
for shader_name in self.output_shader_map:
|
||||
source_glsl = self.output_shader_map[shader_name][0]
|
||||
@ -457,11 +459,11 @@ class SPVGenerator:
|
||||
|
||||
@dataclass
|
||||
class ShaderInfo:
|
||||
tile_size: List[int]
|
||||
layouts: List[str]
|
||||
tile_size: list[int]
|
||||
layouts: list[str]
|
||||
weight_storage_type: str = ""
|
||||
bias_storage_type: str = ""
|
||||
register_for: Optional[Tuple[str, List[str]]] = None
|
||||
register_for: tuple[str, list[str]] | None = None
|
||||
|
||||
|
||||
def getName(filePath: str) -> str:
|
||||
@ -478,7 +480,7 @@ def isTileSizeLine(lineStr: str) -> bool:
|
||||
return re.search(tile_size_id, lineStr) is not None
|
||||
|
||||
|
||||
def findTileSizes(lineStr: str) -> List[int]:
|
||||
def findTileSizes(lineStr: str) -> list[int]:
|
||||
tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)"
|
||||
matches = re.search(tile_size_id, lineStr)
|
||||
if matches is None:
|
||||
@ -520,7 +522,7 @@ def isRegisterForLine(lineStr: str) -> bool:
|
||||
return re.search(register_for_id, lineStr) is not None
|
||||
|
||||
|
||||
def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]:
|
||||
def findRegisterFor(lineStr: str) -> tuple[str, list[str]]:
|
||||
register_for_pattern = r"'([A-Za-z0-9_]+)'"
|
||||
matches = re.findall(register_for_pattern, lineStr)
|
||||
if matches is None:
|
||||
@ -609,7 +611,7 @@ static const api::ShaderRegisterInit register_shaders(®ister_fn);
|
||||
"""
|
||||
|
||||
|
||||
def generateSpvBinStr(spvPath: str, name: str) -> Tuple[int, str]:
|
||||
def generateSpvBinStr(spvPath: str, name: str) -> tuple[int, str]:
|
||||
with open(spvPath, "rb") as fr:
|
||||
next_bin = array.array("I", fr.read())
|
||||
sizeBytes = 4 * len(next_bin)
|
||||
@ -665,7 +667,7 @@ def generateShaderDispatchStr(shader_info: ShaderInfo, name: str) -> str:
|
||||
|
||||
|
||||
def genCppFiles(
|
||||
spv_files: Dict[str, str], cpp_header_path: str, cpp_src_file_path: str
|
||||
spv_files: dict[str, str], cpp_header_path: str, cpp_src_file_path: str
|
||||
) -> None:
|
||||
spv_bin_strs = []
|
||||
register_shader_info_strs = []
|
||||
@ -705,7 +707,7 @@ def genCppFiles(
|
||||
##########
|
||||
|
||||
|
||||
def parse_arg_env(items: Dict[Any, Any]) -> Dict[Any, Any]:
|
||||
def parse_arg_env(items: dict[Any, Any]) -> dict[Any, Any]:
|
||||
d = {}
|
||||
if items:
|
||||
for item in items:
|
||||
@ -716,7 +718,7 @@ def parse_arg_env(items: Dict[Any, Any]) -> Dict[Any, Any]:
|
||||
return d
|
||||
|
||||
|
||||
def main(argv: List[str]) -> int:
|
||||
def main(argv: list[str]) -> int:
|
||||
parser = argparse.ArgumentParser(description="")
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
|
@ -1,9 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from setuptools import distutils # type: ignore[import]
|
||||
|
||||
@ -12,7 +13,7 @@ UNKNOWN = "Unknown"
|
||||
RELEASE_PATTERN = re.compile(r"/v[0-9]+(\.[0-9]+)*(-rc[0-9]+)?/")
|
||||
|
||||
|
||||
def get_sha(pytorch_root: Union[str, Path]) -> str:
|
||||
def get_sha(pytorch_root: str | Path) -> str:
|
||||
try:
|
||||
rev = None
|
||||
if os.path.exists(os.path.join(pytorch_root, ".git")):
|
||||
@ -30,7 +31,7 @@ def get_sha(pytorch_root: Union[str, Path]) -> str:
|
||||
return UNKNOWN
|
||||
|
||||
|
||||
def get_tag(pytorch_root: Union[str, Path]) -> str:
|
||||
def get_tag(pytorch_root: str | Path) -> str:
|
||||
try:
|
||||
tag = subprocess.run(
|
||||
["git", "describe", "--tags", "--exact"],
|
||||
@ -46,8 +47,8 @@ def get_tag(pytorch_root: Union[str, Path]) -> str:
|
||||
return UNKNOWN
|
||||
|
||||
|
||||
def get_torch_version(sha: Optional[str] = None) -> str:
|
||||
pytorch_root = Path(__file__).parent.parent
|
||||
def get_torch_version(sha: str | None = None) -> str:
|
||||
pytorch_root = Path(__file__).absolute().parent.parent
|
||||
version = open(pytorch_root / "version.txt").read().strip()
|
||||
|
||||
if os.getenv("PYTORCH_BUILD_VERSION"):
|
||||
|
@ -1,10 +1,10 @@
|
||||
"""GitHub Utilities"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from typing import Any, Callable, cast, Dict, Optional, Tuple
|
||||
|
||||
from typing import Any, Callable, cast, Dict
|
||||
from urllib.error import HTTPError
|
||||
from urllib.parse import quote
|
||||
from urllib.request import Request, urlopen
|
||||
@ -13,11 +13,11 @@ from urllib.request import Request, urlopen
|
||||
def gh_fetch_url_and_headers(
|
||||
url: str,
|
||||
*,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
method: Optional[str] = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
method: str | None = None,
|
||||
reader: Callable[[Any], Any] = lambda x: x.read(),
|
||||
) -> Tuple[Any, Any]:
|
||||
) -> tuple[Any, Any]:
|
||||
if headers is None:
|
||||
headers = {}
|
||||
token = os.environ.get("GITHUB_TOKEN")
|
||||
@ -44,9 +44,9 @@ def gh_fetch_url_and_headers(
|
||||
def gh_fetch_url(
|
||||
url: str,
|
||||
*,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
method: Optional[str] = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
method: str | None = None,
|
||||
reader: Callable[[Any], Any] = lambda x: x.read(),
|
||||
) -> Any:
|
||||
return gh_fetch_url_and_headers(
|
||||
@ -56,8 +56,8 @@ def gh_fetch_url(
|
||||
|
||||
def _gh_fetch_json_any(
|
||||
url: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
headers = {"Accept": "application/vnd.github.v3+json"}
|
||||
if params is not None and len(params) > 0:
|
||||
@ -69,13 +69,13 @@ def _gh_fetch_json_any(
|
||||
|
||||
def gh_fetch_json_dict(
|
||||
url: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
params: dict[str, Any] | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
return cast(Dict[str, Any], _gh_fetch_json_any(url, params, data))
|
||||
|
||||
|
||||
def gh_fetch_commit(org: str, repo: str, sha: str) -> Dict[str, Any]:
|
||||
def gh_fetch_commit(org: str, repo: str, sha: str) -> dict[str, Any]:
|
||||
return gh_fetch_json_dict(
|
||||
f"https://api.github.com/repos/{org}/{repo}/commits/{sha}"
|
||||
)
|
||||
|
@ -1,6 +1,7 @@
|
||||
import re
|
||||
import sys
|
||||
|
||||
|
||||
QUOTE_INCLUDE_RE = re.compile(r'^#include "(.*)"')
|
||||
ANGLE_INCLUDE_RE = re.compile(r"^#include <(.*)>")
|
||||
|
||||
|
@ -1,10 +1,13 @@
|
||||
# Generates RegisterCodegenUnboxedKernels.cpp, UnboxingFunctions.h and UnboxingFunctions.cpp.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Literal, Sequence, Union
|
||||
from pathlib import Path
|
||||
from typing import Literal, Sequence, TYPE_CHECKING
|
||||
|
||||
import yaml
|
||||
|
||||
@ -15,10 +18,13 @@ from torchgen.api.unboxing import convert_arguments
|
||||
from torchgen.context import method_with_native_function
|
||||
from torchgen.gen import cpp_string, get_custom_build_selector, parse_native_yaml
|
||||
from torchgen.model import Argument, NativeFunction, NativeFunctionsGroup, Variant
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
from torchgen.utils import FileManager, make_file_manager, mapMaybe, Target
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
|
||||
|
||||
# Generates UnboxingFunctions.h & UnboxingFunctions.cpp.
|
||||
@dataclass(frozen=True)
|
||||
class ComputeUnboxingFunctions:
|
||||
@ -156,7 +162,7 @@ def gen_unboxing(
|
||||
cpu_fm: FileManager,
|
||||
selector: SelectiveBuilder,
|
||||
) -> None:
|
||||
def key_func(fn: Union[NativeFunction, NativeFunctionsGroup]) -> str:
|
||||
def key_func(fn: NativeFunction | NativeFunctionsGroup) -> str:
|
||||
return fn.root_name
|
||||
|
||||
selected_op_num: int = len(selector.operators)
|
||||
@ -195,7 +201,7 @@ def gen_unboxing(
|
||||
)
|
||||
|
||||
|
||||
def main(args: List[str]) -> None:
|
||||
def main(args: list[str]) -> None:
|
||||
parser = argparse.ArgumentParser(description="Generate unboxing source files")
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
@ -272,7 +278,7 @@ def main(args: List[str]) -> None:
|
||||
gen_unboxing(native_functions=native_functions, cpu_fm=cpu_fm, selector=selector)
|
||||
|
||||
if options.output_dependencies:
|
||||
depfile_path = pathlib.Path(options.output_dependencies).resolve()
|
||||
depfile_path = Path(options.output_dependencies).resolve()
|
||||
depfile_name = depfile_path.name
|
||||
depfile_stem = depfile_path.stem
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import json
|
||||
@ -8,7 +10,7 @@ import subprocess
|
||||
import sys
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import List, NamedTuple, Optional, Pattern
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
LINTER_CODE = "ACTIONLINT"
|
||||
@ -22,18 +24,18 @@ class LintSeverity(str, Enum):
|
||||
|
||||
|
||||
class LintMessage(NamedTuple):
|
||||
path: Optional[str]
|
||||
line: Optional[int]
|
||||
char: Optional[int]
|
||||
path: str | None
|
||||
line: int | None
|
||||
char: int | None
|
||||
code: str
|
||||
severity: LintSeverity
|
||||
name: str
|
||||
original: Optional[str]
|
||||
replacement: Optional[str]
|
||||
description: Optional[str]
|
||||
original: str | None
|
||||
replacement: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
RESULTS_RE: Pattern[str] = re.compile(
|
||||
RESULTS_RE: re.Pattern[str] = re.compile(
|
||||
r"""(?mx)
|
||||
^
|
||||
(?P<file>.*?):
|
||||
@ -47,8 +49,8 @@ RESULTS_RE: Pattern[str] = re.compile(
|
||||
|
||||
|
||||
def run_command(
|
||||
args: List[str],
|
||||
) -> "subprocess.CompletedProcess[bytes]":
|
||||
args: list[str],
|
||||
) -> subprocess.CompletedProcess[bytes]:
|
||||
logging.debug("$ %s", " ".join(args))
|
||||
start_time = time.monotonic()
|
||||
try:
|
||||
@ -64,7 +66,7 @@ def run_command(
|
||||
def check_file(
|
||||
binary: str,
|
||||
file: str,
|
||||
) -> List[LintMessage]:
|
||||
) -> list[LintMessage]:
|
||||
try:
|
||||
proc = run_command(
|
||||
[
|
||||
|
@ -5,6 +5,9 @@ archive is downloaded from some sites like GitHub because it can change. Specifi
|
||||
GitHub gives no guarantee to keep the same value forever. Check for more details at
|
||||
https://github.com/community/community/discussions/46034.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
@ -13,7 +16,7 @@ import subprocess
|
||||
import sys
|
||||
import xml.etree.ElementTree as ET
|
||||
from enum import Enum
|
||||
from typing import List, NamedTuple, Optional, Set
|
||||
from typing import NamedTuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
@ -30,18 +33,18 @@ class LintSeverity(str, Enum):
|
||||
|
||||
|
||||
class LintMessage(NamedTuple):
|
||||
path: Optional[str]
|
||||
line: Optional[int]
|
||||
char: Optional[int]
|
||||
path: str | None
|
||||
line: int | None
|
||||
char: int | None
|
||||
code: str
|
||||
severity: LintSeverity
|
||||
name: str
|
||||
original: Optional[str]
|
||||
replacement: Optional[str]
|
||||
description: Optional[str]
|
||||
original: str | None
|
||||
replacement: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
def is_required_checksum(urls: List[Optional[str]]) -> bool:
|
||||
def is_required_checksum(urls: list[str | None]) -> bool:
|
||||
if not urls:
|
||||
return False
|
||||
|
||||
@ -58,7 +61,7 @@ def is_required_checksum(urls: List[Optional[str]]) -> bool:
|
||||
|
||||
def get_disallowed_checksums(
|
||||
binary: str,
|
||||
) -> Set[str]:
|
||||
) -> set[str]:
|
||||
"""
|
||||
Return the set of disallowed checksums from all http_archive rules
|
||||
"""
|
||||
@ -96,8 +99,8 @@ def get_disallowed_checksums(
|
||||
|
||||
def check_bazel(
|
||||
filename: str,
|
||||
disallowed_checksums: Set[str],
|
||||
) -> List[LintMessage]:
|
||||
disallowed_checksums: set[str],
|
||||
) -> list[LintMessage]:
|
||||
original = ""
|
||||
replacement = ""
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import json
|
||||
@ -7,7 +9,7 @@ import subprocess
|
||||
import sys
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, BinaryIO, List, NamedTuple, Optional
|
||||
from typing import Any, BinaryIO, NamedTuple
|
||||
|
||||
|
||||
IS_WINDOWS: bool = os.name == "nt"
|
||||
@ -25,15 +27,15 @@ class LintSeverity(str, Enum):
|
||||
|
||||
|
||||
class LintMessage(NamedTuple):
|
||||
path: Optional[str]
|
||||
line: Optional[int]
|
||||
char: Optional[int]
|
||||
path: str | None
|
||||
line: int | None
|
||||
char: int | None
|
||||
code: str
|
||||
severity: LintSeverity
|
||||
name: str
|
||||
original: Optional[str]
|
||||
replacement: Optional[str]
|
||||
description: Optional[str]
|
||||
original: str | None
|
||||
replacement: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
def as_posix(name: str) -> str:
|
||||
@ -41,11 +43,11 @@ def as_posix(name: str) -> str:
|
||||
|
||||
|
||||
def _run_command(
|
||||
args: List[str],
|
||||
args: list[str],
|
||||
*,
|
||||
stdin: BinaryIO,
|
||||
timeout: int,
|
||||
) -> "subprocess.CompletedProcess[bytes]":
|
||||
) -> subprocess.CompletedProcess[bytes]:
|
||||
logging.debug("$ %s", " ".join(args))
|
||||
start_time = time.monotonic()
|
||||
try:
|
||||
@ -63,12 +65,12 @@ def _run_command(
|
||||
|
||||
|
||||
def run_command(
|
||||
args: List[str],
|
||||
args: list[str],
|
||||
*,
|
||||
stdin: BinaryIO,
|
||||
retries: int,
|
||||
timeout: int,
|
||||
) -> "subprocess.CompletedProcess[bytes]":
|
||||
) -> subprocess.CompletedProcess[bytes]:
|
||||
remaining_retries = retries
|
||||
while True:
|
||||
try:
|
||||
@ -90,7 +92,7 @@ def check_file(
|
||||
filename: str,
|
||||
retries: int,
|
||||
timeout: int,
|
||||
) -> List[LintMessage]:
|
||||
) -> list[LintMessage]:
|
||||
try:
|
||||
with open(filename, "rb") as f:
|
||||
original = f.read()
|
||||
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import json
|
||||
@ -8,7 +10,7 @@ import sys
|
||||
import time
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, List, NamedTuple, Optional
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
|
||||
IS_WINDOWS: bool = os.name == "nt"
|
||||
@ -26,15 +28,15 @@ class LintSeverity(str, Enum):
|
||||
|
||||
|
||||
class LintMessage(NamedTuple):
|
||||
path: Optional[str]
|
||||
line: Optional[int]
|
||||
char: Optional[int]
|
||||
path: str | None
|
||||
line: int | None
|
||||
char: int | None
|
||||
code: str
|
||||
severity: LintSeverity
|
||||
name: str
|
||||
original: Optional[str]
|
||||
replacement: Optional[str]
|
||||
description: Optional[str]
|
||||
original: str | None
|
||||
replacement: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
def as_posix(name: str) -> str:
|
||||
@ -42,10 +44,10 @@ def as_posix(name: str) -> str:
|
||||
|
||||
|
||||
def _run_command(
|
||||
args: List[str],
|
||||
args: list[str],
|
||||
*,
|
||||
timeout: int,
|
||||
) -> "subprocess.CompletedProcess[bytes]":
|
||||
) -> subprocess.CompletedProcess[bytes]:
|
||||
logging.debug("$ %s", " ".join(args))
|
||||
start_time = time.monotonic()
|
||||
try:
|
||||
@ -62,11 +64,11 @@ def _run_command(
|
||||
|
||||
|
||||
def run_command(
|
||||
args: List[str],
|
||||
args: list[str],
|
||||
*,
|
||||
retries: int,
|
||||
timeout: int,
|
||||
) -> "subprocess.CompletedProcess[bytes]":
|
||||
) -> subprocess.CompletedProcess[bytes]:
|
||||
remaining_retries = retries
|
||||
while True:
|
||||
try:
|
||||
@ -89,7 +91,7 @@ def check_file(
|
||||
binary: str,
|
||||
retries: int,
|
||||
timeout: int,
|
||||
) -> List[LintMessage]:
|
||||
) -> list[LintMessage]:
|
||||
try:
|
||||
with open(filename, "rb") as f:
|
||||
original = f.read()
|
||||
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import json
|
||||
@ -11,7 +13,7 @@ import time
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from sysconfig import get_paths as gp
|
||||
from typing import Any, List, NamedTuple, Optional, Pattern
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
|
||||
# PyTorch directory root
|
||||
@ -49,15 +51,15 @@ class LintSeverity(str, Enum):
|
||||
|
||||
|
||||
class LintMessage(NamedTuple):
|
||||
path: Optional[str]
|
||||
line: Optional[int]
|
||||
char: Optional[int]
|
||||
path: str | None
|
||||
line: int | None
|
||||
char: int | None
|
||||
code: str
|
||||
severity: LintSeverity
|
||||
name: str
|
||||
original: Optional[str]
|
||||
replacement: Optional[str]
|
||||
description: Optional[str]
|
||||
original: str | None
|
||||
replacement: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
def as_posix(name: str) -> str:
|
||||
@ -65,7 +67,7 @@ def as_posix(name: str) -> str:
|
||||
|
||||
|
||||
# c10/core/DispatchKey.cpp:281:26: error: 'k' used after it was moved [bugprone-use-after-move]
|
||||
RESULTS_RE: Pattern[str] = re.compile(
|
||||
RESULTS_RE: re.Pattern[str] = re.compile(
|
||||
r"""(?mx)
|
||||
^
|
||||
(?P<file>.*?):
|
||||
@ -80,8 +82,8 @@ RESULTS_RE: Pattern[str] = re.compile(
|
||||
|
||||
|
||||
def run_command(
|
||||
args: List[str],
|
||||
) -> "subprocess.CompletedProcess[bytes]":
|
||||
args: list[str],
|
||||
) -> subprocess.CompletedProcess[bytes]:
|
||||
logging.debug("$ %s", " ".join(args))
|
||||
start_time = time.monotonic()
|
||||
try:
|
||||
@ -103,7 +105,7 @@ severities = {
|
||||
}
|
||||
|
||||
|
||||
def clang_search_dirs() -> List[str]:
|
||||
def clang_search_dirs() -> list[str]:
|
||||
# Compilers are ordered based on fallback preference
|
||||
# We pick the first one that is available on the system
|
||||
compilers = ["clang", "gcc", "cpp", "cc"]
|
||||
@ -152,7 +154,7 @@ def check_file(
|
||||
filename: str,
|
||||
binary: str,
|
||||
build_dir: Path,
|
||||
) -> List[LintMessage]:
|
||||
) -> list[LintMessage]:
|
||||
try:
|
||||
proc = run_command(
|
||||
[binary, f"-p={build_dir}", *include_args, filename],
|
||||
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import json
|
||||
@ -7,7 +9,7 @@ import re
|
||||
import subprocess
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import List, NamedTuple, Optional, Pattern
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
LINTER_CODE = "CMAKE"
|
||||
@ -21,19 +23,19 @@ class LintSeverity(str, Enum):
|
||||
|
||||
|
||||
class LintMessage(NamedTuple):
|
||||
path: Optional[str]
|
||||
line: Optional[int]
|
||||
char: Optional[int]
|
||||
path: str | None
|
||||
line: int | None
|
||||
char: int | None
|
||||
code: str
|
||||
severity: LintSeverity
|
||||
name: str
|
||||
original: Optional[str]
|
||||
replacement: Optional[str]
|
||||
description: Optional[str]
|
||||
original: str | None
|
||||
replacement: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
# CMakeLists.txt:901: Lines should be <= 80 characters long [linelength]
|
||||
RESULTS_RE: Pattern[str] = re.compile(
|
||||
RESULTS_RE: re.Pattern[str] = re.compile(
|
||||
r"""(?mx)
|
||||
^
|
||||
(?P<file>.*?):
|
||||
@ -46,8 +48,8 @@ RESULTS_RE: Pattern[str] = re.compile(
|
||||
|
||||
|
||||
def run_command(
|
||||
args: List[str],
|
||||
) -> "subprocess.CompletedProcess[bytes]":
|
||||
args: list[str],
|
||||
) -> subprocess.CompletedProcess[bytes]:
|
||||
logging.debug("$ %s", " ".join(args))
|
||||
start_time = time.monotonic()
|
||||
try:
|
||||
@ -63,7 +65,7 @@ def run_command(
|
||||
def check_file(
|
||||
filename: str,
|
||||
config: str,
|
||||
) -> List[LintMessage]:
|
||||
) -> list[LintMessage]:
|
||||
try:
|
||||
proc = run_command(
|
||||
["cmakelint", f"--config={config}", filename],
|
||||
|
@ -2,13 +2,15 @@
|
||||
CONSTEXPR: Ensures users don't use vanilla constexpr since it causes issues
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from enum import Enum
|
||||
from typing import NamedTuple, Optional
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
CONSTEXPR = "constexpr char"
|
||||
CONSTEXPR_MACRO = "CONSTEXPR_EXCEPT_WIN_CUDA char"
|
||||
@ -21,18 +23,18 @@ class LintSeverity(str, Enum):
|
||||
|
||||
|
||||
class LintMessage(NamedTuple):
|
||||
path: Optional[str]
|
||||
line: Optional[int]
|
||||
char: Optional[int]
|
||||
path: str | None
|
||||
line: int | None
|
||||
char: int | None
|
||||
code: str
|
||||
severity: LintSeverity
|
||||
name: str
|
||||
original: Optional[str]
|
||||
replacement: Optional[str]
|
||||
description: Optional[str]
|
||||
original: str | None
|
||||
replacement: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
def check_file(filename: str) -> Optional[LintMessage]:
|
||||
def check_file(filename: str) -> LintMessage | None:
|
||||
logging.debug("Checking file %s", filename)
|
||||
|
||||
with open(filename) as f:
|
||||
|
@ -1,14 +1,17 @@
|
||||
"""
|
||||
EXEC: Ensure that source files are not executable.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
from enum import Enum
|
||||
from typing import NamedTuple, Optional
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
LINTER_CODE = "EXEC"
|
||||
|
||||
@ -21,18 +24,18 @@ class LintSeverity(str, Enum):
|
||||
|
||||
|
||||
class LintMessage(NamedTuple):
|
||||
path: Optional[str]
|
||||
line: Optional[int]
|
||||
char: Optional[int]
|
||||
path: str | None
|
||||
line: int | None
|
||||
char: int | None
|
||||
code: str
|
||||
severity: LintSeverity
|
||||
name: str
|
||||
original: Optional[str]
|
||||
replacement: Optional[str]
|
||||
description: Optional[str]
|
||||
original: str | None
|
||||
replacement: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
def check_file(filename: str) -> Optional[LintMessage]:
|
||||
def check_file(filename: str) -> LintMessage | None:
|
||||
is_executable = os.access(filename, os.X_OK)
|
||||
if is_executable:
|
||||
return LintMessage(
|
||||
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
@ -7,7 +9,7 @@ import subprocess
|
||||
import sys
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Pattern, Set
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
|
||||
IS_WINDOWS: bool = os.name == "nt"
|
||||
@ -25,15 +27,15 @@ class LintSeverity(str, Enum):
|
||||
|
||||
|
||||
class LintMessage(NamedTuple):
|
||||
path: Optional[str]
|
||||
line: Optional[int]
|
||||
char: Optional[int]
|
||||
path: str | None
|
||||
line: int | None
|
||||
char: int | None
|
||||
code: str
|
||||
severity: LintSeverity
|
||||
name: str
|
||||
original: Optional[str]
|
||||
replacement: Optional[str]
|
||||
description: Optional[str]
|
||||
original: str | None
|
||||
replacement: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
def as_posix(name: str) -> str:
|
||||
@ -42,7 +44,7 @@ def as_posix(name: str) -> str:
|
||||
|
||||
# fmt: off
|
||||
# https://www.flake8rules.com/
|
||||
DOCUMENTED_IN_FLAKE8RULES: Set[str] = {
|
||||
DOCUMENTED_IN_FLAKE8RULES: set[str] = {
|
||||
"E101", "E111", "E112", "E113", "E114", "E115", "E116", "E117",
|
||||
"E121", "E122", "E123", "E124", "E125", "E126", "E127", "E128", "E129",
|
||||
"E131", "E133",
|
||||
@ -78,14 +80,14 @@ DOCUMENTED_IN_FLAKE8RULES: Set[str] = {
|
||||
}
|
||||
|
||||
# https://pypi.org/project/flake8-comprehensions/#rules
|
||||
DOCUMENTED_IN_FLAKE8COMPREHENSIONS: Set[str] = {
|
||||
DOCUMENTED_IN_FLAKE8COMPREHENSIONS: set[str] = {
|
||||
"C400", "C401", "C402", "C403", "C404", "C405", "C406", "C407", "C408", "C409",
|
||||
"C410",
|
||||
"C411", "C412", "C413", "C414", "C415", "C416",
|
||||
}
|
||||
|
||||
# https://github.com/PyCQA/flake8-bugbear#list-of-warnings
|
||||
DOCUMENTED_IN_BUGBEAR: Set[str] = {
|
||||
DOCUMENTED_IN_BUGBEAR: set[str] = {
|
||||
"B001", "B002", "B003", "B004", "B005", "B006", "B007", "B008", "B009", "B010",
|
||||
"B011", "B012", "B013", "B014", "B015",
|
||||
"B301", "B302", "B303", "B304", "B305", "B306",
|
||||
@ -98,7 +100,7 @@ DOCUMENTED_IN_BUGBEAR: Set[str] = {
|
||||
# stdin:3:6: T484 Name 'foo' is not defined
|
||||
# stdin:3:-100: W605 invalid escape sequence '\/'
|
||||
# stdin:3:1: E302 expected 2 blank lines, found 1
|
||||
RESULTS_RE: Pattern[str] = re.compile(
|
||||
RESULTS_RE: re.Pattern[str] = re.compile(
|
||||
r"""(?mx)
|
||||
^
|
||||
(?P<file>.*?):
|
||||
@ -134,10 +136,10 @@ def _test_results_re() -> None:
|
||||
|
||||
|
||||
def _run_command(
|
||||
args: List[str],
|
||||
args: list[str],
|
||||
*,
|
||||
extra_env: Optional[Dict[str, str]],
|
||||
) -> "subprocess.CompletedProcess[str]":
|
||||
extra_env: dict[str, str] | None,
|
||||
) -> subprocess.CompletedProcess[str]:
|
||||
logging.debug(
|
||||
"$ %s",
|
||||
" ".join(
|
||||
@ -158,11 +160,11 @@ def _run_command(
|
||||
|
||||
|
||||
def run_command(
|
||||
args: List[str],
|
||||
args: list[str],
|
||||
*,
|
||||
extra_env: Optional[Dict[str, str]],
|
||||
extra_env: dict[str, str] | None,
|
||||
retries: int,
|
||||
) -> "subprocess.CompletedProcess[str]":
|
||||
) -> subprocess.CompletedProcess[str]:
|
||||
remaining_retries = retries
|
||||
while True:
|
||||
try:
|
||||
@ -243,11 +245,11 @@ def get_issue_documentation_url(code: str) -> str:
|
||||
|
||||
|
||||
def check_files(
|
||||
filenames: List[str],
|
||||
flake8_plugins_path: Optional[str],
|
||||
severities: Dict[str, LintSeverity],
|
||||
filenames: list[str],
|
||||
flake8_plugins_path: str | None,
|
||||
severities: dict[str, LintSeverity],
|
||||
retries: int,
|
||||
) -> List[LintMessage]:
|
||||
) -> list[LintMessage]:
|
||||
try:
|
||||
proc = run_command(
|
||||
[sys.executable, "-mflake8", "--exit-zero"] + filenames,
|
||||
@ -351,7 +353,7 @@ def main() -> None:
|
||||
else os.path.realpath(args.flake8_plugins_path)
|
||||
)
|
||||
|
||||
severities: Dict[str, LintSeverity] = {}
|
||||
severities: dict[str, LintSeverity] = {}
|
||||
if args.severity:
|
||||
for severity in args.severity:
|
||||
parts = severity.split(":", 1)
|
||||
|
@ -2,6 +2,8 @@
|
||||
Generic linter that greps for a pattern and optionally suggests replacements.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
@ -10,7 +12,7 @@ import subprocess
|
||||
import sys
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, List, NamedTuple, Optional
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
|
||||
IS_WINDOWS: bool = os.name == "nt"
|
||||
@ -28,15 +30,15 @@ class LintSeverity(str, Enum):
|
||||
|
||||
|
||||
class LintMessage(NamedTuple):
|
||||
path: Optional[str]
|
||||
line: Optional[int]
|
||||
char: Optional[int]
|
||||
path: str | None
|
||||
line: int | None
|
||||
char: int | None
|
||||
code: str
|
||||
severity: LintSeverity
|
||||
name: str
|
||||
original: Optional[str]
|
||||
replacement: Optional[str]
|
||||
description: Optional[str]
|
||||
original: str | None
|
||||
replacement: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
def as_posix(name: str) -> str:
|
||||
@ -44,8 +46,8 @@ def as_posix(name: str) -> str:
|
||||
|
||||
|
||||
def run_command(
|
||||
args: List[str],
|
||||
) -> "subprocess.CompletedProcess[bytes]":
|
||||
args: list[str],
|
||||
) -> subprocess.CompletedProcess[bytes]:
|
||||
logging.debug("$ %s", " ".join(args))
|
||||
start_time = time.monotonic()
|
||||
try:
|
||||
@ -65,7 +67,7 @@ def lint_file(
|
||||
linter_name: str,
|
||||
error_name: str,
|
||||
error_description: str,
|
||||
) -> Optional[LintMessage]:
|
||||
) -> LintMessage | None:
|
||||
# matching_line looks like:
|
||||
# tools/linter/clangtidy_linter.py:13:import foo.bar.baz
|
||||
split = matching_line.split(":")
|
||||
|
@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from enum import Enum
|
||||
from typing import NamedTuple, Optional, Tuple
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
LINTER_CODE = "LINTRUNNER_VERSION"
|
||||
@ -16,18 +18,18 @@ class LintSeverity(str, Enum):
|
||||
|
||||
|
||||
class LintMessage(NamedTuple):
|
||||
path: Optional[str]
|
||||
line: Optional[int]
|
||||
char: Optional[int]
|
||||
path: str | None
|
||||
line: int | None
|
||||
char: int | None
|
||||
code: str
|
||||
severity: LintSeverity
|
||||
name: str
|
||||
original: Optional[str]
|
||||
replacement: Optional[str]
|
||||
description: Optional[str]
|
||||
original: str | None
|
||||
replacement: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
def toVersionString(version_tuple: Tuple[int, int, int]) -> str:
|
||||
def toVersionString(version_tuple: tuple[int, int, int]) -> str:
|
||||
return ".".join(str(x) for x in version_tuple)
|
||||
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
@ -8,7 +10,7 @@ import sys
|
||||
import time
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Pattern
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
|
||||
IS_WINDOWS: bool = os.name == "nt"
|
||||
@ -26,15 +28,15 @@ class LintSeverity(str, Enum):
|
||||
|
||||
|
||||
class LintMessage(NamedTuple):
|
||||
path: Optional[str]
|
||||
line: Optional[int]
|
||||
char: Optional[int]
|
||||
path: str | None
|
||||
line: int | None
|
||||
char: int | None
|
||||
code: str
|
||||
severity: LintSeverity
|
||||
name: str
|
||||
original: Optional[str]
|
||||
replacement: Optional[str]
|
||||
description: Optional[str]
|
||||
original: str | None
|
||||
replacement: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
def as_posix(name: str) -> str:
|
||||
@ -42,7 +44,7 @@ def as_posix(name: str) -> str:
|
||||
|
||||
|
||||
# tools/linter/flake8_linter.py:15:13: error: Incompatibl...int") [assignment]
|
||||
RESULTS_RE: Pattern[str] = re.compile(
|
||||
RESULTS_RE: re.Pattern[str] = re.compile(
|
||||
r"""(?mx)
|
||||
^
|
||||
(?P<file>.*?):
|
||||
@ -56,7 +58,7 @@ RESULTS_RE: Pattern[str] = re.compile(
|
||||
)
|
||||
|
||||
# torch/_dynamo/variables/tensor.py:363: error: INTERNAL ERROR
|
||||
INTERNAL_ERROR_RE: Pattern[str] = re.compile(
|
||||
INTERNAL_ERROR_RE: re.Pattern[str] = re.compile(
|
||||
r"""(?mx)
|
||||
^
|
||||
(?P<file>.*?):
|
||||
@ -69,11 +71,11 @@ INTERNAL_ERROR_RE: Pattern[str] = re.compile(
|
||||
|
||||
|
||||
def run_command(
|
||||
args: List[str],
|
||||
args: list[str],
|
||||
*,
|
||||
extra_env: Optional[Dict[str, str]],
|
||||
extra_env: dict[str, str] | None,
|
||||
retries: int,
|
||||
) -> "subprocess.CompletedProcess[bytes]":
|
||||
) -> subprocess.CompletedProcess[bytes]:
|
||||
logging.debug("$ %s", " ".join(args))
|
||||
start_time = time.monotonic()
|
||||
try:
|
||||
@ -94,7 +96,7 @@ severities = {
|
||||
}
|
||||
|
||||
|
||||
def check_mypy_installed(code: str) -> List[LintMessage]:
|
||||
def check_mypy_installed(code: str) -> list[LintMessage]:
|
||||
cmd = [sys.executable, "-mmypy", "-V"]
|
||||
try:
|
||||
subprocess.run(cmd, check=True, capture_output=True)
|
||||
@ -117,11 +119,11 @@ def check_mypy_installed(code: str) -> List[LintMessage]:
|
||||
|
||||
|
||||
def check_files(
|
||||
filenames: List[str],
|
||||
filenames: list[str],
|
||||
config: str,
|
||||
retries: int,
|
||||
code: str,
|
||||
) -> List[LintMessage]:
|
||||
) -> list[LintMessage]:
|
||||
# dmypy has a bug where it won't pick up changes if you pass it absolute
|
||||
# file names, see https://github.com/python/mypy/issues/16768
|
||||
filenames = [os.path.relpath(f) for f in filenames]
|
||||
@ -224,7 +226,7 @@ def main() -> None:
|
||||
|
||||
# Use a dictionary here to preserve order. mypy cares about order,
|
||||
# tragically, e.g. https://github.com/python/mypy/issues/2015
|
||||
filenames: Dict[str, bool] = {}
|
||||
filenames: dict[str, bool] = {}
|
||||
|
||||
# If a stub file exists, have mypy check it instead of the original file, in
|
||||
# accordance with PEP-484 (see https://www.python.org/dev/peps/pep-0484/#stub-files)
|
||||
|
@ -14,12 +14,14 @@ is simply to make sure that there is *some* configuration of ruamel that can rou
|
||||
the YAML, not to be prescriptive about it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from enum import Enum
|
||||
from io import StringIO
|
||||
from typing import NamedTuple, Optional
|
||||
from typing import NamedTuple
|
||||
|
||||
import ruamel.yaml # type: ignore[import]
|
||||
|
||||
@ -32,15 +34,15 @@ class LintSeverity(str, Enum):
|
||||
|
||||
|
||||
class LintMessage(NamedTuple):
|
||||
path: Optional[str]
|
||||
line: Optional[int]
|
||||
char: Optional[int]
|
||||
path: str | None
|
||||
line: int | None
|
||||
char: int | None
|
||||
code: str
|
||||
severity: LintSeverity
|
||||
name: str
|
||||
original: Optional[str]
|
||||
replacement: Optional[str]
|
||||
description: Optional[str]
|
||||
original: str | None
|
||||
replacement: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,13 +1,16 @@
|
||||
"""
|
||||
NEWLINE: Checks files to make sure there are no trailing newlines.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from enum import Enum
|
||||
from typing import List, NamedTuple, Optional
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
NEWLINE = 10 # ASCII "\n"
|
||||
CARRIAGE_RETURN = 13 # ASCII "\r"
|
||||
@ -22,18 +25,18 @@ class LintSeverity(str, Enum):
|
||||
|
||||
|
||||
class LintMessage(NamedTuple):
|
||||
path: Optional[str]
|
||||
line: Optional[int]
|
||||
char: Optional[int]
|
||||
path: str | None
|
||||
line: int | None
|
||||
char: int | None
|
||||
code: str
|
||||
severity: LintSeverity
|
||||
name: str
|
||||
original: Optional[str]
|
||||
replacement: Optional[str]
|
||||
description: Optional[str]
|
||||
original: str | None
|
||||
replacement: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
def check_file(filename: str) -> Optional[LintMessage]:
|
||||
def check_file(filename: str) -> LintMessage | None:
|
||||
logging.debug("Checking file %s", filename)
|
||||
|
||||
with open(filename, "rb") as f:
|
||||
@ -85,7 +88,7 @@ def check_file(filename: str) -> Optional[LintMessage]:
|
||||
description="Trailing newline found. Run `lintrunner --take NEWLINE -a` to apply changes.",
|
||||
)
|
||||
has_changes = False
|
||||
original_lines: Optional[List[bytes]] = None
|
||||
original_lines: list[bytes] | None = None
|
||||
for idx, line in enumerate(lines):
|
||||
if len(line) >= 2 and line[-1] == NEWLINE and line[-2] == CARRIAGE_RETURN:
|
||||
if not has_changes:
|
||||
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import json
|
||||
@ -5,7 +7,7 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
from enum import Enum
|
||||
from typing import Any, List, NamedTuple, Optional
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
|
||||
IS_WINDOWS: bool = os.name == "nt"
|
||||
@ -23,18 +25,18 @@ class LintSeverity(str, Enum):
|
||||
|
||||
|
||||
class LintMessage(NamedTuple):
|
||||
path: Optional[str]
|
||||
line: Optional[int]
|
||||
char: Optional[int]
|
||||
path: str | None
|
||||
line: int | None
|
||||
char: int | None
|
||||
code: str
|
||||
severity: LintSeverity
|
||||
name: str
|
||||
original: Optional[str]
|
||||
replacement: Optional[str]
|
||||
description: Optional[str]
|
||||
original: str | None
|
||||
replacement: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
def check_file(filename: str) -> List[LintMessage]:
|
||||
def check_file(filename: str) -> list[LintMessage]:
|
||||
with open(filename, "rb") as f:
|
||||
original = f.read().decode("utf-8")
|
||||
replacement = ""
|
||||
|
@ -1,6 +1,9 @@
|
||||
"""
|
||||
Initializer script that installs stuff to pip.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
@ -9,10 +12,8 @@ import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
from typing import List
|
||||
|
||||
|
||||
def run_command(args: List[str]) -> "subprocess.CompletedProcess[bytes]":
|
||||
def run_command(args: list[str]) -> subprocess.CompletedProcess[bytes]:
|
||||
logging.debug("$ %s", " ".join(args))
|
||||
start_time = time.monotonic()
|
||||
try:
|
||||
|
@ -14,6 +14,7 @@ import sys
|
||||
import time
|
||||
from typing import Any, BinaryIO
|
||||
|
||||
|
||||
LINTER_CODE = "RUFF"
|
||||
IS_WINDOWS: bool = os.name == "nt"
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
@ -6,7 +8,7 @@ import subprocess
|
||||
import sys
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import List, NamedTuple, Optional
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
LINTER_CODE = "SHELLCHECK"
|
||||
@ -20,20 +22,20 @@ class LintSeverity(str, Enum):
|
||||
|
||||
|
||||
class LintMessage(NamedTuple):
|
||||
path: Optional[str]
|
||||
line: Optional[int]
|
||||
char: Optional[int]
|
||||
path: str | None
|
||||
line: int | None
|
||||
char: int | None
|
||||
code: str
|
||||
severity: LintSeverity
|
||||
name: str
|
||||
original: Optional[str]
|
||||
replacement: Optional[str]
|
||||
description: Optional[str]
|
||||
original: str | None
|
||||
replacement: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
def run_command(
|
||||
args: List[str],
|
||||
) -> "subprocess.CompletedProcess[bytes]":
|
||||
args: list[str],
|
||||
) -> subprocess.CompletedProcess[bytes]:
|
||||
logging.debug("$ %s", " ".join(args))
|
||||
start_time = time.monotonic()
|
||||
try:
|
||||
@ -47,8 +49,8 @@ def run_command(
|
||||
|
||||
|
||||
def check_files(
|
||||
files: List[str],
|
||||
) -> List[LintMessage]:
|
||||
files: list[str],
|
||||
) -> list[LintMessage]:
|
||||
try:
|
||||
proc = run_command(
|
||||
["shellcheck", "--external-sources", "--format=json1"] + files
|
||||
|
@ -6,15 +6,19 @@ calls run_tests to ensure that the test will be run in OSS CI.
|
||||
|
||||
Takes ~2 minuters to run without the multiprocessing, probably overkill.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
from enum import Enum
|
||||
from typing import List, NamedTuple, Optional
|
||||
from typing import NamedTuple
|
||||
|
||||
import libcst as cst
|
||||
import libcst.matchers as m
|
||||
|
||||
|
||||
LINTER_CODE = "TEST_HAS_MAIN"
|
||||
|
||||
|
||||
@ -62,18 +66,18 @@ class LintSeverity(str, Enum):
|
||||
|
||||
|
||||
class LintMessage(NamedTuple):
|
||||
path: Optional[str]
|
||||
line: Optional[int]
|
||||
char: Optional[int]
|
||||
path: str | None
|
||||
line: int | None
|
||||
char: int | None
|
||||
code: str
|
||||
severity: LintSeverity
|
||||
name: str
|
||||
original: Optional[str]
|
||||
replacement: Optional[str]
|
||||
description: Optional[str]
|
||||
original: str | None
|
||||
replacement: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
def check_file(filename: str) -> List[LintMessage]:
|
||||
def check_file(filename: str) -> list[LintMessage]:
|
||||
lint_messages = []
|
||||
|
||||
with open(filename) as f:
|
||||
|
@ -8,10 +8,13 @@ has valid ownership information in a comment header. Valid means:
|
||||
- Each owner label actually exists in PyTorch
|
||||
- Each owner label starts with "module: " or "oncall: " or is in ACCEPTABLE_OWNER_LABELS
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, List, NamedTuple, Optional
|
||||
from typing import Any, NamedTuple
|
||||
from urllib.request import urlopen
|
||||
|
||||
|
||||
@ -26,15 +29,15 @@ class LintSeverity(str, Enum):
|
||||
|
||||
|
||||
class LintMessage(NamedTuple):
|
||||
path: Optional[str]
|
||||
line: Optional[int]
|
||||
char: Optional[int]
|
||||
path: str | None
|
||||
line: int | None
|
||||
char: int | None
|
||||
code: str
|
||||
severity: LintSeverity
|
||||
name: str
|
||||
original: Optional[str]
|
||||
replacement: Optional[str]
|
||||
description: Optional[str]
|
||||
original: str | None
|
||||
replacement: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
# Team/owner labels usually start with "module: " or "oncall: ", but the following are acceptable exceptions
|
||||
@ -58,8 +61,8 @@ GLOB_EXCEPTIONS = ["**/test/run_test.py"]
|
||||
|
||||
|
||||
def check_labels(
|
||||
labels: List[str], filename: str, line_number: int
|
||||
) -> List[LintMessage]:
|
||||
labels: list[str], filename: str, line_number: int
|
||||
) -> list[LintMessage]:
|
||||
lint_messages = []
|
||||
for label in labels:
|
||||
if label not in PYTORCH_LABELS:
|
||||
@ -104,7 +107,7 @@ def check_labels(
|
||||
return lint_messages
|
||||
|
||||
|
||||
def check_file(filename: str) -> List[LintMessage]:
|
||||
def check_file(filename: str) -> list[LintMessage]:
|
||||
lint_messages = []
|
||||
has_ownership_info = False
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import json
|
||||
@ -6,7 +8,7 @@ import os
|
||||
import sys
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, List, NamedTuple, Optional
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from ufmt.core import ufmt_string
|
||||
from ufmt.util import make_black_config
|
||||
@ -28,15 +30,15 @@ class LintSeverity(str, Enum):
|
||||
|
||||
|
||||
class LintMessage(NamedTuple):
|
||||
path: Optional[str]
|
||||
line: Optional[int]
|
||||
char: Optional[int]
|
||||
path: str | None
|
||||
line: int | None
|
||||
char: int | None
|
||||
code: str
|
||||
severity: LintSeverity
|
||||
name: str
|
||||
original: Optional[str]
|
||||
replacement: Optional[str]
|
||||
description: Optional[str]
|
||||
original: str | None
|
||||
replacement: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
def as_posix(name: str) -> str:
|
||||
@ -59,7 +61,7 @@ def format_error_message(filename: str, err: Exception) -> LintMessage:
|
||||
|
||||
def check_file(
|
||||
filename: str,
|
||||
) -> List[LintMessage]:
|
||||
) -> list[LintMessage]:
|
||||
with open(filename, "rb") as f:
|
||||
original = f.read().decode("utf-8")
|
||||
|
||||
|
@ -2,16 +2,20 @@
|
||||
|
||||
Any job with a specific `sync-tag` must match all other jobs with the same `sync-tag`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, NamedTuple, Optional
|
||||
from typing import Any, Iterable, NamedTuple
|
||||
|
||||
from yaml import dump, load
|
||||
|
||||
|
||||
# Safely load fast C Yaml loader/dumper if they are available
|
||||
try:
|
||||
from yaml import CSafeLoader as Loader
|
||||
@ -27,15 +31,15 @@ class LintSeverity(str, Enum):
|
||||
|
||||
|
||||
class LintMessage(NamedTuple):
|
||||
path: Optional[str]
|
||||
line: Optional[int]
|
||||
char: Optional[int]
|
||||
path: str | None
|
||||
line: int | None
|
||||
char: int | None
|
||||
code: str
|
||||
severity: LintSeverity
|
||||
name: str
|
||||
original: Optional[str]
|
||||
replacement: Optional[str]
|
||||
description: Optional[str]
|
||||
original: str | None
|
||||
replacement: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
def glob_yamls(path: Path) -> Iterable[Path]:
|
||||
@ -51,7 +55,7 @@ def is_workflow(yaml: Any) -> bool:
|
||||
return yaml.get("jobs") is not None
|
||||
|
||||
|
||||
def print_lint_message(path: Path, job: Dict[str, Any], sync_tag: str) -> None:
|
||||
def print_lint_message(path: Path, job: dict[str, Any], sync_tag: str) -> None:
|
||||
job_id = next(iter(job.keys()))
|
||||
with open(path) as f:
|
||||
lines = f.readlines()
|
||||
|
@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
|
||||
def run_cmd(cmd: List[str]) -> None:
|
||||
def run_cmd(cmd: list[str]) -> None:
|
||||
print(f"Running: {cmd}")
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
|
@ -1,13 +1,16 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from typing import Set
|
||||
|
||||
import yaml
|
||||
|
||||
from torchgen.code_template import CodeTemplate
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
|
||||
|
||||
# Safely load fast C Yaml loader/dumper if they are available
|
||||
try:
|
||||
from yaml import CSafeLoader as Loader
|
||||
@ -46,7 +49,7 @@ selected_mobile_ops_preamble = """#pragma once
|
||||
"""
|
||||
|
||||
|
||||
def extract_root_operators(selective_builder: SelectiveBuilder) -> Set[str]:
|
||||
def extract_root_operators(selective_builder: SelectiveBuilder) -> set[str]:
|
||||
ops = []
|
||||
for op_name, op in selective_builder.operators.items():
|
||||
if op.is_root_operator:
|
||||
@ -125,7 +128,7 @@ def write_selected_mobile_ops(
|
||||
# 2. All kernel dtypes
|
||||
def write_selected_mobile_ops_with_all_dtypes(
|
||||
output_file_path: str,
|
||||
root_ops: Set[str],
|
||||
root_ops: set[str],
|
||||
) -> None:
|
||||
with open(output_file_path, "wb") as out_file:
|
||||
body_parts = [selected_mobile_ops_preamble]
|
||||
|
@ -1,5 +1,6 @@
|
||||
import lldb # type: ignore[import]
|
||||
|
||||
|
||||
# load into lldb instance with:
|
||||
# command script import tools/lldb/deploy_debugger.py
|
||||
|
||||
|
@ -24,6 +24,9 @@ well. This can be done with
|
||||
Pulling will reinstalle the conda dependencies as well as the nightly binaries into
|
||||
the repo directory.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import datetime
|
||||
import functools
|
||||
@ -40,23 +43,10 @@ import time
|
||||
import uuid
|
||||
from argparse import ArgumentParser
|
||||
from ast import literal_eval
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
from typing import Any, Callable, cast, Generator, Iterable, Iterator, Sequence, TypeVar
|
||||
|
||||
LOGGER: Optional[logging.Logger] = None
|
||||
|
||||
LOGGER: logging.Logger | None = None
|
||||
URL_FORMAT = "{base_url}/{platform}/{dist_name}.tar.bz2"
|
||||
DATETIME_FORMAT = "%Y-%m-%d_%Hh%Mm%Ss"
|
||||
SHA1_RE = re.compile("([0-9a-fA-F]{40})")
|
||||
@ -68,9 +58,9 @@ SPECS_TO_INSTALL = ("pytorch", "mypy", "pytest", "hypothesis", "ipython", "sphin
|
||||
|
||||
|
||||
class Formatter(logging.Formatter):
|
||||
redactions: Dict[str, str]
|
||||
redactions: dict[str, str]
|
||||
|
||||
def __init__(self, fmt: Optional[str] = None, datefmt: Optional[str] = None):
|
||||
def __init__(self, fmt: str | None = None, datefmt: str | None = None) -> None:
|
||||
super().__init__(fmt, datefmt)
|
||||
self.redactions = {}
|
||||
|
||||
@ -192,7 +182,7 @@ def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, N
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def check_in_repo() -> Optional[str]:
|
||||
def check_in_repo() -> str | None:
|
||||
"""Ensures that we are in the PyTorch repo."""
|
||||
if not os.path.isfile("setup.py"):
|
||||
return "Not in root-level PyTorch repo, no setup.py found"
|
||||
@ -203,7 +193,7 @@ def check_in_repo() -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def check_branch(subcommand: str, branch: Optional[str]) -> Optional[str]:
|
||||
def check_branch(subcommand: str, branch: str | None) -> str | None:
|
||||
"""Checks that the branch name can be checked out."""
|
||||
if subcommand != "checkout":
|
||||
return None
|
||||
@ -259,7 +249,7 @@ def timed(prefix: str) -> Callable[[F], F]:
|
||||
def _make_channel_args(
|
||||
channels: Iterable[str] = ("pytorch-nightly",),
|
||||
override_channels: bool = False,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
args = []
|
||||
for channel in channels:
|
||||
args.append("--channel")
|
||||
@ -271,11 +261,11 @@ def _make_channel_args(
|
||||
|
||||
@timed("Solving conda environment")
|
||||
def conda_solve(
|
||||
name: Optional[str] = None,
|
||||
prefix: Optional[str] = None,
|
||||
name: str | None = None,
|
||||
prefix: str | None = None,
|
||||
channels: Iterable[str] = ("pytorch-nightly",),
|
||||
override_channels: bool = False,
|
||||
) -> Tuple[List[str], str, str, bool, List[str]]:
|
||||
) -> tuple[list[str], str, str, bool, list[str]]:
|
||||
"""Performs the conda solve and splits the deps from the package."""
|
||||
# compute what environment to use
|
||||
if prefix is not None:
|
||||
@ -329,7 +319,7 @@ def conda_solve(
|
||||
|
||||
|
||||
@timed("Installing dependencies")
|
||||
def deps_install(deps: List[str], existing_env: bool, env_opts: List[str]) -> None:
|
||||
def deps_install(deps: list[str], existing_env: bool, env_opts: list[str]) -> None:
|
||||
"""Install dependencies to deps environment"""
|
||||
if not existing_env:
|
||||
# first remove previous pytorch-deps env
|
||||
@ -342,7 +332,7 @@ def deps_install(deps: List[str], existing_env: bool, env_opts: List[str]) -> No
|
||||
|
||||
|
||||
@timed("Installing pytorch nightly binaries")
|
||||
def pytorch_install(url: str) -> "tempfile.TemporaryDirectory[str]":
|
||||
def pytorch_install(url: str) -> tempfile.TemporaryDirectory[str]:
|
||||
"""Install pytorch into a temporary directory"""
|
||||
pytdir = tempfile.TemporaryDirectory()
|
||||
cmd = ["conda", "create", "--yes", "--no-deps", "--prefix", pytdir.name, url]
|
||||
@ -421,33 +411,33 @@ def pull_nightly_version(spdir: str) -> None:
|
||||
p = subprocess.run(cmd, check=True)
|
||||
|
||||
|
||||
def _get_listing_linux(source_dir: str) -> List[str]:
|
||||
def _get_listing_linux(source_dir: str) -> list[str]:
|
||||
listing = glob.glob(os.path.join(source_dir, "*.so"))
|
||||
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.so")))
|
||||
return listing
|
||||
|
||||
|
||||
def _get_listing_osx(source_dir: str) -> List[str]:
|
||||
def _get_listing_osx(source_dir: str) -> list[str]:
|
||||
# oddly, these are .so files even on Mac
|
||||
listing = glob.glob(os.path.join(source_dir, "*.so"))
|
||||
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dylib")))
|
||||
return listing
|
||||
|
||||
|
||||
def _get_listing_win(source_dir: str) -> List[str]:
|
||||
def _get_listing_win(source_dir: str) -> list[str]:
|
||||
listing = glob.glob(os.path.join(source_dir, "*.pyd"))
|
||||
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.lib")))
|
||||
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dll")))
|
||||
return listing
|
||||
|
||||
|
||||
def _glob_pyis(d: str) -> Set[str]:
|
||||
def _glob_pyis(d: str) -> set[str]:
|
||||
search = os.path.join(d, "**", "*.pyi")
|
||||
pyis = {os.path.relpath(p, d) for p in glob.iglob(search)}
|
||||
return pyis
|
||||
|
||||
|
||||
def _find_missing_pyi(source_dir: str, target_dir: str) -> List[str]:
|
||||
def _find_missing_pyi(source_dir: str, target_dir: str) -> list[str]:
|
||||
source_pyis = _glob_pyis(source_dir)
|
||||
target_pyis = _glob_pyis(target_dir)
|
||||
missing_pyis = [os.path.join(source_dir, p) for p in (source_pyis - target_pyis)]
|
||||
@ -455,7 +445,7 @@ def _find_missing_pyi(source_dir: str, target_dir: str) -> List[str]:
|
||||
return missing_pyis
|
||||
|
||||
|
||||
def _get_listing(source_dir: str, target_dir: str, platform: str) -> List[str]:
|
||||
def _get_listing(source_dir: str, target_dir: str, platform: str) -> list[str]:
|
||||
if platform.startswith("linux"):
|
||||
listing = _get_listing_linux(source_dir)
|
||||
elif platform.startswith("osx"):
|
||||
@ -510,12 +500,12 @@ def _move_single(
|
||||
mover(src, trg)
|
||||
|
||||
|
||||
def _copy_files(listing: List[str], source_dir: str, target_dir: str) -> None:
|
||||
def _copy_files(listing: list[str], source_dir: str, target_dir: str) -> None:
|
||||
for src in listing:
|
||||
_move_single(src, source_dir, target_dir, shutil.copy2, "Copying")
|
||||
|
||||
|
||||
def _link_files(listing: List[str], source_dir: str, target_dir: str) -> None:
|
||||
def _link_files(listing: list[str], source_dir: str, target_dir: str) -> None:
|
||||
for src in listing:
|
||||
_move_single(src, source_dir, target_dir, os.link, "Linking")
|
||||
|
||||
@ -537,7 +527,7 @@ def move_nightly_files(spdir: str, platform: str) -> None:
|
||||
_copy_files(listing, source_dir, target_dir)
|
||||
|
||||
|
||||
def _available_envs() -> Dict[str, str]:
|
||||
def _available_envs() -> dict[str, str]:
|
||||
cmd = ["conda", "env", "list"]
|
||||
p = subprocess.run(
|
||||
cmd,
|
||||
@ -559,7 +549,7 @@ def _available_envs() -> Dict[str, str]:
|
||||
|
||||
|
||||
@timed("Writing pytorch-nightly.pth")
|
||||
def write_pth(env_opts: List[str], platform: str) -> None:
|
||||
def write_pth(env_opts: list[str], platform: str) -> None:
|
||||
"""Writes Python path file for this dir."""
|
||||
env_type, env_dir = env_opts
|
||||
if env_type == "--name":
|
||||
@ -582,9 +572,9 @@ def install(
|
||||
*,
|
||||
logger: logging.Logger,
|
||||
subcommand: str = "checkout",
|
||||
branch: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
prefix: Optional[str] = None,
|
||||
branch: str | None = None,
|
||||
name: str | None = None,
|
||||
prefix: str | None = None,
|
||||
channels: Iterable[str] = ("pytorch-nightly",),
|
||||
override_channels: bool = False,
|
||||
) -> None:
|
||||
@ -673,7 +663,7 @@ def make_parser() -> ArgumentParser:
|
||||
return p
|
||||
|
||||
|
||||
def main(args: Optional[Sequence[str]] = None) -> None:
|
||||
def main(args: Sequence[str] | None = None) -> None:
|
||||
"""Main entry point"""
|
||||
global LOGGER
|
||||
p = make_parser()
|
||||
|
@ -13,13 +13,15 @@ CMAKE_CUDA_COMPILER_LAUNCHER="python;tools/nvcc_fix_deps.py;ccache"
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, TextIO
|
||||
from typing import TextIO
|
||||
|
||||
|
||||
def resolve_include(path: Path, include_dirs: List[Path]) -> Path:
|
||||
def resolve_include(path: Path, include_dirs: list[Path]) -> Path:
|
||||
for include_path in include_dirs:
|
||||
abs_path = include_path / path
|
||||
if abs_path.exists():
|
||||
@ -36,7 +38,7 @@ Tried the following paths, but none existed:
|
||||
)
|
||||
|
||||
|
||||
def repair_depfile(depfile: TextIO, include_dirs: List[Path]) -> None:
|
||||
def repair_depfile(depfile: TextIO, include_dirs: list[Path]) -> None:
|
||||
changes_made = False
|
||||
out = ""
|
||||
for line in depfile:
|
||||
@ -70,8 +72,8 @@ PRE_INCLUDE_ARGS = ["-include", "--pre-include"]
|
||||
POST_INCLUDE_ARGS = ["-I", "--include-path", "-isystem", "--system-include"]
|
||||
|
||||
|
||||
def extract_include_arg(include_dirs: List[Path], i: int, args: List[str]) -> None:
|
||||
def extract_one(name: str, i: int, args: List[str]) -> Optional[str]:
|
||||
def extract_include_arg(include_dirs: list[Path], i: int, args: list[str]) -> None:
|
||||
def extract_one(name: str, i: int, args: list[str]) -> str | None:
|
||||
arg = args[i]
|
||||
if arg == name:
|
||||
return args[i + 1]
|
||||
|
@ -24,6 +24,7 @@ import yaml
|
||||
from torchgen import utils as torchgen_utils
|
||||
from torchgen.yaml_utils import YamlLoader
|
||||
|
||||
|
||||
_RULES_GENERATED_COMMENT = """\
|
||||
GENERATED CODE - DO NOT EDIT DIRECTLY
|
||||
This file is generated by gen_diagnostics.py.
|
||||
|
@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import importlib
|
||||
import sys
|
||||
from pprint import pformat
|
||||
from typing import Dict, List, Sequence
|
||||
from typing import Sequence
|
||||
from unittest.mock import Mock, patch
|
||||
from warnings import warn
|
||||
|
||||
@ -220,7 +222,7 @@ to_py_type_ops = ("bool", "float", "complex", "long", "index", "int", "nonzero")
|
||||
all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops
|
||||
|
||||
|
||||
def sig_for_ops(opname: str) -> List[str]:
|
||||
def sig_for_ops(opname: str) -> list[str]:
|
||||
"""sig_for_ops(opname : str) -> List[str]
|
||||
|
||||
Returns signatures for operator special functions (__add__ etc.)"""
|
||||
@ -254,8 +256,8 @@ def sig_for_ops(opname: str) -> List[str]:
|
||||
raise Exception("unknown op", opname) # noqa: TRY002
|
||||
|
||||
|
||||
def generate_type_hints(sig_group: PythonSignatureGroup) -> List[str]:
|
||||
type_hints: List[str] = []
|
||||
def generate_type_hints(sig_group: PythonSignatureGroup) -> list[str]:
|
||||
type_hints: list[str] = []
|
||||
|
||||
# Some deprecated ops that are on the blocklist are still included in pyi
|
||||
if sig_group.signature.name in blocklist and not sig_group.signature.deprecated:
|
||||
@ -285,7 +287,7 @@ def generate_type_hints(sig_group: PythonSignatureGroup) -> List[str]:
|
||||
return type_hints
|
||||
|
||||
|
||||
def get_max_pool_dispatch(name: str, arg_list: List[str]) -> Dict[str, List[str]]:
|
||||
def get_max_pool_dispatch(name: str, arg_list: list[str]) -> dict[str, list[str]]:
|
||||
flag_pos = arg_list.index("{return_indices}")
|
||||
# If return_indices is positional arg, everything before should have no default
|
||||
arg_list_positional = (
|
||||
@ -329,7 +331,7 @@ def gen_nn_functional(fm: FileManager) -> None:
|
||||
)
|
||||
|
||||
# TODO the list for `torch._C._nn` is nonexhaustive
|
||||
unsorted_c_nn_function_hints: Dict[str, List[str]] = {}
|
||||
unsorted_c_nn_function_hints: dict[str, list[str]] = {}
|
||||
|
||||
for d in (2, 3):
|
||||
unsorted_c_nn_function_hints.update(
|
||||
@ -471,7 +473,7 @@ def gen_nn_functional(fm: FileManager) -> None:
|
||||
}
|
||||
)
|
||||
|
||||
c_nn_function_hints: List[str] = []
|
||||
c_nn_function_hints: list[str] = []
|
||||
for _, hints in sorted(unsorted_c_nn_function_hints.items()):
|
||||
if len(hints) > 1:
|
||||
hints = ["@overload\n" + h for h in hints]
|
||||
@ -528,7 +530,7 @@ def gen_nn_functional(fm: FileManager) -> None:
|
||||
)
|
||||
|
||||
# Functions generated by `torch._jit_internal.boolean_dispatch` in `nn.functional`
|
||||
unsorted_dispatched_hints: Dict[str, List[str]] = {}
|
||||
unsorted_dispatched_hints: dict[str, list[str]] = {}
|
||||
|
||||
for d in (1, 2, 3):
|
||||
unsorted_dispatched_hints.update(
|
||||
@ -563,7 +565,7 @@ def gen_nn_functional(fm: FileManager) -> None:
|
||||
# There's no fractional_max_pool1d
|
||||
del unsorted_dispatched_hints["fractional_max_pool1d"]
|
||||
|
||||
dispatched_hints: List[str] = []
|
||||
dispatched_hints: list[str] = []
|
||||
for _, hints in sorted(unsorted_dispatched_hints.items()):
|
||||
if len(hints) > 1:
|
||||
hints = ["@overload\n" + h for h in hints]
|
||||
@ -594,7 +596,7 @@ We gather the docstrings for torch with the following steps:
|
||||
"""
|
||||
|
||||
|
||||
def gather_docstrs() -> Dict[str, str]:
|
||||
def gather_docstrs() -> dict[str, str]:
|
||||
docstrs = {}
|
||||
|
||||
def mock_add_docstr(func: Mock, docstr: str) -> None:
|
||||
@ -648,12 +650,12 @@ def gen_pyi(
|
||||
# also needs to update the other file.
|
||||
|
||||
# Dictionary for NamedTuple definitions
|
||||
structseqs: Dict[str, str] = {}
|
||||
structseqs: dict[str, str] = {}
|
||||
|
||||
# Generate type signatures for top-level functions
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
unsorted_function_hints: Dict[str, List[str]] = collections.defaultdict(list)
|
||||
unsorted_function_hints: dict[str, list[str]] = collections.defaultdict(list)
|
||||
|
||||
for n, n1, n2 in [
|
||||
("csr", "crow", "col"),
|
||||
@ -1054,7 +1056,7 @@ def gen_pyi(
|
||||
# Generate type signatures for Tensor methods
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
unsorted_tensor_method_hints: Dict[str, List[str]] = collections.defaultdict(list)
|
||||
unsorted_tensor_method_hints: dict[str, list[str]] = collections.defaultdict(list)
|
||||
unsorted_tensor_method_hints.update(
|
||||
{
|
||||
"size": [
|
||||
|
@ -1,8 +1,11 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from typing import Any, List, Union
|
||||
from typing import Any
|
||||
|
||||
|
||||
try:
|
||||
from junitparser import ( # type: ignore[import]
|
||||
@ -23,8 +26,8 @@ except ImportError:
|
||||
print("rich not found, for color output use 'pip install rich'")
|
||||
|
||||
|
||||
def parse_junit_reports(path_to_reports: str) -> List[TestCase]: # type: ignore[no-any-unimported]
|
||||
def parse_file(path: str) -> List[TestCase]: # type: ignore[no-any-unimported]
|
||||
def parse_junit_reports(path_to_reports: str) -> list[TestCase]: # type: ignore[no-any-unimported]
|
||||
def parse_file(path: str) -> list[TestCase]: # type: ignore[no-any-unimported]
|
||||
try:
|
||||
return convert_junit_to_testcases(JUnitXml.fromfile(path))
|
||||
except Exception as err:
|
||||
@ -46,7 +49,7 @@ def parse_junit_reports(path_to_reports: str) -> List[TestCase]: # type: ignore
|
||||
return ret_xml
|
||||
|
||||
|
||||
def convert_junit_to_testcases(xml: Union[JUnitXml, TestSuite]) -> List[TestCase]: # type: ignore[no-any-unimported]
|
||||
def convert_junit_to_testcases(xml: JUnitXml | TestSuite) -> list[TestCase]: # type: ignore[no-any-unimported]
|
||||
testcases = []
|
||||
for item in xml:
|
||||
if isinstance(item, TestSuite):
|
||||
@ -56,7 +59,7 @@ def convert_junit_to_testcases(xml: Union[JUnitXml, TestSuite]) -> List[TestCase
|
||||
return testcases
|
||||
|
||||
|
||||
def render_tests(testcases: List[TestCase]) -> None: # type: ignore[no-any-unimported]
|
||||
def render_tests(testcases: list[TestCase]) -> None: # type: ignore[no-any-unimported]
|
||||
num_passed = 0
|
||||
num_skipped = 0
|
||||
num_failed = 0
|
||||
|
@ -1,9 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def which(thefile: str) -> Optional[str]:
|
||||
def which(thefile: str) -> str | None:
|
||||
path = os.environ.get("PATH", os.defpath).split(os.pathsep)
|
||||
for d in path:
|
||||
fname = os.path.join(d, thefile)
|
||||
|
@ -1,5 +1,6 @@
|
||||
"Manages CMake."
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
@ -8,7 +9,7 @@ import sys
|
||||
import sysconfig
|
||||
from distutils.version import LooseVersion
|
||||
from subprocess import CalledProcessError, check_call, check_output
|
||||
from typing import Any, cast, Dict, List, Optional
|
||||
from typing import Any, cast
|
||||
|
||||
from . import which
|
||||
from .cmake_utils import CMakeValue, get_cmake_cache_variables_from_file
|
||||
@ -77,7 +78,7 @@ class CMake:
|
||||
return cmake_command
|
||||
|
||||
@staticmethod
|
||||
def _get_version(cmd: Optional[str]) -> Any:
|
||||
def _get_version(cmd: str | None) -> Any:
|
||||
"Returns cmake version."
|
||||
|
||||
if cmd is None:
|
||||
@ -87,7 +88,7 @@ class CMake:
|
||||
return LooseVersion(line.strip().split(" ")[2])
|
||||
raise RuntimeError("no version found")
|
||||
|
||||
def run(self, args: List[str], env: Dict[str, str]) -> None:
|
||||
def run(self, args: list[str], env: dict[str, str]) -> None:
|
||||
"Executes cmake with arguments and an environment."
|
||||
|
||||
command = [self._cmake_command] + args
|
||||
@ -101,13 +102,13 @@ class CMake:
|
||||
sys.exit(1)
|
||||
|
||||
@staticmethod
|
||||
def defines(args: List[str], **kwargs: CMakeValue) -> None:
|
||||
def defines(args: list[str], **kwargs: CMakeValue) -> None:
|
||||
"Adds definitions to a cmake argument list."
|
||||
for key, value in sorted(kwargs.items()):
|
||||
if value is not None:
|
||||
args.append(f"-D{key}={value}")
|
||||
|
||||
def get_cmake_cache_variables(self) -> Dict[str, CMakeValue]:
|
||||
def get_cmake_cache_variables(self) -> dict[str, CMakeValue]:
|
||||
r"""Gets values in CMakeCache.txt into a dictionary.
|
||||
Returns:
|
||||
dict: A ``dict`` containing the value of cached CMake variables.
|
||||
@ -117,11 +118,11 @@ class CMake:
|
||||
|
||||
def generate(
|
||||
self,
|
||||
version: Optional[str],
|
||||
cmake_python_library: Optional[str],
|
||||
version: str | None,
|
||||
cmake_python_library: str | None,
|
||||
build_python: bool,
|
||||
build_test: bool,
|
||||
my_env: Dict[str, str],
|
||||
my_env: dict[str, str],
|
||||
rerun: bool,
|
||||
) -> None:
|
||||
"Runs cmake to generate native build files."
|
||||
@ -181,7 +182,7 @@ class CMake:
|
||||
_mkdir_p(self.build_dir)
|
||||
|
||||
# Store build options that are directly stored in environment variables
|
||||
build_options: Dict[str, CMakeValue] = {}
|
||||
build_options: dict[str, CMakeValue] = {}
|
||||
|
||||
# Build options that do not start with "BUILD_", "USE_", or "CMAKE_" and are directly controlled by env vars.
|
||||
# This is a dict that maps environment variables to the corresponding variable name in CMake.
|
||||
@ -340,7 +341,7 @@ class CMake:
|
||||
args.append(base_dir)
|
||||
self.run(args, env=my_env)
|
||||
|
||||
def build(self, my_env: Dict[str, str]) -> None:
|
||||
def build(self, my_env: dict[str, str]) -> None:
|
||||
"Runs cmake to build binaries."
|
||||
|
||||
from .env import build_type
|
||||
|
@ -3,8 +3,10 @@ This is refactored from cmake.py to avoid circular imports issue with env.py,
|
||||
which calls get_cmake_cache_variables_from_file
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Dict, IO, Optional, Union
|
||||
from typing import IO, Optional, Union
|
||||
|
||||
|
||||
CMakeValue = Optional[Union[bool, str]]
|
||||
@ -42,7 +44,7 @@ def convert_cmake_value_to_python_value(
|
||||
|
||||
def get_cmake_cache_variables_from_file(
|
||||
cmake_cache_file: IO[str],
|
||||
) -> Dict[str, CMakeValue]:
|
||||
) -> dict[str, CMakeValue]:
|
||||
r"""Gets values in CMakeCache.txt into a dictionary.
|
||||
|
||||
Args:
|
||||
|
@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import platform
|
||||
import struct
|
||||
import sys
|
||||
from itertools import chain
|
||||
from typing import cast, Iterable, List, Optional
|
||||
from typing import cast, Iterable
|
||||
|
||||
|
||||
IS_WINDOWS = platform.system() == "Windows"
|
||||
@ -30,11 +32,11 @@ def check_negative_env_flag(name: str, default: str = "") -> bool:
|
||||
return os.getenv(name, default).upper() in ["OFF", "0", "NO", "FALSE", "N"]
|
||||
|
||||
|
||||
def gather_paths(env_vars: Iterable[str]) -> List[str]:
|
||||
def gather_paths(env_vars: Iterable[str]) -> list[str]:
|
||||
return list(chain(*(os.getenv(v, "").split(os.pathsep) for v in env_vars)))
|
||||
|
||||
|
||||
def lib_paths_from_base(base_path: str) -> List[str]:
|
||||
def lib_paths_from_base(base_path: str) -> list[str]:
|
||||
return [os.path.join(base_path, s) for s in ["lib/x64", "lib", "lib64"]]
|
||||
|
||||
|
||||
@ -54,7 +56,7 @@ class BuildType:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, cmake_build_type_env: Optional[str] = None) -> None:
|
||||
def __init__(self, cmake_build_type_env: str | None = None) -> None:
|
||||
if cmake_build_type_env is not None:
|
||||
self.build_type_string = cmake_build_type_env
|
||||
return
|
||||
|
@ -2,9 +2,12 @@
|
||||
# and use the version numbers from there as substitutions for
|
||||
# an expand_template action. Since there isn't, this silly script exists.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from typing import cast, Dict, Tuple
|
||||
from typing import cast, Tuple
|
||||
|
||||
|
||||
Version = Tuple[int, int, int]
|
||||
|
||||
@ -30,7 +33,7 @@ def parse_version(version: str) -> Version:
|
||||
return cast(Version, tuple([int(n) for n in version_number_str.split(".")]))
|
||||
|
||||
|
||||
def apply_replacements(replacements: Dict[str, str], text: str) -> str:
|
||||
def apply_replacements(replacements: dict[str, str], text: str) -> str:
|
||||
"""
|
||||
Applies the given replacements within the text.
|
||||
|
||||
|
@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
from typing import Any, cast, Optional
|
||||
from typing import Any, cast
|
||||
|
||||
import yaml
|
||||
|
||||
@ -18,10 +20,10 @@ TAGS_PATH = "aten/src/ATen/native/tags.yaml"
|
||||
|
||||
def generate_code(
|
||||
gen_dir: pathlib.Path,
|
||||
native_functions_path: Optional[str] = None,
|
||||
tags_path: Optional[str] = None,
|
||||
install_dir: Optional[str] = None,
|
||||
subset: Optional[str] = None,
|
||||
native_functions_path: str | None = None,
|
||||
tags_path: str | None = None,
|
||||
install_dir: str | None = None,
|
||||
subset: str | None = None,
|
||||
disable_autograd: bool = False,
|
||||
force_schema_registration: bool = False,
|
||||
operator_selector: Any = None,
|
||||
@ -102,8 +104,8 @@ def get_selector_from_legacy_operator_selection_list(
|
||||
|
||||
|
||||
def get_selector(
|
||||
selected_op_list_path: Optional[str],
|
||||
operators_yaml_path: Optional[str],
|
||||
selected_op_list_path: str | None,
|
||||
operators_yaml_path: str | None,
|
||||
) -> Any:
|
||||
# cwrap depends on pyyaml, so we can't import it earlier
|
||||
root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
@ -1,10 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import xml.etree.ElementTree as ET
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any, Dict, Generator, Tuple
|
||||
from typing import Any, Generator
|
||||
|
||||
from tools.stats.upload_stats_lib import (
|
||||
download_s3_artifacts,
|
||||
@ -14,13 +16,14 @@ from tools.stats.upload_stats_lib import (
|
||||
)
|
||||
from tools.stats.upload_test_stats import process_xml_element
|
||||
|
||||
|
||||
TESTCASE_TAG = "testcase"
|
||||
SEPARATOR = ";"
|
||||
|
||||
|
||||
def process_report(
|
||||
report: Path,
|
||||
) -> Dict[str, Dict[str, int]]:
|
||||
) -> dict[str, dict[str, int]]:
|
||||
"""
|
||||
Return a list of disabled tests that should be re-enabled and those that are still
|
||||
flaky (failed or skipped)
|
||||
@ -36,7 +39,7 @@ def process_report(
|
||||
# * Skipped tests from unittest
|
||||
#
|
||||
# We want to keep track of how many times the test fails (num_red) or passes (num_green)
|
||||
all_tests: Dict[str, Dict[str, int]] = {}
|
||||
all_tests: dict[str, dict[str, int]] = {}
|
||||
|
||||
for test_case in root.iter(TESTCASE_TAG):
|
||||
parsed_test_case = process_xml_element(test_case)
|
||||
@ -116,7 +119,7 @@ def get_test_reports(
|
||||
yield from Path(".").glob("**/*.xml")
|
||||
|
||||
|
||||
def get_disabled_test_name(test_id: str) -> Tuple[str, str, str, str]:
|
||||
def get_disabled_test_name(test_id: str) -> tuple[str, str, str, str]:
|
||||
"""
|
||||
Follow flaky bot convention here, if that changes, this will also need to be updated
|
||||
"""
|
||||
@ -133,7 +136,7 @@ def prepare_record(
|
||||
flaky: bool,
|
||||
num_red: int = 0,
|
||||
num_green: int = 0,
|
||||
) -> Tuple[Any, Dict[str, Any]]:
|
||||
) -> tuple[Any, dict[str, Any]]:
|
||||
"""
|
||||
Prepare the record to save onto S3
|
||||
"""
|
||||
@ -162,7 +165,7 @@ def prepare_record(
|
||||
def save_results(
|
||||
workflow_id: int,
|
||||
workflow_run_attempt: int,
|
||||
all_tests: Dict[str, Dict[str, int]],
|
||||
all_tests: dict[str, dict[str, int]],
|
||||
) -> None:
|
||||
"""
|
||||
Save the result to S3, so it can go to Rockset
|
||||
@ -228,7 +231,7 @@ def main(repo: str, workflow_run_id: int, workflow_run_attempt: int) -> None:
|
||||
Find the list of all disabled tests that should be re-enabled
|
||||
"""
|
||||
# Aggregated across all jobs
|
||||
all_tests: Dict[str, Dict[str, int]] = {}
|
||||
all_tests: dict[str, dict[str, int]] = {}
|
||||
|
||||
for report in get_test_reports(
|
||||
args.repo, args.workflow_run_id, args.workflow_run_attempt
|
||||
|
@ -1,17 +1,19 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import shutil
|
||||
from typing import Any, Callable, cast, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, cast, Dict
|
||||
from urllib.request import urlopen
|
||||
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
|
||||
|
||||
|
||||
def get_disabled_issues() -> List[str]:
|
||||
def get_disabled_issues() -> list[str]:
|
||||
reenabled_issues = os.getenv("REENABLED_ISSUES", "")
|
||||
issue_numbers = reenabled_issues.split(",")
|
||||
print("Ignoring disabled issues: ", issue_numbers)
|
||||
@ -34,11 +36,11 @@ FILE_CACHE_LIFESPAN_SECONDS = datetime.timedelta(hours=3).seconds
|
||||
|
||||
|
||||
def fetch_and_cache(
|
||||
dirpath: Union[str, pathlib.Path],
|
||||
dirpath: str | pathlib.Path,
|
||||
name: str,
|
||||
url: str,
|
||||
process_fn: Callable[[Dict[str, Any]], Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
process_fn: Callable[[dict[str, Any]], dict[str, Any]],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
This fetch and cache utils allows sharing between different process.
|
||||
"""
|
||||
@ -76,7 +78,7 @@ def fetch_and_cache(
|
||||
|
||||
def get_slow_tests(
|
||||
dirpath: str, filename: str = SLOW_TESTS_FILE
|
||||
) -> Optional[Dict[str, float]]:
|
||||
) -> dict[str, float] | None:
|
||||
url = "https://ossci-metrics.s3.amazonaws.com/slow-tests.json"
|
||||
try:
|
||||
return fetch_and_cache(dirpath, filename, url, lambda x: x)
|
||||
@ -85,7 +87,7 @@ def get_slow_tests(
|
||||
return {}
|
||||
|
||||
|
||||
def get_test_times() -> Dict[str, Dict[str, float]]:
|
||||
def get_test_times() -> dict[str, dict[str, float]]:
|
||||
return get_from_test_infra_generated_stats(
|
||||
"test-times.json",
|
||||
TEST_TIMES_FILE,
|
||||
@ -93,7 +95,7 @@ def get_test_times() -> Dict[str, Dict[str, float]]:
|
||||
)
|
||||
|
||||
|
||||
def get_test_class_times() -> Dict[str, Dict[str, float]]:
|
||||
def get_test_class_times() -> dict[str, dict[str, float]]:
|
||||
return get_from_test_infra_generated_stats(
|
||||
"test-class-times.json",
|
||||
TEST_CLASS_TIMES_FILE,
|
||||
@ -103,8 +105,8 @@ def get_test_class_times() -> Dict[str, Dict[str, float]]:
|
||||
|
||||
def get_disabled_tests(
|
||||
dirpath: str, filename: str = DISABLED_TESTS_FILE
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
def process_disabled_test(the_response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
) -> dict[str, Any] | None:
|
||||
def process_disabled_test(the_response: dict[str, Any]) -> dict[str, Any]:
|
||||
# remove re-enabled tests and condense even further by getting rid of pr_num
|
||||
disabled_issues = get_disabled_issues()
|
||||
disabled_test_from_issues = dict()
|
||||
@ -124,7 +126,7 @@ def get_disabled_tests(
|
||||
return {}
|
||||
|
||||
|
||||
def get_test_file_ratings() -> Dict[str, Any]:
|
||||
def get_test_file_ratings() -> dict[str, Any]:
|
||||
return get_from_test_infra_generated_stats(
|
||||
"file_test_rating.json",
|
||||
TEST_FILE_RATINGS_FILE,
|
||||
@ -132,7 +134,7 @@ def get_test_file_ratings() -> Dict[str, Any]:
|
||||
)
|
||||
|
||||
|
||||
def get_test_class_ratings() -> Dict[str, Any]:
|
||||
def get_test_class_ratings() -> dict[str, Any]:
|
||||
return get_from_test_infra_generated_stats(
|
||||
"file_test_class_rating.json",
|
||||
TEST_CLASS_RATINGS_FILE,
|
||||
@ -140,7 +142,7 @@ def get_test_class_ratings() -> Dict[str, Any]:
|
||||
)
|
||||
|
||||
|
||||
def get_td_heuristic_historial_edited_files_json() -> Dict[str, Any]:
|
||||
def get_td_heuristic_historial_edited_files_json() -> dict[str, Any]:
|
||||
return get_from_test_infra_generated_stats(
|
||||
"td_heuristic_historical_edited_files.json",
|
||||
TD_HEURISTIC_HISTORICAL_EDITED_FILES,
|
||||
@ -148,7 +150,7 @@ def get_td_heuristic_historial_edited_files_json() -> Dict[str, Any]:
|
||||
)
|
||||
|
||||
|
||||
def get_td_heuristic_profiling_json() -> Dict[str, Any]:
|
||||
def get_td_heuristic_profiling_json() -> dict[str, Any]:
|
||||
return get_from_test_infra_generated_stats(
|
||||
"td_heuristic_profiling.json",
|
||||
TD_HEURISTIC_PROFILING_FILE,
|
||||
@ -182,7 +184,7 @@ def copy_additional_previous_failures() -> None:
|
||||
|
||||
def get_from_test_infra_generated_stats(
|
||||
from_file: str, to_file: str, failure_explanation: str
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
url = f"https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/{from_file}"
|
||||
try:
|
||||
return fetch_and_cache(
|
||||
|
@ -1,14 +1,17 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import signal
|
||||
import time
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
import psutil # type: ignore[import]
|
||||
|
||||
|
||||
def get_processes_running_python_tests() -> List[Any]:
|
||||
def get_processes_running_python_tests() -> list[Any]:
|
||||
python_processes = []
|
||||
for process in psutil.process_iter():
|
||||
try:
|
||||
@ -20,7 +23,7 @@ def get_processes_running_python_tests() -> List[Any]:
|
||||
return python_processes
|
||||
|
||||
|
||||
def get_per_process_cpu_info() -> List[Dict[str, Any]]:
|
||||
def get_per_process_cpu_info() -> list[dict[str, Any]]:
|
||||
processes = get_processes_running_python_tests()
|
||||
per_process_info = []
|
||||
for p in processes:
|
||||
@ -49,7 +52,7 @@ def get_per_process_cpu_info() -> List[Dict[str, Any]]:
|
||||
return per_process_info
|
||||
|
||||
|
||||
def get_per_process_gpu_info(handle: Any) -> List[Dict[str, Any]]:
|
||||
def get_per_process_gpu_info(handle: Any) -> list[dict[str, Any]]:
|
||||
processes = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
|
||||
per_process_info = []
|
||||
for p in processes:
|
||||
@ -58,7 +61,7 @@ def get_per_process_gpu_info(handle: Any) -> List[Dict[str, Any]]:
|
||||
return per_process_info
|
||||
|
||||
|
||||
def rocm_get_per_process_gpu_info(handle: Any) -> List[Dict[str, Any]]:
|
||||
def rocm_get_per_process_gpu_info(handle: Any) -> list[dict[str, Any]]:
|
||||
processes = amdsmi.amdsmi_get_gpu_process_list(handle)
|
||||
per_process_info = []
|
||||
for p in processes:
|
||||
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
@ -6,7 +8,7 @@ from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any, cast, Dict, List
|
||||
from typing import Any, cast
|
||||
|
||||
import requests
|
||||
|
||||
@ -18,6 +20,7 @@ from tools.stats.upload_stats_lib import (
|
||||
upload_workflow_stats_to_s3,
|
||||
)
|
||||
|
||||
|
||||
REGEX_JOB_INFO = r"(.*) \/ .*test \(([^,]*), .*\)"
|
||||
|
||||
|
||||
@ -56,7 +59,7 @@ def get_test_config(job_name: str) -> str:
|
||||
|
||||
def get_td_exclusions(
|
||||
workflow_run_id: int, workflow_run_attempt: int
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
print("Using temporary directory:", temp_dir)
|
||||
os.chdir(temp_dir)
|
||||
@ -68,7 +71,7 @@ def get_td_exclusions(
|
||||
for path in s3_paths:
|
||||
unzip(path)
|
||||
|
||||
grouped_tests: Dict[str, Any] = defaultdict(lambda: defaultdict(set))
|
||||
grouped_tests: dict[str, Any] = defaultdict(lambda: defaultdict(set))
|
||||
for td_exclusions in Path(".").glob("**/td_exclusions*.json"):
|
||||
with open(td_exclusions) as f:
|
||||
exclusions = json.load(f)
|
||||
@ -85,9 +88,9 @@ def get_td_exclusions(
|
||||
return grouped_tests
|
||||
|
||||
|
||||
def group_test_cases(test_cases: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
def group_test_cases(test_cases: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
start = time.time()
|
||||
grouped_tests: Dict[str, Any] = defaultdict(
|
||||
grouped_tests: dict[str, Any] = defaultdict(
|
||||
lambda: defaultdict(
|
||||
lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
||||
)
|
||||
@ -112,8 +115,8 @@ def group_test_cases(test_cases: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
return grouped_tests
|
||||
|
||||
|
||||
def get_reruns(grouped_tests: Dict[str, Any]) -> Dict[str, Any]:
|
||||
reruns: Dict[str, Any] = defaultdict(
|
||||
def get_reruns(grouped_tests: dict[str, Any]) -> dict[str, Any]:
|
||||
reruns: dict[str, Any] = defaultdict(
|
||||
lambda: defaultdict(
|
||||
lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
||||
)
|
||||
@ -136,8 +139,8 @@ def get_reruns(grouped_tests: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return reruns
|
||||
|
||||
|
||||
def get_invoking_file_summary(grouped_tests: Dict[str, Any]) -> Dict[str, Any]:
|
||||
invoking_file_summary: Dict[str, Any] = defaultdict(
|
||||
def get_invoking_file_summary(grouped_tests: dict[str, Any]) -> dict[str, Any]:
|
||||
invoking_file_summary: dict[str, Any] = defaultdict(
|
||||
lambda: defaultdict(lambda: defaultdict(lambda: {"count": 0, "time": 0.0}))
|
||||
)
|
||||
for build_name, build in grouped_tests.items():
|
||||
@ -157,7 +160,7 @@ def get_invoking_file_summary(grouped_tests: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
|
||||
def upload_additional_info(
|
||||
workflow_run_id: int, workflow_run_attempt: int, test_cases: List[Dict[str, Any]]
|
||||
workflow_run_id: int, workflow_run_attempt: int, test_cases: list[dict[str, Any]]
|
||||
) -> None:
|
||||
grouped_tests = group_test_cases(test_cases)
|
||||
reruns = get_reruns(grouped_tests)
|
||||
|
@ -5,6 +5,7 @@ from tempfile import TemporaryDirectory
|
||||
|
||||
from tools.stats.upload_stats_lib import download_gha_artifacts, upload_file_to_s3
|
||||
|
||||
|
||||
ARTIFACTS = [
|
||||
"sccache-stats",
|
||||
"test-jsons",
|
||||
|
@ -1,10 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from tools.stats.upload_stats_lib import download_s3_artifacts, unzip, upload_to_rockset
|
||||
|
||||
@ -23,7 +25,7 @@ def upload_dynamo_perf_stats_to_rockset(
|
||||
workflow_run_attempt: int,
|
||||
head_branch: str,
|
||||
match_filename: str,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
match_filename_regex = re.compile(match_filename)
|
||||
perf_stats = []
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
|
@ -1,16 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
|
||||
import time
|
||||
import urllib.parse
|
||||
from typing import Any, Callable, cast, Dict, List, Optional, Set
|
||||
from typing import Any, Callable, cast, Dict, List
|
||||
from urllib.error import HTTPError
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
from tools.stats.upload_stats_lib import upload_to_s3
|
||||
|
||||
|
||||
FILTER_OUT_USERS = {
|
||||
"pytorchmergebot",
|
||||
"facebook-github-bot",
|
||||
@ -23,9 +25,9 @@ FILTER_OUT_USERS = {
|
||||
|
||||
def _fetch_url(
|
||||
url: str,
|
||||
headers: Dict[str, str],
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
method: Optional[str] = None,
|
||||
headers: dict[str, str],
|
||||
data: dict[str, Any] | None = None,
|
||||
method: str | None = None,
|
||||
reader: Callable[[Any], Any] = lambda x: x.read(),
|
||||
) -> Any:
|
||||
token = os.environ.get("GITHUB_TOKEN")
|
||||
@ -49,9 +51,9 @@ def _fetch_url(
|
||||
|
||||
def fetch_json(
|
||||
url: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
params: dict[str, Any] | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
headers = {"Accept": "application/vnd.github.v3+json"}
|
||||
if params is not None and len(params) > 0:
|
||||
url += "?" + "&".join(
|
||||
@ -65,16 +67,16 @@ def fetch_json(
|
||||
|
||||
def get_external_pr_data(
|
||||
start_date: datetime.date, end_date: datetime.date, period_length: int = 1
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
pr_info = []
|
||||
period_begin_date = start_date
|
||||
|
||||
pr_count = 0
|
||||
users: Set[str] = set()
|
||||
users: set[str] = set()
|
||||
while period_begin_date < end_date:
|
||||
period_end_date = period_begin_date + datetime.timedelta(days=period_length - 1)
|
||||
page = 1
|
||||
responses: List[Dict[str, Any]] = []
|
||||
responses: list[dict[str, Any]] = []
|
||||
while len(responses) > 0 or page == 1:
|
||||
response = cast(
|
||||
Dict[str, Any],
|
||||
|
@ -1,13 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import inspect
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
from warnings import warn
|
||||
|
||||
|
||||
# boto3 is an optional dependency. If it's not installed,
|
||||
# we'll just not emit the metrics.
|
||||
# Keeping this logic here so that callers don't have to
|
||||
@ -65,7 +67,7 @@ class EnvVarMetric:
|
||||
return value
|
||||
|
||||
|
||||
global_metrics: Dict[str, Any] = {}
|
||||
global_metrics: dict[str, Any] = {}
|
||||
|
||||
|
||||
def add_global_metric(metric_name: str, metric_value: Any) -> None:
|
||||
@ -79,7 +81,7 @@ def add_global_metric(metric_name: str, metric_value: Any) -> None:
|
||||
|
||||
def emit_metric(
|
||||
metric_name: str,
|
||||
metrics: Dict[str, Any],
|
||||
metrics: dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Upload a metric to DynamoDB (and from there, Rockset).
|
||||
@ -174,7 +176,7 @@ def emit_metric(
|
||||
print(f"Not emitting metrics for {metric_name}. Boto wasn't imported.")
|
||||
|
||||
|
||||
def _convert_float_values_to_decimals(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _convert_float_values_to_decimals(data: dict[str, Any]) -> dict[str, Any]:
|
||||
# Attempt to recurse
|
||||
def _helper(o: Any) -> Any:
|
||||
if isinstance(o, float):
|
||||
|
@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from tools.stats.upload_stats_lib import (
|
||||
download_s3_artifacts,
|
||||
@ -13,7 +15,7 @@ from tools.stats.upload_stats_lib import (
|
||||
|
||||
def get_sccache_stats(
|
||||
workflow_run_id: int, workflow_run_attempt: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
print("Using temporary directory:", temp_dir)
|
||||
os.chdir(temp_dir)
|
||||
|
@ -1,16 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import gzip
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import zipfile
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import boto3 # type: ignore[import]
|
||||
import requests
|
||||
import rockset # type: ignore[import]
|
||||
|
||||
|
||||
PYTORCH_REPO = "https://api.github.com/repos/pytorch/pytorch"
|
||||
S3_RESOURCE = boto3.resource("s3")
|
||||
|
||||
@ -21,14 +23,14 @@ MAX_RETRY_IN_NON_DISABLED_MODE = 3 * 3
|
||||
BATCH_SIZE = 5000
|
||||
|
||||
|
||||
def _get_request_headers() -> Dict[str, str]:
|
||||
def _get_request_headers() -> dict[str, str]:
|
||||
return {
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
"Authorization": "token " + os.environ["GITHUB_TOKEN"],
|
||||
}
|
||||
|
||||
|
||||
def _get_artifact_urls(prefix: str, workflow_run_id: int) -> Dict[Path, str]:
|
||||
def _get_artifact_urls(prefix: str, workflow_run_id: int) -> dict[Path, str]:
|
||||
"""Get all workflow artifacts with 'test-report' in the name."""
|
||||
response = requests.get(
|
||||
f"{PYTORCH_REPO}/actions/runs/{workflow_run_id}/artifacts?per_page=100",
|
||||
@ -78,7 +80,7 @@ def _download_artifact(
|
||||
|
||||
def download_s3_artifacts(
|
||||
prefix: str, workflow_run_id: int, workflow_run_attempt: int
|
||||
) -> List[Path]:
|
||||
) -> list[Path]:
|
||||
bucket = S3_RESOURCE.Bucket("gha-artifacts")
|
||||
objs = bucket.objects.filter(
|
||||
Prefix=f"pytorch/pytorch/{workflow_run_id}/{workflow_run_attempt}/artifact/{prefix}"
|
||||
@ -104,7 +106,7 @@ def download_s3_artifacts(
|
||||
|
||||
def download_gha_artifacts(
|
||||
prefix: str, workflow_run_id: int, workflow_run_attempt: int
|
||||
) -> List[Path]:
|
||||
) -> list[Path]:
|
||||
artifact_urls = _get_artifact_urls(prefix, workflow_run_id)
|
||||
paths = []
|
||||
for name, url in artifact_urls.items():
|
||||
@ -114,7 +116,7 @@ def download_gha_artifacts(
|
||||
|
||||
def upload_to_rockset(
|
||||
collection: str,
|
||||
docs: List[Any],
|
||||
docs: list[Any],
|
||||
workspace: str = "commons",
|
||||
client: Any = None,
|
||||
) -> None:
|
||||
@ -142,7 +144,7 @@ def upload_to_rockset(
|
||||
def upload_to_s3(
|
||||
bucket_name: str,
|
||||
key: str,
|
||||
docs: List[Dict[str, Any]],
|
||||
docs: list[dict[str, Any]],
|
||||
) -> None:
|
||||
print(f"Writing {len(docs)} documents to S3")
|
||||
body = io.StringIO()
|
||||
@ -164,7 +166,7 @@ def upload_to_s3(
|
||||
def read_from_s3(
|
||||
bucket_name: str,
|
||||
key: str,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
print(f"Reading from s3://{bucket_name}/{key}")
|
||||
body = (
|
||||
S3_RESOURCE.Object(
|
||||
@ -182,7 +184,7 @@ def upload_workflow_stats_to_s3(
|
||||
workflow_run_id: int,
|
||||
workflow_run_attempt: int,
|
||||
collection: str,
|
||||
docs: List[Dict[str, Any]],
|
||||
docs: list[dict[str, Any]],
|
||||
) -> None:
|
||||
bucket_name = "ossci-raw-job-status"
|
||||
key = f"{collection}/{workflow_run_id}/{workflow_run_attempt}"
|
||||
@ -220,7 +222,7 @@ def unzip(p: Path) -> None:
|
||||
zip.extractall(unzipped_dir)
|
||||
|
||||
|
||||
def is_rerun_disabled_tests(tests: Dict[str, Dict[str, int]]) -> bool:
|
||||
def is_rerun_disabled_tests(tests: dict[str, dict[str, int]]) -> bool:
|
||||
"""
|
||||
Check if the test report is coming from rerun_disabled_tests workflow where
|
||||
each test is run multiple times
|
||||
@ -231,7 +233,7 @@ def is_rerun_disabled_tests(tests: Dict[str, Dict[str, int]]) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def get_job_id(report: Path) -> Optional[int]:
|
||||
def get_job_id(report: Path) -> int | None:
|
||||
# [Job id in artifacts]
|
||||
# Retrieve the job id from the report path. In our GHA workflows, we append
|
||||
# the job id to the end of the report name, so `report` looks like:
|
||||
|
@ -1,17 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Any, List, Union
|
||||
from typing import Any
|
||||
|
||||
import rockset # type: ignore[import]
|
||||
|
||||
from tools.stats.upload_stats_lib import upload_to_s3
|
||||
|
||||
|
||||
def get_oncall_from_testfile(testfile: str) -> Union[List[str], None]:
|
||||
def get_oncall_from_testfile(testfile: str) -> list[str] | None:
|
||||
path = f"test/{testfile}"
|
||||
if not path.endswith(".py"):
|
||||
path += ".py"
|
||||
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
@ -5,7 +7,7 @@ import xml.etree.ElementTree as ET
|
||||
from multiprocessing import cpu_count, Pool
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from tools.stats.test_dashboard import upload_additional_info
|
||||
from tools.stats.upload_stats_lib import (
|
||||
@ -21,14 +23,14 @@ def parse_xml_report(
|
||||
report: Path,
|
||||
workflow_id: int,
|
||||
workflow_run_attempt: int,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert a test report xml file into a JSON-serializable list of test cases."""
|
||||
print(f"Parsing {tag}s for test report: {report}")
|
||||
|
||||
job_id = get_job_id(report)
|
||||
print(f"Found job id: {job_id}")
|
||||
|
||||
test_cases: List[Dict[str, Any]] = []
|
||||
test_cases: list[dict[str, Any]] = []
|
||||
|
||||
root = ET.parse(report)
|
||||
for test_case in root.iter(tag):
|
||||
@ -53,9 +55,9 @@ def parse_xml_report(
|
||||
return test_cases
|
||||
|
||||
|
||||
def process_xml_element(element: ET.Element) -> Dict[str, Any]:
|
||||
def process_xml_element(element: ET.Element) -> dict[str, Any]:
|
||||
"""Convert a test suite element into a JSON-serializable dict."""
|
||||
ret: Dict[str, Any] = {}
|
||||
ret: dict[str, Any] = {}
|
||||
|
||||
# Convert attributes directly into dict elements.
|
||||
# e.g.
|
||||
@ -110,7 +112,7 @@ def process_xml_element(element: ET.Element) -> Dict[str, Any]:
|
||||
return ret
|
||||
|
||||
|
||||
def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> List[Dict[str, Any]]:
|
||||
def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> list[dict[str, Any]]:
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
print("Using temporary directory:", temp_dir)
|
||||
os.chdir(temp_dir)
|
||||
@ -146,7 +148,7 @@ def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> List[Dict[str,
|
||||
|
||||
def get_tests_for_circleci(
|
||||
workflow_run_id: int, workflow_run_attempt: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
# Parse the reports and transform them to JSON
|
||||
test_cases = []
|
||||
for xml_report in Path(".").glob("**/test/test-reports/**/*.xml"):
|
||||
@ -159,13 +161,13 @@ def get_tests_for_circleci(
|
||||
return test_cases
|
||||
|
||||
|
||||
def summarize_test_cases(test_cases: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
def summarize_test_cases(test_cases: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Group test cases by classname, file, and job_id. We perform the aggregation
|
||||
manually instead of using the `test-suite` XML tag because xmlrunner does
|
||||
not produce reliable output for it.
|
||||
"""
|
||||
|
||||
def get_key(test_case: Dict[str, Any]) -> Any:
|
||||
def get_key(test_case: dict[str, Any]) -> Any:
|
||||
return (
|
||||
test_case.get("file"),
|
||||
test_case.get("classname"),
|
||||
@ -176,7 +178,7 @@ def summarize_test_cases(test_cases: List[Dict[str, Any]]) -> List[Dict[str, Any
|
||||
test_case["invoking_file"],
|
||||
)
|
||||
|
||||
def init_value(test_case: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def init_value(test_case: dict[str, Any]) -> dict[str, Any]:
|
||||
return {
|
||||
"file": test_case.get("file"),
|
||||
"classname": test_case.get("classname"),
|
||||
|
@ -4,6 +4,7 @@ import sys
|
||||
from tools.stats.test_dashboard import upload_additional_info
|
||||
from tools.stats.upload_test_stats import get_tests
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Upload test stats to Rockset")
|
||||
parser.add_argument(
|
||||
|
@ -5,7 +5,6 @@ import argparse
|
||||
import json
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from gen_operators_yaml import (
|
||||
@ -43,10 +42,10 @@ def _mock_load_op_dep_graph():
|
||||
|
||||
|
||||
class GenOperatorsYAMLTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
pass
|
||||
|
||||
def test_filter_creation(self):
|
||||
def test_filter_creation(self) -> None:
|
||||
filter_func = make_filter_from_options(
|
||||
model_name="abc",
|
||||
model_versions=["100", "101"],
|
||||
@ -99,7 +98,7 @@ class GenOperatorsYAMLTest(unittest.TestCase):
|
||||
len(filtered_configs) == 2
|
||||
), f"Expected 2 elements in filtered_configs, but got {len(filtered_configs)}"
|
||||
|
||||
def test_verification_success(self):
|
||||
def test_verification_success(self) -> None:
|
||||
filter_func = make_filter_from_options(
|
||||
model_name="abc",
|
||||
model_versions=["100", "101"],
|
||||
@ -142,7 +141,7 @@ class GenOperatorsYAMLTest(unittest.TestCase):
|
||||
"expected verify_all_specified_present to succeed instead it raised an exception"
|
||||
)
|
||||
|
||||
def test_verification_fail(self):
|
||||
def test_verification_fail(self) -> None:
|
||||
config = [
|
||||
{
|
||||
"model": {
|
||||
@ -229,7 +228,7 @@ class GenOperatorsYAMLTest(unittest.TestCase):
|
||||
)
|
||||
def test_fill_output_with_arguments_not_include_all_overloads(
|
||||
self, mock_parse_options: Mock, mock_load_op_dep_graph: Mock
|
||||
):
|
||||
) -> None:
|
||||
parser = argparse.ArgumentParser(description="Generate used operators YAML")
|
||||
options = get_parser_options(parser)
|
||||
|
||||
|
@ -8,10 +8,10 @@ from tools.code_analyzer.gen_oplist import throw_if_any_op_includes_overloads
|
||||
|
||||
|
||||
class GenOplistTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
pass
|
||||
|
||||
def test_throw_if_any_op_includes_overloads(self):
|
||||
def test_throw_if_any_op_includes_overloads(self) -> None:
|
||||
selective_builder = MagicMock()
|
||||
selective_builder.operators = MagicMock()
|
||||
selective_builder.operators.items.return_value = [
|
||||
|
@ -1,10 +1,12 @@
|
||||
# For testing specific heuristics
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import json
|
||||
import pathlib
|
||||
import sys
|
||||
import unittest
|
||||
from typing import Any, Dict, List, Set
|
||||
from typing import Any
|
||||
from unittest import mock
|
||||
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
|
||||
@ -28,14 +30,14 @@ sys.path.remove(str(REPO_ROOT))
|
||||
HEURISTIC_CLASS = "tools.testing.target_determination.heuristics.historical_class_failure_correlation."
|
||||
|
||||
|
||||
def mocked_file(contents: Dict[Any, Any]) -> io.IOBase:
|
||||
def mocked_file(contents: dict[Any, Any]) -> io.IOBase:
|
||||
file_object = io.StringIO()
|
||||
json.dump(contents, file_object)
|
||||
file_object.seek(0)
|
||||
return file_object
|
||||
|
||||
|
||||
def gen_historical_class_failures() -> Dict[str, Dict[str, float]]:
|
||||
def gen_historical_class_failures() -> dict[str, dict[str, float]]:
|
||||
return {
|
||||
"file1": {
|
||||
"test1::classA": 0.5,
|
||||
@ -80,8 +82,8 @@ class TestHistoricalClassFailureCorrelation(TestTD):
|
||||
)
|
||||
def test_get_prediction_confidence(
|
||||
self,
|
||||
historical_class_failures: Dict[str, Dict[str, float]],
|
||||
changed_files: List[str],
|
||||
historical_class_failures: dict[str, dict[str, float]],
|
||||
changed_files: list[str],
|
||||
) -> None:
|
||||
tests_to_prioritize = ALL_TESTS
|
||||
|
||||
@ -113,7 +115,7 @@ class TestHistoricalClassFailureCorrelation(TestTD):
|
||||
class TestParsePrevTests(TestTD):
|
||||
@mock.patch("os.path.exists", return_value=False)
|
||||
def test_cache_does_not_exist(self, mock_exists: Any) -> None:
|
||||
expected_failing_test_files: Set[str] = set()
|
||||
expected_failing_test_files: set[str] = set()
|
||||
|
||||
found_tests = get_previous_failures()
|
||||
|
||||
@ -122,7 +124,7 @@ class TestParsePrevTests(TestTD):
|
||||
@mock.patch("os.path.exists", return_value=True)
|
||||
@mock.patch("builtins.open", return_value=mocked_file({"": True}))
|
||||
def test_empty_cache(self, mock_exists: Any, mock_open: Any) -> None:
|
||||
expected_failing_test_files: Set[str] = set()
|
||||
expected_failing_test_files: set[str] = set()
|
||||
|
||||
found_tests = get_previous_failures()
|
||||
|
||||
|
@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pathlib
|
||||
import sys
|
||||
import unittest
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
|
||||
sys.path.append(str(REPO_ROOT))
|
||||
@ -13,7 +15,7 @@ sys.path.remove(str(REPO_ROOT))
|
||||
|
||||
class TestTD(unittest.TestCase):
|
||||
def assert_test_scores_almost_equal(
|
||||
self, d1: Dict[TestRun, float], d2: Dict[TestRun, float]
|
||||
self, d1: dict[TestRun, float], d2: dict[TestRun, float]
|
||||
) -> None:
|
||||
# Check that dictionaries are the same, except for floating point errors
|
||||
self.assertEqual(set(d1.keys()), set(d2.keys()))
|
||||
@ -24,7 +26,7 @@ class TestTD(unittest.TestCase):
|
||||
# Create a dummy heuristic class
|
||||
class Heuristic(interface.HeuristicInterface):
|
||||
def get_prediction_confidence(
|
||||
self, tests: List[str]
|
||||
self, tests: list[str]
|
||||
) -> interface.TestPrioritizations:
|
||||
# Return junk
|
||||
return interface.TestPrioritizations([], {})
|
||||
@ -259,9 +261,9 @@ class TestTestPrioritizations(TestTD):
|
||||
class TestAggregatedHeuristics(TestTD):
|
||||
def check(
|
||||
self,
|
||||
tests: List[str],
|
||||
test_prioritizations: List[Dict[TestRun, float]],
|
||||
expected: Dict[TestRun, float],
|
||||
tests: list[str],
|
||||
test_prioritizations: list[dict[TestRun, float]],
|
||||
expected: dict[TestRun, float],
|
||||
) -> None:
|
||||
aggregated_heuristics = interface.AggregatedHeuristics(tests)
|
||||
for i, test_prioritization in enumerate(test_prioritizations):
|
||||
@ -429,7 +431,7 @@ class TestAggregatedHeuristicsTestStats(TestTD):
|
||||
stats3 = aggregator.get_test_stats(TestRun("test3"))
|
||||
stats5 = aggregator.get_test_stats(TestRun("test5::classA"))
|
||||
|
||||
def assert_valid_dict(dict_contents: Dict[str, Any]) -> None:
|
||||
def assert_valid_dict(dict_contents: dict[str, Any]) -> None:
|
||||
for key, value in dict_contents.items():
|
||||
self.assertTrue(isinstance(key, str))
|
||||
self.assertTrue(
|
||||
|
@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pathlib
|
||||
import sys
|
||||
import unittest
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent
|
||||
@ -14,14 +16,14 @@ sys.path.remove(str(REPO_ROOT))
|
||||
|
||||
class TestHeuristicsUtils(unittest.TestCase):
|
||||
def assertDictAlmostEqual(
|
||||
self, first: Dict[TestRun, Any], second: Dict[TestRun, Any]
|
||||
self, first: dict[TestRun, Any], second: dict[TestRun, Any]
|
||||
) -> None:
|
||||
self.assertEqual(first.keys(), second.keys())
|
||||
for key in first.keys():
|
||||
self.assertAlmostEqual(first[key], second[key])
|
||||
|
||||
def test_normalize_ratings(self) -> None:
|
||||
ratings: Dict[TestRun, float] = {
|
||||
ratings: dict[TestRun, float] = {
|
||||
TestRun("test1"): 1,
|
||||
TestRun("test2"): 2,
|
||||
TestRun("test3"): 4,
|
||||
|
@ -1,12 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import typing
|
||||
import unittest
|
||||
import unittest.mock
|
||||
from typing import Iterator, Optional, Sequence
|
||||
from typing import Iterator, Sequence
|
||||
|
||||
import tools.setup_helpers.cmake
|
||||
|
||||
import tools.setup_helpers.env # noqa: F401 unused but resolves circular import
|
||||
|
||||
|
||||
@ -79,7 +80,7 @@ class TestCMake(unittest.TestCase):
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def env_var(key: str, value: Optional[str]) -> Iterator[None]:
|
||||
def env_var(key: str, value: str | None) -> Iterator[None]:
|
||||
"""Sets/clears an environment variable within a Python context."""
|
||||
# Get the previous value and then override it.
|
||||
previous_value = os.environ.get(key)
|
||||
@ -91,7 +92,7 @@ def env_var(key: str, value: Optional[str]) -> Iterator[None]:
|
||||
set_env_var(key, previous_value)
|
||||
|
||||
|
||||
def set_env_var(key: str, value: Optional[str]) -> None:
|
||||
def set_env_var(key: str, value: str | None) -> None:
|
||||
"""Sets/clears an environment variable."""
|
||||
if value is None:
|
||||
os.environ.pop(key, None)
|
||||
|
@ -1,14 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import typing
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List
|
||||
|
||||
import yaml
|
||||
|
||||
from tools.autograd import gen_autograd_functions, load_derivatives
|
||||
|
||||
import torchgen.model
|
||||
from torchgen import dest
|
||||
from torchgen.api.types import CppSignatureGroup, DispatcherSignature
|
||||
from torchgen.context import native_function_manager
|
||||
@ -22,6 +21,7 @@ from torchgen.model import (
|
||||
BackendIndex,
|
||||
BackendMetadata,
|
||||
DispatchKey,
|
||||
FunctionSchema,
|
||||
Location,
|
||||
NativeFunction,
|
||||
OperatorName,
|
||||
@ -32,7 +32,7 @@ from torchgen.selective_build.selector import SelectiveBuilder
|
||||
|
||||
class TestCreateDerivative(unittest.TestCase):
|
||||
def test_named_grads(self) -> None:
|
||||
schema = torchgen.model.FunctionSchema.parse(
|
||||
schema = FunctionSchema.parse(
|
||||
"func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
|
||||
)
|
||||
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
|
||||
@ -47,7 +47,7 @@ class TestCreateDerivative(unittest.TestCase):
|
||||
|
||||
def test_non_differentiable_output(self) -> None:
|
||||
specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
|
||||
schema = torchgen.model.FunctionSchema.parse(specification)
|
||||
schema = FunctionSchema.parse(specification)
|
||||
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
|
||||
|
||||
_, differentiability_info = load_derivatives.create_differentiability_info(
|
||||
@ -69,7 +69,7 @@ class TestCreateDerivative(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_indexed_grads(self) -> None:
|
||||
schema = torchgen.model.FunctionSchema.parse(
|
||||
schema = FunctionSchema.parse(
|
||||
"func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
|
||||
)
|
||||
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
|
||||
@ -84,7 +84,7 @@ class TestCreateDerivative(unittest.TestCase):
|
||||
|
||||
def test_named_grads_and_indexed_grads(self) -> None:
|
||||
specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
|
||||
schema = torchgen.model.FunctionSchema.parse(specification)
|
||||
schema = FunctionSchema.parse(specification)
|
||||
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
@ -112,7 +112,7 @@ class TestCreateDerivative(unittest.TestCase):
|
||||
class TestGenAutogradFunctions(unittest.TestCase):
|
||||
def test_non_differentiable_output_invalid_type(self) -> None:
|
||||
specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
|
||||
schema = torchgen.model.FunctionSchema.parse(specification)
|
||||
schema = FunctionSchema.parse(specification)
|
||||
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
|
||||
|
||||
_, differentiability_info = load_derivatives.create_differentiability_info(
|
||||
@ -141,7 +141,7 @@ class TestGenAutogradFunctions(unittest.TestCase):
|
||||
|
||||
def test_non_differentiable_output_output_differentiability(self) -> None:
|
||||
specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y, Tensor z)"
|
||||
schema = torchgen.model.FunctionSchema.parse(specification)
|
||||
schema = FunctionSchema.parse(specification)
|
||||
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
|
||||
|
||||
_, differentiability_info = load_derivatives.create_differentiability_info(
|
||||
@ -182,7 +182,7 @@ class TestGenAutogradFunctions(unittest.TestCase):
|
||||
|
||||
def test_register_bogus_dispatch_key(self) -> None:
|
||||
specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
|
||||
schema = torchgen.model.FunctionSchema.parse(specification)
|
||||
schema = FunctionSchema.parse(specification)
|
||||
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
@ -213,17 +213,17 @@ class TestGenAutogradFunctions(unittest.TestCase):
|
||||
class TestGenSchemaRegistration(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.selector = SelectiveBuilder.get_nop_selector()
|
||||
self.custom_native_function, _ = torchgen.model.NativeFunction.from_yaml(
|
||||
self.custom_native_function, _ = NativeFunction.from_yaml(
|
||||
{"func": "custom::func() -> bool"},
|
||||
loc=torchgen.model.Location(__file__, 1),
|
||||
loc=Location(__file__, 1),
|
||||
valid_tags=set(),
|
||||
)
|
||||
(
|
||||
self.fragment_custom_native_function,
|
||||
_,
|
||||
) = torchgen.model.NativeFunction.from_yaml(
|
||||
) = NativeFunction.from_yaml(
|
||||
{"func": "quantized_decomposed::func() -> bool"},
|
||||
loc=torchgen.model.Location(__file__, 1),
|
||||
loc=Location(__file__, 1),
|
||||
valid_tags=set(),
|
||||
)
|
||||
|
||||
@ -285,9 +285,9 @@ TORCH_LIBRARY(custom, m) {
|
||||
)
|
||||
|
||||
def test_3_namespaces_schema_registration_code_valid(self) -> None:
|
||||
custom2_native_function, _ = torchgen.model.NativeFunction.from_yaml(
|
||||
custom2_native_function, _ = NativeFunction.from_yaml(
|
||||
{"func": "custom2::func() -> bool"},
|
||||
loc=torchgen.model.Location(__file__, 1),
|
||||
loc=Location(__file__, 1),
|
||||
valid_tags=set(),
|
||||
)
|
||||
(
|
||||
@ -320,7 +320,7 @@ class TestGenNativeFunctionDeclaration(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.op_1_native_function, op_1_backend_index = NativeFunction.from_yaml(
|
||||
{"func": "op_1() -> bool", "dispatch": {"CPU": "kernel_1"}},
|
||||
loc=torchgen.model.Location(__file__, 1),
|
||||
loc=Location(__file__, 1),
|
||||
valid_tags=set(),
|
||||
)
|
||||
self.op_2_native_function, op_2_backend_index = NativeFunction.from_yaml(
|
||||
@ -328,11 +328,11 @@ class TestGenNativeFunctionDeclaration(unittest.TestCase):
|
||||
"func": "op_2() -> bool",
|
||||
"dispatch": {"CPU": "kernel_2", "QuantizedCPU": "custom::kernel_3"},
|
||||
},
|
||||
loc=torchgen.model.Location(__file__, 1),
|
||||
loc=Location(__file__, 1),
|
||||
valid_tags=set(),
|
||||
)
|
||||
|
||||
backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = {
|
||||
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = {
|
||||
DispatchKey.CPU: {},
|
||||
DispatchKey.QuantizedCPU: {},
|
||||
}
|
||||
@ -382,9 +382,9 @@ TORCH_API bool kernel_1();
|
||||
# Test for native_function_generation
|
||||
class TestNativeFunctionGeneratrion(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.native_functions: List[NativeFunction] = []
|
||||
self.backend_indices: Dict[
|
||||
DispatchKey, Dict[OperatorName, BackendMetadata]
|
||||
self.native_functions: list[NativeFunction] = []
|
||||
self.backend_indices: dict[
|
||||
DispatchKey, dict[OperatorName, BackendMetadata]
|
||||
] = defaultdict(dict)
|
||||
yaml_entry = """
|
||||
- func: op(Tensor self) -> Tensor
|
||||
@ -405,7 +405,7 @@ class TestNativeFunctionGeneratrion(unittest.TestCase):
|
||||
"dispatch": {"CPU": "kernel_1"},
|
||||
"autogen": "op_2.out",
|
||||
},
|
||||
loc=torchgen.model.Location(__file__, 1),
|
||||
loc=Location(__file__, 1),
|
||||
valid_tags=set(),
|
||||
)
|
||||
BackendIndex.grow_index(self.backend_indices, two_returns_backend_index)
|
||||
@ -442,8 +442,8 @@ class TestNativeFunctionGeneratrion(unittest.TestCase):
|
||||
# Test for static_dispatch
|
||||
class TestStaticDispatchGeneratrion(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.backend_indices: Dict[
|
||||
DispatchKey, Dict[OperatorName, BackendMetadata]
|
||||
self.backend_indices: dict[
|
||||
DispatchKey, dict[OperatorName, BackendMetadata]
|
||||
] = defaultdict(dict)
|
||||
yaml_entry = """
|
||||
- func: op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
@ -500,9 +500,9 @@ class TestStaticDispatchGeneratrion(unittest.TestCase):
|
||||
|
||||
# Represents the most basic NativeFunction. Use dataclasses.replace()
|
||||
# to edit for use.
|
||||
DEFAULT_NATIVE_FUNCTION, _ = torchgen.model.NativeFunction.from_yaml(
|
||||
DEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml(
|
||||
{"func": "func() -> bool"},
|
||||
loc=torchgen.model.Location(__file__, 1),
|
||||
loc=Location(__file__, 1),
|
||||
valid_tags=set(),
|
||||
)
|
||||
|
||||
|
@ -1,4 +1,6 @@
|
||||
from typing import Any, List
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest import main, TestCase
|
||||
|
||||
from tools.alerts.create_alerts import filter_job_names, JobStatus
|
||||
@ -38,7 +40,7 @@ MOCK_TEST_DATA = [
|
||||
class TestGitHubPR(TestCase):
|
||||
# Should fail when jobs are ? ? Fail Fail
|
||||
def test_alert(self) -> None:
|
||||
modified_data: List[Any] = [{}]
|
||||
modified_data: list[Any] = [{}]
|
||||
modified_data.append({})
|
||||
modified_data.extend(MOCK_TEST_DATA)
|
||||
status = JobStatus(JOB_NAME, modified_data)
|
||||
|
@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
|
||||
import expecttest
|
||||
@ -13,10 +15,11 @@ from torchgen.model import Location, NativeFunction
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
from torchgen.utils import FileManager
|
||||
|
||||
|
||||
SPACES = " "
|
||||
|
||||
|
||||
def _get_native_function_from_yaml(yaml_obj: Dict[str, object]) -> NativeFunction:
|
||||
def _get_native_function_from_yaml(yaml_obj: dict[str, object]) -> NativeFunction:
|
||||
native_function, _ = NativeFunction.from_yaml(
|
||||
yaml_obj,
|
||||
loc=Location(__file__, 1),
|
||||
@ -33,7 +36,7 @@ class TestComputeNativeFunctionStub(expecttest.TestCase):
|
||||
"""
|
||||
|
||||
def _test_function_schema_generates_correct_kernel(
|
||||
self, obj: Dict[str, Any], expected: str
|
||||
self, obj: dict[str, Any], expected: str
|
||||
) -> None:
|
||||
func = _get_native_function_from_yaml(obj)
|
||||
|
||||
|
@ -1,13 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Dict
|
||||
|
||||
import yaml
|
||||
|
||||
from torchgen.executorch.model import ETKernelIndex, ETKernelKey
|
||||
from torchgen.gen import LineLoader
|
||||
|
||||
from torchgen.gen_executorch import (
|
||||
ComputeCodegenUnboxedKernels,
|
||||
gen_functions_declarations,
|
||||
@ -24,6 +24,7 @@ from torchgen.model import (
|
||||
)
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
|
||||
|
||||
TEST_YAML = """
|
||||
- func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
@ -345,7 +346,7 @@ class TestGenFunctionsDeclarations(unittest.TestCase):
|
||||
valid_tags=set(),
|
||||
)
|
||||
|
||||
backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = {
|
||||
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = {
|
||||
DispatchKey.CPU: {},
|
||||
DispatchKey.QuantizedCPU: {},
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ from torchgen.executorch.api.types import ExecutorchCppSignature
|
||||
from torchgen.local import parametrize
|
||||
from torchgen.model import Location, NativeFunction
|
||||
|
||||
|
||||
DEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml(
|
||||
{"func": "foo.out(Tensor input, *, Tensor(a!) out) -> Tensor(a!)"},
|
||||
loc=Location(__file__, 1),
|
||||
|
@ -1,9 +1,10 @@
|
||||
# Owner(s): ["module: codegen"]
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
import expecttest
|
||||
|
||||
@ -29,7 +30,7 @@ class TestGenBackendStubs(expecttest.TestCase):
|
||||
run(fp.name, "", True)
|
||||
|
||||
def get_errors_from_gen_backend_stubs(
|
||||
self, yaml_str: str, *, kernels_str: Optional[str] = None
|
||||
self, yaml_str: str, *, kernels_str: str | None = None
|
||||
) -> str:
|
||||
with tempfile.NamedTemporaryFile(mode="w") as fp:
|
||||
fp.write(yaml_str)
|
||||
|
@ -1,7 +1,7 @@
|
||||
import unittest
|
||||
|
||||
from torchgen.selective_build.operator import * # noqa: F403
|
||||
from torchgen.model import Location, NativeFunction
|
||||
from torchgen.selective_build.operator import * # noqa: F403
|
||||
from torchgen.selective_build.selector import (
|
||||
combine_selective_builders,
|
||||
SelectiveBuilder,
|
||||
@ -9,7 +9,7 @@ from torchgen.selective_build.selector import (
|
||||
|
||||
|
||||
class TestSelectiveBuild(unittest.TestCase):
|
||||
def test_selective_build_operator(self):
|
||||
def test_selective_build_operator(self) -> None:
|
||||
op = SelectiveBuildOperator(
|
||||
"aten::add.int",
|
||||
is_root_operator=True,
|
||||
@ -21,7 +21,7 @@ class TestSelectiveBuild(unittest.TestCase):
|
||||
self.assertFalse(op.is_used_for_training)
|
||||
self.assertFalse(op.include_all_overloads)
|
||||
|
||||
def test_selector_factory(self):
|
||||
def test_selector_factory(self) -> None:
|
||||
yaml_config_v1 = """
|
||||
debug_info:
|
||||
- model1@v100
|
||||
@ -132,7 +132,7 @@ operators:
|
||||
selector_legacy_v1.is_operator_selected_for_training("aten::add.float")
|
||||
)
|
||||
|
||||
def test_operator_combine(self):
|
||||
def test_operator_combine(self) -> None:
|
||||
op1 = SelectiveBuildOperator(
|
||||
"aten::add.int",
|
||||
is_root_operator=True,
|
||||
@ -177,7 +177,7 @@ operators:
|
||||
|
||||
self.assertRaises(Exception, gen_new_op)
|
||||
|
||||
def test_training_op_fetch(self):
|
||||
def test_training_op_fetch(self) -> None:
|
||||
yaml_config = """
|
||||
operators:
|
||||
aten::add.int:
|
||||
@ -194,7 +194,7 @@ operators:
|
||||
self.assertTrue(selector.is_operator_selected_for_training("aten::add.int"))
|
||||
self.assertTrue(selector.is_operator_selected_for_training("aten::add"))
|
||||
|
||||
def test_kernel_dtypes(self):
|
||||
def test_kernel_dtypes(self) -> None:
|
||||
yaml_config = """
|
||||
kernel_metadata:
|
||||
add_kernel:
|
||||
@ -221,7 +221,7 @@ kernel_metadata:
|
||||
self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16"))
|
||||
self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32"))
|
||||
|
||||
def test_merge_kernel_dtypes(self):
|
||||
def test_merge_kernel_dtypes(self) -> None:
|
||||
yaml_config1 = """
|
||||
kernel_metadata:
|
||||
add_kernel:
|
||||
@ -266,7 +266,7 @@ kernel_metadata:
|
||||
self.assertTrue(selector.is_kernel_dtype_selected("mul_kernel", "int8"))
|
||||
self.assertFalse(selector.is_kernel_dtype_selected("mul_kernel", "int32"))
|
||||
|
||||
def test_all_kernel_dtypes_selected(self):
|
||||
def test_all_kernel_dtypes_selected(self) -> None:
|
||||
yaml_config = """
|
||||
include_all_non_op_selectives: True
|
||||
"""
|
||||
@ -279,7 +279,7 @@ include_all_non_op_selectives: True
|
||||
self.assertTrue(selector.is_kernel_dtype_selected("add1_kernel", "int32"))
|
||||
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "float"))
|
||||
|
||||
def test_custom_namespace_selected_correctly(self):
|
||||
def test_custom_namespace_selected_correctly(self) -> None:
|
||||
yaml_config = """
|
||||
operators:
|
||||
aten::add.int:
|
||||
@ -301,7 +301,7 @@ operators:
|
||||
|
||||
|
||||
class TestExecuTorchSelectiveBuild(unittest.TestCase):
|
||||
def test_et_kernel_selected(self):
|
||||
def test_et_kernel_selected(self) -> None:
|
||||
yaml_config = """
|
||||
et_kernel_metadata:
|
||||
aten::add.out:
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user