mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[3/N] Apply py39 ruff fixes (#142115)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/142115 Approved by: https://github.com/ezyang
This commit is contained in:
@ -20,7 +20,7 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
import textwrap
|
import textwrap
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Sequence, TYPE_CHECKING
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
import torchgen.api.python as python
|
import torchgen.api.python as python
|
||||||
from torchgen.context import with_native_function
|
from torchgen.context import with_native_function
|
||||||
@ -39,6 +39,8 @@ from .gen_python_functions import (
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from torchgen.model import Argument, BaseOperatorName, NativeFunction
|
from torchgen.model import Argument, BaseOperatorName, NativeFunction
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Sequence
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from torchgen.api.autograd import (
|
from torchgen.api.autograd import (
|
||||||
Derivative,
|
Derivative,
|
||||||
@ -47,6 +47,10 @@ from torchgen.utils import FileManager
|
|||||||
from .gen_inplace_or_view_type import VIEW_FUNCTIONS
|
from .gen_inplace_or_view_type import VIEW_FUNCTIONS
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
FUNCTION_DECLARATION = CodeTemplate(
|
FUNCTION_DECLARATION = CodeTemplate(
|
||||||
"""\
|
"""\
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
|
@ -36,7 +36,7 @@ from __future__ import annotations
|
|||||||
import itertools
|
import itertools
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Callable, Iterable, Sequence
|
from typing import Callable, TYPE_CHECKING
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@ -76,6 +76,10 @@ from .gen_inplace_or_view_type import is_tensor_list_type
|
|||||||
from .gen_trace_type import should_trace
|
from .gen_trace_type import should_trace
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Iterable, Sequence
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# declarations blocklist
|
# declarations blocklist
|
||||||
# We skip codegen for these functions, for various reasons.
|
# We skip codegen for these functions, for various reasons.
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
from typing import Sequence
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from torchgen.api import cpp
|
from torchgen.api import cpp
|
||||||
from torchgen.api.types import DispatcherSignature
|
from torchgen.api.types import DispatcherSignature
|
||||||
@ -11,6 +11,10 @@ from torchgen.model import Argument, NativeFunction, SchemaKind, TensorOptionsAr
|
|||||||
from torchgen.utils import FileManager
|
from torchgen.utils import FileManager
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
# Note [Manual Backend kernels]
|
# Note [Manual Backend kernels]
|
||||||
# For these ops, we want to manually register to dispatch key Backend and
|
# For these ops, we want to manually register to dispatch key Backend and
|
||||||
# skip codegen-ed registeration to all keys before Backend.
|
# skip codegen-ed registeration to all keys before Backend.
|
||||||
|
@ -29,7 +29,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Callable, Sequence
|
from typing import Callable, TYPE_CHECKING
|
||||||
|
|
||||||
from torchgen.api import cpp
|
from torchgen.api import cpp
|
||||||
from torchgen.api.autograd import (
|
from torchgen.api.autograd import (
|
||||||
@ -105,6 +105,10 @@ from .gen_trace_type import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
# We don't set or modify grad_fn on these methods. Generally, they return
|
# We don't set or modify grad_fn on these methods. Generally, they return
|
||||||
# tensors that have requires_grad=False. In-place functions listed here will
|
# tensors that have requires_grad=False. In-place functions listed here will
|
||||||
# not examine or modify requires_grad or grad_fn.
|
# not examine or modify requires_grad or grad_fn.
|
||||||
|
@ -6,8 +6,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
from collections import Counter, defaultdict
|
||||||
from typing import Any, Counter, Dict, Sequence, Set, Tuple
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@ -53,7 +53,11 @@ from torchgen.utils import concatMap, IDENT_REGEX, split_name_params
|
|||||||
from torchgen.yaml_utils import YamlLoader
|
from torchgen.yaml_utils import YamlLoader
|
||||||
|
|
||||||
|
|
||||||
DerivativeRet = Tuple[Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], Set[str]]
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
|
DerivativeRet = tuple[dict[FunctionSchema, dict[str, DifferentiabilityInfo]], set[str]]
|
||||||
|
|
||||||
_GLOBAL_LOAD_DERIVATIVE_CACHE: dict[tuple[str, str], DerivativeRet] = {}
|
_GLOBAL_LOAD_DERIVATIVE_CACHE: dict[tuple[str, str], DerivativeRet] = {}
|
||||||
|
|
||||||
@ -631,7 +635,7 @@ def create_differentiability_info(
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Not supported: for {specification},"
|
f"Not supported: for {specification},"
|
||||||
f"output_differentiability must either be "
|
f"output_differentiability must either be "
|
||||||
f"List[bool] or a List[str] where each str is a "
|
f"list[bool] or a list[str] where each str is a "
|
||||||
f"condition. In the case where it is a condition, "
|
f"condition. In the case where it is a condition, "
|
||||||
f"we only support single-output functions. "
|
f"we only support single-output functions. "
|
||||||
f"Please file us an issue. "
|
f"Please file us an issue. "
|
||||||
|
@ -12,12 +12,11 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Dict, Set
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
DepGraph = Dict[str, Set[str]]
|
DepGraph = dict[str, set[str]]
|
||||||
|
|
||||||
|
|
||||||
def canonical_name(opname: str) -> str:
|
def canonical_name(opname: str) -> str:
|
||||||
|
@ -2,13 +2,13 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
from typing import IO, Tuple
|
from typing import IO
|
||||||
|
|
||||||
from ..oss.utils import get_pytorch_folder
|
from ..oss.utils import get_pytorch_folder
|
||||||
from ..util.setting import SUMMARY_FOLDER_DIR, TestList, TestStatusType
|
from ..util.setting import SUMMARY_FOLDER_DIR, TestList, TestStatusType
|
||||||
|
|
||||||
|
|
||||||
CoverageItem = Tuple[str, float, int, int]
|
CoverageItem = tuple[str, float, int, int]
|
||||||
|
|
||||||
|
|
||||||
def key_by_percentage(x: CoverageItem) -> float:
|
def key_by_percentage(x: CoverageItem) -> float:
|
||||||
|
@ -31,7 +31,7 @@ if TYPE_CHECKING:
|
|||||||
from .parser.coverage_record import CoverageRecord
|
from .parser.coverage_record import CoverageRecord
|
||||||
|
|
||||||
|
|
||||||
# coverage_records: Dict[str, LineInfo] = {}
|
# coverage_records: dict[str, LineInfo] = {}
|
||||||
covered_lines: dict[str, set[int]] = {}
|
covered_lines: dict[str, set[int]] = {}
|
||||||
uncovered_lines: dict[str, set[int]] = {}
|
uncovered_lines: dict[str, set[int]] = {}
|
||||||
tests_type: TestStatusType = {"success": set(), "partial": set(), "fail": set()}
|
tests_type: TestStatusType = {"success": set(), "partial": set(), "fail": set()}
|
||||||
|
@ -2,7 +2,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Set
|
|
||||||
|
|
||||||
|
|
||||||
# <project folder>
|
# <project folder>
|
||||||
@ -43,8 +42,8 @@ class Test:
|
|||||||
self.test_type = test_type
|
self.test_type = test_type
|
||||||
|
|
||||||
|
|
||||||
TestList = List[Test]
|
TestList = list[Test]
|
||||||
TestStatusType = Dict[str, Set[str]]
|
TestStatusType = dict[str, set[str]]
|
||||||
|
|
||||||
|
|
||||||
# option
|
# option
|
||||||
|
@ -6,13 +6,13 @@ import argparse
|
|||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
from typing_extensions import TypedDict # Python 3.11+
|
from typing_extensions import TypedDict # Python 3.11+
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
Step = Dict[str, Any]
|
Step = dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class Script(TypedDict):
|
class Script(TypedDict):
|
||||||
|
@ -7,7 +7,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal, Sequence, TYPE_CHECKING
|
from typing import Literal, TYPE_CHECKING
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@ -22,6 +22,8 @@ from torchgen.utils import FileManager, make_file_manager, mapMaybe, Target
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from torchgen.selective_build.selector import SelectiveBuilder
|
from torchgen.selective_build.selector import SelectiveBuilder
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional
|
from typing import Any, Callable, NamedTuple, Optional
|
||||||
|
|
||||||
from yaml import load
|
from yaml import load
|
||||||
|
|
||||||
@ -77,11 +77,11 @@ def gen_lint_message(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_file(filename: str) -> List[LintMessage]:
|
def check_file(filename: str) -> list[LintMessage]:
|
||||||
logging.debug("Checking file %s", filename)
|
logging.debug("Checking file %s", filename)
|
||||||
|
|
||||||
workflow = load_yaml(Path(filename))
|
workflow = load_yaml(Path(filename))
|
||||||
bad_jobs: Dict[str, Optional[str]] = {}
|
bad_jobs: dict[str, Optional[str]] = {}
|
||||||
if type(workflow) is not dict:
|
if type(workflow) is not dict:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@ -106,7 +106,7 @@ def check_file(filename: str) -> List[LintMessage]:
|
|||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if_statement = str(if_statement)
|
if_statement = str(if_statement)
|
||||||
valid_checks: List[Callable[[str], bool]] = [
|
valid_checks: list[Callable[[str], bool]] = [
|
||||||
lambda x: "github.repository == 'pytorch/pytorch'" in x
|
lambda x: "github.repository == 'pytorch/pytorch'" in x
|
||||||
and "github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'"
|
and "github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'"
|
||||||
not in x,
|
not in x,
|
||||||
|
@ -11,11 +11,15 @@ import json
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Iterable, NamedTuple
|
from typing import Any, NamedTuple, TYPE_CHECKING
|
||||||
|
|
||||||
from yaml import dump, load
|
from yaml import dump, load
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
|
||||||
# Safely load fast C Yaml loader/dumper if they are available
|
# Safely load fast C Yaml loader/dumper if they are available
|
||||||
try:
|
try:
|
||||||
from yaml import CSafeLoader as Loader
|
from yaml import CSafeLoader as Loader
|
||||||
|
@ -50,16 +50,11 @@ from ast import literal_eval
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from platform import system as platform_system
|
from platform import system as platform_system
|
||||||
from typing import (
|
from typing import Any, Callable, cast, NamedTuple, TYPE_CHECKING, TypeVar
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
cast,
|
if TYPE_CHECKING:
|
||||||
Generator,
|
from collections.abc import Generator, Iterable, Iterator
|
||||||
Iterable,
|
|
||||||
Iterator,
|
|
||||||
NamedTuple,
|
|
||||||
TypeVar,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -7,7 +7,7 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import urllib.request
|
import urllib.request
|
||||||
from typing import cast, List, NoReturn, Optional
|
from typing import cast, NoReturn, Optional
|
||||||
|
|
||||||
|
|
||||||
def parse_arguments() -> argparse.Namespace:
|
def parse_arguments() -> argparse.Namespace:
|
||||||
@ -74,7 +74,7 @@ def get_pytorch_path() -> str:
|
|||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
torch_paths: List[str] = cast(List[str], torch.__path__)
|
torch_paths: list[str] = cast(list[str], torch.__path__)
|
||||||
torch_path: str = torch_paths[0]
|
torch_path: str = torch_paths[0]
|
||||||
parent_path: str = os.path.dirname(torch_path)
|
parent_path: str = os.path.dirname(torch_path)
|
||||||
print(f"PyTorch is installed at: {torch_path}")
|
print(f"PyTorch is installed at: {torch_path}")
|
||||||
@ -114,9 +114,10 @@ def download_patch(pr_number: int, repo_url: str, download_dir: str) -> str:
|
|||||||
patch_file = os.path.join(download_dir, f"pr-{pr_number}.patch")
|
patch_file = os.path.join(download_dir, f"pr-{pr_number}.patch")
|
||||||
print(f"Downloading PR #{pr_number} patch from {patch_url}...")
|
print(f"Downloading PR #{pr_number} patch from {patch_url}...")
|
||||||
try:
|
try:
|
||||||
with urllib.request.urlopen(patch_url) as response, open(
|
with (
|
||||||
patch_file, "wb"
|
urllib.request.urlopen(patch_url) as response,
|
||||||
) as out_file:
|
open(patch_file, "wb") as out_file,
|
||||||
|
):
|
||||||
shutil.copyfileobj(response, out_file)
|
shutil.copyfileobj(response, out_file)
|
||||||
if not os.path.isfile(patch_file):
|
if not os.path.isfile(patch_file):
|
||||||
print(f"Failed to download patch for PR #{pr_number}")
|
print(f"Failed to download patch for PR #{pr_number}")
|
||||||
|
@ -17,7 +17,8 @@ import os
|
|||||||
import string
|
import string
|
||||||
import subprocess
|
import subprocess
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import Any, Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@ -56,7 +57,7 @@ class _{pascal_case_name}(infra.Rule):
|
|||||||
self,
|
self,
|
||||||
level: infra.Level,
|
level: infra.Level,
|
||||||
{message_arguments}
|
{message_arguments}
|
||||||
) -> Tuple[infra.Rule, infra.Level, str]:
|
) -> tuple[infra.Rule, infra.Level, str]:
|
||||||
\"\"\"Returns a tuple of (Rule, Level, message) for this Rule.
|
\"\"\"Returns a tuple of (Rule, Level, message) for this Rule.
|
||||||
|
|
||||||
Message template: {message_template}
|
Message template: {message_template}
|
||||||
|
@ -26,7 +26,7 @@ import os
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@ -56,7 +56,7 @@ def requirements_installed() -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def setup_py(cmd_args: List[str], extra_env: Optional[Dict[str, str]] = None) -> None:
|
def setup_py(cmd_args: list[str], extra_env: Optional[dict[str, str]] = None) -> None:
|
||||||
if extra_env is None:
|
if extra_env is None:
|
||||||
extra_env = {}
|
extra_env = {}
|
||||||
cmd = [sys.executable, str(SETUP_PY_PATH), *cmd_args]
|
cmd = [sys.executable, str(SETUP_PY_PATH), *cmd_args]
|
||||||
|
@ -5,7 +5,7 @@ import collections
|
|||||||
import importlib
|
import importlib
|
||||||
import sys
|
import sys
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Sequence
|
from typing import TYPE_CHECKING
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
|
||||||
@ -25,6 +25,10 @@ from torchgen.model import _TorchDispatchModeKey, DispatchKey, Variant
|
|||||||
from torchgen.utils import FileManager
|
from torchgen.utils import FileManager
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This module implements generation of type stubs for PyTorch,
|
This module implements generation of type stubs for PyTorch,
|
||||||
enabling use of autocomplete in IDEs like PyCharm, which otherwise
|
enabling use of autocomplete in IDEs like PyCharm, which otherwise
|
||||||
@ -229,7 +233,7 @@ all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops
|
|||||||
|
|
||||||
|
|
||||||
def sig_for_ops(opname: str) -> list[str]:
|
def sig_for_ops(opname: str) -> list[str]:
|
||||||
"""sig_for_ops(opname : str) -> List[str]
|
"""sig_for_ops(opname : str) -> list[str]
|
||||||
|
|
||||||
Returns signatures for operator special functions (__add__ etc.)"""
|
Returns signatures for operator special functions (__add__ etc.)"""
|
||||||
|
|
||||||
@ -330,11 +334,11 @@ def get_max_pool_dispatch(name: str, arg_list: list[str]) -> dict[str, list[str]
|
|||||||
),
|
),
|
||||||
tmpl.format(name=name, args=", ".join(arg_list_positional)).format(
|
tmpl.format(name=name, args=", ".join(arg_list_positional)).format(
|
||||||
return_indices="return_indices: Literal[True]",
|
return_indices="return_indices: Literal[True]",
|
||||||
return_type="Tuple[Tensor, Tensor]",
|
return_type="tuple[Tensor, Tensor]",
|
||||||
),
|
),
|
||||||
tmpl.format(name=name, args=", ".join(arg_list_keyword)).format(
|
tmpl.format(name=name, args=", ".join(arg_list_keyword)).format(
|
||||||
return_indices="return_indices: Literal[True]",
|
return_indices="return_indices: Literal[True]",
|
||||||
return_type="Tuple[Tensor, Tensor]",
|
return_type="tuple[Tensor, Tensor]",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@ -380,13 +384,13 @@ def gen_nn_functional(fm: FileManager) -> None:
|
|||||||
"_random_samples: Tensor",
|
"_random_samples: Tensor",
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
"Tuple[Tensor, Tensor]",
|
"tuple[Tensor, Tensor]",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
f"adaptive_max_pool{d}d": [
|
f"adaptive_max_pool{d}d": [
|
||||||
f"def adaptive_max_pool{d}d({{}}) -> {{}}: ...".format(
|
f"def adaptive_max_pool{d}d({{}}) -> {{}}: ...".format(
|
||||||
", ".join([f"{INPUT}", "output_size: Union[_int, _size]"]),
|
", ".join([f"{INPUT}", "output_size: Union[_int, _size]"]),
|
||||||
"Tuple[Tensor, Tensor]",
|
"tuple[Tensor, Tensor]",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@ -690,9 +694,9 @@ def gen_pyi(
|
|||||||
f"def sparse_{n}_tensor({{}}) -> Tensor: ...".format(
|
f"def sparse_{n}_tensor({{}}) -> Tensor: ...".format(
|
||||||
", ".join(
|
", ".join(
|
||||||
[
|
[
|
||||||
f"{n1}_indices: Union[Tensor, List]",
|
f"{n1}_indices: Union[Tensor, list]",
|
||||||
f"{n2}_indices: Union[Tensor, List]",
|
f"{n2}_indices: Union[Tensor, list]",
|
||||||
"values: Union[Tensor, List]",
|
"values: Union[Tensor, list]",
|
||||||
"size: Optional[_size] = None",
|
"size: Optional[_size] = None",
|
||||||
"*",
|
"*",
|
||||||
"dtype: Optional[_dtype] = None",
|
"dtype: Optional[_dtype] = None",
|
||||||
@ -767,7 +771,7 @@ def gen_pyi(
|
|||||||
", ".join(
|
", ".join(
|
||||||
[
|
[
|
||||||
"indices: Tensor",
|
"indices: Tensor",
|
||||||
"values: Union[Tensor, List]",
|
"values: Union[Tensor, list]",
|
||||||
"size: Optional[_size] = None",
|
"size: Optional[_size] = None",
|
||||||
"*",
|
"*",
|
||||||
"dtype: Optional[_dtype] = None",
|
"dtype: Optional[_dtype] = None",
|
||||||
@ -783,9 +787,9 @@ def gen_pyi(
|
|||||||
"def sparse_compressed_tensor({}) -> Tensor: ...".format(
|
"def sparse_compressed_tensor({}) -> Tensor: ...".format(
|
||||||
", ".join(
|
", ".join(
|
||||||
[
|
[
|
||||||
"compressed_indices: Union[Tensor, List]",
|
"compressed_indices: Union[Tensor, list]",
|
||||||
"plain_indices: Union[Tensor, List]",
|
"plain_indices: Union[Tensor, list]",
|
||||||
"values: Union[Tensor, List]",
|
"values: Union[Tensor, list]",
|
||||||
"size: Optional[_size] = None",
|
"size: Optional[_size] = None",
|
||||||
"*",
|
"*",
|
||||||
"dtype: Optional[_dtype] = None",
|
"dtype: Optional[_dtype] = None",
|
||||||
@ -973,7 +977,7 @@ def gen_pyi(
|
|||||||
"size: _size",
|
"size: _size",
|
||||||
"fill_value: Union[Number, _complex]",
|
"fill_value: Union[Number, _complex]",
|
||||||
"*",
|
"*",
|
||||||
"names: List[Union[str, None]]",
|
"names: list[Union[str, None]]",
|
||||||
"layout: _layout = strided",
|
"layout: _layout = strided",
|
||||||
FACTORY_PARAMS,
|
FACTORY_PARAMS,
|
||||||
]
|
]
|
||||||
@ -986,7 +990,7 @@ def gen_pyi(
|
|||||||
],
|
],
|
||||||
"nonzero": [
|
"nonzero": [
|
||||||
"def nonzero(input: Tensor, *, as_tuple: Literal[False] = False, out: Optional[Tensor] = None) -> Tensor: ...",
|
"def nonzero(input: Tensor, *, as_tuple: Literal[False] = False, out: Optional[Tensor] = None) -> Tensor: ...",
|
||||||
"def nonzero(input: Tensor, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...",
|
"def nonzero(input: Tensor, *, as_tuple: Literal[True]) -> tuple[Tensor, ...]: ...",
|
||||||
],
|
],
|
||||||
"dsmm": ["def dsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."],
|
"dsmm": ["def dsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."],
|
||||||
"hsmm": ["def hsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."],
|
"hsmm": ["def hsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."],
|
||||||
@ -1087,7 +1091,7 @@ def gen_pyi(
|
|||||||
"def size(self, dim: _int) -> _int: ...",
|
"def size(self, dim: _int) -> _int: ...",
|
||||||
],
|
],
|
||||||
"stride": [
|
"stride": [
|
||||||
"def stride(self, dim: None = None) -> Tuple[_int, ...]: ...",
|
"def stride(self, dim: None = None) -> tuple[_int, ...]: ...",
|
||||||
"def stride(self, dim: _int) -> _int: ...",
|
"def stride(self, dim: _int) -> _int: ...",
|
||||||
],
|
],
|
||||||
"new_ones": [
|
"new_ones": [
|
||||||
@ -1131,7 +1135,7 @@ def gen_pyi(
|
|||||||
"__setitem__": [
|
"__setitem__": [
|
||||||
f"def __setitem__(self, {INDICES}, val: Union[Tensor, Number]) -> None: ..."
|
f"def __setitem__(self, {INDICES}, val: Union[Tensor, Number]) -> None: ..."
|
||||||
],
|
],
|
||||||
"tolist": ["def tolist(self) -> List: ..."],
|
"tolist": ["def tolist(self) -> list: ..."],
|
||||||
"requires_grad_": [
|
"requires_grad_": [
|
||||||
"def requires_grad_(self, mode: _bool = True) -> Tensor: ..."
|
"def requires_grad_(self, mode: _bool = True) -> Tensor: ..."
|
||||||
],
|
],
|
||||||
@ -1140,7 +1144,7 @@ def gen_pyi(
|
|||||||
"dim": ["def dim(self) -> _int: ..."],
|
"dim": ["def dim(self) -> _int: ..."],
|
||||||
"nonzero": [
|
"nonzero": [
|
||||||
"def nonzero(self, *, as_tuple: Literal[False] = False) -> Tensor: ...",
|
"def nonzero(self, *, as_tuple: Literal[False] = False) -> Tensor: ...",
|
||||||
"def nonzero(self, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...",
|
"def nonzero(self, *, as_tuple: Literal[True]) -> tuple[Tensor, ...]: ...",
|
||||||
],
|
],
|
||||||
"numel": ["def numel(self) -> _int: ..."],
|
"numel": ["def numel(self) -> _int: ..."],
|
||||||
"ndimension": ["def ndimension(self) -> _int: ..."],
|
"ndimension": ["def ndimension(self) -> _int: ..."],
|
||||||
@ -1233,7 +1237,7 @@ def gen_pyi(
|
|||||||
],
|
],
|
||||||
"split": [
|
"split": [
|
||||||
"def split(self, split_size: _int, dim: _int = 0) -> Sequence[Tensor]: ...",
|
"def split(self, split_size: _int, dim: _int = 0) -> Sequence[Tensor]: ...",
|
||||||
"def split(self, split_size: Tuple[_int, ...], dim: _int = 0) -> Sequence[Tensor]: ...",
|
"def split(self, split_size: tuple[_int, ...], dim: _int = 0) -> Sequence[Tensor]: ...",
|
||||||
],
|
],
|
||||||
"div": [
|
"div": [
|
||||||
"def div(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..."
|
"def div(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..."
|
||||||
|
@ -5,7 +5,11 @@ import platform
|
|||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import cast, Iterable
|
from typing import cast, TYPE_CHECKING
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
|
||||||
IS_WINDOWS = platform.system() == "Windows"
|
IS_WINDOWS = platform.system() == "Windows"
|
||||||
|
@ -6,10 +6,10 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from typing import cast, Tuple
|
from typing import cast
|
||||||
|
|
||||||
|
|
||||||
Version = Tuple[int, int, int]
|
Version = tuple[int, int, int]
|
||||||
|
|
||||||
|
|
||||||
def parse_version(version: str) -> Version:
|
def parse_version(version: str) -> Version:
|
||||||
|
@ -5,12 +5,15 @@ import os
|
|||||||
import typing
|
import typing
|
||||||
import unittest
|
import unittest
|
||||||
import unittest.mock
|
import unittest.mock
|
||||||
from typing import Iterator, Sequence
|
|
||||||
|
|
||||||
import tools.setup_helpers.cmake
|
import tools.setup_helpers.cmake
|
||||||
import tools.setup_helpers.env # noqa: F401 unused but resolves circular import
|
import tools.setup_helpers.env # noqa: F401 unused but resolves circular import
|
||||||
|
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from collections.abc import Iterator, Sequence
|
||||||
|
|
||||||
|
|
||||||
T = typing.TypeVar("T")
|
T = typing.TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ import json
|
|||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
|
|
||||||
@ -35,7 +35,7 @@ JOB_NAME = "some-job-name"
|
|||||||
|
|
||||||
@mock.patch("boto3.resource")
|
@mock.patch("boto3.resource")
|
||||||
class TestUploadStats(unittest.TestCase):
|
class TestUploadStats(unittest.TestCase):
|
||||||
emitted_metric: Dict[str, Any] = {"did_not_emit": True}
|
emitted_metric: dict[str, Any] = {"did_not_emit": True}
|
||||||
|
|
||||||
def mock_put_item(self, **kwargs: Any) -> None:
|
def mock_put_item(self, **kwargs: Any) -> None:
|
||||||
# Utility for mocking putting items into s3. THis will save the emitted
|
# Utility for mocking putting items into s3. THis will save the emitted
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
import clickhouse_connect # type: ignore[import]
|
import clickhouse_connect # type: ignore[import]
|
||||||
|
|
||||||
@ -25,12 +25,12 @@ def get_clickhouse_client() -> Any:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def query_clickhouse(query: str, params: Dict[str, Any]) -> List[Dict[str, Any]]:
|
def query_clickhouse(query: str, params: dict[str, Any]) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Queries ClickHouse. Returns datetime in YYYY-MM-DD HH:MM:SS format.
|
Queries ClickHouse. Returns datetime in YYYY-MM-DD HH:MM:SS format.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def convert_to_json_list(res: bytes) -> List[Dict[str, Any]]:
|
def convert_to_json_list(res: bytes) -> list[dict[str, Any]]:
|
||||||
rows = []
|
rows = []
|
||||||
for row in res.decode().split("\n"):
|
for row in res.decode().split("\n"):
|
||||||
if row:
|
if row:
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, cast, Dict
|
from typing import Any, cast
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
|
||||||
from tools.stats.import_test_stats import (
|
from tools.stats.import_test_stats import (
|
||||||
@ -45,7 +45,7 @@ def _get_historical_test_class_correlations() -> dict[str, dict[str, float]]:
|
|||||||
print(f"could not find path {path}")
|
print(f"could not find path {path}")
|
||||||
return {}
|
return {}
|
||||||
with open(path) as f:
|
with open(path) as f:
|
||||||
test_class_correlations = cast(Dict[str, Dict[str, float]], json.load(f))
|
test_class_correlations = cast(dict[str, dict[str, float]], json.load(f))
|
||||||
return test_class_correlations
|
return test_class_correlations
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,11 +2,15 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from typing import Any, Iterable, Iterator
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
from tools.testing.test_run import TestRun
|
from tools.testing.test_run import TestRun
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Iterable, Iterator
|
||||||
|
|
||||||
|
|
||||||
class TestPrioritizations:
|
class TestPrioritizations:
|
||||||
"""
|
"""
|
||||||
Describes the results of whether heuristics consider a test relevant or not.
|
Describes the results of whether heuristics consider a test relevant or not.
|
||||||
|
@ -5,9 +5,9 @@ import os
|
|||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import lru_cache
|
from functools import cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import cast, Dict, TYPE_CHECKING
|
from typing import cast, TYPE_CHECKING
|
||||||
from urllib.request import Request, urlopen
|
from urllib.request import Request, urlopen
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
|
||||||
@ -26,7 +26,7 @@ def python_test_file_to_test_name(tests: set[str]) -> set[str]:
|
|||||||
return valid_tests
|
return valid_tests
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@cache
|
||||||
def get_pr_number() -> int | None:
|
def get_pr_number() -> int | None:
|
||||||
pr_number = os.environ.get("PR_NUMBER", "")
|
pr_number = os.environ.get("PR_NUMBER", "")
|
||||||
if pr_number == "":
|
if pr_number == "":
|
||||||
@ -38,7 +38,7 @@ def get_pr_number() -> int | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@cache
|
||||||
def get_merge_base() -> str:
|
def get_merge_base() -> str:
|
||||||
pr_number = get_pr_number()
|
pr_number = get_pr_number()
|
||||||
if pr_number is not None:
|
if pr_number is not None:
|
||||||
@ -91,7 +91,7 @@ def query_changed_files() -> list[str]:
|
|||||||
return lines
|
return lines
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@cache
|
||||||
def get_git_commit_info() -> str:
|
def get_git_commit_info() -> str:
|
||||||
"""Gets the commit info since the last commit on the default branch."""
|
"""Gets the commit info since the last commit on the default branch."""
|
||||||
base_commit = get_merge_base()
|
base_commit = get_merge_base()
|
||||||
@ -105,7 +105,7 @@ def get_git_commit_info() -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@cache
|
||||||
def get_issue_or_pr_body(number: int) -> str:
|
def get_issue_or_pr_body(number: int) -> str:
|
||||||
"""Gets the body of an issue or PR"""
|
"""Gets the body of an issue or PR"""
|
||||||
github_token = os.environ.get("GITHUB_TOKEN")
|
github_token = os.environ.get("GITHUB_TOKEN")
|
||||||
@ -148,7 +148,7 @@ def get_ratings_for_tests(file: str | Path) -> dict[str, float]:
|
|||||||
print(f"could not find path {path}")
|
print(f"could not find path {path}")
|
||||||
return {}
|
return {}
|
||||||
with open(path) as f:
|
with open(path) as f:
|
||||||
test_file_ratings = cast(Dict[str, Dict[str, float]], json.load(f))
|
test_file_ratings = cast(dict[str, dict[str, float]], json.load(f))
|
||||||
try:
|
try:
|
||||||
changed_files = query_changed_files()
|
changed_files = query_changed_files()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -2,7 +2,11 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from functools import total_ordering
|
from functools import total_ordering
|
||||||
from typing import Any, Iterable
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
|
||||||
class TestRun:
|
class TestRun:
|
||||||
|
@ -4,12 +4,16 @@ import math
|
|||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Sequence
|
from typing import Callable, TYPE_CHECKING
|
||||||
|
|
||||||
from tools.stats.import_test_stats import get_disabled_tests
|
from tools.stats.import_test_stats import get_disabled_tests
|
||||||
from tools.testing.test_run import ShardedTest, TestRun
|
from tools.testing.test_run import ShardedTest, TestRun
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||||
|
|
||||||
IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1"
|
IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1"
|
||||||
|
@ -3,7 +3,7 @@ import os
|
|||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, cast, Dict, List, Optional, Tuple
|
from typing import Any, cast, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from clickhouse import query_clickhouse # type: ignore[import]
|
from clickhouse import query_clickhouse # type: ignore[import]
|
||||||
@ -96,7 +96,7 @@ PYTORCHBOT_TOKEN = os.environ["PYTORCHBOT_TOKEN"]
|
|||||||
|
|
||||||
|
|
||||||
def git_api(
|
def git_api(
|
||||||
url: str, params: Dict[str, Any], type: str = "get", token: str = UPDATEBOT_TOKEN
|
url: str, params: dict[str, Any], type: str = "get", token: str = UPDATEBOT_TOKEN
|
||||||
) -> Any:
|
) -> Any:
|
||||||
headers = {
|
headers = {
|
||||||
"Accept": "application/vnd.github.v3+json",
|
"Accept": "application/vnd.github.v3+json",
|
||||||
@ -122,7 +122,7 @@ def git_api(
|
|||||||
).json()
|
).json()
|
||||||
|
|
||||||
|
|
||||||
def make_pr(source_repo: str, params: Dict[str, Any]) -> int:
|
def make_pr(source_repo: str, params: dict[str, Any]) -> int:
|
||||||
response = git_api(f"/repos/{source_repo}/pulls", params, type="post")
|
response = git_api(f"/repos/{source_repo}/pulls", params, type="post")
|
||||||
print(f"made pr {response['html_url']}")
|
print(f"made pr {response['html_url']}")
|
||||||
return cast(int, response["number"])
|
return cast(int, response["number"])
|
||||||
@ -150,7 +150,7 @@ def make_comment(source_repo: str, pr_number: int, msg: str) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_labels(source_repo: str, pr_number: int, labels: List[str]) -> None:
|
def add_labels(source_repo: str, pr_number: int, labels: list[str]) -> None:
|
||||||
params = {"labels": labels}
|
params = {"labels": labels}
|
||||||
git_api(
|
git_api(
|
||||||
f"/repos/{source_repo}/issues/{pr_number}/labels",
|
f"/repos/{source_repo}/issues/{pr_number}/labels",
|
||||||
@ -161,7 +161,7 @@ def add_labels(source_repo: str, pr_number: int, labels: List[str]) -> None:
|
|||||||
|
|
||||||
def search_for_open_pr(
|
def search_for_open_pr(
|
||||||
source_repo: str, search_string: str
|
source_repo: str, search_string: str
|
||||||
) -> Optional[Tuple[int, str]]:
|
) -> Optional[tuple[int, str]]:
|
||||||
params = {
|
params = {
|
||||||
"q": f"is:pr is:open in:title author:pytorchupdatebot repo:{source_repo} {search_string}",
|
"q": f"is:pr is:open in:title author:pytorchupdatebot repo:{source_repo} {search_string}",
|
||||||
"sort": "created",
|
"sort": "created",
|
||||||
|
@ -4,7 +4,7 @@ import time
|
|||||||
import zipfile
|
import zipfile
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||||
@ -18,7 +18,7 @@ def get_s3_resource() -> Any:
|
|||||||
return boto3.client("s3")
|
return boto3.client("s3")
|
||||||
|
|
||||||
|
|
||||||
def zip_artifact(file_name: str, paths: List[str]) -> None:
|
def zip_artifact(file_name: str, paths: list[str]) -> None:
|
||||||
"""Zip the files in the paths listed into file_name. The paths will be used
|
"""Zip the files in the paths listed into file_name. The paths will be used
|
||||||
in a glob and should be relative to REPO_ROOT."""
|
in a glob and should be relative to REPO_ROOT."""
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user