[3/N] Apply py39 ruff fixes (#142115)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142115
Approved by: https://github.com/ezyang
This commit is contained in:
cyy
2024-12-11 17:50:07 +00:00
committed by PyTorch MergeBot
parent f7e621c3ce
commit 82aaf64422
31 changed files with 137 additions and 91 deletions

View File

@ -20,7 +20,7 @@ import argparse
import os import os
import textwrap import textwrap
from collections import defaultdict from collections import defaultdict
from typing import Any, Sequence, TYPE_CHECKING from typing import Any, TYPE_CHECKING
import torchgen.api.python as python import torchgen.api.python as python
from torchgen.context import with_native_function from torchgen.context import with_native_function
@ -39,6 +39,8 @@ from .gen_python_functions import (
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Sequence
from torchgen.model import Argument, BaseOperatorName, NativeFunction from torchgen.model import Argument, BaseOperatorName, NativeFunction

View File

@ -7,7 +7,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Sequence from typing import TYPE_CHECKING
from torchgen.api.autograd import ( from torchgen.api.autograd import (
Derivative, Derivative,
@ -47,6 +47,10 @@ from torchgen.utils import FileManager
from .gen_inplace_or_view_type import VIEW_FUNCTIONS from .gen_inplace_or_view_type import VIEW_FUNCTIONS
if TYPE_CHECKING:
from collections.abc import Sequence
FUNCTION_DECLARATION = CodeTemplate( FUNCTION_DECLARATION = CodeTemplate(
"""\ """\
#ifdef _WIN32 #ifdef _WIN32

View File

@ -36,7 +36,7 @@ from __future__ import annotations
import itertools import itertools
import re import re
from collections import defaultdict from collections import defaultdict
from typing import Callable, Iterable, Sequence from typing import Callable, TYPE_CHECKING
import yaml import yaml
@ -76,6 +76,10 @@ from .gen_inplace_or_view_type import is_tensor_list_type
from .gen_trace_type import should_trace from .gen_trace_type import should_trace
if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
# #
# declarations blocklist # declarations blocklist
# We skip codegen for these functions, for various reasons. # We skip codegen for these functions, for various reasons.

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import itertools import itertools
from typing import Sequence from typing import TYPE_CHECKING
from torchgen.api import cpp from torchgen.api import cpp
from torchgen.api.types import DispatcherSignature from torchgen.api.types import DispatcherSignature
@ -11,6 +11,10 @@ from torchgen.model import Argument, NativeFunction, SchemaKind, TensorOptionsAr
from torchgen.utils import FileManager from torchgen.utils import FileManager
if TYPE_CHECKING:
from collections.abc import Sequence
# Note [Manual Backend kernels] # Note [Manual Backend kernels]
# For these ops, we want to manually register to dispatch key Backend and # For these ops, we want to manually register to dispatch key Backend and
# skip codegen-ed registeration to all keys before Backend. # skip codegen-ed registeration to all keys before Backend.

View File

@ -29,7 +29,7 @@
from __future__ import annotations from __future__ import annotations
import re import re
from typing import Callable, Sequence from typing import Callable, TYPE_CHECKING
from torchgen.api import cpp from torchgen.api import cpp
from torchgen.api.autograd import ( from torchgen.api.autograd import (
@ -105,6 +105,10 @@ from .gen_trace_type import (
) )
if TYPE_CHECKING:
from collections.abc import Sequence
# We don't set or modify grad_fn on these methods. Generally, they return # We don't set or modify grad_fn on these methods. Generally, they return
# tensors that have requires_grad=False. In-place functions listed here will # tensors that have requires_grad=False. In-place functions listed here will
# not examine or modify requires_grad or grad_fn. # not examine or modify requires_grad or grad_fn.

View File

@ -6,8 +6,8 @@
from __future__ import annotations from __future__ import annotations
import re import re
from collections import defaultdict from collections import Counter, defaultdict
from typing import Any, Counter, Dict, Sequence, Set, Tuple from typing import Any, TYPE_CHECKING
import yaml import yaml
@ -53,7 +53,11 @@ from torchgen.utils import concatMap, IDENT_REGEX, split_name_params
from torchgen.yaml_utils import YamlLoader from torchgen.yaml_utils import YamlLoader
DerivativeRet = Tuple[Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], Set[str]] if TYPE_CHECKING:
from collections.abc import Sequence
DerivativeRet = tuple[dict[FunctionSchema, dict[str, DifferentiabilityInfo]], set[str]]
_GLOBAL_LOAD_DERIVATIVE_CACHE: dict[tuple[str, str], DerivativeRet] = {} _GLOBAL_LOAD_DERIVATIVE_CACHE: dict[tuple[str, str], DerivativeRet] = {}
@ -631,7 +635,7 @@ def create_differentiability_info(
raise RuntimeError( raise RuntimeError(
f"Not supported: for {specification}," f"Not supported: for {specification},"
f"output_differentiability must either be " f"output_differentiability must either be "
f"List[bool] or a List[str] where each str is a " f"list[bool] or a list[str] where each str is a "
f"condition. In the case where it is a condition, " f"condition. In the case where it is a condition, "
f"we only support single-output functions. " f"we only support single-output functions. "
f"Please file us an issue. " f"Please file us an issue. "

View File

@ -12,12 +12,11 @@ from __future__ import annotations
import argparse import argparse
from collections import defaultdict from collections import defaultdict
from typing import Dict, Set
import yaml import yaml
DepGraph = Dict[str, Set[str]] DepGraph = dict[str, set[str]]
def canonical_name(opname: str) -> str: def canonical_name(opname: str) -> str:

View File

@ -2,13 +2,13 @@ from __future__ import annotations
import os import os
import subprocess import subprocess
from typing import IO, Tuple from typing import IO
from ..oss.utils import get_pytorch_folder from ..oss.utils import get_pytorch_folder
from ..util.setting import SUMMARY_FOLDER_DIR, TestList, TestStatusType from ..util.setting import SUMMARY_FOLDER_DIR, TestList, TestStatusType
CoverageItem = Tuple[str, float, int, int] CoverageItem = tuple[str, float, int, int]
def key_by_percentage(x: CoverageItem) -> float: def key_by_percentage(x: CoverageItem) -> float:

View File

@ -31,7 +31,7 @@ if TYPE_CHECKING:
from .parser.coverage_record import CoverageRecord from .parser.coverage_record import CoverageRecord
# coverage_records: Dict[str, LineInfo] = {} # coverage_records: dict[str, LineInfo] = {}
covered_lines: dict[str, set[int]] = {} covered_lines: dict[str, set[int]] = {}
uncovered_lines: dict[str, set[int]] = {} uncovered_lines: dict[str, set[int]] = {}
tests_type: TestStatusType = {"success": set(), "partial": set(), "fail": set()} tests_type: TestStatusType = {"success": set(), "partial": set(), "fail": set()}

View File

@ -2,7 +2,6 @@ from __future__ import annotations
import os import os
from enum import Enum from enum import Enum
from typing import Dict, List, Set
# <project folder> # <project folder>
@ -43,8 +42,8 @@ class Test:
self.test_type = test_type self.test_type = test_type
TestList = List[Test] TestList = list[Test]
TestStatusType = Dict[str, Set[str]] TestStatusType = dict[str, set[str]]
# option # option

View File

@ -6,13 +6,13 @@ import argparse
import re import re
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Any, Dict from typing import Any
from typing_extensions import TypedDict # Python 3.11+ from typing_extensions import TypedDict # Python 3.11+
import yaml import yaml
Step = Dict[str, Any] Step = dict[str, Any]
class Script(TypedDict): class Script(TypedDict):

View File

@ -7,7 +7,7 @@ import os
import sys import sys
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Literal, Sequence, TYPE_CHECKING from typing import Literal, TYPE_CHECKING
import yaml import yaml
@ -22,6 +22,8 @@ from torchgen.utils import FileManager, make_file_manager, mapMaybe, Target
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Sequence
from torchgen.selective_build.selector import SelectiveBuilder from torchgen.selective_build.selector import SelectiveBuilder

View File

@ -22,7 +22,7 @@ import os
import re import re
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, NamedTuple, Optional from typing import Any, Callable, NamedTuple, Optional
from yaml import load from yaml import load
@ -77,11 +77,11 @@ def gen_lint_message(
) )
def check_file(filename: str) -> List[LintMessage]: def check_file(filename: str) -> list[LintMessage]:
logging.debug("Checking file %s", filename) logging.debug("Checking file %s", filename)
workflow = load_yaml(Path(filename)) workflow = load_yaml(Path(filename))
bad_jobs: Dict[str, Optional[str]] = {} bad_jobs: dict[str, Optional[str]] = {}
if type(workflow) is not dict: if type(workflow) is not dict:
return [] return []
@ -106,7 +106,7 @@ def check_file(filename: str) -> List[LintMessage]:
pass pass
else: else:
if_statement = str(if_statement) if_statement = str(if_statement)
valid_checks: List[Callable[[str], bool]] = [ valid_checks: list[Callable[[str], bool]] = [
lambda x: "github.repository == 'pytorch/pytorch'" in x lambda x: "github.repository == 'pytorch/pytorch'" in x
and "github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'" and "github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'"
not in x, not in x,

View File

@ -11,11 +11,15 @@ import json
from collections import defaultdict from collections import defaultdict
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Iterable, NamedTuple from typing import Any, NamedTuple, TYPE_CHECKING
from yaml import dump, load from yaml import dump, load
if TYPE_CHECKING:
from collections.abc import Iterable
# Safely load fast C Yaml loader/dumper if they are available # Safely load fast C Yaml loader/dumper if they are available
try: try:
from yaml import CSafeLoader as Loader from yaml import CSafeLoader as Loader

View File

@ -50,16 +50,11 @@ from ast import literal_eval
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from platform import system as platform_system from platform import system as platform_system
from typing import ( from typing import Any, Callable, cast, NamedTuple, TYPE_CHECKING, TypeVar
Any,
Callable,
cast, if TYPE_CHECKING:
Generator, from collections.abc import Generator, Iterable, Iterator
Iterable,
Iterator,
NamedTuple,
TypeVar,
)
try: try:

View File

@ -7,7 +7,7 @@ import subprocess
import sys import sys
import tempfile import tempfile
import urllib.request import urllib.request
from typing import cast, List, NoReturn, Optional from typing import cast, NoReturn, Optional
def parse_arguments() -> argparse.Namespace: def parse_arguments() -> argparse.Namespace:
@ -74,7 +74,7 @@ def get_pytorch_path() -> str:
try: try:
import torch import torch
torch_paths: List[str] = cast(List[str], torch.__path__) torch_paths: list[str] = cast(list[str], torch.__path__)
torch_path: str = torch_paths[0] torch_path: str = torch_paths[0]
parent_path: str = os.path.dirname(torch_path) parent_path: str = os.path.dirname(torch_path)
print(f"PyTorch is installed at: {torch_path}") print(f"PyTorch is installed at: {torch_path}")
@ -114,9 +114,10 @@ def download_patch(pr_number: int, repo_url: str, download_dir: str) -> str:
patch_file = os.path.join(download_dir, f"pr-{pr_number}.patch") patch_file = os.path.join(download_dir, f"pr-{pr_number}.patch")
print(f"Downloading PR #{pr_number} patch from {patch_url}...") print(f"Downloading PR #{pr_number} patch from {patch_url}...")
try: try:
with urllib.request.urlopen(patch_url) as response, open( with (
patch_file, "wb" urllib.request.urlopen(patch_url) as response,
) as out_file: open(patch_file, "wb") as out_file,
):
shutil.copyfileobj(response, out_file) shutil.copyfileobj(response, out_file)
if not os.path.isfile(patch_file): if not os.path.isfile(patch_file):
print(f"Failed to download patch for PR #{pr_number}") print(f"Failed to download patch for PR #{pr_number}")

View File

@ -17,7 +17,8 @@ import os
import string import string
import subprocess import subprocess
import textwrap import textwrap
from typing import Any, Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any
import yaml import yaml
@ -56,7 +57,7 @@ class _{pascal_case_name}(infra.Rule):
self, self,
level: infra.Level, level: infra.Level,
{message_arguments} {message_arguments}
) -> Tuple[infra.Rule, infra.Level, str]: ) -> tuple[infra.Rule, infra.Level, str]:
\"\"\"Returns a tuple of (Rule, Level, message) for this Rule. \"\"\"Returns a tuple of (Rule, Level, message) for this Rule.
Message template: {message_template} Message template: {message_template}

View File

@ -26,7 +26,7 @@ import os
import subprocess import subprocess
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional from typing import Optional
logging.basicConfig( logging.basicConfig(
@ -56,7 +56,7 @@ def requirements_installed() -> bool:
return False return False
def setup_py(cmd_args: List[str], extra_env: Optional[Dict[str, str]] = None) -> None: def setup_py(cmd_args: list[str], extra_env: Optional[dict[str, str]] = None) -> None:
if extra_env is None: if extra_env is None:
extra_env = {} extra_env = {}
cmd = [sys.executable, str(SETUP_PY_PATH), *cmd_args] cmd = [sys.executable, str(SETUP_PY_PATH), *cmd_args]

View File

@ -5,7 +5,7 @@ import collections
import importlib import importlib
import sys import sys
from pprint import pformat from pprint import pformat
from typing import Sequence from typing import TYPE_CHECKING
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from warnings import warn from warnings import warn
@ -25,6 +25,10 @@ from torchgen.model import _TorchDispatchModeKey, DispatchKey, Variant
from torchgen.utils import FileManager from torchgen.utils import FileManager
if TYPE_CHECKING:
from collections.abc import Sequence
""" """
This module implements generation of type stubs for PyTorch, This module implements generation of type stubs for PyTorch,
enabling use of autocomplete in IDEs like PyCharm, which otherwise enabling use of autocomplete in IDEs like PyCharm, which otherwise
@ -229,7 +233,7 @@ all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops
def sig_for_ops(opname: str) -> list[str]: def sig_for_ops(opname: str) -> list[str]:
"""sig_for_ops(opname : str) -> List[str] """sig_for_ops(opname : str) -> list[str]
Returns signatures for operator special functions (__add__ etc.)""" Returns signatures for operator special functions (__add__ etc.)"""
@ -330,11 +334,11 @@ def get_max_pool_dispatch(name: str, arg_list: list[str]) -> dict[str, list[str]
), ),
tmpl.format(name=name, args=", ".join(arg_list_positional)).format( tmpl.format(name=name, args=", ".join(arg_list_positional)).format(
return_indices="return_indices: Literal[True]", return_indices="return_indices: Literal[True]",
return_type="Tuple[Tensor, Tensor]", return_type="tuple[Tensor, Tensor]",
), ),
tmpl.format(name=name, args=", ".join(arg_list_keyword)).format( tmpl.format(name=name, args=", ".join(arg_list_keyword)).format(
return_indices="return_indices: Literal[True]", return_indices="return_indices: Literal[True]",
return_type="Tuple[Tensor, Tensor]", return_type="tuple[Tensor, Tensor]",
), ),
] ]
} }
@ -380,13 +384,13 @@ def gen_nn_functional(fm: FileManager) -> None:
"_random_samples: Tensor", "_random_samples: Tensor",
] ]
), ),
"Tuple[Tensor, Tensor]", "tuple[Tensor, Tensor]",
) )
], ],
f"adaptive_max_pool{d}d": [ f"adaptive_max_pool{d}d": [
f"def adaptive_max_pool{d}d({{}}) -> {{}}: ...".format( f"def adaptive_max_pool{d}d({{}}) -> {{}}: ...".format(
", ".join([f"{INPUT}", "output_size: Union[_int, _size]"]), ", ".join([f"{INPUT}", "output_size: Union[_int, _size]"]),
"Tuple[Tensor, Tensor]", "tuple[Tensor, Tensor]",
) )
], ],
} }
@ -690,9 +694,9 @@ def gen_pyi(
f"def sparse_{n}_tensor({{}}) -> Tensor: ...".format( f"def sparse_{n}_tensor({{}}) -> Tensor: ...".format(
", ".join( ", ".join(
[ [
f"{n1}_indices: Union[Tensor, List]", f"{n1}_indices: Union[Tensor, list]",
f"{n2}_indices: Union[Tensor, List]", f"{n2}_indices: Union[Tensor, list]",
"values: Union[Tensor, List]", "values: Union[Tensor, list]",
"size: Optional[_size] = None", "size: Optional[_size] = None",
"*", "*",
"dtype: Optional[_dtype] = None", "dtype: Optional[_dtype] = None",
@ -767,7 +771,7 @@ def gen_pyi(
", ".join( ", ".join(
[ [
"indices: Tensor", "indices: Tensor",
"values: Union[Tensor, List]", "values: Union[Tensor, list]",
"size: Optional[_size] = None", "size: Optional[_size] = None",
"*", "*",
"dtype: Optional[_dtype] = None", "dtype: Optional[_dtype] = None",
@ -783,9 +787,9 @@ def gen_pyi(
"def sparse_compressed_tensor({}) -> Tensor: ...".format( "def sparse_compressed_tensor({}) -> Tensor: ...".format(
", ".join( ", ".join(
[ [
"compressed_indices: Union[Tensor, List]", "compressed_indices: Union[Tensor, list]",
"plain_indices: Union[Tensor, List]", "plain_indices: Union[Tensor, list]",
"values: Union[Tensor, List]", "values: Union[Tensor, list]",
"size: Optional[_size] = None", "size: Optional[_size] = None",
"*", "*",
"dtype: Optional[_dtype] = None", "dtype: Optional[_dtype] = None",
@ -973,7 +977,7 @@ def gen_pyi(
"size: _size", "size: _size",
"fill_value: Union[Number, _complex]", "fill_value: Union[Number, _complex]",
"*", "*",
"names: List[Union[str, None]]", "names: list[Union[str, None]]",
"layout: _layout = strided", "layout: _layout = strided",
FACTORY_PARAMS, FACTORY_PARAMS,
] ]
@ -986,7 +990,7 @@ def gen_pyi(
], ],
"nonzero": [ "nonzero": [
"def nonzero(input: Tensor, *, as_tuple: Literal[False] = False, out: Optional[Tensor] = None) -> Tensor: ...", "def nonzero(input: Tensor, *, as_tuple: Literal[False] = False, out: Optional[Tensor] = None) -> Tensor: ...",
"def nonzero(input: Tensor, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...", "def nonzero(input: Tensor, *, as_tuple: Literal[True]) -> tuple[Tensor, ...]: ...",
], ],
"dsmm": ["def dsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."], "dsmm": ["def dsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."],
"hsmm": ["def hsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."], "hsmm": ["def hsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."],
@ -1087,7 +1091,7 @@ def gen_pyi(
"def size(self, dim: _int) -> _int: ...", "def size(self, dim: _int) -> _int: ...",
], ],
"stride": [ "stride": [
"def stride(self, dim: None = None) -> Tuple[_int, ...]: ...", "def stride(self, dim: None = None) -> tuple[_int, ...]: ...",
"def stride(self, dim: _int) -> _int: ...", "def stride(self, dim: _int) -> _int: ...",
], ],
"new_ones": [ "new_ones": [
@ -1131,7 +1135,7 @@ def gen_pyi(
"__setitem__": [ "__setitem__": [
f"def __setitem__(self, {INDICES}, val: Union[Tensor, Number]) -> None: ..." f"def __setitem__(self, {INDICES}, val: Union[Tensor, Number]) -> None: ..."
], ],
"tolist": ["def tolist(self) -> List: ..."], "tolist": ["def tolist(self) -> list: ..."],
"requires_grad_": [ "requires_grad_": [
"def requires_grad_(self, mode: _bool = True) -> Tensor: ..." "def requires_grad_(self, mode: _bool = True) -> Tensor: ..."
], ],
@ -1140,7 +1144,7 @@ def gen_pyi(
"dim": ["def dim(self) -> _int: ..."], "dim": ["def dim(self) -> _int: ..."],
"nonzero": [ "nonzero": [
"def nonzero(self, *, as_tuple: Literal[False] = False) -> Tensor: ...", "def nonzero(self, *, as_tuple: Literal[False] = False) -> Tensor: ...",
"def nonzero(self, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...", "def nonzero(self, *, as_tuple: Literal[True]) -> tuple[Tensor, ...]: ...",
], ],
"numel": ["def numel(self) -> _int: ..."], "numel": ["def numel(self) -> _int: ..."],
"ndimension": ["def ndimension(self) -> _int: ..."], "ndimension": ["def ndimension(self) -> _int: ..."],
@ -1233,7 +1237,7 @@ def gen_pyi(
], ],
"split": [ "split": [
"def split(self, split_size: _int, dim: _int = 0) -> Sequence[Tensor]: ...", "def split(self, split_size: _int, dim: _int = 0) -> Sequence[Tensor]: ...",
"def split(self, split_size: Tuple[_int, ...], dim: _int = 0) -> Sequence[Tensor]: ...", "def split(self, split_size: tuple[_int, ...], dim: _int = 0) -> Sequence[Tensor]: ...",
], ],
"div": [ "div": [
"def div(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..." "def div(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..."

View File

@ -5,7 +5,11 @@ import platform
import struct import struct
import sys import sys
from itertools import chain from itertools import chain
from typing import cast, Iterable from typing import cast, TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Iterable
IS_WINDOWS = platform.system() == "Windows" IS_WINDOWS = platform.system() == "Windows"

View File

@ -6,10 +6,10 @@ from __future__ import annotations
import argparse import argparse
import os import os
from typing import cast, Tuple from typing import cast
Version = Tuple[int, int, int] Version = tuple[int, int, int]
def parse_version(version: str) -> Version: def parse_version(version: str) -> Version:

View File

@ -5,12 +5,15 @@ import os
import typing import typing
import unittest import unittest
import unittest.mock import unittest.mock
from typing import Iterator, Sequence
import tools.setup_helpers.cmake import tools.setup_helpers.cmake
import tools.setup_helpers.env # noqa: F401 unused but resolves circular import import tools.setup_helpers.env # noqa: F401 unused but resolves circular import
if typing.TYPE_CHECKING:
from collections.abc import Iterator, Sequence
T = typing.TypeVar("T") T = typing.TypeVar("T")

View File

@ -6,7 +6,7 @@ import json
import sys import sys
import unittest import unittest
from pathlib import Path from pathlib import Path
from typing import Any, Dict from typing import Any
from unittest import mock from unittest import mock
@ -35,7 +35,7 @@ JOB_NAME = "some-job-name"
@mock.patch("boto3.resource") @mock.patch("boto3.resource")
class TestUploadStats(unittest.TestCase): class TestUploadStats(unittest.TestCase):
emitted_metric: Dict[str, Any] = {"did_not_emit": True} emitted_metric: dict[str, Any] = {"did_not_emit": True}
def mock_put_item(self, **kwargs: Any) -> None: def mock_put_item(self, **kwargs: Any) -> None:
# Utility for mocking putting items into s3. THis will save the emitted # Utility for mocking putting items into s3. THis will save the emitted

View File

@ -1,7 +1,7 @@
import json import json
import os import os
from functools import lru_cache from functools import lru_cache
from typing import Any, Dict, List from typing import Any
import clickhouse_connect # type: ignore[import] import clickhouse_connect # type: ignore[import]
@ -25,12 +25,12 @@ def get_clickhouse_client() -> Any:
) )
def query_clickhouse(query: str, params: Dict[str, Any]) -> List[Dict[str, Any]]: def query_clickhouse(query: str, params: dict[str, Any]) -> list[dict[str, Any]]:
""" """
Queries ClickHouse. Returns datetime in YYYY-MM-DD HH:MM:SS format. Queries ClickHouse. Returns datetime in YYYY-MM-DD HH:MM:SS format.
""" """
def convert_to_json_list(res: bytes) -> List[Dict[str, Any]]: def convert_to_json_list(res: bytes) -> list[dict[str, Any]]:
rows = [] rows = []
for row in res.decode().split("\n"): for row in res.decode().split("\n"):
if row: if row:

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import json import json
import os import os
from collections import defaultdict from collections import defaultdict
from typing import Any, cast, Dict from typing import Any, cast
from warnings import warn from warnings import warn
from tools.stats.import_test_stats import ( from tools.stats.import_test_stats import (
@ -45,7 +45,7 @@ def _get_historical_test_class_correlations() -> dict[str, dict[str, float]]:
print(f"could not find path {path}") print(f"could not find path {path}")
return {} return {}
with open(path) as f: with open(path) as f:
test_class_correlations = cast(Dict[str, Dict[str, float]], json.load(f)) test_class_correlations = cast(dict[str, dict[str, float]], json.load(f))
return test_class_correlations return test_class_correlations

View File

@ -2,11 +2,15 @@ from __future__ import annotations
from abc import abstractmethod from abc import abstractmethod
from copy import copy from copy import copy
from typing import Any, Iterable, Iterator from typing import Any, TYPE_CHECKING
from tools.testing.test_run import TestRun from tools.testing.test_run import TestRun
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator
class TestPrioritizations: class TestPrioritizations:
""" """
Describes the results of whether heuristics consider a test relevant or not. Describes the results of whether heuristics consider a test relevant or not.

View File

@ -5,9 +5,9 @@ import os
import re import re
import subprocess import subprocess
from collections import defaultdict from collections import defaultdict
from functools import lru_cache from functools import cache
from pathlib import Path from pathlib import Path
from typing import cast, Dict, TYPE_CHECKING from typing import cast, TYPE_CHECKING
from urllib.request import Request, urlopen from urllib.request import Request, urlopen
from warnings import warn from warnings import warn
@ -26,7 +26,7 @@ def python_test_file_to_test_name(tests: set[str]) -> set[str]:
return valid_tests return valid_tests
@lru_cache(maxsize=None) @cache
def get_pr_number() -> int | None: def get_pr_number() -> int | None:
pr_number = os.environ.get("PR_NUMBER", "") pr_number = os.environ.get("PR_NUMBER", "")
if pr_number == "": if pr_number == "":
@ -38,7 +38,7 @@ def get_pr_number() -> int | None:
return None return None
@lru_cache(maxsize=None) @cache
def get_merge_base() -> str: def get_merge_base() -> str:
pr_number = get_pr_number() pr_number = get_pr_number()
if pr_number is not None: if pr_number is not None:
@ -91,7 +91,7 @@ def query_changed_files() -> list[str]:
return lines return lines
@lru_cache(maxsize=None) @cache
def get_git_commit_info() -> str: def get_git_commit_info() -> str:
"""Gets the commit info since the last commit on the default branch.""" """Gets the commit info since the last commit on the default branch."""
base_commit = get_merge_base() base_commit = get_merge_base()
@ -105,7 +105,7 @@ def get_git_commit_info() -> str:
) )
@lru_cache(maxsize=None) @cache
def get_issue_or_pr_body(number: int) -> str: def get_issue_or_pr_body(number: int) -> str:
"""Gets the body of an issue or PR""" """Gets the body of an issue or PR"""
github_token = os.environ.get("GITHUB_TOKEN") github_token = os.environ.get("GITHUB_TOKEN")
@ -148,7 +148,7 @@ def get_ratings_for_tests(file: str | Path) -> dict[str, float]:
print(f"could not find path {path}") print(f"could not find path {path}")
return {} return {}
with open(path) as f: with open(path) as f:
test_file_ratings = cast(Dict[str, Dict[str, float]], json.load(f)) test_file_ratings = cast(dict[str, dict[str, float]], json.load(f))
try: try:
changed_files = query_changed_files() changed_files = query_changed_files()
except Exception as e: except Exception as e:

View File

@ -2,7 +2,11 @@ from __future__ import annotations
from copy import copy from copy import copy
from functools import total_ordering from functools import total_ordering
from typing import Any, Iterable from typing import Any, TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Iterable
class TestRun: class TestRun:

View File

@ -4,12 +4,16 @@ import math
import os import os
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from typing import Callable, Sequence from typing import Callable, TYPE_CHECKING
from tools.stats.import_test_stats import get_disabled_tests from tools.stats.import_test_stats import get_disabled_tests
from tools.testing.test_run import ShardedTest, TestRun from tools.testing.test_run import ShardedTest, TestRun
if TYPE_CHECKING:
from collections.abc import Sequence
REPO_ROOT = Path(__file__).resolve().parent.parent.parent REPO_ROOT = Path(__file__).resolve().parent.parent.parent
IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1" IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1"

View File

@ -3,7 +3,7 @@ import os
import subprocess import subprocess
import time import time
from pathlib import Path from pathlib import Path
from typing import Any, cast, Dict, List, Optional, Tuple from typing import Any, cast, Optional
import requests import requests
from clickhouse import query_clickhouse # type: ignore[import] from clickhouse import query_clickhouse # type: ignore[import]
@ -96,7 +96,7 @@ PYTORCHBOT_TOKEN = os.environ["PYTORCHBOT_TOKEN"]
def git_api( def git_api(
url: str, params: Dict[str, Any], type: str = "get", token: str = UPDATEBOT_TOKEN url: str, params: dict[str, Any], type: str = "get", token: str = UPDATEBOT_TOKEN
) -> Any: ) -> Any:
headers = { headers = {
"Accept": "application/vnd.github.v3+json", "Accept": "application/vnd.github.v3+json",
@ -122,7 +122,7 @@ def git_api(
).json() ).json()
def make_pr(source_repo: str, params: Dict[str, Any]) -> int: def make_pr(source_repo: str, params: dict[str, Any]) -> int:
response = git_api(f"/repos/{source_repo}/pulls", params, type="post") response = git_api(f"/repos/{source_repo}/pulls", params, type="post")
print(f"made pr {response['html_url']}") print(f"made pr {response['html_url']}")
return cast(int, response["number"]) return cast(int, response["number"])
@ -150,7 +150,7 @@ def make_comment(source_repo: str, pr_number: int, msg: str) -> None:
) )
def add_labels(source_repo: str, pr_number: int, labels: List[str]) -> None: def add_labels(source_repo: str, pr_number: int, labels: list[str]) -> None:
params = {"labels": labels} params = {"labels": labels}
git_api( git_api(
f"/repos/{source_repo}/issues/{pr_number}/labels", f"/repos/{source_repo}/issues/{pr_number}/labels",
@ -161,7 +161,7 @@ def add_labels(source_repo: str, pr_number: int, labels: List[str]) -> None:
def search_for_open_pr( def search_for_open_pr(
source_repo: str, search_string: str source_repo: str, search_string: str
) -> Optional[Tuple[int, str]]: ) -> Optional[tuple[int, str]]:
params = { params = {
"q": f"is:pr is:open in:title author:pytorchupdatebot repo:{source_repo} {search_string}", "q": f"is:pr is:open in:title author:pytorchupdatebot repo:{source_repo} {search_string}",
"sort": "created", "sort": "created",

View File

@ -4,7 +4,7 @@ import time
import zipfile import zipfile
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import Any, List from typing import Any
REPO_ROOT = Path(__file__).resolve().parent.parent.parent REPO_ROOT = Path(__file__).resolve().parent.parent.parent
@ -18,7 +18,7 @@ def get_s3_resource() -> Any:
return boto3.client("s3") return boto3.client("s3")
def zip_artifact(file_name: str, paths: List[str]) -> None: def zip_artifact(file_name: str, paths: list[str]) -> None:
"""Zip the files in the paths listed into file_name. The paths will be used """Zip the files in the paths listed into file_name. The paths will be used
in a glob and should be relative to REPO_ROOT.""" in a glob and should be relative to REPO_ROOT."""